Skip to content

Views

QueueTaskViewSet

Bases: ViewSet

A ViewSet for queuing AI tasks generally

Source code in API/orchestrator/views.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class QueueTaskViewSet(viewsets.ViewSet):
    """
    A ViewSet for queuing AI tasks generally

    """

    # This ensures that only authenticated users can access these endpoints
    permission_classes = [IsAuthenticated]

    @swagger_auto_schema(
        operation_summary="Queue an AI task",
        operation_description="This will include LLM, STT, and other AI tasks",
        request_body=TaskSerializer,
        responses={200: "Task queued successfully"},
    )
    @action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
    def ai_task(self, request):
        """
        Endpoint to queue tasks for AI Client side to run
        """
        data = request.data
        serializer = TaskSerializer(data=data)

        try:
            serializer.is_valid(raise_exception=True)
        except Exception as e:
            logger.error(f"Error validating task request: {e}")
            return Response(
                {"error": f"Error validating task request: {e}"},
                status=status.HTTP_400_BAD_REQUEST,
            )

        # if track id not set, set up the track id
        track_id = data.get("track_id", None)
        logger.info(f"Track ID: {track_id}")
        # if track_id is not provided, then we need to generate it
        if track_id is None:
            track_id = Task.init_track_id(CLUSTER_Q_ETE_CONVERSATION_NAME)
            logger.info(f"Generated track ID: {track_id}")
            serializer.validated_data["track_id"] = track_id

        # based on the track cluster name, determine what to do next
        task_id = ClusterManager.chain_next(
            track_id=track_id,
            current_component="init",
            next_component_params=serializer.validated_data["parameters"],
            name=data.get("name", None),
            user=request.user,
        )

        return Response(
            {"message": "LLM task queued successfully", "task_id": task_id},
            status=status.HTTP_200_OK,
        )

    @swagger_auto_schema(
        operation_summary="Worker: Get Task",
        operation_description="Get the task",
        responses={200: "Task retrieved successfully"},
    )
    @action(
        detail=False,
        methods=["get"],
        permission_classes=[IsAuthenticated],
        url_path="task/(?P<task_name>.+)",
        url_name="task",
    )
    def task(self, request, task_name="all"):
        """
        Endpoint to get the task for AI
        """
        cool_down_task = 10  # 10 second
        cool_down_time = datetime.now() - timedelta(seconds=cool_down_task)
        try:
            if task_name == "all":
                task = Task.objects.filter(
                    result_status="pending", created_at__lte=cool_down_time
                ).first()
            else:
                task = Task.objects.filter(
                    task_name=task_name,
                    result_status="pending",
                    created_at__lte=cool_down_time,
                ).first()
            if task is None:
                return Response(
                    {"error": f"No pending {task_name} tasks found"},
                    status=status.HTTP_404_NOT_FOUND,
                )
            task.result_status = "started"
            task.save()
            task_serializer = TaskSerializer(task)
            logger.critical(f"Task {task.id} retrieved successfully")
            return Response(data=task_serializer.data, status=status.HTTP_200_OK)
        except Task.DoesNotExist:
            return Response(
                {"error": f"No pending {task_name} tasks found"},
                status=status.HTTP_404_NOT_FOUND,
            )

    # add an endpoint to update the task result
    @swagger_auto_schema(
        operation_summary="Worker: Result Update",
        operation_description="Update the task result",
        request_body=TaskSerializer,
        responses={200: "Task result updated successfully"},
    )
    @action(
        detail=True,
        methods=["post"],
        permission_classes=[IsAuthenticated],
        url_path="update_result",
        url_name="update_result",
    )
    def update_result(self, request, pk=None):
        """
        Endpoint to update the result of a task.
        """
        try:
            data = request.data
            task = Task.objects.filter(id=pk).first()
            if task is None:
                return Response(
                    {"error": f"Task with ID {pk} does not exist"},
                    status=status.HTTP_404_NOT_FOUND,
                )

            serializer = TaskSerializer(data=data, instance=task, partial=True)
            serializer.is_valid(raise_exception=True)
            serializer.save()
            return Response(
                {"message": f"Task {task.id} updated successfully"},
                status=status.HTTP_200_OK,
            )
        except Exception as e:
            logger.error(f"Error updating task result: {e}")
            logger.exception(e)
            return Response(
                {"error": f"Error updating task result: {e}"},
                status=status.HTTP_400_BAD_REQUEST,
            )

    @swagger_auto_schema(
        operation_summary="Worker: Register",
        operation_description="Register a worker",
        responses={200: "Worker registered or updated successfully"},
        request_body=TaskWorkerSerializer,
    )
    @action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
    def worker(self, request):
        """
        Endpoint to register a GPU worker.
        """
        data = request.data
        serializer = TaskWorkerSerializer(data=data)
        serializer.is_valid(raise_exception=True)

        uuid = data.get("uuid")
        mac_address = data.get("mac_address")
        ip_address = data.get("ip_address")
        task_name = data.get("task_name")

        worker, created = TaskWorker.objects.get_or_create(
            uuid=uuid,
            defaults={
                "mac_address": mac_address,
                "ip_address": ip_address,
                "task_name": task_name,
            },
        )
        if not created:
            worker.mac_address = mac_address
            worker.ip_address = ip_address
            worker.save()

        return Response(
            {"message": f"Worker {uuid} registered or updated successfully"},
            status=status.HTTP_200_OK,
        )

