spyrit.core.torch.Permutation_Matrix

spyrit.core.torch.Permutation_Matrix(sig: tensor) tensor[source]

Returns a permutation matrix based on the significance tensor.

The permutation matrix is a square matrix whose rows or columns are permuted based on the significance tensor. The permutation matrix is used to sort a tensor by decreasing significance of its elements.

Args:

sig (torch.tensor): Significance tensor. Its length must be equal to the number of rows or columns in the tensor to be sorted. If it is not a 1D tensor, it is flattened.

Returns:

torch.tensor: Permutation matrix of shape (n, n) based on the significance tensor, where n is the length of the significance tensor.

Example:
>>> sig = torch.tensor([0.1, 0.4, 0.2, 0.3])
>>> Permutation_Matrix(sig)
tensor([[0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.]])