spyrit.core.torch.reindex

spyrit.core.torch.reindex(values: tensor, indices: tensor, axis: str = 'rows', inverse_permutation: bool = False) tensor[source]

Sorts a tensor along a specified axis using the indices tensor.

The indices tensor contains the new indices of the elements in the values tensor. values[0] will be placed at the index indices[0], values[1] at indices[1], and so on.

Using the inverse permutation allows to revert the permutation: in this case, it is the element at index indices[0] that will be placed at the index 0, the element at index indices[1] that will be placed at the index 1, and so on.

Args:

values (torch.tensor): The tensor to sort. Can be 1D, 2D, or any multi-dimensional batch of 2D tensors.

indices (torch.tensor): Tensor containing the new indices of the elements contained in values.

axis (str, optional): The axis to sort along. Must be either ‘rows’ or ‘cols’. If values is 1D, axis is not used. Default is ‘rows’.

inverse_permutation (bool, optional): Whether to apply the permutation inverse. Default is False.

Raises:

ValueError: If axis is not ‘rows’ or ‘cols’.

Returns:

torch.tensor: The sorted tensor by the given indices along the specified axis.

Example:
>>> values = torch.tensor([[10, 20, 30], [100, 200, 300]])
>>> indices = torch.tensor([2, 0, 1])
>>> out = reindex(values, indices, axis="cols", False)
>>> out
tensor([[ 20,  30,  10],
        [200, 300, 100]])
>>> reindex(out, indices, axis="cols", inverse_permutation=True)
tensor([[ 10,  20,  30],
        [100, 200, 300]])