checkpoint¶
Save and load PyTorch checkpoints with ease.
-
torchutils.checkpoint.
save_checkpoint
(epoch, model_path, model, optimizer=None, scheduler=None, metric=0)¶ Save checkpoint.
- Parameters
epoch (int) – Epoch/iteration number.
model_path (str) – Path for saving the model.
model (nn.Module) – PyTorch model.
optimizer (optim.Optimizer) – PyTorch optimizer. (default: None)
scheduler (optim.lr_scheduler._LRScheduler) – PyTorch scheduler. (default: None)
metric (float) – Metric to add to checkpoint name, for example, validation accuracy. (default: 0)
- Returns
Returns nothing.
- Return type
None
Example:
import torchvision import torchutils as tu import torch.optim as optim model = torchvision.models.alexnet() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1) # change optimizer lr, just for load_checkpoint test optimizer = tu.set_lr(optimizer, 0.1234) # checkpoint saved as model_20190814-212442_e0_0.7531.pt tu.save_checkpoint(epoch=0, model_path='.', model=model, optimizer=optimizer, scheduler=scheduler, metric=0.7531)
-
torchutils.checkpoint.
load_checkpoint
(model_path, ckpt_name, model, optimizer=None, scheduler=None, device=None)¶ Load checkpoint.
- Parameters
model_path (str) – Path for loading the model.
ckpt_name (str) – Checkpoint file name.
model (nn.Module) – PyTorch model.
optimizer (optim.Optimizer) – PyTorch optimizer. (default: None)
scheduler (optim.lr_scheduler._LRScheduler) – PyTorch scheduler. (default: None)
device (str) – Device to map the checkpoint, “cpu” or “cuda”. (default: None)
- Returns
Start epoch/iteration number to continue training.
- Return type
int
Example:
import torchvision import torchutils as tu import torch.optim as optim model = torchvision.models.alexnet() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1) print('Original learning rate:', tu.get_lr(optimizer)) # load checkpoint model_20190814-212442_e0_0.7531.pt start_epoch = tu.load_checkpoint(model_path='.', ckpt_name='model_20190814-212442_e0_0.7531.pt', model=model, optimizer=optimizer, scheduler=scheduler) print('Checkpoint learning rate:', tu.get_lr(optimizer)) print('Start from epoch:', start_epoch)
Out:
Original learning rate: 0.001 Checkpoint learning rate: 0.1234 Start from epoch: 1