from matrice.utils import handle_response
import os
import json
from datetime import datetime, timedelta
[docs]
def list_public_model_families(session, project_type="classification", page_size=10, page_num=0):
"""
Fetch public model families for a given project.
Parameters
----------
project_type : str, optional
The type of the project (default is "classification")(Available types are "detection" and "instance_segmentation").
page_size : int, optional
The number of model families to fetch per page (default is 10).
page_num : int, optional
The page number to fetch (default is 0).
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = list_public_model_families(session,"classification")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Public model families: {resp}")
"""
path = f"/v1/model_store/list_public_model_families?projectType={project_type}&pageSize={page_size}&pageNum={page_num}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all public model families",
"An error occured while fetching the public model families",
)
[docs]
def list_private_model_families(session, project_id=None, project_name=None, page_size=10, page_num=0):
"""
Fetch private model families for a given project.
Parameters
----------
project_id : str
The ID of the project.
project_name : str
The name of the project.
page_size : int, optional
The number of model families to fetch per page (default is 10).
page_num : int, optional
The page number to fetch (default is 0).
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = list_private_model_families(session,"66912342583678074789d")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Private model families: {resp}")
"""
assert project_id is not None or project_name is not None
path = f"/v1/model_store/list_private_model_families?projectId={project_id}&pageSize={page_size}&pageNum={page_num}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all private model families",
"An error occured while fetching the private model families",
)
[docs]
def list_public_model_archs(session, project_type="classification", page_size=10, page_num=0):
"""
Fetch public model architectures for a given project.
Parameters
----------
project_type : str, optional
The type of the project (default is "classification")(Available types are "detection" and "instance_segmentation").
page_size : int, optional
The number of model architectures to fetch per page (default is 10).
page_num : int, optional
The page number to fetch (default is 0).
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = list_public_model_archs(session,"classification")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Public model architectures: {resp}")
"""
path = f"/v1/model_store/list_public_model_archs?projectType={project_type}&pageSize={page_size}&pageNum={page_num}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all public model architectures",
"An error occured while fetching the public model architectures",
)
[docs]
def list_private_model_archs(session, project_id=None, project_name=None, page_size=10, page_num=0):
"""
Fetch private model architectures for a given project.
Parameters
----------
project_id : str
The ID of the project.
project_name : str
The name of the project.
page_size : int, optional
The number of model architectures to fetch per page (default is 10).
page_num : int, optional
The page number to fetch (default is 0).
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = list_private_model_archs(session,"66912342583678074789d")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Private model architectures: {resp}")
"""
assert project_id is not None or project_name is not None
path = f"/v1/model_store/list_private_model_archs?projectId={project_id}&pageSize={page_size}&pageNum={page_num}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all private model architectures",
"An error occured while fetching the private model architectures",
)
def _get_all_models(session, project_id=None, project_name=None, project_type="classification"):
"""
Fetch all models for a given project.
Parameters
----------
project_id : str
The ID of the project.
project_type : str, optional
The type of the project (default is "classification")(Available types are "detection" and "instance_segmentation").
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = get_all_models(session,"66912342583678074789d")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"All models: {resp}")
"""
path = f"/v1/model_store/get_all_models?projectId={project_id}&projectType={project_type}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all model infos",
"An error occured while fetching the model family",
)
def _get_all_model_families(session , project_id, project_name=None, project_type="classification"):
"""
Fetch all model families for a given project.
Parameters
----------
project_id : str
The ID of the project.
project_type : str, optional
The type of the project (default is "classification").
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = get_all_model_families(session,"66912342583678074789d")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"All model families: {resp}")
"""
path = f"/v1/model_store/get_all_model_families?projectId={project_id}&projectType={project_type}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all model family",
"An error occured while fetching the model family",
)
[docs]
def byom_status_summary(session, project_id, project_name):
"""
Fetch the BYOM (Bring Your Own Model) status summary for a given project.
Parameters
----------
project_id : str
The ID of the project.
project_name : str
The name of the project.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = byom_status_summary(session,"66912342583678074789d")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"BYOM status summary: {resp}")
"""
path = f"/v1/model_store/byom_status_summary?projectId={project_id}"
resp = session.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched the BYOM status summary",
"An error occured while fetching the BYOM status summary",
)
# Check if the model family already exists publicly or within the project
[docs]
def check_family_exists_by_name(session, family_name):
"""
Check if a model family exists by its name.
Parameters
----------
session : Session
The session object containing authentication information.
family_name : str
The name of the model family to check.
Returns
-------
bool
True if the model family exists, False otherwise.
Example
-------
>>> session = Session(account="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> family_name = "ResNet"
>>> exists = check_family_exists_by_name(session, family_name)
>>> if exists:
>>> print(f"The model family '{family_name}' exists.")
>>> else:
>>> print(f"The model family '{family_name}' does not exist.")
"""
path = f"/v1/model_store/check_family_exists_by_name?familyName={family_name}"
resp = session.rpc.get(path=path)
data, error, message = handle_response(
resp,
"Successfully checked model family existence",
"An error occurred while checking model family existence",
)
if error:
return False
return data.get("exists", False)
[docs]
def fetch_supported_runtimes_metrics(session, project_id, model_inputs, model_outputs):
"""
Fetch supported runtimes and metrics for a given project.
Parameters
----------
model_inputs : list
List of model inputs.
model_outputs : list
List of model outputs.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = fetch_supported_runtimes_metrics(session,["image"], ["classification"])
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Supported runtimes and metrics: {resp}")
"""
path = f"/v1/model_store/fetch_supported_runtimes_metrics?projectId={project_id}"
payload = {
"modelInputs": model_inputs,
"modelOutputs": model_outputs,
}
headers = {"Content-Type": "application/json"}
resp = session.rpc.post(path=path, headers=headers, payload=payload)
data, error, message = handle_response(
resp,
"Successfully fetched supported runtimes and metrics",
"An error occurred while fetching supported runtimes and metrics",
)
if error:
return data, error, message
runtimes = data
return runtimes, error, message
[docs]
def get_automl_config(session , project_id , experiment_id , model_count , recommended_runtime , performance_tradeoff , tuning_type="auto"):
"""
Generate AutoML configurations for model training based on specified parameters.
This static method fetches recommended model configurations from the backend and
processes them into a format suitable for model training. It calculates the
number of model variants based on hyperparameter combinations.
Parameters
----------
session : Session
Active session object for making API calls
project_id : str
Identifier for the project
experiment_id : str
Identifier for the experiment
model_count : int
Number of models to request configurations for
recommended_runtime : bool
Flag to indicate whether to only include models within recommended runtime
performance_tradeoff : float
Value indicating the trade-off between performance and resource usage
tuning_type : str, optional
Type of hyperparameter tuning strategy (default: "auto")
Returns
-------
tuple
A tuple containing three elements:
- model_archs (list): List of ModelArch instances for recommended models
- configs (list): List of configuration dictionaries for each model
Each config contains:
- is_autoML (bool): Set to True for AutoML
- tuning_type (str): Type of tuning strategy
- model_checkpoint (str): Checkpoint configuration
- checkpoint_type (str): Type of checkpoint
- action_config (dict): Raw configuration parameters
- model_config (dict): Processed configuration values
- model_counts (list): List of integers representing the number of
model variants for each model based on hyperparameter combinations
Example
-------
>>> session = Session()
>>> model_archs, configs, counts = get_automl_config(
... session=session,
... project_id="project123",
... experiment_id="exp456",
... model_count=5,
... recommended_runtime=True,
... performance_tradeoff=0.7
... )
>>> for arch, config, count in zip(model_archs, configs, counts):
... print(f"Model: {arch.model_key}, Variants: {count}")
... print(f"Config: {config}")
Notes
-----
The number of model variants (model_counts) is calculated by multiplying the
number of unique values for batch size, epochs, and learning rate for each model.
This represents the total number of training configurations that will be generated
for each model architecture.
"""
payload = {
"_idExperiment": experiment_id,
"_idProject": project_id,
"recommendedOnly": recommended_runtime,
"modelCount": model_count,
"performanceTradeoff": performance_tradeoff,
"searchType": tuning_type,
}
path = f"/v1/model_store/get_recommended_models/v2?projectId={project_id}"
headers = {"Content-Type": "application/json"}
resp = session.rpc.post(
path=path, headers=headers, payload=payload
)
model_archs = []
configs = []
model_counts = []
for model_data in resp.get("data", []):
model_key = model_data.get("modelKey")
model_family_name = model_data.get("modelFamilyName")
action_config_list = model_data.get("actionConfig", [])
print(model_key)
print(model_family_name)
# Extract the action_config and model_config from the response
action_config = {item["keyName"]: item for item in action_config_list}
model_config = {item["keyName"]: item["selectedValues"] for item in action_config_list}
# Calculate the total model count for this model
batch_size_count = len(model_config.get("batch", [1]))
epochs_count = len(model_config.get("epochs", [1]))
learning_rate_count = len(model_config.get("learning_rate", [1]))
model_specific_count = batch_size_count * epochs_count * learning_rate_count
print(model_specific_count)
model_counts.append(model_specific_count)
# Construct the new model configuration structure
config = {
'is_autoML': True,
'tuning_type': tuning_type,
'model_checkpoint': "auto",
'checkpoint_type': 'predefined',
'action_config': action_config,
'model_config': model_config
}
model_arch = ModelArch(session=session, model_family_name=model_family_name, model_key=model_key)
model_archs[model_key] = model_arch
model_archs.append(model_arch)
configs.append(config)
return model_archs, configs , model_counts
[docs]
class ModelArch:
"""
A class to interact with model architectures through the model architecture API.
This class handles fetching and storing model architecture information, including
configuration parameters, export formats, and other model metadata.
Parameters
----------
session : Session
Active session object for making API calls
model_family_name : str
Name of the model family this architecture belongs to
model_key : str
Unique identifier key for the model architecture
Attributes
----------
account_number : str
Account number from the session
project_id : str
Project identifier from the session
model_family_name : str
Name of the model family
model_key : str
Model's unique identifier key
last_refresh_time : datetime
Timestamp of last data refresh
rpc : RPCClient
RPC client object from session for API calls
model_info_id : str or None
Model information unique identifier
model_name : str or None
Human readable name of the model
model_family_id : str or None
Unique identifier of the model family
params_millions : float or None
Number of parameters in millions
export_formats : list or None
List of supported export formats
model_config : dict or None
Default configuration parameters for model training
Notes
-----
Upon initialization, the class automatically fetches:
- Model information using _get_model_info()
- Training configuration using get_model_train_config()
- Export formats using get_export_formats()
If model_key is not provided, these fetches are skipped and the class
initializes with minimal information.
Example
-------
>>> session = Session()
>>> model = ModelArch(
... session=session,
... model_family_name="resnet",
... model_key="resnet50"
... )
>>> print(f"Model: {model.model_name}")
>>> print(f"Parameters: {model.params_millions}M")
>>> print(f"Export formats: {model.export_formats}")
Raises
------
AssertionError
If model_key or model_family_name is None
"""
[docs]
def __init__(self, session, model_family_name, model_key):
self.session = session
self.account_number = session.account_number
assert model_key is not None, "Model key must be provided"
assert model_family_name is not None, "Model family name must be provided"
self.model_family_name=model_family_name
self.project_id = session.project_id
self.last_refresh_time = datetime.now()
self.rpc = session.rpc
self.model_key = model_key
self.model_info_id = None
if model_key is not None:
model_info, error , message= self._get_model_info()
model_train_config , error , message = self.get_model_train_config()
export_formats ,error , message = self.get_export_formats()
else:
print("Class initialized without model info")
self.model_info_id = model_info["_id"]
self.model_name = model_info["modelName"]
self.model_key = model_info["modelKey"]
self.model_family_id = model_info["_idModelFamily"]
self.params_millions = model_info["paramsMillions"]
self.export_formats = export_formats
self.model_config = {
param["keyName"]: [param["defaultValue"]]
for param in model_train_config["actionConfig"]
}
[docs]
def refresh(self):
"""
Refresh the instance by reinstantiating it with the previous values.
"""
# Check if two minutes have passed since the last refresh
if datetime.now() - self.last_refresh_time < timedelta(minutes=2):
raise Exception("Refresh can only be called after two minutes since the last refresh.")
# Capture the necessary state for reinitialization
init_params = {
'session': self.session,
'model_family_name': self.model_family_name,
'model_key': self.model_key
}
# Reinitialize the instance
self.__init__(**init_params)
# Update the last refresh time
self.last_refresh_time = datetime.now()
# Fetch model info
def _get_model_info(self):
"""
Fetch model information by its ID.
Parameters
----------
model_info_id : str
The ID of the model info to fetch.
model_name : str
The name of the model.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> session = Session(account="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> model_arch = ModelArch(session, model_family_name="resnet", model_key="resnet50")
>>> resp, error, message = model_arch._get_model_info()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model info: {resp}")
"""
if self.model_info_id is not None:
path = f"/v1/model_store/model_info/{self.model_info_id}"
else:
path = f"/v1/model_store/model_info_from_model_key_and_family/{self.model_key}/{self.model_family_name}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched the model info",
"An error occured while fetching the model info",
)
[docs]
def get_train_config(self, experiment_id , tuning_type="default", model_checkpoint="auto"):
"""
Fetch the training configuration for a given model.
This function retrieves the training configuration for a specified model and experiment.
It constructs the payload with the provided parameters and sends a POST request to fetch
the model parameters. The response is then processed to extract the action configuration
and model configuration, which are used to construct the final training configuration.
Parameters
----------
experiment_id : str
The ID of the experiment for which the training configuration is to be fetched.
tuning_type : str, optional
The type of parameter search to be used for tuning (default is "default").
model_checkpoint : str, optional
The model checkpoint to be used (default is "auto").
Returns
-------
tuple
A tuple containing the ModelArch instance and the training configuration dictionary.
Example
-------
>>> session = Session(account="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> model_arch = ModelArch(session, model_family_name="resnet", model_key="resnet50")
>>> experiment_id = "your_experiment_id"
>>> model_arch_instance, train_config = model_arch.get_train_config(experiment_id)
>>> print("ModelArch Instance:", model_arch_instance)
>>> print("Training Configuration:", train_config)
"""
payload = {
"modelCheckpoint": [model_checkpoint],
"paramsSearchType": tuning_type,
"_idExperiment": experiment_id,
"_idModelInfo": self.model_info_id,
}
path = f"/v1/model_store/get_model_params/v2?projectId={self.project_id}"
headers = {"Content-Type": "application/json"}
resp = self.rpc.post(
path=path, headers=headers, payload=payload
)
# Extract the action_config and model_config from the response
action_config = {item["keyName"]: item for item in resp["data"]}
model_config = {item["keyName"]: item["selectedValues"] for item in resp["data"]}
# Construct the new model configuration structure
config = {
'is_autoML': False,
'tuning_type': tuning_type,
'model_checkpoint': model_checkpoint,
'checkpoint_type': 'predefined',
'action_config': action_config,
'model_config': model_config
}
return ModelArch(session=self.session , model_family_name=self.model_family_name,model_key=self.model_key) , config
def _get_default_train_config(self):
"""
Generates the default payload for training a model based on its configuration.
This method retrieves model information and training configuration using the model's `model_info_id`.
It constructs a payload that can be used for initiating model training with default parameters.
Parameters
----------
model_info_id : str
The unique identifier of the model for which the training payload is to be generated.
Returns
-------
list
A list containing a dictionary with default training payload settings.
Example
-------
>>> model_info_id = "model123"
>>> default_payload = model_arch.get_default_model_training_payload()
>>> print(default_payload)
[
{
'model_key': 'resnet50',
'is_autoML': False,
'tuning_type': 'manual',
'model_checkpoint': 'auto',
'checkpoint_type': 'predefined',
'params_millions': 25.6,
'model_name': 'ResNet-50',
'id_model_info': 'model123',
'action_config': {},
'model_config': {
'learning_rate': [0.001],
'batch_size': [32],
...
}
}
]
Detailed Description
--------------------
- The function first fetches model information (`model_info`) and training configuration (`model_train_config`)
using helper functions `get_model_info` and `get_model_train_config`.
- It then constructs a payload that contains details such as the model's key, name, tuning type,
and configuration parameters (e.g., learning rate, batch size) for training.
- The parameters for model training are set to their default values, which are fetched from the
model's configuration.
"""
model_info = self._get_model_info(model_info_id=self.model_info_id)
model_train_config = self.get_model_train_config(model_info_id=self.model_info_id)
model_training_payload = [
{
"model_key": model_info["modelKey"],
"is_autoML": False,
"tuning_type": "manual",
"model_checkpoint": "auto",
"checkpoint_type": "predefined",
"params_millions": model_info["paramsMillions"],
"model_name": model_info["modelName"],
"id_model_info": self.model_info_id,
"action_config": {},
"model_config": {
param["keyName"]: [param["defaultValue"]]
for param in model_train_config["actionConfig"]
},
}
]
return model_training_payload
[docs]
def get_default_export_config(self, export_format):
"""
Retrieves the default configuration for exporting a model in a specified format.
This method fetches the export configuration for the given `model_info_id` and export format,
returning a dictionary of default export settings.
Parameters
----------
model_info_id : str
The unique identifier of the model whose export configuration is to be retrieved.
export_format : str
The format in which the model is to be exported (e.g., 'ONNX', 'TF SavedModel').
Returns
-------
dict
A dictionary containing default export configuration settings, where keys are parameter names and values are default values.
Example
-------
>>> export_format = "ONNX"
>>> default_export_config = model_arch.get_default_model_export_config(export_format)
>>> print(default_export_config)
{
optimize: True,
int8: False,
...
}
"""
model_export_config , err , msg= self.get_export_action_configs(
export_format
)
default_model_export_config = {
param["keyName"]: param["defaultValue"]
for param in model_export_config["actionConfig"]
}
return default_model_export_config
# To fetch model action config
[docs]
def get_model_action_configs(self):
"""
Fetch model action configuration by its ID.
Parameters
----------
model_action_config_id : str
The ID of the model action config to fetch.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_arch.get_model_action_config()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model action config: {resp}")
"""
if self.model_info_id is not None:
path = f"/v1/model_store/model_action_config/{self.model_info_id}"
else:
path = f"/v1/model_store/model_action_config_from_model_key_and_family/{self.model_key}/{self.model_family_name}"
#path = f"/v1/model_store/model_action_config/{model_action_config_id}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched the model action config",
"An error occured while fetching the model action config",
)
[docs]
def get_export_action_configs(self,export_format):
"""
Fetch action configuration for model export.
Parameters
----------
model_info_id : str
The ID of the model info.
export_format : str
The export format.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_arch.get_action_config_for_model_export("ONNX")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Action config for model export: {resp}")
"""
path = f"/v1/model_store/get_action_config_for_model_export?modelInfoId={self.model_info_id}&exportFormat={export_format}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched all model family",
"An error occured while fetching the model family",
)
[docs]
def get_model_train_config(self):
"""
Fetch model training configuration by its ID.
Parameters
----------
model_info_id : str
The ID of the model info.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_arch.get_model_train_config()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model train config: {resp}")
"""
if self.model_info_id is not None:
path = f"/v1/model_store/get_train_config/{self.model_info_id}"
else:
path = f"/v1/model_store/model_train_config_from_model_key_and_family/{self.model_key}/{self.model_family_name}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched model train config",
"An error occured while fetching model train config",
)
[docs]
class ModelFamily:
"""
Class to interact with the model family API to get model configuration info and model-related info.
This class handles fetching and storing model family information, including model inputs, outputs,
supported runtimes, metrics, and other metadata.
Parameters
----------
session : Session
The session object containing authentication information.
model_family_id : str, optional
The ID of the model family to fetch.
model_family_name : str, optional
The name of the model family to fetch.
Attributes
----------
session : Session
The session object containing authentication information.
account_number : str
The account number from the session.
project_id : str
The project identifier from the session.
rpc : RPCClient
The RPC client object from the session for API calls.
model_family_id : str
The ID of the model family.
model_family_name : str
The name of the model family.
family_data : dict
The data of the model family fetched from the API.
model_inputs : list
List of model inputs.
model_outputs : list
List of model outputs.
model_keys : dict
Dictionary mapping model keys to model names.
description : str
Description of the model family.
training_framework : str
Training framework used for the model family.
supported_runtimes : list
List of supported runtimes.
benchmark_datasets : list
List of benchmark datasets.
supported_metrics : list
List of supported metrics.
input_format : str
Input format for the model family.
Methods
-------
__get_model_family()
Fetch a model family by its ID or name.
get_model_archs(model_name=None, model_key=None)
Fetch model information by model family or by name and key.
Example
-------
>>> session = Session(account_number="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> model_family = ModelFamily(session, model_family_name="resnet")
>>> print(f"Model Family: {model_family.model_family_name}")
>>> print(f"Model Inputs: {model_family.model_inputs}")
>>> print(f"Model Outputs: {model_family.model_outputs}")
>>> print(f"Supported Runtimes: {model_family.supported_runtimes}")
>>> print(f"Supported Metrics: {model_family.supported_metrics}")
Raises
------
AssertionError
If neither model_family_id nor model_family_name is provided.
"""
[docs]
def __init__(self, session , model_family_id=None, model_family_name = None):
self.session = session
self.account_number = session.account_number
self.project_id = session.project_id
self.rpc = session.rpc
assert model_family_id is not None or model_family_name is not None , "Either model_family_id or model_family_name must be provided"
self.model_family_id = model_family_id
self.model_family_name = model_family_name
resp , error, message = self.__get_model_family()
if error:
print(f"Error: {error}")
return
family_data = resp
self.family_data = family_data
self.model_family_id = family_data["_id"]
self.model_family_name = family_data["modelFamily"]
self.model_inputs = family_data["modelInputs"]
self.model_outputs = family_data["modelOutputs"]
self.model_keys = {model["modelKey"]: model["modelName"] for model in family_data["models"]}
self.description = family_data["description"]
self.training_framework = family_data["trainingFramework"]
self.supported_runtimes = family_data["supportedRuntimes"]
self.benchmark_datasets = family_data["benchmarkDatasets"]
self.supported_metrics = family_data["supportedMetrics"]
self.input_format = family_data["dataProcessing"]["inputFormat"]
# To fetch a model family
def __get_model_family(self):
"""
Fetch a model family by its ID or name.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> session = Session(account_number="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> model_family = ModelFamily(session, model_family_name="resnet")
>>> resp, error, message = model_family.__get_model_family()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model family: {resp}")
"""
if self.model_family_id is not None:
path = f"/v1/model_store/model_family/{self.model_family_id}"
else:
path = f"/v1/model_store/model_family/{self.model_family_name}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched the model family",
"An error occured while fetching the model family",
)
[docs]
def get_model_archs(self, model_name=None, model_key=None):
"""
Fetch a model family by its ID or name.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> session = Session(account_number="your_account_number", access_key="your_access_key", secret_key="your_secret_key")
>>> model_family = ModelFamily(session, model_family_name="resnet")
>>> resp, error, message = model_family.__get_model_family()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model family: {resp}")
"""
if model_name and model_key:
path = f"/v1/model_store/get_model_info_by_name_and_key?modelName={model_name}&modelKey={model_key}"
elif self.model_family_id:
path = f"/v1/model_store/get_models_by_modelfamily?modelFamilyId={self.model_family_id}"
else:
path = f"/v1/model_store/get_models_by_modelfamily?modelFamilyName={self.model_family_name}"
resp = self.rpc.get(path=path)
data, error, message = handle_response(
resp,
"Successfully fetched model info",
"An error occurred while fetching model info",
)
if error:
return data, error, message
if isinstance(data, list): # Handle when data is a list of dictionaries
data_list = data
elif isinstance(data, dict): # Handle when data is a single dictionary
data_list = [data] # Convert it to a list with one dictionary
else:
error = "Data is not in the expected format. Expected a list or dictionary."
return None, error, message
# Now proceed with processing as a list
if model_name and model_key:
model_info_list = [{"model_key": item["modelKey"], "model_arch_instance": ModelArch(self.session, self.model_family_name, item["modelKey"])} for item in data_list]
return model_info_list, error, message
else:
model_archs = {item["modelKey"]: ModelArch(self.session, self.model_family_name, item["modelKey"]) for item in data_list}
return model_archs, error, message
[docs]
class BYOM:
"""
A class to interact with the BYOM (Bring Your Own Model) API for managing model families, model information,
and model action configurations.
Attributes:
-----------
session : Session
A session object containing account information and RPC (Remote Procedure Call) details.
account_number : str
The account number associated with the session.
rpc : RPC
The RPC object used to make API calls.
Methods:
--------
delete_model_family(model_family_id)
Deletes a model family using its ID.
delete_model_info(model_info_id)
Deletes model information using its ID.
delete_model_action_config(model_action_config_id)
Deletes a model action configuration using its ID.
add_model_family(...)
Adds a new model family.
add_model_info(...)
Adds new model information.
add_model_action_config(...)
Adds a new model action configuration.
update_model_family(...)
Updates a model family.
update_model_info(...)
Updates model information.
update_model_action_config(...)
Updates a model action configuration.
add_model_family_action_config(...)
Adds an action configuration to a model family.
"""
[docs]
def __init__(self, session , family_name=None , family_id=None , family_config=None):
"""
Initializes the BYOM class with a session object.
Parameters:
-----------
session : Session
A session object containing account information and RPC details.
"""
self.session = session
self.account_number = session.account_number
self.rpc = session.rpc
self.project_id = session.project_id
self.family_name = family_name
self.family_id = family_id
self.family_config = None
assert family_name is not None or family_id is not None
if family_config is not None:
self.family_config = self.__load_config(family_config)
# Check existence by family_id
if self.family_id is not None:
self.model_family = ModelFamily(self.session, family_id=self.family_id)
# Check existence by family_name
elif self.family_name is not None:
if not check_family_exists_by_name(self.family_name):
if self.family_config:
self._add_model_family(self.family_config)
print("added the new model family:", self.family_id)
self.model_family = ModelFamily(self.session, family_name=self.family_name)
else:
raise ValueError("Either family_name, family_id, or family_config must be provided.")
def __load_config(self, config):
if isinstance(config, str) and os.path.isfile(config):
with open(config, 'r') as file:
return json.load(file)
elif isinstance(config, dict):
return config
else:
raise ValueError("Invalid family_config. Must be a dictionary or a valid file path.")
# To delete a model family
[docs]
def delete_model_family(self):
"""
Deletes a model family using its ID.
Parameters:
-----------
model_family_id : str
The ID of the model family to delete.
Returns:
--------
tuple
A tuple containing the API response, error message (or None if successful), and a status message.
"""
if self.family_id is not None:
path = f"/v1/model_store/model_family/{self.family_id}"
else:
path = f"/v1/model_store/model_family/{self.family_name}"
resp = self.rpc.delete(path=path)
return handle_response(
resp,
"Successfully deleted the model family",
"An error occured while deleting the model family",
)
# To delete model info
[docs]
def delete_model_info(self, model_info_id , model_key):
"""
Deletes model information using its ID.
Parameters:
-----------
model_info_id : str
The ID of the model information to delete.
Returns:
--------
tuple
A tuple containing the API response, error message (or None if successful), and a status message.
"""
if model_info_id is not None:
path = f"/v1/model_store/model_info/{model_info_id}"
else:
path = f"/v1/model_store/model_info/{model_key}"
resp = self.rpc.delete(path=path)
return handle_response(
resp,
"Successfully deleted the model family",
"An error occured while deleting the model family",
)
# To delete model action config
[docs]
def delete_action_config(self, model_action_config_id, export_format=None):
"""
Deletes a model action configuration using its ID.
Parameters:
-----------
model_action_config_id : str
The ID of the model action configuration to delete.
Returns:
--------
tuple
A tuple containing the API response, error message (or None if successful), and a status message.
"""
path = f"/v1/model_store/model_action_config/{model_action_config_id}"
resp = self.rpc.delete(path=path)
return handle_response(
resp,
"Successfully deleted the model action config",
"An error occured while deleting the model action config",
)
def _add_model_family(self, config):
"""
Adds a new model family to the model store.
This function sends a POST request to add a new model family with the provided parameters.
Parameters:
-----------
config : str or dict
The path to the local JSON file containing the model config or the model config dictionary.
Returns:
--------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises:
-------
ValueError
If the config is neither a valid file path nor a dictionary.
"""
# Load the config from a file if a file path is provided
if isinstance(config, str) and os.path.isfile(config):
with open(config, 'r') as file:
config = json.load(file)
elif not isinstance(config, dict):
raise ValueError("Invalid config. Must be a dictionary or a valid file path.")
project_id = self.project_id
# Extract necessary parameters from the config
model_family = config.get("model_family")
model_inputs = config.get("model_inputs")
model_outputs = config.get("model_outputs")
models = config.get("models")
description = config.get("description")
training_framework = config.get("training_framework")
supported_runtimes = config.get("supported_runtimes")
benchmark_datasets = config.get("benchmark_datasets")
supported_metrics = config.get("supported_metrics")
input_format = config.get("input_format")
data_loader_class_definition = config.get("data_loader_class_definition")
data_loader_call_signature = config.get("data_loader_call_signature")
references = config.get("references")
is_private = config.get("is_private")
# Create the payload
model_store_payload = {
"modelFamily": model_family,
"modelInputs": model_inputs,
"modelOutputs": model_outputs,
"models": models,
"description": description,
"trainingFramework": training_framework,
"supportedRuntimes": supported_runtimes,
"benchmarkDatasets": benchmark_datasets,
"supportedMetrics": supported_metrics,
"dataProcessing": {
"inputFormat": input_format,
"dataLoaderClassDefinition": data_loader_class_definition,
"dataLoaderCallSignature": data_loader_call_signature,
},
"references": references,
"isPrivate": is_private,
"projectId": project_id,
}
# Send the POST request
path = "/v1/model_store/add_model_family"
headers = {"Content-Type": "application/json"}
resp = self.rpc.post(path=path, headers=headers, payload=model_store_payload)
return handle_response(
resp,
"New model family created",
"An error occurred while creating model family",
)
[docs]
def add_or_update_model_info(self, config):
"""
Adds or updates information for a specific model in the model store.
This function sends a POST request to add or a PUT request to update information about a model with the provided parameters.
Parameters:
-----------
config : str or dict
The path to the local JSON file containing the model config or the model config dictionary.
Returns:
--------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises:
-------
ValueError
If the config is neither a valid file path nor a dictionary.
"""
# Load the config and id
config = self.load_config(config)
model_family_id = self.model_family["id"]
# Extract necessary parameters from the config
model_key = config.get("modelKey")
model_name = config.get("modelName")
params_millions = config.get("paramsMillions")
recommended_runtime = config.get("recommendedRuntimes")
benchmark_results = config.get("benchmarkResults")
runtime_results = config.get("runtimeResults")
# Create the payload
model_store_payload = {
"modelKey": model_key,
"modelName": model_name,
"_idModelFamily": model_family_id,
"paramsMillions": params_millions,
"recommendedRuntimes": recommended_runtime,
"benchmarkResults": benchmark_results,
"runtimeResults": runtime_results,
}
# Determine if the model exists
if self.model_family.model_infos['model_key']:
# Update model info
model_info_id = next(model['_id'] for model in _get_all_models() if model['modelKey'] == model_key)
path = f"/v1/model_store/model_info/{model_info_id}"
success_message = "Model info updated successfully"
error_message = "An error occurred while updating model info"
resp = self.session.put(path=path, headers={"Content-Type": "application/json"}, payload=model_store_payload)
else:
# Add model info
path = "/v1/model_store/add_model_info"
success_message = "New model info created successfully"
error_message = "An error occurred while creating model info"
resp = self.session.post(path=path, headers={"Content-Type": "application/json"}, payload=model_store_payload)
return handle_response(resp, success_message, error_message)
# To add a new entry into model action config
[docs]
def add_action_config(
self,
action_type,
action_config,
model_checkpoint,
export_format,
):
"""
Adds a new action configuration for a specific model in the model store.
This function sends a POST request to add a new action configuration for a model with the provided parameters.
Parameters:
-----------
action_type : str
The type of action (e.g., 'train_model', 'export_model').
action_config : dict
Configuration details for the action.
model_checkpoint : str
Path or identifier for the model checkpoint.
export_format : str
Format for exporting the model.
Returns:
--------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises:
-------
May raise exceptions related to network issues or API errors.
Notes:
------
This function uses the self.rpc.post method to send the request and
handle_response to process the response.
"""
path = "/v1/model_store/add_model_action_config"
model_store_payload = {
"_idModelFamily": self.family_id,
"actionType": action_type,
"actionConfig": action_config,
"modelCheckpoint": model_checkpoint,
"exportFormat": export_format,
}
headers = {"Content-Type": "application/json"}
resp = self.rpc.post(path=path, headers=headers, payload=model_store_payload)
return handle_response(
resp,
"New model action config created",
"An error occured while creating model action config",
)
[docs]
def update_model_family(self, config):
"""
Updates an existing model family in the model store.
This function sends a PUT request to update a model family with the provided parameters.
Parameters:
-----------
model_family_id : str
The unique identifier of the model family to update.
config : str or dict
The path to the local JSON file containing the model config or the model config dictionary.
Returns:
--------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises:
-------
ValueError
If the config is neither a valid file path nor a dictionary.
"""
# Load the config
config = self.__load_config(config)
model_family_id = self.model_family["id"]
project_id = self.project_id
# Extract necessary parameters from the config
model_family = self.model_family
model_inputs = config.get("model_inputs")
model_outputs = config.get("model_outputs")
model_keys = config.get("model_keys")
description = config.get("description")
training_framework = config.get("training_framework")
supported_runtimes = config.get("supported_runtimes")
benchmark_datasets = config.get("benchmark_datasets")
supported_metrics = config.get("supported_metrics")
pruning_support = config.get("pruning_support")
code_repository = config.get("code_repository")
training_docker_container = config.get("training_docker_container")
input_format = config.get("input_format")
data_loader_class_definition = config.get("data_loader_class_definition")
data_loader_call_signature = config.get("data_loader_call_signature")
references = config.get("references")
is_private = config.get("is_private")
# Create the payload
model_store_payload = {
"modelFamily": model_family,
"modelInputs": model_inputs,
"modelOutputs": model_outputs,
"modelKeys": model_keys,
"description": description,
"trainingFramework": training_framework,
"supportedRuntimes": supported_runtimes,
"benchmarkDatasets": benchmark_datasets,
"supportedMetrics": supported_metrics,
"pruningSupport": pruning_support,
"codeRepository": code_repository,
"trainingDockerContainer": training_docker_container,
"dataProcessing": {
"inputFormat": input_format,
"dataLoaderClassDefinition": data_loader_class_definition,
"dataLoaderCallSignature": data_loader_call_signature,
},
"references": references,
"isPrivate": is_private,
"projectId": project_id,
}
# Send the PUT request
path = f"/v1/model_store/model_family/{model_family_id}"
headers = {"Content-Type": "application/json"}
resp = self.session.put(path=path, headers=headers, payload=model_store_payload)
return handle_response(
resp,
"Model family successfully updated",
"An error occurred while updating model family",
)
# To update model action config
[docs]
def update_action_config(
self,
model_action_config_id,
model_info_id,
action_type,
action_config,
export_format,
model_checkpoint,
):
"""
Updates the action configuration for a specific model in the model store.
This function sends a PUT request to update model action configuration with the provided parameters.
Parameters
----------
model_action_config_id : str
The unique identifier of the model action config to update.
model_info_id : str
The identifier of the model info this action config belongs to.
action_type : str
The updated type of action (e.g., 'train_model', 'export_model').
action_config : dict
Updated configuration details for the action.
export_format : str
Updated format for exporting the model.
Returns
-------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises
------
May raise exceptions related to network issues or API errors.
Notes
-----
This function uses the self.rpc.put method to send the request and
handle_response to process the response.
Example
-------
>>> resp, error, message = model_store.update_action_config(
>>> model_action_config_id="12345",
>>> model_info_id="67890",
>>> action_type="train_model",
>>> action_config={"param1": "value1", "param2": "value2"},
>>> export_format="ONNX"
>>> )
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Action config updated: {resp}")
"""
path = f"/v1/model_store/model_action_config/{model_action_config_id}"
model_store_payload = {
"_idModelInfo": model_info_id,
"actionType": action_type,
"model_checkpoint": model_checkpoint,
"actionConfig": action_config,
"export_format": export_format,
}
headers = {"Content-Type": "application/json"}
resp = self.rpc.put(path=path, headers=headers, payload=model_store_payload)
return handle_response(
resp,
"Model action config successfully updated",
"An error occurred while updating the model action config",
)
[docs]
def validate_model_info(self):
"""
Validate the model information for a given project.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.validate_model_info()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Validation result: {resp}")
"""
path = f"/v1/model_store/validate_model_info/{self.project_id}/{self.family_id}"
resp = self.rpc.post(path=path)
return handle_response(
resp,
"Successfully validated model info",
"An error occurred while validating the model information",
)
[docs]
def validate_train_config(self):
"""
Validate the training configuration for a given project.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.validate_train_config()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Validation result: {resp}")
"""
path = f"/v1/model_store/validate_train_config/{self.project_id}/{self.family_id}"
resp = self.rpc.post(path=path)
return handle_response(
resp,
"Successfully validated training configuration",
"An error occurred while validating the training configuration",
)
[docs]
def validate_export_config(self):
"""
Validate the export configuration for a given project.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.validate_export_config()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Validation result: {resp}")
"""
path = f"/v1/model_store/validate_export_config/{self.project_id}/{self.family_id}"
resp = self.rpc.post(path=path)
return handle_response(
resp,
"Successfully validated export configuration",
"An error occurred while validating the export configuration",
)
[docs]
def validate_local_test_logs(self, file_path):
"""
Validate the local test logs for a given project.
Parameters
----------
file_path : str
The path to the test log file.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.validate_test_logs("path/to/log/file")
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Validation result: {resp}")
"""
path = f"/v1/model_store/validate_test_logs/{self.project_id}/{self.family_id}"
files = {'jsonfile': open(file_path, 'rb')}
resp = self.rpc.post(path=path, files=files)
return handle_response(
resp,
"Successfully validated test logs",
"An error occurred while validating the test logs",
)
[docs]
def validate_all(self):
"""
Validate all configurations and logs for the model family.
Returns
-------
dict
A dictionary containing the results of all validations.
Example
-------
>>> results = model_store.validate_all()
>>> for key, (resp, error, message) in results.items():
>>> if error:
>>> print(f"{key} validation failed: {error}")
>>> else:
>>> print(f"{key} validation succeeded: {message}")
"""
results = {}
results['model_info'] = self.validate_model_info()
results['train_config'] = self.validate_train_config()
results['export_config'] = self.validate_export_config()
# Validate local test logs
# Assuming you have a predefined path for test logs
test_log_path = "path/to/test/log/file"
results['test_logs'] = self.validate_local_test_logs(test_log_path)
return results
def _get_code_upload_path(self):
"""
Get the user upload path for a given project.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.get_user_upload_path()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"User upload path: {resp}")
"""
path = f"/v1/model_store/get_user_upload_path/{self.project_id}/{self.family_id}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully retrieved user upload path",
"An error occurred while retrieving the user upload path",
)
[docs]
def get_validation_logs(self):
"""
Get the validation logs for a given model family.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.get_validation_logs()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Validation logs: {resp}")
"""
path = f"/v1/model_store/get_validation_logs/validation/all?{self.model_family}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully retrieved validation logs",
"An error occurred while retrieving the validation logs",
)
[docs]
def add_codebase(self, code_path, project_type, matrice_version, training_framework, framework_version, license_content):
"""
Add model family code for a given project.
Parameters
----------
cloud_path : str
The cloud path where the model family code is stored.
project_type : str
The type of the project (e.g., "classification", "detection", "instance_segmentation").
python_sdk_version : str
The version of the Python SDK.
pytorch_version : str
The version of PyTorch.
license_info : str
The license information for the model family code.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.add_model_family_code(
>>> "cloud/path/to/code", "classification", "1.0.0", "1.7.1", "MIT"
>>> )
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model family code added: {resp}")
"""
path = f"/v1/model_store/get_validation_logs/validation/all?{self.model_family}"
model_store_payload = {
"cloudPath": cloud_path,
"accountNumber": self.account_number,
"_idModelFamily": self.model_family,
"projectType": project_type,
"sdk_version": python_sdk_version,
"pytorch_version": pytorch_version,
"license": license_info,
}
headers = {"Content-Type": "application/json"}
resp = self.rpc.post(path=path, json=model_store_payload, headers=headers)
return handle_response(
resp,
"Successfully added model family code",
"An error occurred while adding the model family code",
)
[docs]
def get_codebase_details(self):
"""
Fetch user code base details for the specified model family.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.get_user_code_base_details()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"User code details: {resp}")
"""
path = f"/v1/model_store/get_user_code_details/{self.model_family}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched user code details",
"An error occurred while fetching user code details"
)
[docs]
def download_codebase(self, local_path):
"""
Fetch the download path for the user code base of a model family.
Returns
-------
tuple
A tuple containing the response data, error (if any), and message.
Example
-------
>>> resp, error, message = model_store.get_user_code_base_download_path()
>>> if error:
>>> print(f"Error: {error}")
>>> else:
>>> print(f"Model code path: {resp}")
"""
path = f"/v1/model_store/get_user_code_download_path/{self.model_family}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
"Successfully fetched model family code path",
"An error occurred while fetching model family code path"
)
[docs]
def get_test_case_logs(self):
pass
[docs]
def get_byom_status(self):
self.model_family.refresh()
return self.model_family.status