spyrit.core.torch.sort_by_significance

spyrit.core.torch.sort_by_significance(values: tensor, sig: tensor, axis: str = 'rows', inverse_permutation: bool = False, get_indices: bool = False) tensor[source]

Returns a tensor sorted by decreasing significance of its elements as determined by the significance tensor.

The element in the values tensor whose significance is the highest will be placed first, followed by the element with the second highest significance, and so on. The significance tensor sig must have the same shape as values along the specified axis.

This function is equivalent to (but much faster than) the following code:

from spyrit.core.torch import Permutation_Matrix

h = 64 values = torch.randn(2*h, h) sig_rows = torch.randn(2*h) sig_cols = torch.randn(h)

# 1 y1 = sort_by_significance(values, sig_rows, ‘rows’, False) y2 = Permutation_Matrix(sig_rows) @ values assert torch.allclose(y1, y2) # True

# 2 y1 = sort_by_significance(values, sig_rows, ‘rows’, True) y2 = Permutation_Matrix(sig_rows).T @ values assert torch.allclose(y1, y2) # True

# 3 y1 = sort_by_significance(values, sig_cols, ‘cols’, False) y2 = values @ Permutation_Matrix(sig_cols) assert torch.allclose(y1, y2) # True

# 4 y1 = sort_by_significance(values, sig_cols, ‘cols’, True) y2 = values @ Permutation_Matrix(sig_cols).T assert torch.allclose(y1, y2) # True

Args:

values (torch.tensor): Tensor to sort by significance. Can be 1D, 2D, or any multi-dimensional batch of 2D tensors.

sig (torch.tensor): Significance tensor. Its length must be equal to the number of rows or columns in values depending on the specified axis.

axis (str, optional): Axis along which to sort. Must be either ‘rows’ or ‘cols’. Default is ‘rows’.

inverse_permutation (bool, optional): If True, the inverse permutation is applied. Default is False.

get_indices (bool, optional): If True, the function will return the indices tensor used to sort the values tensor. Default is False.

Returns:

torch.tensor or 2-tuple of torch.tensors: Tensor ordered by decreasing significance along the specified axis. If get_indices is True, the function will return a tuple containing the ordered tensor and the indices tensor used to sort the values tensor.