ai_task(request)

Endpoint to queue tasks for AI Client side to run

Source code in API/orchestrator/views.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@swagger_auto_schema(
    operation_summary="Queue an AI task",
    operation_description="This will include LLM, STT, and other AI tasks",
    request_body=TaskSerializer,
    responses={200: "Task queued successfully"},
)
@action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
def ai_task(self, request):
    """
    Endpoint to queue tasks for AI Client side to run
    """
    data = request.data
    serializer = TaskSerializer(data=data)

    try:
        serializer.is_valid(raise_exception=True)
    except Exception as e:
        logger.error(f"Error validating task request: {e}")
        return Response(
            {"error": f"Error validating task request: {e}"},
            status=status.HTTP_400_BAD_REQUEST,
        )

    # if track id not set, set up the track id
    track_id = data.get("track_id", None)
    logger.info(f"Track ID: {track_id}")
    # if track_id is not provided, then we need to generate it
    if track_id is None:
        track_id = Task.init_track_id(CLUSTER_Q_ETE_CONVERSATION_NAME)
        logger.info(f"Generated track ID: {track_id}")
        serializer.validated_data["track_id"] = track_id

    # based on the track cluster name, determine what to do next
    task_id = ClusterManager.chain_next(
        track_id=track_id,
        current_component="init",
        next_component_params=serializer.validated_data["parameters"],
        name=data.get("name", None),
        user=request.user,
    )

    return Response(
        {"message": "LLM task queued successfully", "task_id": task_id},
        status=status.HTTP_200_OK,
    )

task(request, task_name='all')

Endpoint to get the task for AI

Source code in API/orchestrator/views.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@swagger_auto_schema(
    operation_summary="Worker: Get Task",
    operation_description="Get the task",
    responses={200: "Task retrieved successfully"},
)
@action(
    detail=False,
    methods=["get"],
    permission_classes=[IsAuthenticated],
    url_path="task/(?P<task_name>.+)",
    url_name="task",
)
def task(self, request, task_name="all"):
    """
    Endpoint to get the task for AI
    """
    cool_down_task = 10  # 10 second
    cool_down_time = datetime.now() - timedelta(seconds=cool_down_task)
    try:
        if task_name == "all":
            task = Task.objects.filter(
                result_status="pending", created_at__lte=cool_down_time
            ).first()
        else:
            task = Task.objects.filter(
                task_name=task_name,
                result_status="pending",
                created_at__lte=cool_down_time,
            ).first()
        if task is None:
            return Response(
                {"error": f"No pending {task_name} tasks found"},
                status=status.HTTP_404_NOT_FOUND,
            )
        task.result_status = "started"
        task.save()
        task_serializer = TaskSerializer(task)
        logger.critical(f"Task {task.id} retrieved successfully")
        return Response(data=task_serializer.data, status=status.HTTP_200_OK)
    except Task.DoesNotExist:
        return Response(
            {"error": f"No pending {task_name} tasks found"},
            status=status.HTTP_404_NOT_FOUND,
        )

update_result(request, pk=None)

Endpoint to update the result of a task.

Source code in API/orchestrator/views.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@swagger_auto_schema(
    operation_summary="Worker: Result Update",
    operation_description="Update the task result",
    request_body=TaskSerializer,
    responses={200: "Task result updated successfully"},
)
@action(
    detail=True,
    methods=["post"],
    permission_classes=[IsAuthenticated],
    url_path="update_result",
    url_name="update_result",
)
def update_result(self, request, pk=None):
    """
    Endpoint to update the result of a task.
    """
    try:
        data = request.data
        task = Task.objects.filter(id=pk).first()
        if task is None:
            return Response(
                {"error": f"Task with ID {pk} does not exist"},
                status=status.HTTP_404_NOT_FOUND,
            )

        serializer = TaskSerializer(data=data, instance=task, partial=True)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(
            {"message": f"Task {task.id} updated successfully"},
            status=status.HTTP_200_OK,
        )
    except Exception as e:
        logger.error(f"Error updating task result: {e}")
        logger.exception(e)
        return Response(
            {"error": f"Error updating task result: {e}"},
            status=status.HTTP_400_BAD_REQUEST,
        )

worker(request)

Endpoint to register a GPU worker.

Source code in API/orchestrator/views.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@swagger_auto_schema(
    operation_summary="Worker: Register",
    operation_description="Register a worker",
    responses={200: "Worker registered or updated successfully"},
    request_body=TaskWorkerSerializer,
)
@action(detail=False, methods=["post"], permission_classes=[IsAuthenticated])
def worker(self, request):
    """
    Endpoint to register a GPU worker.
    """
    data = request.data
    serializer = TaskWorkerSerializer(data=data)
    serializer.is_valid(raise_exception=True)

    uuid = data.get("uuid")
    mac_address = data.get("mac_address")
    ip_address = data.get("ip_address")
    task_name = data.get("task_name")

    worker, created = TaskWorker.objects.get_or_create(
        uuid=uuid,
        defaults={
            "mac_address": mac_address,
            "ip_address": ip_address,
            "task_name": task_name,
        },
    )
    if not created:
        worker.mac_address = mac_address
        worker.ip_address = ip_address
        worker.save()

    return Response(
        {"message": f"Worker {uuid} registered or updated successfully"},
        status=status.HTTP_200_OK,
    )