datasets¶
Useful tools for your datasets.
-
class
torchutils.datasets.
RunningStat
(dims=3)¶ Calculate and track statistics of data.
Calculates mean, standard deviation and variance. Uses Welford’s algorithm for computing the statistics. See Knuth TAOCP Vol 2, 3rd edition, page 232.
- Parameters
dims (int) – Number of dimensions of stat. Example, 3 for RGB images. (default: 3)
Example:
import torchutils as tu # define your dataset and dataloader sample, _ = loader.dataset[0] running_stat = tu.RunningStat(dims=sample.size(1)) for batch_idx, (data, _) in enumerate(loader): # data must be (N,C) running_stat.update(data) print('Mean:', running_stat.mean) print('Std:', running_stat.std)
Out:
Mean: tensor([10000.0029, 9999.9941, 10000.0137]) Std: tensor([1.0037, 1.0009, 0.9997])
-
property
mean
¶ Mean of data seen till now.
- Type
torch.Tensor
-
property
num_data_points
¶ Number of data points seen till now.
- Type
int
-
reset
()¶ Reset stats tracker.
- Returns
Returns nothing.
- Return type
None
-
property
std
¶ Standard deviation of data seen till now.
- Type
torch.Tensor
-
update
(data)¶ Update running stats tracker.
- Parameters
data (torch.Tensor) – Input data. Must be (N, dims).
- Returns
Current stats of entire data.
- Return type
dict {“mean” -> Mean, “std” -> Standard deviation, “var” -> Variance}
-
property
var
¶ Variance of data seen till now.
- Type
torch.Tensor
-
torchutils.datasets.
get_dataset_stats
(loader, verbose=False)¶ Get statistics of dataset.
Calculates mean, standard deviation and variance. Supports data of shape (N,C) and (N,C,H,W).
- Parameters
loader (torch.utils.data.DataLoader) – PyTorch dataloader.
verbose (bool) – Enable/disable print statements.
- Returns
Stats of entire dataset.
- Return type
dict {“mean” -> Mean, “std” -> Standard deviation, “var” -> Variance}
Example:
import torch import torchutils as tu # define your dataset and dataloader dataset = MyDataset() trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False) stats = tu.get_dataset_stats(trainloader, verbose=True) print('Mean:', stats['mean']) print('Std:', stats['std'])
Out:
Calculating dataset stats... Batch 100/100 Mean: tensor([10000.0098, 9999.9795, 9999.9893]) Std: tensor([0.9969, 1.0003, 0.9972])