visualize

General visualization APIs for PyTorch.

torchutils.visualize.plot_gradients(model, file_path, include_bias=False, plot_max=False, plot_type='line', ylim=(-1.0, -1.0))

Plot (average) gradients for each layer in model.

Useful for debugging vanishing gradient problem. This API should be called after loss.backward() and before optimizer.step().

Parameters
  • model (nn.Module) – PyTorch model.

  • file_path (str) – File path (including file name) to save plot.

  • include_bias (bool) – Include/exclude bias gradients from plot. (default: False)

  • plot_max (bool) – Plot max gradients also. (default: False)

  • plot_type (str) – Type of plot. Must be one of (‘line’, ‘bar’). (default: ‘line’)

  • ylim (tuple) – Limit the y-axis (gradient values) of the plot. Useful for zooming into low gradient regions. Must be tuple (low, high). Negative low/high value will plot entire y-limit. (default: (-1.0, -1.0))

Returns

Returns nothing.

Return type

None

Example:

import torch
import torchvision
import torchutils as tu

criterion = torch.nn.CrossEntropyLoss()
net = torchvision.models.alexnet(num_classes=10)
out = net(torch.rand(1, 3, 224, 224))
ground_truth = torch.randint(0, 10, (1, ))
loss = criterion(out, ground_truth)
loss.backward()
tu.plot_gradients(net, './grad_figures/grad_01.png', plot_type='line')