spyrit.core.torch.reindex
- spyrit.core.torch.reindex(values: tensor, indices: tensor, axis: str = 'rows', inverse_permutation: bool = False) tensor[source]
Sorts a tensor along a specified axis using the indices tensor.
The indices tensor contains the new indices of the elements in the values tensor. values[0] will be placed at the index indices[0], values[1] at indices[1], and so on.
Using the inverse permutation allows to revert the permutation: in this case, it is the element at index indices[0] that will be placed at the index 0, the element at index indices[1] that will be placed at the index 1, and so on.
- Args:
values (torch.tensor): The tensor to sort. Can be 1D, 2D, or any multi-dimensional batch of 2D tensors.
indices (torch.tensor): Tensor containing the new indices of the elements contained in values.
axis (str, optional): The axis to sort along. Must be either ‘rows’ or ‘cols’. If values is 1D, axis is not used. Default is ‘rows’.
inverse_permutation (bool, optional): Whether to apply the permutation inverse. Default is False.
- Raises:
ValueError: If axis is not ‘rows’ or ‘cols’.
- Returns:
torch.tensor: The sorted tensor by the given indices along the specified axis.
- Example:
>>> values = torch.tensor([[10, 20, 30], [100, 200, 300]]) >>> indices = torch.tensor([2, 0, 1]) >>> out = reindex(values, indices, axis="cols", False) >>> out tensor([[ 20, 30, 10], [200, 300, 100]]) >>> reindex(out, indices, axis="cols", inverse_permutation=True) tensor([[ 10, 20, 30], [100, 200, 300]])