Source code for matrice.actionTracker

import math
import os
import shutil
import sys
import tarfile
import traceback
import zipfile
from types import SimpleNamespace

import requests
import yaml
from bson import ObjectId
from pycocotools.coco import COCO

from matrice import rpc
from matrice.models import Model
from matrice.session import Session
from matrice.utils import log_error
from matrice.model_store import fetch_supported_runtimes_metrics

class __dotdict(dict):
    """
    A dictionary subclass that allows dot notation access to its attributes.

    This class enables both standard dictionary key access and dot notation access for easier manipulation
    of data attributes. It can be particularly useful for handling configuration parameters or other data
    structures where attributes are frequently accessed.

    Example
    -------
    >>> my_dict = __dotdict({'key': 'value'})
    >>> print(my_dict.key)  # Outputs: value
    >>> print(my_dict['key'])  # Outputs: value

    Parameters
    ----------
    initial_data : dict, optional
        An optional dictionary to initialize the `__dotdict`. If provided, the items will be added to the `__dotdict`.

    Attributes
    ----------
    None

    Methods
    -------
    __getattr__(key)
        Retrieves the value associated with the given key using dot notation.
    
    __setattr__(key, value)
        Sets the value for the given key using dot notation.
    
    __delattr__(key)
        Deletes the specified key from the dictionary using dot notation.

    Examples
    --------
    >>> my_dict = __dotdict({'name': 'Alice', 'age': 30})
    >>> print(my_dict.name)  # Outputs: Alice
    >>> my_dict.location = 'Wonderland'
    >>> print(my_dict['location'])  # Outputs: Wonderland
    >>> del my_dict.age
    >>> print(my_dict)  # Outputs: __dotdict({'name': 'Alice', 'location': 'Wonderland'})
    """

    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

