models¶
Query PyTorch model information.
-
torchutils.models.
get_model_param_count
(model, trainable=None)¶ Count total parameters in the PyTorch model.
- Parameters
model (nn.Module) – PyTorch model.
trainable (None or bool) – Pass
None
: total,True
: trainable, orFalse
: non-trainable parameters.
- Returns
Number of parameters in the model.
- Return type
int
Example:
import torchvision import torchutils as tu model = torchvision.models.alexnet() total_params = tu.get_model_param_count(model) print('Total model params: {:,}'.format(total_params))
Out:
Total model params: 61,100,840
-
torchutils.models.
get_model_flops
(model, *input, unit='FLOP', **kwargs)¶ Count total FLOPs for the PyTorch model.
- Parameters
model (nn.Module) – PyTorch model.
input (user dependent) – Input(s) for model. Shape: [N, *]. Input dtype and device must match to the model. Can be comma separated inputs for multi-input models.
unit (str) – FLOPs unit. Can be ‘FLOP’, ‘MFLOP’ or ‘GFLOP’. (default: ‘FLOP’)
**kwargs – Other keyword arguments used in model.forward function.
- Returns
Number of FLOPs.
- Return type
float
Example:
import torch import torchvision import torchutils as tu model = torchvision.models.alexnet() total_flops = tu.get_model_flops(model, torch.rand((1, 3, 224, 224))) print('Total model FLOPs: {:,}'.format(total_flops))
Out:
Total model FLOPs: 773,304,664
-
torchutils.models.
get_model_summary
(model, *input, compact=False, **kwargs)¶ Print model summary.
- Parameters
model (nn.Module) – PyTorch model.
input (user dependent) – Input(s) for model. Shape: [N, *]. Input dtype and device must match to the model. Can be comma separated inputs for multi-input models.
compact (bool) – To print compact summary, only layer and output shape. (default: False)
**kwargs – Other keyword arguments used in model.forward function.
- Returns
Returns nothing.
- Return type
None
Example:
import torch import torchvision import torchutils as tu model = torchvision.models.alexnet() tu.get_model_summary(model, torch.rand((1, 3, 224, 224)), compact=True)
Out:
=========================================== Layer Output =========================================== 0_features.Conv2d_0 [1, 64, 55, 55] 1_features.ReLU_1 [1, 64, 55, 55] 2_features.MaxPool2d_2 [1, 64, 27, 27] 3_features.Conv2d_3 [1, 192, 27, 27] 4_features.ReLU_4 [1, 192, 27, 27] 5_features.MaxPool2d_5 [1, 192, 13, 13] 6_features.Conv2d_6 [1, 384, 13, 13] 7_features.ReLU_7 [1, 384, 13, 13] 8_features.Conv2d_8 [1, 256, 13, 13] 9_features.ReLU_9 [1, 256, 13, 13] 10_features.Conv2d_10 [1, 256, 13, 13] 11_features.ReLU_11 [1, 256, 13, 13] 12_features.MaxPool2d_12 [1, 256, 6, 6] 13_classifier.Dropout_0 [1, 9216] 14_classifier.Linear_1 [1, 4096] 15_classifier.ReLU_2 [1, 4096] 16_classifier.Dropout_3 [1, 4096] 17_classifier.Linear_4 [1, 4096] 18_classifier.ReLU_5 [1, 4096] 19_classifier.Linear_6 [1, 1000] =========================================== Total params: 61,100,840 Trainable params: 61,100,840 Non-trainable params: 0 Total FLOPs: 773,286,232 / 773.29 MFLOPs ------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 8.31 Params size (MB): 233.08 Estimated Total Size (MB): 241.96 ===========================================