spyrit.misc.statistics.stat_2

spyrit.misc.statistics.stat_2(dataloader, device, root, n_loop: int = 1, ext='npy')[source]

Computes and saves 2D mean image and covariance matrix of an image database

Args:

dataloader (torch.utils.data.DataLoader): Dataloader. The fetch data are Torch tensors with shape (B,C,N,N).

device (torch.device): Device.

root (file, str, or pathlib.Path): Path where the covariance and mean are saved.

n_loop (int, optional): Number of loops across image database. Defaults to 1. nloop > 1 is relevant for dataloaders with random transforms.

ext (string): Extension of saved files:

  • ‘npy’ for numpy (default),

  • ‘pt’ for pytorch,

  • do not save files otherwise.

Returns:

mean (np.array): Mean image with shape (N, N).

cov (np.array): Covariance matrix with shape (N*N, N*N).