import sys
from matrice.utils import handle_response , get_summary
from datetime import datetime, timedelta
import requests
[docs]
class ExportedModel:
"""
A class to handle operations related to model export within a project.
The `ExportedModel` class facilitates managing model export processes,
including fetching summaries, listing available exported models, and performing
evaluation tasks on optimized inferences.
Parameters
----------
session : Session
An active session object that holds project information such as the project ID and RPC client.
model_export_id : str, optional
A unique identifier for the model export or inference optimization. Defaults to None.
model_export_name : str, optional
The name of the model export or inference optimization. Defaults to an empty string.
Attributes
----------
project_id : str
The project ID associated with the current session.
model_export_id : str or None
The unique identifier for the model export, provided at initialization or set later.
model_export_name : str
The name of the model export, provided at initialization or set later.
rpc : object
The RPC client used to make API requests.
Example
-------
>>> session = Session(account_number=account_number)
>>> exported_model = ExportedModel(session=session, model_export_id="12345", model_export_name="sample_export")
>>> print(exported_model.model_export_name) # Output: "sample_export"
"""
[docs]
def __init__(self, session, model_export_id=None, model_export_name=""):
self.session = session
self.project_id = session.project_id
self.last_refresh_time = datetime.now()
assert model_export_id or model_export_name, "At least one of 'model_export_id' or 'model_export_name' must be provided."
self.model_export_id = model_export_id
self.model_export_name = model_export_name
self.rpc = session.rpc
# Summary information
self.summary_response, self.err, self.msg= get_summary(self.session , self.project_id, service_name="exports")
if self.summary_response:
summary_data = self.summary_response # Store summary data
model_count_by_status = summary_data.get("modelCountByStatus", {})
self.error_model_count = model_count_by_status.get("error", 0)
self.exported_model_count = model_count_by_status.get("exported", 0)
self.exporting_model_count = model_count_by_status.get("exporting", 0)
self.queued_model_count = model_count_by_status.get("queued", 0)
self.total_models = summary_data.get("total", 0)
else:
print(f"Error fetching summary: {self.summary_response.get('message')}")
# Details of Exported Model
details_response , err , msg = self.get_details()
self.details = details_response
self.model_train_id = details_response.get("_idModelTrain")
self.model_train_name = details_response.get("modelTrainName")
self.experiment_name = details_response.get("experimentName")
self.dataset_name = details_response.get("datasetName")
self.model_name = details_response.get("modelName")
self.model_inputs = details_response.get("modelInputs", [])
self.model_info_id = details_response.get("_idModelInfo")
self.user_id = details_response.get("_idUser")
self.user_name = details_response.get("userName")
self.model_export_name = details_response.get("modelExportName")
self.model_outputs = details_response.get("modelOutputs", [])
self.export_format = details_response.get("exportFormat")
self.dataset_id = details_response.get("_idDataset")
self.project_id = details_response.get("_idProject")
self.action_id = details_response.get("_idAction")
self.dataset_version = details_response.get("datasetVersion")
self.gpu_required = details_response.get("gpuRequired")
self.action_config = details_response.get("actionConfig", {})
self.model_config = details_response.get("modelConfig", {})
self.val_split_results = details_response.get("valSplitResults", [])
self.test_split_results = details_response.get("testSplitStruct", [])
self.status = details_response.get("status")
self.cloud_path = details_response.get("cloudPath")
self.created_at = details_response.get("createdAt")
self.baseModel=details_response.get("modelTrainName")
self.architecture=details_response.get("modelName")
self.training_framework=details_response.get("trainingFramework")
self.lastUpdated=details_response.get("updatedAt")
[docs]
def refresh(self):
"""
Refresh the instance by reinstantiating it with the previous values.
"""
# Check if two minutes have passed since the last refresh
if datetime.now() - self.last_refresh_time < timedelta(minutes=2):
raise Exception("Refresh can only be called after two minutes since the last refresh.")
# Capture the current state
state = self.__dict__.copy()
init_params = {
'session': self.session,
'model_export_name': self.model_export_name,
}
# Reinitialize the instance
self.__init__(**init_params)
# Update the last refresh time
self.last_refresh_time = datetime.now()
[docs]
def get_details(self):
"""
Retrieve details of the model export based on the model export ID or name.
This method fetches details by ID if available; otherwise, it attempts
to fetch by name. Raises a ValueError if neither identifier is provided.
Returns
-------
tuple
A tuple containing the model export details, error message (if any), and a status message.
Raises
------
ValueError
If neither 'model_export_id' nor 'model_export_name' is provided.
Example
-------
>>> details, err, msg = exported_model.get_details()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Export Details: {details}")
"""
id = self.model_export_id
name = self.model_export_name
if id:
try:
return self._get_model_export_by_id()
except Exception as e:
print(f"Error retrieving model_export by id: {e}")
elif name:
try:
return self._get_model_export_by_name()
except Exception as e:
print(f"Error retrieving model_export by name: {e}")
else:
raise ValueError(
"At least one of 'model_export_id' or 'model_export_name' must be provided."
)
def _get_model_export_by_id(self):
"""
Fetch details of a specific model export by its ID.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> details, err, msg = exported_model._get_model_export_by_id()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Export Details: {details}")
"""
path = f"/v1/model/get_model_export_by_id?modelExportId={self.model_export_id}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Model export by ID fetched successfully",
"Could not fetch model export by ID",
)
def _get_model_export_by_name(self):
"""
Fetch details of a specific model export by its name.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> details, err, msg = exported_model._get_model_export_by_name()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Export Details: {details}")
"""
if self.model_export_name == "":
print(
"Model export name not set for this Model export. Cannot perform the operation for Model export without model export name."
)
sys.exit(0)
path = f"/v1/model/model_export/get_model_export_by_name?modelExportName={self.model_export_name}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Model export by name fetched successfully",
"Could not fetch model export by name",
)
[docs]
def rename(self, updated_name):
"""
Update the name of a model export.
Parameters
----------
updated_name : str
The new name for the model export.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> result, err, msg = exported_model.rename("NewModelExportName")
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Export Name Updated: {result}")
"""
body = {
"modelExportId": self.model_export_id,
"name": updated_name,
}
headers = {"Content-Type": "application/json"}
path = f"/v1/model/{self.model_export_id}/update_modelExport_name"
resp = self.rpc.put(path=path, headers=headers, payload=body)
return handle_response(
resp,
f"Model export name updated to {updated_name}",
"Could not update the model export name",
)
[docs]
def delete(self):
"""
Delete a model export.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> result, err, msg = exported_model.delete()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Export Deleted: {result}")
"""
path = f"/v1/model/model_export/{self.model_export_id}"
resp = self.rpc.delete(path=path)
return handle_response(
resp, f"Model export deleted", "Could not delete the model export"
)
[docs]
def add_evaluation(
self,
id_dataset,
dataset_version,
split_types,
is_gpu_required=True,
is_pruned=False,
):
"""
Add a new model evaluation using specified parameters.
Parameters
----------
is_pruned : bool
Whether the model is pruned.
id_dataset : str
The ID of the dataset used for evaluation.
id_experiment : str
The ID of the experiment associated with the model.
dataset_version : str
The version of the dataset.
is_gpu_required : bool
Whether the model requires GPU for inference.
split_types : list
A list of split types used in the evaluation.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> eval_result, err, msg = exported_model.add_evaluation(
is_pruned=False,
id_dataset="dataset123", id_experiment="experiment123", dataset_version="v1.0",
is_gpu_required=True, split_types=["train", "test"])
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Evaluation added: {eval_result}")
"""
model_info , err , msg = self.get_details()
runtime_framework = model_info["exportFormat"]
model_train_info , err , msg = self.get_trained_model()
path = "/v1/model/add_model_eval"
headers = {"Content-Type": "application/json"}
model_payload = {
"_idModel": self.model_export_id,
"_idProject": self.project_id,
"isOptimized": True,
"isPruned": is_pruned,
"runtimeFramework": runtime_framework,
"_idDataset": id_dataset,
"_idExperiment": model_train_info["_idExperiment"],
"datasetVersion": dataset_version,
"gpuRequired": is_gpu_required,
"splitTypes": split_types,
"modelType": "exported",
"computeAlias": "",
}
resp = self.rpc.post(path=path, headers=headers, payload=model_payload)
return handle_response(
resp,
"Model eval added successfully",
"An error occurred while adding model eval",
)
[docs]
def get_trained_model(self):
"""
Fetch details of a model training associated with a specific export ID.
Returns
-------
tuple
A tuple containing:
- resp (dict): The API response object.
- error (str or None): Error message if the API call failed, otherwise None.
- message (str): Success or error message.
Example
-------
>>> training_data, err, msg = exported_model.get_model_train_of_the_export()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Model Training Data: {training_data}")
"""
path = f"/v1/model/get_model_train_by_export_id?exportId={self.model_export_id}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Model train by export ID fetched successfully",
"Could not fetch model train by export ID",
)
[docs]
def get_evaluation_result(self, dataset_id, dataset_version, split_types):
"""
Fetch the evaluation result of a trained model using a specific dataset version and split type.
Parameters
----------
dataset_id : str
The ID of the dataset.
dataset_version : str
The version of the dataset.
split_type : list
The types of splits used for the evaluation.
Returns
-------
tuple
A tuple with the evaluation result, error message, and status message.
Example
-------
>>> eval_result, error, message = exported_model.get_evaluation_result("dataset123", "v1.0", ["train"])
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Evaluation result: {eval_result}")
"""
path = "/v1/model/get_eval_result"
headers = {"Content-Type": "application/json"}
model_payload = {
"_idDataset": dataset_id,
"_idModel": self.model_export_id,
"datasetVersion": dataset_version,
"splitType": split_types,
}
resp = self.rpc.post(path=path, headers=headers, payload=model_payload)
return handle_response(
resp,
"Eval result fetched successfully",
"An error occurred while fetching Eval result",
)
[docs]
def get_prediction(self, image_path):
"""
Tests a exported model for a given image.
Parameters:
-----------
image_path : str
The path to the image for testing.
Returns:
--------
tuple:
A tuple consisting of (result, error, message) with the test results.
Example:
--------
>>> result, error, message = exported_model.get_prediction("/path/to/test_image.jpg")
>>> print(result)
{'test_result': 'success', 'confidence': 0.85}
"""
files = {"image": open(image_path, "rb")}
# Model testing
url = f"/v1/model_prediction/model_test/{self.model_info_id}?projectId={self.project_id}"
resp = self.rpc.post(url, files=files)
success_message = "Model test completed successfully"
error_message = "An error occurred while testing the model."
return handle_response(resp, success_message, error_message)
[docs]
def get_download_path(self):
"""
Get the download path for the specified model type. There are 2 types of model types: trained and exported.
Parameters
----------
model_type : str
The type of the model to download.
Returns
-------
tuple
A tuple with the download path, error message, and status message.
Example
-------
>>> download_path, error, message = exported_model.get_model_download_path()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Download path: {download_path}")
"""
if self.model_export_id is None:
print(
"Model id not set for this model. Cannot perform the operation for model without model id"
)
sys.exit(0)
path = "/v1/model/get_model_download_path"
headers = {"Content-Type": "application/json"}
model_payload = {
"modelID": self.model_export_id,
"modelType": "exported",
"expiryTimeInMinutes": 15,
}
resp = self.rpc.post(path=path, headers=headers, payload=model_payload)
return handle_response(
resp,
"Model download path fetched successfully and it will expire in 15 mins",
"An error occured while downloading the model",
)
[docs]
def download_model(self, file_name):
"""
Download the specified model type to a local file. There are 2 types of model types: trained and exported.
Parameters
----------
file_name : str
The name of the file to save the downloaded model.
model_type : str
The type of the model to download. Default is "trained".
Returns
-------
tuple
A tuple with the download status, error message, and status message.
Example
-------
>>> result, error, message = exported_model.download_model("model.pth")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model downloaded: {result}")
"""
presigned_url = self.rpc.post(
path=f"/v1/model/get_model_download_path",
payload={
"modelID": self.model_export_id,
"modelType": "exported",
"expiryTimeInMinutes": 59,
"exportFormat": self.export_format,
},
)["data"]
response = requests.get(presigned_url)
if response.status_code == 200:
with open(file_name, "wb") as file:
file.write(response.content)
print("Model downloaded successfully")
return file_name
else:
print(f"Model download failed with status code: {response.status_code}")
return ""