datasets

Useful tools for your datasets.

class torchutils.datasets.RunningStat(dims=3)

Calculate and track statistics of data.

Calculates mean, standard deviation and variance. Uses Welford’s algorithm for computing the statistics. See Knuth TAOCP Vol 2, 3rd edition, page 232.

Parameters

dims (int) – Number of dimensions of stat. Example, 3 for RGB images. (default: 3)

Example:

import torchutils as tu

# define your dataset and dataloader
sample, _ = loader.dataset[0]
running_stat = tu.RunningStat(dims=sample.size(1))
for batch_idx, (data, _) in enumerate(loader):
    # data must be (N,C)
    running_stat.update(data)
print('Mean:', running_stat.mean)
print('Std:', running_stat.std)

Out:

Mean: tensor([10000.0029,  9999.9941, 10000.0137])
Std: tensor([1.0037, 1.0009, 0.9997])
property mean

Mean of data seen till now.

Type

torch.Tensor

property num_data_points

Number of data points seen till now.

Type

int

reset()

Reset stats tracker.

Returns

Returns nothing.

Return type

None

property std

Standard deviation of data seen till now.

Type

torch.Tensor

update(data)

Update running stats tracker.

Parameters

data (torch.Tensor) – Input data. Must be (N, dims).

Returns

Current stats of entire data.

Return type

dict {“mean” -> Mean, “std” -> Standard deviation, “var” -> Variance}

property var

Variance of data seen till now.

Type

torch.Tensor

torchutils.datasets.get_dataset_stats(loader, verbose=False)

Get statistics of dataset.

Calculates mean, standard deviation and variance. Supports data of shape (N,C) and (N,C,H,W).

Parameters
  • loader (torch.utils.data.DataLoader) – PyTorch dataloader.

  • verbose (bool) – Enable/disable print statements.

Returns

Stats of entire dataset.

Return type

dict {“mean” -> Mean, “std” -> Standard deviation, “var” -> Variance}

Example:

import torch
import torchutils as tu

# define your dataset and dataloader
dataset = MyDataset()
trainloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                          num_workers=1,
                                          shuffle=False)
stats = tu.get_dataset_stats(trainloader, verbose=True)
print('Mean:', stats['mean'])
print('Std:', stats['std'])

Out:

Calculating dataset stats...
Batch 100/100
Mean: tensor([10000.0098,  9999.9795,  9999.9893])
Std: tensor([0.9969, 1.0003, 0.9972])