Source code for matrice.models

import sys

import requests

from matrice.dataset import Dataset
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta

from matrice.utils import handle_response

[docs] class Model: """ The `Model` class provides methods for interacting with models in a project, including fetching summaries, listing models, and performing evaluations. Parameters ---------- session : Session A session object containing the project ID and RPC client. model_id : str, optional The unique identifier for the model (default is None). model_name : str, optional The name of the model (default is an empty string). Example ------- >>> session = Session(project_id="project123") >>> model = Model(session, model_id="model789") """
[docs] def __init__(self, session, model_id=None, model_name=""): self.session = session self.project_id = session.project_id self.last_refresh_time = datetime.now() assert model_id or model_name, "Either model_id or model_name must be provided" self.model_id = model_id self.model_train_name = model_name self.rpc = session.rpc # Model Train Details details_response , err , msg = self.get_details() self.details = details_response self.model_id = details_response.get("_id") self.dataset_id = details_response.get("_idDataset") self.model_info_id = details_response.get("_idModelInfo") self.experiment_id = details_response.get("_idExperiment") self.action_status_id = details_response.get("_idActionStatus") self.model_family_id = details_response.get("_idModelFamily") self.project_id = details_response.get("_idProject") self.user_name = details_response.get("userName") self.user_id = details_response.get("_idUser") self.model_key = details_response.get("modelKey") self.model_family_name = details_response.get("modelName") self.auto_ml = details_response.get("autoML") self.params_millions = details_response.get("paramsMillions") self.tuning_type = details_response.get("tuningType") self.training_framework = details_response.get("trainingFramework") self.model_checkpoint = details_response.get("modelCheckpoint") self.checkpoint_type = details_response.get("checkpointType") self.primary_metric = details_response.get("primaryMetric") self.test_score = details_response.get("testScore") self.val_score = details_response.get("valScore") self.dataset_name = details_response.get("datasetName") self.status = details_response.get("status") self.dataset_version = details_response.get("datasetVersion") self.model_inputs = details_response.get("modelInputs", []) self.model_outputs = details_response.get("modelOutputs", []) self.target_runtime = details_response.get("targetRuntime", []) self.action_config = details_response.get("actionConfig", {}) self.model_config = details_response.get("modelConfig", {}) self.model_train_name = details_response.get("modelTrainName") self.val_split_result = details_response.get("valSplitResult", []) self.test_split_result = details_response.get("testSplitResult", []) self.index_to_cat = details_response.get("indexToCat", {}) self.best_epoch = details_response.get("bestEpoch") self.cloud_path = details_response.get("cloudPath") self.experiment_name = details_response.get("experimentName") self.created_at = details_response.get("createdAt") self.updated_at = details_response.get("updatedAt") self.architecture=details_response.get("modelKey")
[docs] def refresh(self): """ Refresh the instance by reinstantiating it with the previous values. """ # Check if two minutes have passed since the last refresh if datetime.now() - self.last_refresh_time < timedelta(minutes=2): raise Exception("Refresh can only be called after two minutes since the last refresh.") # Capture the current state state = self.__dict__.copy() init_params = { 'session': self.session, 'model_name': self.model_name, } # Reinitialize the instance self.__init__(**init_params) # Update the last refresh time self.last_refresh_time = datetime.now()
[docs] def get_details(self): """ Get model details based on the provided ID or name. Returns ------- tuple A tuple containing the model details, error message, and status message. Example ------- >>> details, error, message = model.get_details() >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model details: {details}") """ id = self.model_id name = self.model_name if id: try: return self._get_model_train_by_id() except Exception as e: print(f"Error retrieving model train by id: {e}") elif name: try: return self._get_model_train_by_name() except Exception as e: print(f"Error retrieving model train by name: {e}") else: raise ValueError( "At least one of 'model_id' or 'model_name' must be provided." )
[docs] def rename(self, name): """ Update the name of the trained model. Parameters ---------- name : str The new name for the trained model. Returns ------- tuple A tuple with the update result, error message, and status message. Example ------- >>> result, error, message = model.rename("NewModelName") >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model name updated: {result}") """ if self.model_id is None: print("Set Model Id for model object") sys.exit(0) path = f"/v1/model/{self.model_id}/update_modelTrain_name" headers = {"Content-Type": "application/json"} model_payload = {"modelTrainId": self.model_id, "name": name} resp = self.rpc.put(path=path, headers=headers, payload=model_payload) return handle_response( resp, "Model train name updated successfully", "Could not update model train name", )
[docs] def delete(self): """ Delete the trained model. Returns ------- tuple A tuple with the deletion result, error message, and status message. Example ------- >>> result, error, message = model.delete() >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model deleted: {result}") """ if self.model_id is None: print("Set Model Id for model object") sys.exit(0) path = f"/v1/model/delete_model_train/{self.model_id}" resp = self.rpc.delete(path=path) return handle_response( resp, "Model train deleted successfully", "Could not delete model train" )
[docs] def get_prediction(self, image_path): """ Tests a trained model for a given image. Parameters: ----------- image_path : str The path to the image for testing. Returns: -------- tuple: A tuple consisting of (result, error, message) with the test results. Example: -------- >>> result, error, message = model.test_model("/path/to/test_image.jpg") >>> print(result) {'test_result': 'success', 'confidence': 0.85} """ files = {"image": open(image_path, "rb")} # Model testing url = f"/v1/model_prediction/model_test/{self.model_id}?projectId={self.project_id}" resp = self.rpc.post(url, files=files) success_message = "Model test completed successfully" error_message = "An error occurred while testing the model." return handle_response(resp, success_message, error_message)
[docs] def get_eval_result(self, dataset_id, dataset_version, split_type): """ Fetch the evaluation result of a trained model using a specific dataset version and split type. Parameters ---------- dataset_id : str The ID of the dataset. dataset_version : str The version of the dataset. split_type : str The type of split used for the evaluation. Returns ------- tuple A tuple with the evaluation result, error message, and status message. Example ------- >>> eval_result, error, message = model.get_eval_result("dataset123", "v1.0", "train") >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Evaluation result: {eval_result}") """ D = Dataset(self.session, dataset_id) dataset_info, _, _ = D.get_processed_versions() if dataset_info is None: print("No datasets found") sys.exit(0) flag = False for data_info in dataset_info: if dataset_id == data_info["_id"]: if dataset_version in data_info["processedVersions"]: flag = True break if flag == False: print( "Dataset or Dataset version does not exist. Can not use this dataset version to get/create a eval." ) sys.exit(0) if self.model_id is None: print("Model Id is required for this operation") sys.exit(0) path = "/v1/model/get_eval_result" headers = {"Content-Type": "application/json"} model_payload = { "_idDataset": dataset_id, "_idModel": self.model_id, "datasetVersion": dataset_version, "splitType": split_type, } resp = self.rpc.post(path=path, headers=headers, payload=model_payload) return handle_response( resp, "Eval result fetched successfully", "An error occurred while fetching Eval result", )
[docs] def plot_eval_results(self): """ Plot the evaluation results for the model. Example ------- >>> model.plot_eval_results() """ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns eval_result = self.get_eval_result( dataset_id=self.dataset_id, dataset_version=self.dataset_version, split_type=["train", "val", "test"], )[0] df = pd.DataFrame(eval_result) # Set up the figure plt.figure(figsize=(14, 12)) # List of unique metrics metrics = df["metricName"].unique() num_metrics = len(metrics) # Loop through each metric and create a horizontal bar plot for i, metric in enumerate(metrics, 1): plt.subplot((num_metrics + 1) // 2, 2, i) # Filter data for the current metric metric_data = df[df["metricName"] == metric] # Create horizontal bar plot sns.barplot( data=metric_data, x="metricValue", y="splitType", hue="category", orient="h", ) # Set titles and labels plt.xlabel(metric) plt.xlim(0, 1) # Assuming metric values are between 0 and 1 plt.legend(title="Category") plt.tight_layout() plt.show()
def _get_model_train_by_id(self): """ Fetch details of a specific trained model by its ID. Returns ------- tuple A tuple with the model training data, error message, and status message. Example ------- >>> model_data, error, message = model._get_model_train_by_id() >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model data: {model_data}") """ path = f"/v1/model/model_train/{self.model_id}" resp = self.rpc.get(path=path) return handle_response( resp, "Model train by ID fetched successfully", "Could not fetch model train by ID", ) def _get_model_train_by_name(self): """ Fetch details of a specific trained model by its name. Returns ------- tuple A tuple with the model training data, error message, and status message. Example ------- >>> model_data, error, message = model._get_model_train_by_name() >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model data: {model_data}") """ if self.model_name == "": print( "Model name not set for this Model train. Cannot perform the operation for Model without model name" ) sys.exit(0) path = f"/v1/model/model_train/get_model_train_by_name?modelTrainName={self.model_name}" resp = self.rpc.get(path=path) return handle_response( resp, "Model train by name fetched successfully", "Could not fetch model train by name", )
[docs] def add_evaluation( self, dataset_id, dataset_version, split_types, is_pruned=False, is_gpu_required=False, ): """ Add a new model evaluation using specified parameters. Parameters ---------- dataset_id : str The ID of the dataset. dataset_version : str The version of the dataset. split_types : list The split types used in the evaluation. is_pruned : bool, optional Whether the model is pruned (default is False). is_gpu_required : bool, optional Whether the model requires a GPU (default is False). Returns ------- tuple A tuple with the evaluation result, error message, and status message. Example ------- >>> result, error, message = model.add_model_eval( >>> id_dataset="dataset123", >>> dataset_version="v1.0", >>> split_types=["train", "val"], >>> ) """ if self.model_id is None: print("Set Model Id for model object") sys.exit(0) model_by_id_resp, _, _ = self._get_model_train_by_id() path = "/v1/model/add_model_eval" headers = {"Content-Type": "application/json"} model_payload = { "_idModel": self.model_id, "_idProject": self.project_id, "isOptimized": False, "isPruned": is_pruned, "runtimeFramework": "Pytorch", "_idDataset": dataset_id, "_idExperiment": model_by_id_resp["_idExperiment"], "datasetVersion": dataset_version, "gpuRequired": is_gpu_required, "splitTypes": split_types, "modelType": "trained", "exportFormat": None, } resp = self.rpc.post(path=path, headers=headers, payload=model_payload) return handle_response( resp, "Model eval added successfully", "An error occurred while adding model eval", )
[docs] def get_download_path(self): """ Get the download path for the specified model type. There are 2 types of model types: trained and exported. Parameters ---------- model_type : str The type of the model to download. Returns ------- tuple A tuple with the download path, error message, and status message. Example ------- >>> download_path, error, message = model.get_model_download_path("trained") >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Download path: {download_path}") """ if self.model_id is None: print( "Model id not set for this model. Cannot perform the operation for model without model id" ) sys.exit(0) path = "/v1/model/get_model_download_path" headers = {"Content-Type": "application/json"} model_payload = { "modelID": self.model_id, "modelType": "trained", "expiryTimeInMinutes": 15, } resp = self.rpc.post(path=path, headers=headers, payload=model_payload) return handle_response( resp, "Model download path fetched successfully and it will expire in 15 mins", "An error occured while downloading the model", )
[docs] def download_model(self, file_name): """ Download the specified model type to a local file. There are 2 types of model types: trained and exported. Parameters ---------- file_name : str The name of the file to save the downloaded model. model_type : str The type of the model to download. Default is "trained". Returns ------- tuple A tuple with the download status, error message, and status message. Example ------- >>> result, error, message = model.download_model("model.pth", model_type="trained") >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Model downloaded: {result}") """ presigned_url = self.rpc.post( path=f"/v1/model/get_model_download_path", payload={ "modelID": self.model_id, "modelType": "trained", "expiryTimeInMinutes": 59, }, )["data"] response = requests.get(presigned_url) if response.status_code == 200: with open(file_name, "wb") as file: file.write(response.content) print("Model downloaded successfully") return file_name else: print(f"Model download failed with status code: {response.status_code}") return ""
[docs] def get_model_training_logs(self): """ Fetch training logs for the specified model. This method retrieves the logs of the training epochs for a model, including both training and validation metrics such as losses and accuracy. Returns ------- tuple A tuple containing: - A dictionary with the response from the RPC call. - An error message if the request fails. - A success message if the request succeeds. Example ------- >>> response, error, message = model_logging.get_model_training_logs() >>> if error: >>> print(f"Error: {error}") >>> else: >>> print(f"Success: {message}") """ path = f"/v1/model_logging/model/{self.model_id}/train_epoch_logs" resp = self.rpc.get(path=path) if resp.get("success"): error = None message = "Successfully fetched model logs." else: error = resp.get("message") message = "Failed to fetch model logs." return resp, error, message
[docs] def plot_epochs_losses(self): """ Plot training and validation losses over epochs. This method generates two subplots: one for the training losses and one for the validation losses, displaying how these metrics evolve over the epochs. Returns ------- None Example ------- >>> model_logging.plot_epochs_losses() """ resp, error, message = self.get_model_training_logs() training_logs = resp["data"] epochs = [] metrics = {"train": {}, "val": {}} for epoch_data in training_logs: epochs.append(epoch_data["epoch"]) for detail in epoch_data["epochDetails"]: metric_name = detail["metricName"] metric_value = detail["metricValue"] split_type = detail["splitType"] if "loss" in metric_name: if split_type not in metrics: metrics[split_type] = [] if metric_name not in metrics[split_type]: metrics[split_type][metric_name] = [] metrics[split_type][metric_name].append(metric_value) # Set plot style sns.set(style="whitegrid") # Create figure and axes fig, axs = plt.subplots(2, 1, figsize=(12, 18)) # Plot training losses for split_type, split_metrics in metrics.items(): for metric_name in split_metrics.keys(): if split_type == "train": axs[0].plot( epochs, split_metrics[metric_name], label=f"{split_type} {metric_name}", ) elif split_type == "val": axs[1].plot( epochs, split_metrics[metric_name], label=f"{split_type} {metric_name}", ) # Set labels and titles axs[0].set_xlabel("Epoch", fontsize=14) axs[0].set_ylabel("Loss", fontsize=14) axs[0].legend(fontsize=12) axs[0].set_title("Training Losses over Epochs", fontsize=16) axs[0].grid(True) axs[1].set_xlabel("Epoch", fontsize=14) axs[1].set_ylabel("Loss", fontsize=14) axs[1].legend(fontsize=12) axs[1].set_title("Validation Losses over Epochs", fontsize=16) axs[1].grid(True) plt.tight_layout() plt.show()
[docs] def plot_epochs_metrics(self): """ Plot training and validation metrics (excluding losses) over epochs. This method generates subplots for each non-loss metric, such as accuracy, showing how these metrics change during training epochs for both training and validation splits. Returns ------- None Example ------- >>> model_logging.plot_epochs_metrics() """ resp, error, message = self.get_model_training_logs() training_logs = resp["data"] epochs = [] metrics = {"train": {}, "val": {}} metrics_names = set() # Parse metrics data from logs for epoch_data in training_logs: epochs.append(epoch_data["epoch"]) for detail in epoch_data["epochDetails"]: metric_name = detail["metricName"] metric_value = detail["metricValue"] split_type = detail["splitType"] if "loss" not in metric_name: if split_type not in metrics: metrics[split_type] = [] if metric_name not in metrics[split_type]: metrics[split_type][metric_name] = [] metrics[split_type][metric_name].append(metric_value) metrics_names.add(metric_name) metrics_names = list(metrics_names) num_graphs = len(metrics_names) # Set plot style sns.set(style="whitegrid") fig, axs = plt.subplots(num_graphs, 1, figsize=(12, 18)) # Plot each metric for metric_index, metric_name in enumerate(metrics_names): for split_type, split_metrics in metrics.items(): if metric_name in metrics[split_type]: axs[metric_index].plot( epochs, split_metrics[metric_name], label=f"{split_type} {metric_name}", ) axs[metric_index].set_xlabel("Epoch", fontsize=14) axs[metric_index].set_ylabel(metric_name, fontsize=14) axs[metric_index].legend(fontsize=12) axs[metric_index].set_title(f"{metric_name} over Epochs", fontsize=16) axs[metric_index].grid(True) plt.tight_layout() plt.show()
[docs] def model_test(self, model_type="trained"): """ Fetch information about the deployment server for a specific model. Parameters ---------- model_train_id : str The ID of the model training instance. model_type : str The type of model (e.g., 'trained', 'exported'). Returns ------- tuple A tuple containing three elements: - API response (dict): The raw response from the API. - error_message (str or None): Error message if an error occurred, None otherwise. - status_message (str): A status message indicating success or failure. Examples -------- >>> resp, err, msg = model.model_test("trained") >>> if err: >>> print(f"Error: {err}") >>> else: >>> print(f"Deployment server details : {resp}") """ path = f"/v1/deployment/get_deploy_server/{self.model_id}/{model_type}" resp = self.rpc.get(path=path) return handle_response( resp, "Deployment server fetched successfully", "An error occurred while trying to fetch deployment server.", )