integration.pytorch
load_model¶
comet_ml.integration.pytorch.load_model(MODEL_URI: str,
map_location: Any = None, pickle_module: Optional[Module] = None,
**torch_load_args) -> ModelStateDict
Load model's state_dict from experiment, registry or from disk by uri. This will returns a Pytorch state_dict that you will need to load into your model. This will load the model using torch.load.
Here is an example of loading a model from the Model Registry for inference:
from comet_ml.integration.pytorch import load_model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
...
def forward(self, x):
...
return x
# Initialize model
model = TheModelClass()
# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))
model.eval()
prediction = model(...)
Here is an example of loading a model from an Experiment for Resume Training:
from comet_ml.integration.pytorch import load_model
# Initialize model
model = TheModelClass()
# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
Args:
uri: string (required), a uri string defining model location. Possible options are:
file://data/my-model
file:///path/to/my-model
registry://workspace/registry_name (takes the last version)
registry://workspace/registry_name:version
experiment://experiment_key/model_name
experiment://workspace/project_name/experiment_name/model_name
map_location: (optional) passed to torch.load (see torch.load)
- pickle_module: (optional) passed to torch.load (see torch.load)
- torch_load_args: (optional) passed to torch.load (see torch.load)
Returns: model's state dict
log_model¶
comet_ml.integration.pytorch.log_model(experiment, model, model_name,
metadata=None, pickle_module=None, **torch_save_args)
Logs a Pytorch model to an experiment. This will save the model using torch.save and save it as an Experiment Model.
The model parameter can either be an instance of torch.nn.Module
or any input supported by torch.save, see the tutorial about saving and loading Pytorch models for more details.
Here is an example of logging a model for inference:
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
...
def forward(self, x):
...
return x
# Initialize model
model = TheModelClass()
# Train model
train(model)
# Save the model for inference
log_model(experiment, model, model_name="TheModel")
Here is an example of logging a checkpoint for resume training:
model_checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
...
}
log_model(experiment, model_checkpoint, model_name="TheModel")
Args:
- experiment: Experiment (required), instance of experiment to log model
- model: model's state dict or torch.nn.Module (required), model to log
- model_name: string (required), the name of the model
- metadata: dict (optional), some additional data to attach to the the data. Must be a JSON-encodable dict
- pickle_module: (optional) passed to torch.save (see torch.save documentation)
- torch_save_args: (optional) passed to torch.save (see torch.save documentation)
Returns: None
watch¶
comet_ml.integration.pytorch.watch(model: torch.nn.Module,
log_step_interval: int = 1000) -> None
Enables automatic logging of each layer's parameters and gradients in the given PyTorch module. These will be logged as histograms. Note that an Experiment must be created before calling this function.
Args:
- model: torch.nn.Module, an instance of
torch.nn.Module
. - log_step_interval: int (optional), determines how often layers are logged (default is every 1000 steps).