checkpoint

Save and load PyTorch checkpoints with ease.

torchutils.checkpoint.save_checkpoint(epoch, model_path, model, optimizer=None, scheduler=None, metric=0)

Save checkpoint.

Parameters
  • epoch (int) – Epoch/iteration number.

  • model_path (str) – Path for saving the model.

  • model (nn.Module) – PyTorch model.

  • optimizer (optim.Optimizer) – PyTorch optimizer. (default: None)

  • scheduler (optim.lr_scheduler._LRScheduler) – PyTorch scheduler. (default: None)

  • metric (float) – Metric to add to checkpoint name, for example, validation accuracy. (default: 0)

Returns

Returns nothing.

Return type

None

Example:

import torchvision
import torchutils as tu
import torch.optim as optim

model = torchvision.models.alexnet()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1)

# change optimizer lr, just for load_checkpoint test
optimizer = tu.set_lr(optimizer, 0.1234)

# checkpoint saved as model_20190814-212442_e0_0.7531.pt
tu.save_checkpoint(epoch=0, model_path='.', model=model,
                   optimizer=optimizer, scheduler=scheduler,
                   metric=0.7531)
torchutils.checkpoint.load_checkpoint(model_path, ckpt_name, model, optimizer=None, scheduler=None, device=None)

Load checkpoint.

Parameters
  • model_path (str) – Path for loading the model.

  • ckpt_name (str) – Checkpoint file name.

  • model (nn.Module) – PyTorch model.

  • optimizer (optim.Optimizer) – PyTorch optimizer. (default: None)

  • scheduler (optim.lr_scheduler._LRScheduler) – PyTorch scheduler. (default: None)

  • device (str) – Device to map the checkpoint, “cpu” or “cuda”. (default: None)

Returns

Start epoch/iteration number to continue training.

Return type

int

Example:

import torchvision
import torchutils as tu
import torch.optim as optim

model = torchvision.models.alexnet()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1)

print('Original learning rate:', tu.get_lr(optimizer))

# load checkpoint model_20190814-212442_e0_0.7531.pt
start_epoch = tu.load_checkpoint(model_path='.',
                       ckpt_name='model_20190814-212442_e0_0.7531.pt',
                       model=model, optimizer=optimizer,
                       scheduler=scheduler)

print('Checkpoint learning rate:', tu.get_lr(optimizer))
print('Start from epoch:', start_epoch)

Out:

Original learning rate: 0.001
Checkpoint learning rate: 0.1234
Start from epoch: 1