18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196 | class QueueTaskViewSet(viewsets.ViewSet):
"""
A ViewSet for queuing AI tasks generally
"""
# This ensures that only authenticated users can access these endpoints
permission_classes = [IsAuthenticated]
@swagger_auto_schema(
operation_summary="Queue an AI task",
operation_description="This will include LLM, STT, and other AI tasks",
request_body=TaskSerializer,
responses={200: "Task queued successfully"},
)
@action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
def ai_task(self, request):
"""
Endpoint to queue tasks for AI Client side to run
"""
data = request.data
serializer = TaskSerializer(data=data)
try:
serializer.is_valid(raise_exception=True)
except Exception as e:
logger.error(f"Error validating task request: {e}")
return Response(
{"error": f"Error validating task request: {e}"},
status=status.HTTP_400_BAD_REQUEST,
)
# if track id not set, set up the track id
track_id = data.get("track_id", None)
logger.info(f"Track ID: {track_id}")
# if track_id is not provided, then we need to generate it
if track_id is None:
track_id = Task.init_track_id(CLUSTER_Q_ETE_CONVERSATION_NAME)
logger.info(f"Generated track ID: {track_id}")
serializer.validated_data["track_id"] = track_id
# based on the track cluster name, determine what to do next
task_id = ClusterManager.chain_next(
track_id=track_id,
current_component="init",
next_component_params=serializer.validated_data["parameters"],
name=data.get("name", None),
user=request.user,
)
return Response(
{"message": "LLM task queued successfully", "task_id": task_id},
status=status.HTTP_200_OK,
)
@swagger_auto_schema(
operation_summary="Worker: Get Task",
operation_description="Get the task",
responses={200: "Task retrieved successfully"},
)
@action(
detail=False,
methods=["get"],
permission_classes=[IsAuthenticated],
url_path="task/(?P<task_name>.+)",
url_name="task",
)
def task(self, request, task_name="all"):
"""
Endpoint to get the task for AI
"""
cool_down_task = 10 # 10 second
cool_down_time = datetime.now() - timedelta(seconds=cool_down_task)
try:
if task_name == "all":
task = Task.objects.filter(
result_status="pending", created_at__lte=cool_down_time
).first()
else:
task = Task.objects.filter(
task_name=task_name,
result_status="pending",
created_at__lte=cool_down_time,
).first()
if task is None:
return Response(
{"error": f"No pending {task_name} tasks found"},
status=status.HTTP_404_NOT_FOUND,
)
task.result_status = "started"
task.save()
task_serializer = TaskSerializer(task)
logger.critical(f"Task {task.id} retrieved successfully")
return Response(data=task_serializer.data, status=status.HTTP_200_OK)
except Task.DoesNotExist:
return Response(
{"error": f"No pending {task_name} tasks found"},
status=status.HTTP_404_NOT_FOUND,
)
# add an endpoint to update the task result
@swagger_auto_schema(
operation_summary="Worker: Result Update",
operation_description="Update the task result",
request_body=TaskSerializer,
responses={200: "Task result updated successfully"},
)
@action(
detail=True,
methods=["post"],
permission_classes=[IsAuthenticated],
url_path="update_result",
url_name="update_result",
)
def update_result(self, request, pk=None):
"""
Endpoint to update the result of a task.
"""
try:
data = request.data
task = Task.objects.filter(id=pk).first()
if task is None:
return Response(
{"error": f"Task with ID {pk} does not exist"},
status=status.HTTP_404_NOT_FOUND,
)
serializer = TaskSerializer(data=data, instance=task, partial=True)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(
{"message": f"Task {task.id} updated successfully"},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error updating task result: {e}")
logger.exception(e)
return Response(
{"error": f"Error updating task result: {e}"},
status=status.HTTP_400_BAD_REQUEST,
)
@swagger_auto_schema(
operation_summary="Worker: Register",
operation_description="Register a worker",
responses={200: "Worker registered or updated successfully"},
request_body=TaskWorkerSerializer,
)
@action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
def worker(self, request):
"""
Endpoint to register a GPU worker.
"""
data = request.data
serializer = TaskWorkerSerializer(data=data)
serializer.is_valid(raise_exception=True)
uuid = data.get("uuid")
mac_address = data.get("mac_address")
ip_address = data.get("ip_address")
task_name = data.get("task_name")
worker, created = TaskWorker.objects.get_or_create(
uuid=uuid,
defaults={
"mac_address": mac_address,
"ip_address": ip_address,
"task_name": task_name,
},
)
if not created:
worker.mac_address = mac_address
worker.ip_address = ip_address
worker.save()
return Response(
{"message": f"Worker {uuid} registered or updated successfully"},
status=status.HTTP_200_OK,
)
|