visualize¶
General visualization APIs for PyTorch.
-
torchutils.visualize.
plot_gradients
(model, file_path, include_bias=False, plot_max=False, plot_type='line', ylim=(-1.0, -1.0))¶ Plot (average) gradients for each layer in model.
Useful for debugging vanishing gradient problem. This API should be called after loss.backward() and before optimizer.step().
- Parameters
model (nn.Module) – PyTorch model.
file_path (str) – File path (including file name) to save plot.
include_bias (bool) – Include/exclude bias gradients from plot. (default: False)
plot_max (bool) – Plot max gradients also. (default: False)
plot_type (str) – Type of plot. Must be one of (‘line’, ‘bar’). (default: ‘line’)
ylim (tuple) – Limit the y-axis (gradient values) of the plot. Useful for zooming into low gradient regions. Must be tuple (low, high). Negative low/high value will plot entire y-limit. (default: (-1.0, -1.0))
- Returns
Returns nothing.
- Return type
None
Example:
import torch import torchvision import torchutils as tu criterion = torch.nn.CrossEntropyLoss() net = torchvision.models.alexnet(num_classes=10) out = net(torch.rand(1, 3, 224, 224)) ground_truth = torch.randint(0, 10, (1, )) loss = criterion(out, ground_truth) loss.backward() tu.plot_gradients(net, './grad_figures/grad_01.png', plot_type='line')