spyrit.misc.statistics.cov_2

spyrit.misc.statistics.cov_2(dataloader: DataLoader, mean: array, device: device, n_loop: int = 1)[source]

Computes 2D covariance matrix computed across batches and channels.

Args:

dataloader (torch.utils.data.DataLoader): Dataloader. The fetch data are Torch tensors with shape (B,C,N,N).

mean (np.array): Mean image with shape (N, N).

device (torch.device): Device.

n_loop (int, optional): Number of loops across image database. Defaults to 1. nloop > 1 is relevant for dataloaders with random transforms.

Returns:

cov (np.array): Covariance matrix with shape (N*N, N*N).