Skip to content

Handler

GeneralMLModel

Source code in Agent/modules/general_ml/handler.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class GeneralMLModel:
    def __init__(self):
        self.avail_models = {}

    def handle_task(self, task: Task) -> Task:
        """
        Handle the task
        Args:
            task (Task): The task to handle

        Returns:
            Updated task
        """
        TimeLogger.log_task(task, "start_general_ml")
        result_profile = {}
        latency_profile = {}
        general_ml_parameters = GeneralMLParameters(**task.parameters)
        text = general_ml_parameters.text
        general_model_name = general_ml_parameters.general_model_name
        params = general_ml_parameters.params
        if general_model_name not in self.avail_models:
            logger.error(f"Model {general_model_name} not loaded yet")
            with time_tracker(
                "init", latency_profile, track_type=TrackType.MODEL.value
            ):
                ml_model = self.load_model(general_model_name)
                self.avail_models[general_model_name] = ml_model

        else:
            ml_model = self.avail_models[general_model_name]

        with timer(logger, f"Model infer {general_model_name}"):
            with time_tracker(
                "infer", latency_profile, track_type=TrackType.MODEL.value
            ):
                res = self.infer(ml_model, general_model_name, text, params)
        result_profile["result"] = res

        task.result_status = ResultStatus.completed.value
        task.result_json.result_profile.update(result_profile)
        task.result_json.latency_profile.update(latency_profile)
        TimeLogger.log_task(task, "end_general_ml")
        return task

    @staticmethod
    def load_model(general_model_name: str):
        """
        Load model
        Args:
            general_model_name (str): Model name

        Returns:

        """
        if general_model_name == "sentence_transformer":
            return SentenceTransformer("all-MiniLM-L6-v2")
        raise ValueError(f"Model {general_model_name} is not implemented")

    @staticmethod
    def infer(ml_model, general_model_name: str, text: str, params: dict):
        """
        Infer the model
        Args:
            ml_model: General model
            general_model_name (str): Model name
            text (str): Text
            params (dict): Model params

        Returns:

        """
        if general_model_name == "sentence_transformer":
            result = ml_model.encode(text)
            return result.tolist()
        logger.info(params)
        raise ValueError(f"Model {general_model_name} is not implemented")

handle_task(task)

Handle the task Args: task (Task): The task to handle

Returns:

Type Description
Task

Updated task

Source code in Agent/modules/general_ml/handler.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def handle_task(self, task: Task) -> Task:
    """
    Handle the task
    Args:
        task (Task): The task to handle

    Returns:
        Updated task
    """
    TimeLogger.log_task(task, "start_general_ml")
    result_profile = {}
    latency_profile = {}
    general_ml_parameters = GeneralMLParameters(**task.parameters)
    text = general_ml_parameters.text
    general_model_name = general_ml_parameters.general_model_name
    params = general_ml_parameters.params
    if general_model_name not in self.avail_models:
        logger.error(f"Model {general_model_name} not loaded yet")
        with time_tracker(
            "init", latency_profile, track_type=TrackType.MODEL.value
        ):
            ml_model = self.load_model(general_model_name)
            self.avail_models[general_model_name] = ml_model

    else:
        ml_model = self.avail_models[general_model_name]

    with timer(logger, f"Model infer {general_model_name}"):
        with time_tracker(
            "infer", latency_profile, track_type=TrackType.MODEL.value
        ):
            res = self.infer(ml_model, general_model_name, text, params)
    result_profile["result"] = res

    task.result_status = ResultStatus.completed.value
    task.result_json.result_profile.update(result_profile)
    task.result_json.latency_profile.update(latency_profile)
    TimeLogger.log_task(task, "end_general_ml")
    return task

infer(ml_model, general_model_name, text, params) staticmethod

Infer the model Args: ml_model: General model general_model_name (str): Model name text (str): Text params (dict): Model params

Returns:

Source code in Agent/modules/general_ml/handler.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@staticmethod
def infer(ml_model, general_model_name: str, text: str, params: dict):
    """
    Infer the model
    Args:
        ml_model: General model
        general_model_name (str): Model name
        text (str): Text
        params (dict): Model params

    Returns:

    """
    if general_model_name == "sentence_transformer":
        result = ml_model.encode(text)
        return result.tolist()
    logger.info(params)
    raise ValueError(f"Model {general_model_name} is not implemented")

load_model(general_model_name) staticmethod

Load model Args: general_model_name (str): Model name

Returns:

Source code in Agent/modules/general_ml/handler.py
58
59
60
61
62
63
64
65
66
67
68
69
70
@staticmethod
def load_model(general_model_name: str):
    """
    Load model
    Args:
        general_model_name (str): Model name

    Returns:

    """
    if general_model_name == "sentence_transformer":
        return SentenceTransformer("all-MiniLM-L6-v2")
    raise ValueError(f"Model {general_model_name} is not implemented")