[docs] class LocalActionTracker: pass
[docs] class ActionTracker: """ Tracks and manages the status, actions, and related data of a model's lifecycle, including training, evaluation, and deployment processes. The `ActionTracker` is responsible for tracking various stages of an action (e.g., model training, evaluation, or deployment), logging details, fetching configuration parameters, downloading model checkpoints, and handling error logging. It interacts with the backend system to retrieve and update action statuses. Parameters ---------- action_id : str, optional The unique identifier of the action to be tracked. If not provided, the class will initialize without an active action. The `action_id` is typically linked to specific activities such as model training, evaluation, or deployment. Attributes ---------- rpc : RPCClient A Remote Procedure Call (RPC) client for interacting with the backend API. action_id : bson.ObjectId The ObjectId representing the action being tracked. This is used for retrieving action details from the backend. action_id_str : str The string representation of the `action_id`. action_doc : dict The detailed document containing information about the action, including its status, type, and related model details. action_type : str The type of action being tracked, such as 'model_train', 'model_eval', or 'deploy_add'. _idModel : bson.ObjectId The ObjectId of the model associated with the current action. _idModel_str : str The string representation of `_idModel`. session : Session A session object that manages the user session and ensures that API requests are authorized. Examples -------- >>> tracker = ActionTracker(action_id="60f5f5bfb5a1c2a123456789") >>> tracker.get_job_params() >>> tracker.update_status("training", "in_progress", "Model training started") >>> tracker.log_epoch_results(1, [{'loss': 0.25, 'accuracy': 0.92}]) """
[docs] def __init__(self, action_id=None): """ Initializes the ActionTracker instance and retrieves details related to the specified action ID. This constructor fetches the action document, which contains metadata about the action, including the model's ID. If no `action_id` is provided, the tracker is initialized without an action. Parameters ---------- action_id : str, optional The unique identifier of the action to track. If not provided, the instance is initialized without an action. Raises ------ ConnectionError If there is an error retrieving action details from the backend. SystemExit If there is a critical error during initialization, causing the system to terminate. Examples -------- >>> tracker = ActionTracker(action_id="60f5f5bfb5a1c2a123456789") >>> print(tracker.action_type) # Outputs the action type, e.g., "model_train" """ try: session = Session(account_number="", secret_key=os.environ['MATRICE_SECRET_ACCESS_KEY'], access_key=os.environ['MATRICE_ACCESS_KEY_ID']) self.rpc = session.rpc if action_id is not None: self.action_id = ObjectId(action_id) self.action_id_str = str(self.action_id) url = f"/v1/project/action/{self.action_id_str}/details" self.action_doc = self.rpc.get(url)["data"] self.action_details = self.action_doc["actionDetails"] self.action_type = self.action_doc["action"] job_params = self.get_job_params() # Will be updated if self.action_type in ("model_train", "model_eval"): self._idModel = self.action_doc["_idService"] self._idModel_str = str(self._idModel) elif self.action_type == "deploy_add": self._idModel = self.action_details["_idModelDeploy"] self._idModel_str = str(self._idModel) else: self._idModel = self.action_details["_idModel"] self._idModel_str = str(self._idModel) else: self.action_id = None print("ActionTracker initialized. but No action found") project_id = self.action_doc["_idProject"] try: session.update(project_id=project_id) self.session = session except Exception as e: print("update project error", e) try: self.checkpoint_path, self.pretrained = self.get_checkpoint_path(job_params) except Exception as e: print("get checkpoint error", e) except Exception as e: print("PAR", e) log_error(__file__, "__init__", str(e)) self.update_status("error", "error", "Initialization failed") sys.exit(1)
[docs] def get_checkpoint_path(self, model_config): """ Determines the checkpoint path for the model based on the configuration provided. This function checks if the model's checkpoint should be retrieved from a pre-trained source or a specific model ID. It also handles downloading the model if necessary. Parameters ---------- model_config : dict A dictionary containing the configuration parameters for the model, such as `checkpoint_type` and `model_checkpoint`. Returns ------- tuple A tuple containing: - The absolute path of the model checkpoint if found. - A boolean indicating whether the model is pre-trained. Raises ------ FileNotFoundError If the model checkpoint cannot be downloaded or located. ConnectionError If there is an issue communicating with the model's API. Examples -------- >>> config = {"checkpoint_type": "model_id", "model_checkpoint": "12345abcde"} >>> checkpoint_path, is_pretrained = tracker.get_checkpoint_path(config) >>> print(checkpoint_path, is_pretrained) """ try: checkpoint_type = model_config.get("checkpoint_type", "predefined") model_checkpoint = model_config.get("model_checkpoint", "auto") checkpoint_dir = "./checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) if checkpoint_type == "model_id": if model_checkpoint.lower() not in ["", "none", "auto"]: model_save_path = os.path.abspath(f"{checkpoint_dir}/last.pt") return ( self._download_trained_model_checkpoint( model_save_path, model_checkpoint ), True, ) else: print( f"model_checkpoint {model_checkpoint} is one of [none, auto, ''] it should be a model id" ) return None, False elif checkpoint_type == "predefined": if model_checkpoint.lower() == "auto": return None, True elif model_checkpoint.lower() in ["none", ""]: return None, False else: print( f"model_checkpoint {model_checkpoint} not from [none, auto, '']" ) return None, False else: print( f"checkpoint_type {checkpoint_type} not from [model_id, predefined]" ) return None, False except Exception as e: log_error(__file__, "get_checkpoint_path", str(e)) print(f"Exception in get_checkpoint_path: {str(e)}") return None, False
def _download_trained_model_checkpoint(self, model_save_path, model_id): """ Downloads the model checkpoint based on the provided model ID and saves it to the specified path. This function utilizes the SDK to download a model checkpoint identified by `model_id` and save it to `model_save_path`. It handles successful downloads by returning the file path, and raises exceptions when the download fails. Parameters ---------- model_save_path : str The file path where the downloaded model checkpoint should be saved. model_id : str The unique identifier of the model whose checkpoint is to be downloaded. Returns ------- str The path where the downloaded model checkpoint is saved. Raises ------ Exception If the model checkpoint fails to download, an exception is raised, and the error is logged. Examples -------- >>> model_path = _download_trained_model_checkpoint('/models/checkpoint.pth', 'model_123') Download Successful >>> print(model_path) '/models/checkpoint.pth' Notes ----- - Ensure that the SDK session is properly initialized before calling this function. - If an exception occurs, the program will log the error and exit. """ try: model_sdk = Model(self.session, model_id) model_save_path = model_sdk.download_model(model_save_path) if model_save_path: print("Download Successful") return model_save_path else: print(f"Download failed") raise Exception(f"Failed to download model from presigned_url") except Exception as e: log_error(__file__, "download_trained_model_checkpoint", str(e)) print(f"Exception in download_trained_model_checkpoint: {str(e)}") sys.exit(1)
[docs] def get_job_params(self): """ Fetches the parameters for the job associated with the current action. This method retrieves the parameters required to perform a specific action, such as model training or evaluation. The parameters are returned as a dot-accessible dictionary (`__dotdict`) for convenience. Returns ------- __dotdict A dot-accessible dictionary containing the job parameters. Raises ------ KeyError If the job parameters cannot be found in the action document. SystemExit If the job parameters cannot be retrieved and the system needs to terminate. Examples -------- >>> job_params = tracker.get_job_params() >>> print(job_params.learning_rate) # Accessing parameters using dot notation """ try: self.jobParams = self.action_doc["jobParams"] return __dotdict(self.jobParams) except Exception as e: log_error(__file__, "get_job_params", str(e)) print(f"Exception in get_job_params: {str(e)}") self.update_status("error", "error", "Failed to get job parameters") sys.exit(1)
[docs] def update_status(self, stepCode, status, status_description): """ Updates the status of the tracked action in the backend system. This method allows changing the action's status, such as from "in progress" to "completed" or "error". It logs the provided message with the updated status. Parameters ---------- action_name : str The name of the action being tracked (e.g., "training", "evaluation"). status : str The new status to set for the action (e.g., "in_progress", "completed", "error"). message : str A message providing context about the status update. Returns ------- None Examples -------- >>> tracker.update_status("training", "completed", "Training completed successfully") """ try: print(status_description) url = f"/v1/project/action" payload = { "_id": self.action_id_str, "action": self.action_type, "serviceName": self.action_doc["serviceName"], "stepCode": stepCode, "status": status, "statusDescription": status_description, } self.rpc.put(path=url, payload=payload) except Exception as e: log_error(__file__, "update_status", str(e)) print(f"Exception in update_status: {str(e)}") if status == "error": sys.exit(1)
[docs] def log_epoch_results(self, epoch, epoch_result_list): """ Logs the results of an epoch during model training or evaluation. This method records various metrics (like loss and accuracy) for a specific epoch. It updates the action status and logs the results for tracking purposes. Parameters ---------- epoch : int The epoch number for which the results are being logged. results : list of dict A list of dictionaries containing the metric results for the epoch. Returns ------- None Raises ------ ValueError If the epoch number is invalid. Examples -------- >>> tracker.log_epoch_results(1, [{'loss': 0.25, 'accuracy': 0.92}]) """ try: epoch_result_list = self.round_metrics(epoch_result_list) model_log_payload = { "_idModel": self._idModel_str, "_idAction": self.action_id_str, "epoch": epoch, "epochDetails": epoch_result_list, } headers = {"Content-Type": "application/json"} path = f"/v1/model_logging/model/{self._idModel_str}/train_epoch_log" self.rpc.post(path=path, headers=headers, payload=model_log_payload) except Exception as e: log_error(__file__, "log_epoch_results", str(e)) print(f"Exception in log_epoch_results: {str(e)}") self.update_status("error", "error", "Failed to log epoch results") sys.exit(1)
[docs] def round_metrics(self, epoch_result_list): """Rounds the metrics in the epoch results to 4 decimal places. Parameters ---------- epoch_result_list : list A list of result dictionaries for the epoch. Each dictionary contains: - "metricValue" (float): The value of the metric to be rounded. Returns ------- list The updated list of epoch results with rounded metrics. Each metric value is rounded to four decimal places, with special handling for invalid values (NaN or infinity). Examples -------- >>> results = [{'metricValue': 0.123456}, {'metricValue': float('inf')}, {'metricValue': None}] >>> rounded_results = round_metrics(results) >>> print(rounded_results) [{'metricValue': 0.1235}, {'metricValue': 0}, {'metricValue': 0.0001}] """ for metric in epoch_result_list: if metric["metricValue"] is not None: # Check if the value is within JSON-compliant range if math.isinf(metric["metricValue"]) or math.isnan( metric["metricValue"] ): metric["metricValue"] = 0 else: metric["metricValue"] = round(metric["metricValue"], 4) if metric["metricValue"] == 0: metric["metricValue"] = 0.0001 return epoch_result_list
[docs] def upload_checkpoint(self, checkpoint_path, model_type="trained"): """Uploads a model checkpoint to the backend system. Parameters ---------- checkpoint_path : str The file path of the checkpoint to upload. This should point to a valid model checkpoint file. model_type : str, optional The type of the model ("trained" or "exported"). Defaults to "trained", which refers to a model that has been trained but not yet exported. Returns ------- bool True if the upload was successful, False otherwise. The function will log an error and exit if an exception occurs during the upload process. Examples -------- >>> success = upload_checkpoint("path/to/checkpoint.pth") >>> if success: >>> print("Checkpoint uploaded successfully!") >>> else: >>> print("Checkpoint upload failed.") """ try: if self.action_type == "model_export" and model_type == "exported": model_id = self.action_doc["_idService"] else: model_id = self._idModel_str presigned_url = self.rpc.get( path=f"/v1/model/get_model_upload_path", params={ "modelID": model_id, "modelType": model_type, "filePath": checkpoint_path.split("/")[-1], "expiryTimeInMinutes": 59, }, )["data"] with open(checkpoint_path, "rb") as file: response = requests.put(presigned_url, data=file) if response.status_code == 200: print("Upload Successful") return True else: print(f"Upload failed with status code: {response.status_code}") return False except Exception as e: log_error(__file__, "upload_checkpoint", str(e)) print(f"Exception in upload_checkpoint: {str(e)}") self.update_status("error", "error", "Checkpoint upload failed") sys.exit(1)
[docs] def download_model(self, model_path, model_type="trained"): """Downloads a model from the backend system. Parameters ---------- model_path : str The path to save the downloaded model. The file will be saved at this location after downloading. model_type : str, optional The type of the model ("trained" or "exported"). Defaults to "trained". Returns ------- bool True if the download was successful, False otherwise. The function will log an error and exit if an exception occurs during the download process. Examples -------- >>> success = download_model("path/to/save/model.pth") >>> if success: >>> print("Model downloaded successfully!") >>> else: >>> print("Model download failed.") """ try: model_id = self._idModel_str if model_type == "trained": presigned_url = self.rpc.post( path=f"/v1/model/get_model_download_path", payload={ "modelID": model_id, "modelType": model_type, "expiryTimeInMinutes": 59, }, )["data"] if model_type == "exported": presigned_url = self.rpc.post( path=f"/v1/model/get_model_download_path", payload={ "modelID": model_id, "modelType": model_type, "expiryTimeInMinutes": 59, "exportFormat": self.action_details["runtimeFramework"], }, )["data"] response = requests.get(presigned_url) if response.status_code == 200: with open(model_path, "wb") as file: file.write(response.content) print("Download Successful") return True else: print(f"Download failed with status code: {response.status_code}") return False except Exception as e: log_error(__file__, "download_model", str(e)) print(f"Exception in download_model: {str(e)}") self.update_status("error", "error", "Model download failed") sys.exit(1)
def _validate_eval_results(self, list_of_result_dicts): """ Validates the evaluation results to ensure all required splits, metrics, and categories are present. Parameters ---------- list_of_result_dicts : list A list of dictionaries containing the evaluation results. Each dictionary should include 'split', 'metric', 'value', and optionally 'category'. Raises ------ ValueError If any required split, metric, or category is missing. """ # Validate input data res_dict = {} for entry in list_of_result_dicts: split_type = entry.get("splitType") metric = entry.get("metricName") value = entry.get("metricValue") category = entry.get("category") if not split_type or not metric or value is None: raise ValueError(f"Missing fields in entry: {entry}") if split_type not in res_dict: res_dict[split_type] = {} if metric not in res_dict[split_type]: res_dict[split_type][metric] = {} res_dict[split_type][metric][category] = value # Instantiate ModelFamily to fetch supported metrics response, error, message = fetch_supported_runtimes_metrics(self.session , self.action_doc["_idProject"] ,self.action_details["model_inputs"], self.action_details["model_outputs"]) if error: raise ValueError("Failed to fetch supported runtimes and metrics.") model_metrics = response.get("supportedMetrics", []) categories = list(self.get_index_to_category().values()) + ["all"] # Validation logic messages = [] for split_type in ["val", "test"]: if split_type not in res_dict: messages.append(f"The evaluation metrics are not present for {split_type} split.") else: for metric in model_metrics: for category in categories: if category not in res_dict[split_type].get(metric, {}): messages.append(f"The metric '{metric}' is missing for the category '{category}' in '{split_type}' split.") if messages: raise ValueError("\n".join(messages))
[docs] def save_evaluation_results(self, list_of_result_dicts): """Saves the evaluation results for a model. Parameters ---------- list_of_result_dicts : list A list of dictionaries containing the evaluation results. Each dictionary should include relevant metrics and their values for the model's performance. Raises ------ Exception Logs an error and exits if an exception occurs during the saving process. Examples -------- >>> evaluation_results = [ >>> {"metric": "accuracy", "value": 0.95}, >>> {"metric": "loss", "value": 0.05}, >>> ] >>> save_evaluation_results(evaluation_results) """ try: self._validate_eval_results(list_of_result_dicts) except Exception as e: self.log_error(__file__, "validate_evaluation_results", str(e)) print(f"Exception in validate_evaluation_results: {str(e)}") try: url = '/v1/model/add_eval_results' Payload = { "_idModel": self._idModel, "_idDataset": self.action_details["_idDataset"], "_idProject": self.action_doc["_idProject"], "isOptimized": self.action_details.get("isOptimized", False), "runtimeFramework": self.action_details.get( "runtimeFramework", "Pytorch" ), "datasetVersion": self.action_details["datasetVersion"], "splitTypes": "", "evalResults": list_of_result_dicts, } self.rpc.post(path=url, payload=Payload) except Exception as e: self.log_error(__file__, "save_evaluation_results", str(e)) print(f"Exception in save_evaluation_results: {str(e)}") self.update_status("error", "error", "Failed to save evaluation results") sys.exit(1)
[docs] def add_index_to_category(self, indexToCat): """Adds an index-to-category mapping to the model. This function is used to establish a relationship between numerical indices and their corresponding categorical labels for the model. This mapping is essential for interpreting the model's output, particularly when the model is designed to classify input data into distinct categories. When to Use: ------------- - This function is typically called after the model has been trained but before deploying the model for inference. It ensures that the indices output by the model during predictions can be accurately translated to human-readable category labels. - It is also useful when there are changes in the class labels or when initializing a new model. Parameters ---------- indexToCat : dict A dictionary mapping integer indices to category names. For example, `{0: 'cat', 1: 'dog', 2: 'bird'}` indicates that index 0 corresponds to 'cat', index 1 to 'dog', and index 2 to 'bird'. Raises ------ Exception If an error occurs while trying to add the mapping, it logs the error details and exits the process. Examples -------- >>> index_mapping = {0: 'cat', 1: 'dog', 2: 'bird'} >>> add_index_to_category(index_mapping) """ try: url = f"/v1/model/{self._idModel}/update_index_to_cat" payload = {"indexToCat": indexToCat} self.rpc.put(path=url, payload=payload) except Exception as e: log_error(__file__, "add_index_to_category", str(e)) print(f"Exception in add_index_to_category: {str(e)}") self.update_status("error", "error", "Failed to add index to category") sys.exit(1)
[docs] def get_index_to_category(self, is_exported=False): """Fetches the index-to-category mapping for the model. This function retrieves the current mapping of indices to categories from the backend system. This is crucial for understanding the model's predictions, as it allows users to decode the model outputs back into meaningful category labels. When to Use: ------------- - This function is often called before making predictions with the model to ensure that the index-to-category mapping is up to date and correctly reflects the model's configuration. - It can also be used after exporting a model to validate that the expected mappings are correctly stored and accessible. Parameters ---------- is_exported : bool, optional A flag indicating whether to fetch the mapping for an exported model. Defaults to False. If True, the mapping is retrieved based on the export ID. Returns ------- dict The index-to-category mapping as a dictionary, where keys are indices and values are corresponding category names. Raises ------ Exception If an error occurs during the retrieval process, it logs the error details and exits the process. Examples -------- >>> mapping = get_index_to_category() >>> print(mapping) {0: 'cat', 1: 'dog', 2: 'bird'} >>> exported_mapping = get_index_to_category(is_exported=True) >>> print(exported_mapping) {0: 'cat', 1: 'dog'} """ try: url = "/v1/model/model_train/" + str(self._idModel_str) if is_exported: url = f"/v1/model/get_model_train_by_export_id?exportId={self._idModel_str}" modelTrain_doc = self.rpc.get(url)["data"] self.index_to_category = modelTrain_doc.get("indexToCat", {}) return self.index_to_category except Exception as e: log_error(__file__, "get_index_to_category", str(e)) print(f"Exception in get_index_to_category: {str(e)}") self.update_status("error", "error", "Failed to get index to category") sys.exit(1)
[docs] def get_model_train(self, is_exported=False): try: url = "/v1/model/model_train/" + str(self._idModel_str) if is_exported: url = f"/v1/model/get_model_train_by_export_id?exportId={self._idModel_str}" model_train_doc = self.rpc.get(url)['data'] return model_train_doc except Exception as e: log_error(__file__, 'get_model_train', str(e)) print(f"Exception in get_model_train: {str(e)}") self.update_status("error", "error", "Failed to get model train") sys.exit(1)