"""Module for interacting with backend API to manage experiment."""
from matrice.utils import handle_response , get_summary
from matrice.models import Model
from datetime import datetime, timedelta
[docs]
class Experiment:
"""A class to manage experiment-related operations within a project.
Initialize a new experiment instance.
Parameters
----------
session : Session
The session object that manages the connection to the server.
experiment_id : str, optional
The ID of the experiment (default is an empty string).
experiment_name : str, optional
The name of the experiment (default is an empty string).
Example
-------
>>> session = Session(account_number="account_number")
>>> experiment = Experiment(session=session_object, experiment_id=experiment_id, experiment_name=experiment_name)
"""
[docs]
def __init__(self, session, experiment_id="", experiment_name=""):
"""
Initialize the object with session details and experiment-specific information.
This constructor sets up the required project and session details,
initializes models for training, and fetches experiment details if available.
If no `experiment_id` or `experiment_name` is provided, the experiment data will be
fetched from the server using the `get_details` method, and attributes like dataset information,
primary metric, and model inputs/outputs will be set based on the response.
Parameters
----------
session : object
The session object containing project ID, project name, and the RPC client.
experiment_id : str, optional
The ID of the experiment to fetch details for. Defaults to an empty string,
which means details will be fetched if not provided.
experiment_name : str, optional
The name of the experiment to fetch details for. Defaults to an empty string,
which means details will be fetched if not provided.
Attributes
----------
project_id : str
The project ID associated with the current session.
project_name : str
The project name associated with the current session.
session : object
The session object used to make API calls.
rpc : object
The RPC client for making HTTP requests.
models_for_training : list
A list to store models that are initialized for training.
experiment_id : str
The ID of the experiment. Set based on the provided or fetched experiment data.
experiment_name : str
The name of the experiment. Set based on the provided or fetched experiment data.
experiment_data : dict
The full data of the experiment as fetched from the API.
dataset_id : str
The ID of the dataset associated with the experiment.
dataset_name : str
The name of the dataset associated with the experiment.
dataset_version : str
The version of the dataset used in the experiment.
primary_metric : str
The primary metric used to evaluate the model's performance in the experiment.
model_inputs : list
A list of inputs used by the model in the experiment.
model_outputs : list
A list of outputs generated by the model in the experiment.
target_runtime : str
The runtime environment for the model in the experiment.
Returns
-------
None
Example
-------
>>> session = Session(account_number="account_number")
>>> exp = Experiment(session, experiment_id="exp123", experiment_name="My Experiment")
>>> print(exp.experiment_id) # Output: "exp123"
>>> print(exp.dataset_name) # Output: "Sample Dataset"
Notes
-----
If there is an error fetching the experiment details, a message will be printed to the console.
"""
self.project_id = session.project_id
self.session = session
self.rpc = session.rpc
self.last_refresh_time = datetime.now()
self.models_for_training = []
assert experiment_id or experiment_name
self.experiment_id = experiment_id
self.experiment_name = experiment_name
# Get and store summary information
summary_response , error , message = get_summary(self.session , self.project_id, service_name='experiments')
if summary_response:
summary_data = summary_response
model_count_by_status = summary_data.get("modelCountByStatus", {})
self.error_model_count = model_count_by_status.get("error", 0)
self.queued_model_count = model_count_by_status.get("queued", 0)
self.trained_model_count = model_count_by_status.get("trained", 0)
self.training_model_count = model_count_by_status.get("training", 0)
self.total_models = summary_data.get("total", 0)
else:
print(f"Error fetching summary: {summary_response.get('message')}")
self.best_model = None
self.best_model_test_score = None
# Get and store details information
resp, error, message = self.get_details()
if error:
print(f"Error fetching project info: {message}")
else:
experiment_data = resp
self.experiment_data = experiment_data
self.experiment_id = experiment_data["_id"]
self.experiment_name = experiment_data["experimentName"]
self.dataset_id = experiment_data["_idDataset"]
self.dataset_name = experiment_data["datasetName"]
self.dataset_version = experiment_data["datasetVersion"]
self.primary_metric = experiment_data["primaryMetric"]
self.model_inputs = experiment_data["modelInputs"]
self.model_outputs = experiment_data["modelOutputs"]
self.target_runtime = experiment_data["targetRuntime"]
self.performance_tradeoff = experiment_data.get("performanceTradeoff")
self.matrice_compute = experiment_data.get("matriceCompute")
self.models_trained = experiment_data.get("modelsTrained", [])
self.base_model_storage_path = experiment_data.get("baseModelStoragePath")
self.storage_cloud_credentials = experiment_data.get("storageCloudCredentials", [])
self.running_instances = experiment_data.get("runningInstances", [])
self.data_prep_input_formats = experiment_data.get("dataPrepInputFormats", [])
self.num_samples = experiment_data.get("sampleCount")
self.num_classes = experiment_data.get("classCount")
self.best_test_score = experiment_data.get("bestTestScore")
self.created_at = experiment_data.get("createdAt")
self.updated_at = experiment_data.get("updatedAt")
[docs]
def refresh(self):
"""
Refresh the instance by reinstantiating it with the previous values.
"""
# Check if two minutes have passed since the last refresh
if datetime.now() - self.last_refresh_time < timedelta(minutes=2):
raise Exception("Refresh can only be called after two minutes since the last refresh.")
# Capture the necessary parameters for reinitialization
init_params = {
'session': self.session,
'experiment_id': self.experiment_id,
'experiment_name': self.experiment_name
}
# Reinitialize the instance with the captured parameters
self.__init__(**init_params)
# Update the last refresh time
self.last_refresh_time = datetime.now()
[docs]
def get_details(self):
"""
Retrieve details of the experiment based on the experiment ID or name.
This method fetches experiment details by ID if available; otherwise,
it attempts to fetch by name. Raises a ValueError if neither identifier is provided.
Returns
-------
tuple
A tuple containing experiment details, error message (if any), and a status message.
Raises
------
ValueError
If neither 'experiment_id' nor 'experiment_name' is provided.
Example
-------
>>> experiment_details = experiment.get_details()
>>> if isinstance(experiment_details, dict):
>>> print("Experiment Details:", experiment_details)
>>> else:
>>> print("Failed to retrieve experiment details.")
"""
id = self.experiment_id
name = self.experiment_name
if name:
try:
return self._get_experiment_by_name()
except Exception as e:
print(f"Error retrieving experiment by name: {e}")
elif id:
try:
return self._get_experiment_by_id() #TODO implement this method
except Exception as e:
print(f"Error retrieving experiment by id: {e}")
else:
raise ValueError(
"At least one of 'dexperiment_id' or 'experiment_name' must be provided."
)
def _get_experiment_by_name(self):
"""
Retrieve details of the experiment based on the experiment name.
This method fetches experiment details by name. Raises a ValueError if neither identifier is provided.
Returns
-------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises
------
ValueError
If neither 'experiment_id' nor 'experiment_name' is provided.
Example
-------
>>> resp, err, msg = experiment.get_details()
>>> if err:
>>> print("Failed to retrieve experiment details.")
>>> else:
>>> print("Experiment Details:", experiment_details)
"""
if self.experiment_name == "":
print(
"Experiment name not set for thiseExperiment. Cannot perform the operation for experiment without experiment name"
)
path = f"/v1/model/get_experiment_by_name?experimentName={self.experiment_name}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
f"Experiment Details Fetched successfully",
"Could not fetch experiment details",
)
def _get_experiment_by_id(self):
"""
Retrieve details of the experiment based on the experiment name.
This method fetches experiment details by name. Raises a ValueError if neither identifier is provided.
Parameters
----------
models : ModelArch or list of ModelArch
A single model instance or a list of model instances.
model_configs : dict or list of dict
A single model config or a list of model configs.
compute_type : str
The type of compute to use for training.
Returns
-------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Raises
------
ValueError
If neither 'experiment_id' nor 'experiment_name' is provided.
Example
-------
>>> resp, err, msg = experiment.get_details()
>>> if err:
>>> print("Failed to retrieve experiment details.")
>>> else:
>>> print("Experiment Details:", experiment_details)
"""
if self.experiment_id == "":
print(
"Experiment name not set for thiseExperiment. Cannot perform the operation for experiment without experiment id"
)
path = f"/v1/model/get_experiment/{self.experiment_id}"
resp = self.rpc.get(path=path)
return handle_response(
resp,
f"Experiment Details Fetched successfully",
"Could not fetch experiment details",
)
[docs]
def add_models_for_training(self, models, model_configs, compute_alias=""):
"""
Add models to the training queue for the experiment.
This method prepares and sends model configurations to the backend for training.
It supports both single model and batch model submissions.
Parameters
----------
models : ModelArch or list of ModelArch
A single model instance or a list of model instances to be trained.
model_configs : dict or list of dict
Configuration dictionary or list of dictionaries containing model settings.
Each dictionary should include:
- is_autoML (bool): Flag for AutoML usage
- tuning_type (str): Type of model tuning
- model_checkpoint (str): Model checkpoint information
- checkpoint_type (str): Type of checkpoint
- action_config (dict): Configuration for model actions
- model_config (dict): Model-specific configuration
compute_alias : str, optional
Alias for the compute resource to use for training (default: "")
Returns
-------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API
- error_message (str or None): Error message if an error occurred, None otherwise
- status_message (str): Status message indicating success or failure
Notes
-----
The method accumulates model configurations in self.models_for_training and
sends them as a batch to the backend. The list is cleared after submission.
Example
-------
>>> model = ModelArch(session, model_key="resnet50")
>>> config = {
... "is_autoML": True,
... "tuning_type": "auto",
... "model_checkpoint": "predefined",
... "checkpoint_type": "auto",
... "action_config": {},
... "model_config": {}
... }
>>> resp, err, msg = experiment.add_models_for_training(model, config, "GPU-A100")
>>> if err:
... print(f"Error: {err}")
... else:
... print(f"Success: {msg}")
"""
if not isinstance(models, list):
models = [models]
if not isinstance(model_configs, list):
model_configs = [model_configs]
for model, model_config in zip(models, model_configs):
payload = {
"modelKey": model.model_key,
"autoML": model_config["is_autoML"],
"tuningType": model_config["tuning_type"],
"modelCheckpoint": model_config["model_checkpoint"],
"checkpointType": model_config["checkpoint_type"],
"_idModelInfo": model.model_info_id,
"actionConfig": model_config["action_config"],
"modelConfig": model_config["model_config"],
"modelName": model.model_name,
"paramsMillions": model.params_millions,
"experimentName": self.experiment_name,
"modelInputs": self.model_inputs,
"modelOutputs": self.model_outputs,
"targetRuntime": self.target_runtime,
"_idDataset": self.dataset_id,
"datasetVersion": self.dataset_version,
"_idExperiment": self.experiment_id,
"primaryMetric": self.primary_metric,
"datasetName": self.dataset_name,
"computeAlias": compute_alias
}
self.models_for_training.append(payload)
path = f"/v1/model/add_model_train_list?projectId={self.project_id}"
headers = {"Content-Type": "application/json"}
resp = self.rpc.post(
path=path, headers=headers, payload=self.models_for_training
)
self.models_for_training = []
return handle_response(resp, "Training started successfully", "Could not start training")
[docs]
def stop_training(self):
"""
Stop the training process for the experiment.
This method attempts to halt the ongoing training for the experiment by making
a call to the backend to restrict further progress.
Returns
-------
tuple
A tuple containing three elements:
- API response (dict): The raw response from the API.
- error_message (str or None): Error message if an error occurred, None otherwise.
- status_message (str): A status message indicating success or failure.
Example
-------
>>> resp, err, msg = experiment.stop_training()
>>> if err:
>>> print(f"Error: {err}")
>>> else:
>>> print(f"Training stopped: {resp}")
"""
path = f"/v1/model/restrict_experiment/{self.experiment_id}"
resp = self.rpc.delete(path=path)
return handle_response(
resp,
f"Experiment restricted successfully",
"Could not restricte the experiment",
)
[docs]
def list_models(self):
"""
Fetch and list all models associated with the current experiment.
Retrieves models from the backend and updates the experiment's best model tracking.
The best model is determined by the highest test score among all models.
Returns
-------
tuple
A tuple containing five elements:
- models (list): List of Model instances containing model information
- status_list (list): List of model status strings corresponding to each model
- response (dict): Raw API response
- error_message (str or None): Error message if an error occurred, None otherwise
- status_message (str): Status message indicating success or failure
Notes
-----
This method updates two instance variables:
- self.best_model: Stores the Model instance with the highest test score
- self.best_model_test_score: Stores the highest test score found
Example
-------
>>> models, status_list, response, err, msg = experiment.list_models()
>>> if err:
... print(f"Error: {err}")
... else:
... for model, status in zip(models, status_list):
... print(f"Model: {model.name}, Status: {status}")
"""
path = f"/v1/model/get_models_by_experiment_id/{self.experiment_id}"
resp = self.rpc.get(path=path)
if resp.get("success", False) and "data" in resp:
models = []
status_list = []
for item in resp["data"].get("items", []):
model = Model(model_name=item.get("modelName"))
models.append(model)
status_list.append(item.get("status"))
test_score = item.get("testScore", float('-inf'))
if test_score > best_test_score:
best_test_score = test_score
best_model = model
self.best_model = best_model
self.best_model_test_score = best_test_score
return models, status_list, resp, None, "Experiment models fetched successfully"
else:
return [], [], resp, "Could not fetch experiment models", "Error fetching experiment models"
[docs]
def get_best_model(self):
"""
Retrieve the model with the highest test score from the experiment.
Returns the best performing model based on test score, as determined during
the most recent call to list_models(). Must call list_models() first to
populate best model data.
Returns
-------
tuple
A tuple containing four elements:
- best_model (Model or None): Model instance with highest test score
- best_model_test_score (float or None): Test score of the best model
- error_message (str or None): Error message if an error occurred, None otherwise
- status_message (str): Status message indicating success or failure
Example
-------
>>> best_model, test_score, err, msg = experiment.get_best_model()
>>> if err:
... print(f"Error: {err}")
>>> elif best_model:
... print(f"Best Model: {best_model.name}, Score: {test_score}")
>>> else:
... print("No models found")
"""
if self.best_model is not None:
return self.best_model, self.best_model_test_score, None, "Best model fetched successfully"
else:
return None, None, "No best model found", "Error fetching best model"