spyrit.core.torch

Contains pytorch-based functions used in spyrit.core modules.

The goal of this module is to provide a set of functions that use various pytorch functionalities and optimizations to perform the necessary operations in the spyrit.core modules. It mirrors the the spyrit.misc most used functions, but using pytorch tensors instead of numpy arrays.

Functions

Cov2Var(Cov[, out_shape])

Extracts Variance Matrix from Covariance Matrix.

Permutation_Matrix(sig)

Returns a permutation matrix based on the significance tensor.

assert_power_of_2(n[, raise_error])

Asserts that n is a power of 2.

center_crop(img, out_shape[, ...])

Crops the center of an image to the specified shape.

center_pad(img, out_shape[, vectorized_in_shape])

Pads an image to the specified shape by centering it.

finite_diff_mat(n[, boundary])

Creates a finite difference matrix of shape \((n^2,n^2)\) for a 2D image of shape \((n,n)\).

fwht(x[, order, dim])

Fast Walsh-Hadamard transform of x

fwht_2d(x[, order])

Returns the fast Walsh-Hadamard transform of a 2D tensor.

ifwht(x[, order, dim])

Inverse fast Walsh-Hadamard transform of x

ifwht_2d(x[, order])

Returns the inverse fast Walsh-Hadamard transform of a 2D tensor.

meas2img(meas, Ord)

Returns measurement image from a single measurement tensor or from a batch of measurement tensors.

mult_1d(H, x[, dim])

Multiply a matrix to batches of (1D) vectors.

mult_2d_separable(H, x)

Applies separable transform to batches of (2D) images.

neumann_boundary(img_shape)

Creates a finite difference matrix of shape \((h*w,h*w)\) for a 2D image of shape \((h,w)\).

regularized_lstsq(A, y, regularization, **kwargs)

Batched regularized least squares solution of a system of equations.

regularized_pinv(A, regularization, **kwargs)

Returns a regularized pseudo-inverse of a tensor.

reindex(values, indices[, axis, ...])

Sorts a tensor along a specified axis using the indices tensor.

sequency_perm(X[, ind])

Permute the last dimension of a tensor.

sort_by_significance(values, sig[, axis, ...])

Returns a tensor sorted by decreasing significance of its elements as determined by the significance tensor.

spdiags(diagonals, offsets, shape)

Similar to torch.sparse.spdiags. Arguments are the same, excepted :

walsh2_torch(img[, H])

Return 2D Walsh-ordered Hadamard transform of an image

walsh_matrix(n)

Returns a 1D Walsh-ordered Hadamard.

walsh_matrix_2d(n)

2D Walsh-ordered Hadamard matrix.