models

Query PyTorch model information.

torchutils.models.get_model_param_count(model, trainable=None)

Count total parameters in the PyTorch model.

Parameters
  • model (nn.Module) – PyTorch model.

  • trainable (None or bool) – Pass None: total, True: trainable, or False: non-trainable parameters.

Returns

Number of parameters in the model.

Return type

int

Example:

import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
total_params = tu.get_model_param_count(model)
print('Total model params: {:,}'.format(total_params))

Out:

Total model params: 61,100,840
torchutils.models.get_model_flops(model, *input, unit='FLOP', **kwargs)

Count total FLOPs for the PyTorch model.

Parameters
  • model (nn.Module) – PyTorch model.

  • input (user dependent) – Input(s) for model. Shape: [N, *]. Input dtype and device must match to the model. Can be comma separated inputs for multi-input models.

  • unit (str) – FLOPs unit. Can be ‘FLOP’, ‘MFLOP’ or ‘GFLOP’. (default: ‘FLOP’)

  • **kwargs – Other keyword arguments used in model.forward function.

Returns

Number of FLOPs.

Return type

float

Example:

import torch
import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
total_flops = tu.get_model_flops(model, torch.rand((1, 3, 224, 224)))
print('Total model FLOPs: {:,}'.format(total_flops))

Out:

Total model FLOPs: 773,304,664
torchutils.models.get_model_summary(model, *input, compact=False, **kwargs)

Print model summary.

Parameters
  • model (nn.Module) – PyTorch model.

  • input (user dependent) – Input(s) for model. Shape: [N, *]. Input dtype and device must match to the model. Can be comma separated inputs for multi-input models.

  • compact (bool) – To print compact summary, only layer and output shape. (default: False)

  • **kwargs – Other keyword arguments used in model.forward function.

Returns

Returns nothing.

Return type

None

Example:

import torch
import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
tu.get_model_summary(model, torch.rand((1, 3, 224, 224)), compact=True)

Out:

===========================================
Layer                                Output
===========================================
0_features.Conv2d_0         [1, 64, 55, 55]
1_features.ReLU_1           [1, 64, 55, 55]
2_features.MaxPool2d_2      [1, 64, 27, 27]
3_features.Conv2d_3        [1, 192, 27, 27]
4_features.ReLU_4          [1, 192, 27, 27]
5_features.MaxPool2d_5     [1, 192, 13, 13]
6_features.Conv2d_6        [1, 384, 13, 13]
7_features.ReLU_7          [1, 384, 13, 13]
8_features.Conv2d_8        [1, 256, 13, 13]
9_features.ReLU_9          [1, 256, 13, 13]
10_features.Conv2d_10      [1, 256, 13, 13]
11_features.ReLU_11        [1, 256, 13, 13]
12_features.MaxPool2d_12     [1, 256, 6, 6]
13_classifier.Dropout_0           [1, 9216]
14_classifier.Linear_1            [1, 4096]
15_classifier.ReLU_2              [1, 4096]
16_classifier.Dropout_3           [1, 4096]
17_classifier.Linear_4            [1, 4096]
18_classifier.ReLU_5              [1, 4096]
19_classifier.Linear_6            [1, 1000]
===========================================
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
Total FLOPs: 773,286,232 / 773.29 MFLOPs
-------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 8.31
Params size (MB): 233.08
Estimated Total Size (MB): 241.96
===========================================