Source code for spyrit.core.meas

"""
Measurement operators, static and dynamic.

There are six classes contained in this module, each representing a different
type of measurement operator. Three of them are static, i.e. they are used to
simulate measurements of still images, and three are dynamic, i.e. they are used
to simulate measurements of moving objects, represented as a sequence of images.
 The inheritance tree is as follows::

      Linear          DynamicLinear
        |                   |
        V                   V
    LinearSplit     DynamicLinearSplit
        |                   |
        V                   V
    HadamSplit      DynamicHadamSplit

"""

import warnings

# import memory_profiler as mprof

import math
import torch
import torch.nn as nn

from spyrit.core.warp import DeformationField
import spyrit.core.torch as spytorch


# =============================================================================
# BASE CLASS - FOR INHERITANCE ONLY (INTERAL USE)
# =============================================================================
class _Base(nn.Module):

    def __init__(
        self,
        H_static: torch.tensor,
        Ord: torch.tensor = None,
        meas_shape: tuple = None,
    ) -> None:
        super().__init__()

        # store meas_shape and check it is correct
        if meas_shape is None:
            self._meas_shape = (
                int(math.sqrt(H_static.shape[1])),
                int(math.sqrt(H_static.shape[1])),
            )
        else:
            self._meas_shape = meas_shape
        if self._meas_shape[0] * self._meas_shape[1] != H_static.shape[1]:
            raise ValueError(
                f"The number of pixels in the measurement matrix H "
                + f"({H_static.shape[1]}) does not match the measurement shape "
                + f"{self._meas_shape}."
            )
        self._img_shape = self._meas_shape

        if Ord is not None:
            H_static, ind = spytorch.sort_by_significance(
                H_static, Ord, "rows", False, get_indices=True
            )
        else:
            ind = torch.arange(H_static.shape[0])
            Ord = torch.arange(H_static.shape[0], 0, -1)

        # convert H to float32 if it is not float64
        if H_static.dtype != torch.float64:
            H_static = H_static.to(torch.float32)

        # attributes for internal use
        self._param_H_static = nn.Parameter(H_static, requires_grad=False)
        # need to store M because H_static may be cropped (see HadamSplit)
        self._M = H_static.shape[0]

        self._param_Ord = nn.Parameter(Ord.to(torch.float32), requires_grad=False)
        self._indices = ind.to(torch.int32)
        self._device_tracker = nn.Parameter(torch.tensor(0.0), requires_grad=False)

    ### PROPERTIES ------
    @property
    def M(self) -> int:
        """Number of measurements (first dimension of H)"""
        return self._M

    @property
    def N(self) -> int:
        """Number of pixels in the image"""
        return self.img_h * self.img_w

    @property
    def h(self) -> int:
        """Measurement pattern height"""
        return self.meas_shape[0]

    @property
    def w(self) -> int:
        """Measurement pattern width"""
        return self.meas_shape[1]

    @property
    def meas_shape(self) -> tuple:
        """Shape of the measurement patterns (height, width). Note that
        `height * width = N`."""
        return self._meas_shape

    @property
    def img_shape(self) -> tuple:
        """Shape of the image (height, width)."""
        return self._img_shape

    @property
    def img_h(self) -> int:
        """Height of the image"""
        return self._img_shape[0]

    @property
    def img_w(self) -> int:
        """Width of the image"""
        return self._img_shape[1]

    @property
    def indices(self) -> torch.tensor:
        """Indices used to sort the rows of H"""
        return self._indices

    @property
    def Ord(self) -> torch.tensor:
        """Order matrix used to sort the rows of H"""
        return self._param_Ord.data

    @Ord.setter
    def Ord(self, value: torch.tensor) -> None:
        self._set_Ord(value)

    @property
    def H_static(self) -> torch.tensor:
        """Static measurement matrix H."""
        return self._param_H_static.data[: self.M, :]

    @property
    def P(self) -> torch.tensor:
        """Measurement matrix P with positive and negative components. Used in
        classes *Split and *HadamSplit."""
        return self._param_P.data[: 2 * self.M, :]

    @property
    def device(self) -> torch.device:
        return self._device_tracker.device

    ### -------------------

    def pinv(
        self, x: torch.tensor, reg: str = "rcond", eta: float = 1e-3, diff=False
    ) -> torch.tensor:
        r"""Computes the pseudo inverse solution :math:`y = H^\dagger x`.

        This method will compute the pseudo inverse solution using the
        measurement matrix pseudo-inverse :math:`H^\dagger` if it has been
        calculated and stored in the attribute :attr:`H_pinv`. If not, the
        pseudo inverse will be not be explicitly computed and the torch
        function :func:`torch.linalg.lstsq` will be used to solve the linear
        system.

        Args:
            :attr:`x` (torch.tensor): batch of measurement vectors. If x has
            more than 1 dimension, the pseudo inverse is applied to each
            image in the batch.

            :attr:`reg` (str, optional): Regularization method to use.
            Available options are 'rcond', 'L2' and 'H1'. 'rcond' uses the
            :attr:`rcond` parameter found in :func:`torch.linalg.lstsq`.
            This parameter must be specified if the pseudo inverse has not been
            computed. Defaults to None.

            :attr:`eta` (float, optional): Regularization parameter. Only
            relevant when :attr:`reg` is specified. Defaults to None.

            :attr:`diff` (bool, optional): Use only if a split operator is used
            and if the pseudo inverse has not been computed. Whether to use the
            difference of positive and negative patterns.
            The difference is applied to the measurements and to the dynamic
            measurement matrix. Defaults to False.

        Shape:
            :math:`x`: :math:`(*, M)` where * denotes the batch size and `M`
            the number of measurements.

            Output: :math:`(*, N)` where * denotes the batch size and `N`
            the number of pixels in the image.

        Example:
            >>> H = torch.randn([400, 1600])
            >>> meas_op = Linear(H, True)
            >>> x = torch.randn([10, 400])
            >>> y = meas_op.pinv(x)
            >>> print(y.shape)
            torch.Size([10, 1600])
        """
        # have we calculated the pseudo inverse ?
        if hasattr(self, "H_pinv"):
            ans = self._pinv_mult(x)

        else:

            if isinstance(self, Linear):
                H_to_inv = self.H_static
            elif type(self) == DynamicLinear:
                H_to_inv = self.H_dyn
            elif isinstance(self, DynamicLinearSplit):
                if diff:
                    x = x[..., ::2] - x[..., 1::2]
                    H_to_inv = self.H_dyn[::2, :] - self.H_dyn[1::2, :]
                else:
                    H_to_inv = self.H_dyn
            else:
                raise NotImplementedError(
                    "It seems you have instanciated a _Base element. This class "
                    + "Should not be called on its own."
                )

            # cast to dtype of x
            H_to_inv = H_to_inv.to(x.dtype)
            # devices are supposed to be the same, don't bother checking

            driver = "gelsd"
            if H_to_inv.device != torch.device("cpu"):
                H_to_inv = H_to_inv.cpu()
                x = x.cpu()
                # driver = 'gels'

            if reg == "rcond":
                A = H_to_inv.expand(*x.shape[:-1], *H_to_inv.shape)  # shape (*, M, N)
                B = x.unsqueeze(-1).to(A.dtype)  # shape (*, M, 1)
                ans = torch.linalg.lstsq(A, B, rcond=eta, driver=driver)
                ans = ans.solution.to(x.dtype).squeeze(-1)  # shape (*, N)

            elif reg == "L2":
                A = torch.matmul(H_to_inv.mT, H_to_inv) + eta * torch.eye(
                    H_to_inv.shape[1]
                )
                A = A.expand(*x.shape[:-1], *A.shape)
                B = torch.matmul(x.to(H_to_inv.dtype), H_to_inv)
                ans = torch.linalg.solve(A, B).to(x.dtype)

            elif reg == "H1":
                Dx, Dy = spytorch.neumann_boundary(self.img_shape)
                D2 = Dx.T @ Dx + Dy.T @ Dy

                A = torch.matmul(H_to_inv.mT, H_to_inv) + eta * D2
                A = A.expand(*x.shape[:-1], *A.shape)
                B = torch.matmul(x.to(H_to_inv.dtype), H_to_inv)
                ans = torch.linalg.solve(A, B).to(x.dtype)

            elif reg is None:
                raise ValueError(
                    "Regularization method not specified. Please compute "
                    + "the dynamic pseudo-inverse or specify a regularization "
                    + "method."
                )
            else:
                raise NotImplementedError(
                    f"Regularization method ({reg}) not implemented. Please "
                    + "use 'rcond', 'L2' or 'H1'."
                )

        # if we used bicubic b spline, convolve with the kernel
        if hasattr(self, "recon_mode") and self.recon_mode == "bicubic":
            kernel = torch.tensor([[1, 4, 1], [4, 16, 4], [1, 4, 1]]) / 36
            conv = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
            conv.weight.data = kernel.reshape(1, 1, 3, 3).to(ans.dtype)

            ans = (
                conv(ans.reshape(-1, 1, self.img_h, self.img_w))
                .reshape(-1, self.img_h * self.img_w)
                .data
            )

        return ans.reshape(*ans.shape[:-1], *self.img_shape)

    def reindex(
        self, x: torch.tensor, axis: str = "rows", inverse_permutation: bool = False
    ) -> torch.tensor:
        """Sorts a tensor along a specified axis using the indices tensor. The
        indices tensor is contained in the attribute :attr:`self.indices`.

        The indices tensor contains the new indices of the elements in the values
        tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
        at `indices[1]`, and so on.

        Using the inverse permutation allows to revert the permutation: in this
        case, it is the element at index `indices[0]` that will be placed at the
        index `0`, the element at index `indices[1]` that will be placed at the
        index `1`, and so on.

        .. note::
            See :func:`~spyrit.core.torch.reindex()` for more details.

        Args:
            values (torch.tensor): The tensor to sort. Can be 1D, 2D, or any
            multi-dimensional batch of 2D tensors.

            axis (str, optional): The axis to sort along. Must be either 'rows' or
            'cols'. If `values` is 1D, `axis` is not used. Default is 'rows'.

            inverse_permutation (bool, optional): Whether to apply the permutation
            inverse. Default is False.

        Raises:
            ValueError: If `axis` is not 'rows' or 'cols'.

        Returns:
            torch.tensor: The sorted tensor by the given indices along the
            specified axis.
        """
        return spytorch.reindex(x, self.indices.to(x.device), axis, inverse_permutation)

    def unvectorize(self, input: torch.tensor) -> torch.tensor:
        """Reshape a vectorized tensor to the measurement shape (heigth, width).

        Input:
            input (torch.tensor): A tensor of shape (*, N) where * denotes the
            batch size and :math:`N = hw` is the total number of pixels in the
            image.

        Output:
            torch.tensor: A tensor of shape (*, h, w) where * denotes the batch
            size and h, w the height and width of the image.
        """
        return input.reshape(*input.shape[:-1], *self.meas_shape)

    def vectorize(self, input: torch.tensor) -> torch.tensor:
        """Vectorize an image-shaped tensor.

        Input:
            input (torch.tensor): A tensor of shape (*, h, w) where * denotes the
            batch size and h, w the height and width of the image.

        Output:
            torch.tensor: A tensor of shape (*, N) where * denotes the batch size
            and :math:`N = hw` is the total number of pixels in the image.
        """
        return input.reshape(*input.shape[:-2], self.N)

    def _static_forward_with_op(
        self, x: torch.tensor, op: torch.tensor
    ) -> torch.tensor:
        return torch.einsum("mhw,...hw->...m", self.unvectorize(op).to(x.dtype), x)

    # @mprof.profile
    def _dynamic_forward_with_op(
        self, x: torch.tensor, op: torch.tensor
    ) -> torch.tensor:
        x = spytorch.center_crop(x, self.meas_shape)
        return torch.einsum("thw,...tchw->...ct", self.unvectorize(op).to(x.dtype), x)

    def _pinv_mult(self, y: torch.tensor) -> torch.tensor:
        """Uses the pre-calculated pseudo inverse to compute the solution.
        We assume that the pseudo inverse has been calculated and stored in the
        attribute :attr:`H_pinv`.
        """
        A = self.H_pinv.expand(*y.shape[:-1], *self.H_pinv.shape)
        B = y.unsqueeze(-1).to(A.dtype)
        ans = torch.matmul(A, B).to(y.dtype).squeeze(-1)
        return ans

    def _set_Ord(self, Ord: torch.tensor) -> None:
        """Set the order matrix used to sort the rows of H. This is used in
        the Ord.setter property. This method is defined for simplified
        inheritance. For internal use only."""
        # unsort the rows of H
        H_natural = self.reindex(self.H_static, "rows", inverse_permutation=True)
        # resort the rows of H ; store indices in self._indices
        H_resorted, self._indices = spytorch.sort_by_significance(
            H_natural, Ord, "rows", False, get_indices=True
        )
        # update values of H, Ord
        self._param_H_static.data = H_resorted.to(self.device)
        self._param_Ord.data = Ord.to(self.device)

    def _set_P(self, H_static: torch.tensor) -> None:
        """Set the positive and negative components of the measurement matrix
        P from the static measurement matrix H_static. For internal use only.
        Used in classes *Split and *HadamSplit."""
        H_pos = nn.functional.relu(H_static)
        H_neg = nn.functional.relu(-H_static)
        self._param_P = nn.Parameter(
            torch.cat([H_pos, H_neg], 1).reshape(
                2 * H_static.shape[0], H_static.shape[1]
            ),
            requires_grad=False,
        )

    def _build_pinv(self, tensor: torch.tensor, reg: str, eta: float) -> torch.tensor:

        if reg == "rcond":
            pinv = torch.linalg.pinv(tensor, atol=eta)

        elif reg == "L2":
            if tensor.shape[0] >= tensor.shape[1]:
                pinv = (
                    torch.linalg.inv(
                        tensor.T @ tensor + eta * torch.eye(tensor.shape[1])
                    )
                    @ tensor.T
                )
            else:
                pinv = tensor.T @ torch.linalg.inv(
                    tensor @ tensor.T + eta * torch.eye(tensor.shape[0])
                )

        elif reg == "H1":
            # Boundary condition matrices
            Dx, Dy = spytorch.neumann_boundary(self.img_shape)
            D2 = (Dx.T @ Dx + Dy.T @ Dy).to(tensor.device)
            pinv = torch.linalg.inv(tensor.T @ tensor + eta * D2) @ tensor.T

        else:
            raise NotImplementedError(
                f"Regularization method '{reg}' is not implemented. Please "
                + "choose either 'rcond', 'L2' or 'H1'."
            )
        return pinv.to(self.device)

    def _attributeslist(self) -> list:
        _list = [
            ("M", "self.M", _Base),
            ("N", "self.N", _Base),
            ("H.shape", "self.H_static.shape", _Base),
            ("meas_shape", "self._meas_shape", _Base),
            ("H_dyn", "hasattr(self, 'H_dyn')", DynamicLinear),
            ("img_shape", "self.img_shape", DynamicLinear),
            ("H_pinv", "hasattr(self, 'H_pinv')", _Base),
            ("P.shape", "self.P.shape", (LinearSplit, DynamicLinearSplit)),
        ]
        return _list

    def __repr__(self) -> str:
        s_begin = f"{self.__class__.__name__}(\n  "
        s_fill = "\n  ".join(
            [
                f"({k}): {eval(v)}"
                for k, v, t in self._attributeslist()
                if isinstance(self, t)
            ]
        )
        s_end = "\n)"
        return s_begin + s_fill + s_end


# =============================================================================
[docs] class Linear(_Base): # ========================================================================= r""" Simulates linear measurements :math:`y = Hx`. Computes linear measurements from incoming images: :math:`y = Hx`, where :math:`H` is a given linear operator (matrix) and :math:`x` is a vectorized image or batch of images. The class is constructed from a matrix :math:`H` of shape :math:`(M,N)`, where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements. Args: :attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator) with shape :math:`(M, N)`. Only real values are supported. :attr:`pinv` (bool): Whether to store the pseudo inverse of the measurement matrix :math:`H`. If `True`, the pseudo inverse is initialized as :math:`H^\dagger` and stored in the attribute :attr:`H_pinv`. It is alwats possible to compute and store the pseudo inverse later using the method :meth:`build_H_pinv`. Defaults to `False`. :attr:`rtol` (float, optional): Cutoff for small singular values (see :mod:`torch.linalg.pinv`). Only relevant when :attr:`pinv` is `True`. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. :attr:`meas_shape` (tuple, optional): Shape of the image :math:`x`. Must be a tuple of two integers representing the height and width of the image. If not specified, the image is suppposed to be a square. If not, an error is raised. Defaults to None. Attributes: :attr:`H` (torch.tensor): The learnable measurement matrix of shape :math:`(M, N)` initialized as :math:`H`. :attr:`H_static` (torch.tensor): alias for :attr:`H`. :attr:`H_pinv` (torch.tensor, optional): The learnable pseudo inverse measurement matrix :math:`H^\dagger` of shape :math:`(N, M)`. :attr:`M` (int): Number of measurements performed by the linear operator. :attr:`N` (int): Number of pixels in the image. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. :attr:`meas_shape` (tuple): Shape of the image :math:`x` (height, width). Is equal to `(self.h, self.w)`. :attr:`indices` (torch.tensor): Indices used to sort the rows of H. It is used by the method :meth:`reindex()`. :attr:`Ord` (torch.tensor): Order matrix used to sort the rows of H. It is used by :func:`~spyrit.core.torch.sort_by_significance()`. .. note:: If you know the pseudo inverse of :math:`H` and want to store it, it is best to initialize the class with :attr:`pinv` set to `False` and then call :meth:`build_H_pinv` to store the pseudo inverse. Example 1: >>> H = torch.rand([400, 1600]) >>> meas_op = Linear(H, pinv=False) >>> print(meas_op) Linear( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_pinv): False ) Example 2: >>> H = torch.rand([400, 1600]) >>> meas_op = Linear(H, True) >>> print(meas_op) Linear( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_pinv): True ) """ def __init__( self, H: torch.tensor, pinv: bool = False, rtol: float = None, Ord: torch.tensor = None, meas_shape: tuple = None, # (height, width) ): super().__init__(H, Ord, meas_shape) if pinv: self.build_H_pinv(reg="rcond", eta=rtol) @property def H(self) -> torch.tensor: return self.H_static @property def H_pinv(self) -> torch.tensor: return self._param_H_static_pinv.data @H_pinv.setter def H_pinv(self, value: torch.tensor) -> None: self._param_H_static_pinv = nn.Parameter( value.to(torch.float64), requires_grad=False ) @H_pinv.deleter def H_pinv(self) -> None: del self._param_H_static_pinv # Deprecated method - included for backwards compatibility but to remove
[docs] def get_H(self) -> torch.tensor: """Deprecated method. Use the attribute self.H instead.""" warnings.warn( "The method get_H() is deprecated and will be removed in a future " + "version. Please use the attribute self.H instead." ) return self.H
[docs] def build_H_pinv(self, reg: str = "rcond", eta: float = 1e-3) -> None: """Used to set the pseudo inverse of the measurement matrix :math:`H` using `torch.linalg.pinv`. The result is stored in the attribute :attr:`H_pinv`. Args: reg (str, optional): Regularization method to use. Available options are 'rcond', 'L2' and 'H1'. 'rcond' uses the :attr:`rcond` parameter found in :func:`torch.linalg.lstsq`. This parameter must be specified if the pseudo inverse has not been computed. Defaults to None. eta (float, optional): Regularization parameter (cutoff for small singular values, see :mod:`torch.linalg.pinv`). Defaults to None, in which case the default value of :mod:`torch.linalg.pinv` is used. Returns: None. The pseudo inverse is stored in the attribute :attr:`H_pinv`. """ self.H_pinv = self._build_pinv(self.H_static, reg, eta)
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Hx`. This is equivalent to computing :math:`x \cdot H^T`. The input images must be unvectorized. Args: :math:`x` (torch.tensor): Batch of images of shape :math:`(*, h, w)`. `*` can have any number of dimensions, for instance `(b, c)` where `b` is the batch size and `c` the number of channels. `h` and `w` are the height and width of the images. Shape: :math:`x`: :math:`(*, h, w)` where * denotes the batch size and `N` the total number of pixels in the image. Output: :math:`(*, M)` where * denotes any number of dimensions and `M` the number of measurements. Example: >>> H = torch.randn([400, 1600]) >>> meas_op = Linear(H) >>> x = torch.randn([10, 40, 40]) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 400]) """ # left multiplication with transpose is equivalent to right mult # return x @ self.H.T.to(x.dtype).to(x.device) return self._static_forward_with_op(x, self.H)
[docs] def adjoint(self, y: torch.tensor) -> torch.tensor: r"""Applies adjoint transform to incoming measurements :math:`x = H^{T}y` This brings back the measurements in the image domain, but is not equivalent to the inverse of the forward operator. Args: :math:`y` (torch.tensor): batch of measurement vectors of shape :math:`(*, M)` where * denotes any number of dimensions (e.g. `(b,c)` where `b` is the batch size and `c` the number of channels) and `M` the number of measurements. Output: torch.tensor: The adjoint of the input measurements, which are in the image domain. It has shape :math:`(*, h, w)` where * denotes any number of dimensions and `h`, `w` the height and width of the images. Shape: :math:`y`: :math:`(*, M)` Output: :math:`(*, h, w)` Example: >>> H = torch.randn([400, 1600]) >>> meas_op = Linear(H) >>> y = torch.randn([10, 400] >>> x = meas_op.adjoint(y) >>> print(x.shape) torch.Size([10, 40, 40]) """ # return x @ self.H.to(x.dtype).to(x.device) return torch.einsum("mhw,...m->...hw", self.unvectorize(self.H).to(y.dtype), y)
def _set_Ord(self, Ord: torch.tensor) -> None: """Set the order matrix used to sort the rows of H.""" super()._set_Ord(Ord) # delete self.H_pinv (self._param_H_static_pinv) try: del self._param_H_static_pinv warnings.warn( "The pseudo-inverse H_pinv has been deleted. Please call " + "build_H_pinv() to recompute it." ) except AttributeError: pass
# =============================================================================
[docs] class LinearSplit(Linear): # ========================================================================= r""" Simulates splitted measurements :math:`y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}x`. Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a vectorized image or batch of vectorized images. The matrix :math:`P` contains only positive values and is obtained by splitting a measurement matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. The class is constructed from the :math:`M` by :math:`N` matrix :math:`H`, where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements. Therefore, the shape of :math:`P` is :math:`(2M, N)`. Args: :attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator) with shape :math:`(M, N)`. Only real values are supported. :attr:`pinv` (bool): Whether to store the pseudo inverse of the measurement matrix :math:`H`. If `True`, the pseudo inverse is initialized as :math:`H^\dagger` and stored in the attribute :attr:`H_pinv`. It is alwats possible to compute and store the pseudo inverse later using the method :meth:`build_H_pinv`. Defaults to `False`. :attr:`rtol` (float, optional): Cutoff for small singular values (see :mod:`torch.linalg.pinv`). Only relevant when :attr:`pinv` is `True`. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. :attr:`meas_shape` (tuple, optional): Shape of the measurement patterns. Must be a tuple of two integers representing the height and width of the patterns. If not specified, the shape is suppposed to be a square image. If not, an error is raised. Defaults to None. Attributes: :attr:`H` (torch.tensor): The learnable measurement matrix of shape :math:`(M, N)` initialized as :math:`H`. :attr:`H_static` (torch.tensor): alias for :attr:`H`. :attr:`P` (torch.tensor): The splitted measurement matrix of shape :math:`(2M, N)`. :attr:`H_pinv` (torch.tensor, optional): The learnable pseudo inverse measurement matrix :math:`H^\dagger` of shape :math:`(N, M)`. :attr:`M` (int): Number of measurements performed by the linear operator. :attr:`N` (int): Number of pixels in the image. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. :attr:`meas_shape` (tuple): Shape of the measurement patterns (height, width). Is equal to `(self.h, self.w)`. :attr:`indices` (torch.tensor): Indices used to sort the rows of H. It is used by the method :meth:`reindex()`. :attr:`Ord` (torch.tensor): Order matrix used to sort the rows of H. It is used by :func:`~spyrit.core.torch.sort_by_significance()`. .. note:: If you know the pseudo inverse of :math:`H` and want to store it, it is best to initialize the class with :attr:`pinv` set to `False` and then call :meth:`build_H_pinv` to store the pseudo inverse. .. note:: :math:`H = H_{+} - H_{-}` Example: >>> H = torch.randn(400, 1600) >>> meas_op = LinearSplit(H, False) >>> print(meas_op) LinearSplit( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_pinv): False (P.shape): torch.Size([800, 1600]) ) """ def __init__( self, H: torch.tensor, pinv: bool = False, rtol: float = None, Ord: torch.tensor = None, meas_shape: tuple = None, # (height, width) ): super().__init__(H, pinv, rtol, Ord, meas_shape) self._set_P(self.H_static)
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Px`. This is equivalent to computing :math:`x \cdot P^T`. The input images must be unvectorized. The matrix :math:`P` is obtained by splitting the measurement matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. .. warning:: This method uses the splitted measurement matrix :math:`P` to compute the linear measurements from incoming images. If you want to apply the operator :math:`H` directly, use the method :meth:`forward_H`. Args: :math:`x` (torch.tensor): Batch of images of shape :math:`(*, h, w)`. `*` can have any number of dimensions, for instance `(b, c)` where `b` is the batch size and `c` the number of channels. `h` and `w` are the height and width of the images. Output: torch.tensor: The linear measurements of the input images. It has shape :math:`(*, 2M)` where * denotes any number of dimensions and `M` the number of measurements as defined by the parameter :attr:`M`, which is equal to the number of rows in the measurement matrix :math:`H` defined at initialization. Shape: :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` the total number of pixels in the image. Output: :math:`(*, 2M)` where * denotes the batch size and `M` the number of measurements as defined by the parameter :attr:`M`, which is equal to the number of rows in the measurement matrix :math:`H` defined at initialization. Example: >>> H = torch.randn(400, 1600) >>> meas_op = LinearSplit(H) >>> x = torch.randn(10, 40, 40) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 800]) """ # return x @ self.P.T.to(x.dtype) return self._static_forward_with_op(x, self.P)
[docs] def forward_H(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`m = Hx`. This is equivalent to computing :math:`x \cdot H^T`. The input images must be unvectorized. .. warning:: This method uses the measurement matrix :math:`H` to compute the linear measurements from incoming images. If you want to apply the splitted operator :math:`P`, use the method :meth:`forward`. Args: :attr:`x` (torch.tensor): Batch of images of shape :math:`(*, h, w)`. `*` can have any number of dimensions, for instance `(b, c)` where `b` is the batch size and `c` the number of channels. `h` and `w` are the height and width of the images. Output: torch.tensor: The linear measurements of the input images. It has shape :math:`(*, M)` where * denotes any number of dimensions and `M` the number of measurements. Shape: :attr:`x`: :math:`(*, h, w)` where * denotes the batch size and `h` and `w` are the height and width of the images. Output: :math:`(*, M)` where * denotes the batch size and `M` the number of measurements. Example: >>> H = torch.randn(400, 1600) >>> meas_op = LinearSplit(H) >>> x = torch.randn(10, 40, 40) >>> y = meas_op.forward_H(x) >>> print(y.shape) torch.Size([10, 400]) """ # call Linear.forward() method return super().forward(x)
def _set_Ord(self, Ord): super()._set_Ord(Ord) self._set_P(self.H_static)
# =============================================================================
[docs] class HadamSplit(LinearSplit): # ========================================================================= r""" Simulates splitted measurements :math:`y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}x` with :math:`H` a Hadamard matrix. Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a vectorized image or batch of vectorized images. The matrix :math:`P` contains only positive values and is obtained by splitting a Hadamard-based matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. :math:`H` is obtained by selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power of 2. Args: :attr:`M` (int): Number of measurements. It determines the size of the Hadamard matrix subsample :math:`H`. :attr:`h` (int): Measurement pattern height. The width is taken to be equal to the height, so the measurement pattern is square. The Hadamard matrix will have shape :math:`(h^2, h^2)`. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. Attributes: :attr:`H` (torch.tensor): The learnable measurement matrix of shape :math:`(M, N)`. :attr:`H_static` (torch.tensor): alias for :attr:`H`. :attr:`P` (torch.tensor): The splitted measurement matrix of shape :math:`(2M, N)`. :attr:`H_pinv` (torch.tensor, optional): The learnable pseudo inverse measurement matrix :math:`H^\dagger` of shape :math:`(N, M)`. :attr:`M` (int): Number of measurements performed by the linear operator. Is equal to the parameter :attr:`M`. :attr:`N` (int): Number of pixels in the image, is equal to :math:`h^2`. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. Is equal to :math:`h`. :attr:`meas_shape` (tuple): Shape of the measurement patterns (height, width). Is equal to `(self.h, self.h)`. :attr:`indices` (torch.tensor): Indices used to sort the rows of H. It is used by the method :meth:`reindex()`. :attr:`Ord` (torch.tensor): Order matrix used to sort the rows of H. It is used by :func:`~spyrit.core.torch.sort_by_significance()`. .. note:: The computation of a Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the computation of inverse Hadamard transforms. .. note:: :math:`H = H_{+} - H_{-}` Example: >>> h = 32 >>> Ord = torch.randn(h, h) >>> meas_op = HadamSplit(400, h, Ord) >>> print(meas_op) HadamSplit( (M): 400 (N): 1024 (H.shape): torch.Size([400, 1024]) (meas_shape): (32, 32) (H_pinv): True (P.shape): torch.Size([800, 1024]) ) """ def __init__( self, M: int, h: int, Ord: torch.tensor = None, ): F = spytorch.walsh2_matrix(h) # we pass the whole F matrix to the constructor, but override the # calls self.H etc to only return the first M rows super().__init__(F, pinv=False, Ord=Ord, meas_shape=(h, h)) self._M = M @property def H_pinv(self) -> torch.tensor: return self._param_H_static_pinv.data / self.N @H_pinv.setter def H_pinv(self, value: torch.tensor) -> None: self._param_H_static_pinv = nn.Parameter( value.to(torch.float64), requires_grad=False ) @H_pinv.deleter def H_pinv(self) -> None: del self._param_H_static_pinv
[docs] def forward_H(self, x: torch.tensor) -> torch.tensor: r"""Optimized measurement simulation using the Fast Hadamard Transform. The 2D fast Walsh-ordered Walsh-Hadamard transform is applied to the incoming images :math:`x`. This is equivalent to computing :math:`x \cdot H^T`. Args: :math:`x` (torch.tensor): Batch of images of shape :math:`(*,h,w)`. `*` denotes any dimension, for instance `(b,c)` where `b` is the batch size and `c` the number of channels. `h` and `w` are the height and width of the images. Output: torch.tensor: The linear measurements of the input images. It has shape :math:`(*,M)` where * denotes any number of dimensions and `M` the number of measurements. Shape: :math:`x`: :math:`(*,h,w)` where * denotes any dimension, for instance `(b,c)` where `b` is the batch size and `c` the number of channels. `h` and `w` are the height and width of the images. Output: :math:`(*,M)` where * denotes denotes any number of dimensions and `M` the number of measurements. """ m = spytorch.fwht_2d(x) m_flat = self.vectorize(m) return self.reindex(m_flat, "cols", True)[..., : self.M]
[docs] def build_H_pinv(self): """Build the pseudo-inverse (inverse) of the Hadamard matrix H. This computes the pseudo-inverse of the Hadamard matrix H, and stores it in the attribute H_pinv. In the case of an invertible matrix, the pseudo-inverse is the inverse. Args: None. Returns: None. The pseudo-inverse is stored in the attribute H_pinv. """ # the division by self.N is done in the property so as to avoid # memory overconsumption self.H_pinv = self.H.T
[docs] def pinv(self, x, reg="rcond", eta=0.001, diff=False): return super().pinv(x, reg, eta, diff)
[docs] def inverse(self, y: torch.tensor) -> torch.tensor: r"""Inverse transform of Hadamard-domain images. It can be described as :math:`x = H_{had}^{-1}G y`, where :math:`y` is the input Hadamard-domain measurements, :math:`H_{had}^{-1}` is the inverse Hadamard transform, and :math:`G` is the reordering matrix. .. note:: For this inverse to work, the input vector must have the same number of measurements as there are pixels in the original image (:math:`M = N`), i.e. no subsampling is allowed. .. warning:: This method is deprecated and will be removed in a future version. Use self.pinv instead. Args: :math:`y`: batch of images in the Hadamard domain of shape :math:`(*,c,M)`. `*` denotes any size, `c` the number of channels, and `M` the number of measurements (with `M = N`). Output: :math:`x`: batch of images of shape :math:`(*,c,h,w)`. `*` denotes any size, `c` the number of channels, and `h`, `w` the height and width of the image (with `h \times w = N = M`). Shape: :math:`y`: :math:`(*, c, M)` with :math:`*` any size, :math:`c` the number of channels, and :math:`N` the number of measurements (with `M = N`). Output: math:`(*, c, h, w)` with :math:`h` and :math:`w` the height and width of the image. Example: >>> h = 32 >>> Ord = torch.randn(h, h) >>> meas_op = HadamSplit(400, h, Ord) >>> y = torch.randn(10, h**2) >>> x = meas_op.inverse(y) >>> print(x.shape) torch.Size([10, 32, 32]) """ # permutations y = self.reindex(y, "cols", False) y = self.unvectorize(y) # inverse of full transform x = 1 / self.N * spytorch.fwht_2d(y, True) return x
def _pinv_mult(self, y): """We use fast walsh-hadamard transform to compute the pseudo inverse. Args: y (torch.tensor): batch of images in the Hadamard domain of shape (*,M). * denotes any size, and M the number of measurements. Returns: torch.tensor: batch of images in the image domain of shape (*,N). """ # zero-pad the measurements until size N y_shape = y.shape y_new_shape = y_shape[:-1] + (self.N,) y_new = torch.zeros(y_new_shape, device=y.device, dtype=y.dtype) y_new[..., : y_shape[-1]] = y # unsort the measurements y_new = self.reindex(y_new, "cols", False) y_new = self.unvectorize(y_new) # inverse of full transform return 1 / self.N * spytorch.fwht_2d(y_new, True) def _set_Ord(self, Ord: torch.tensor) -> None: """Set the order matrix used to sort the rows of H.""" # get only the indices, as done in spyrit.core.torch.sort_by_significance self._indices = torch.argsort(-Ord.flatten(), stable=True).to(torch.int32) # update the Ord attribute self._param_Ord.data = Ord.to(self.device)
# =============================================================================
[docs] class DynamicLinear(_Base): # ========================================================================= r""" Simulates the measurement of a moving object :math:`y = H \cdot x(t)`. Computes linear measurements :math:`y` from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. The class is constructed from a matrix :math:`H` of shape :math:`(M, N)`, where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements and the number of frames in the animated object. .. warning:: For each call, there must be **exactly** as many images in :math:`x` as there are measurements in the linear operator used to initialize the class. Args: :attr:`H` (torch.tensor): measurement matrix (linear operator) with shape :math:`(M, N)`. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. :attr:`meas_shape` (tuple, optional): Shape of the measurement patterns. Must be a tuple of two integers representing the height and width of the patterns. If not specified, the shape is suppposed to be a square image. If not, an error is raised. Defaults to None. :attr:`img_shape` (tuple, optional): Shape of the image. Must be a tuple of two integers representing the height and width of the image. If not specified, the shape is taken as equal to `meas_shape`. Setting this value is particularly useful when using an :ref:`extended field of view <_MICCAI24>`. Attributes: :attr:`H_static` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)` initialized as :math:`H`. Only real values are supported. :attr:`M` (int): Number of measurements performed by the linear operator. :attr:`N` (int): Number of pixels in the image. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. :attr:`meas_shape` (tuple): Shape of the measurement patterns (height, width). Is equal to `(self.h, self.w)`. :attr:`img_h` (int): Image height. :attr:`img_w` (int): Image width. :attr:`img_shape` (tuple): Shape of the image (height, width). Is equal to `(self.img_h, self.img_w)`. :attr:`H_dyn` (torch.tensor): Dynamic measurement matrix :math:`H`. Must be set using the method :meth:`build_H_dyn` before being accessed. :attr:`H` (torch.tensor): Alias for :attr:`H_dyn`. :attr:`H_dyn_pinv` (torch.tensor): Dynamic pseudo-inverse measurement matrix :math:`H_{dyn}^\dagger`. Must be set using the method :meth:`build_H_dyn_pinv` before being accessed. :attr:`H_pinv` (torch.tensor): Alias for :attr:`H_dyn_pinv`. .. warning:: The attributes :attr:`H` and :attr:`H_pinv` are used as aliases for :attr:`H_dyn` and :attr:`H_dyn_pinv`. If you want to access the static versions of the attributes, be sure to include the suffix `_static`. Example: >>> H_static = torch.rand([400, 1600]) >>> meas_op = DynamicLinear(H_static) >>> print(meas_op) DynamicLinear( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_dyn): False (img_shape): (40, 40) (H_pinv): False ) Reference: .. _MICCAI24: [MaBP24] (MICCAI 2024 paper #883) Thomas Maitre, Elie Bretin, Romain Phan, Nicolas Ducros, Michaël Sdika. Dynamic Single-Pixel Imaging on an Extended Field of View without Warping the Patterns. 2024. hal-04533981 """ # Class variable _measurement_mode = "static" def __init__( self, H: torch.tensor, Ord: torch.tensor = None, meas_shape: tuple = None, # (height, width) img_shape: tuple = None, # (height, width) ): super().__init__(H, Ord, meas_shape) if img_shape is not None: self._img_shape = img_shape if img_shape[0] < self.meas_shape[0] or img_shape[1] < self.meas_shape[1]: raise ValueError( "The image shape must be at least as large as the measurement " + f"shape. Got image shape {img_shape} and measurement shape " + f"{self.meas_shape}." ) # else, it is done in the _Base class __init__ (set to meas_shape) @property def H(self) -> torch.tensor: """Dynamic measurement matrix H. Equal to self.H_dyn.""" return self.H_dyn @property def H_dyn(self) -> torch.tensor: """Dynamic measurement matrix H.""" try: return self._param_H_dyn.data except AttributeError as e: raise AttributeError( "The dynamic measurement matrix H has not been set yet. " + "Please call build_H_dyn() before accessing the attribute " + "H_dyn (or H)." ) from e @H_dyn.setter def H_dyn(self, value: torch.tensor) -> None: self._param_H_dyn = nn.Parameter(value.to(torch.float64), requires_grad=False) try: del H_pinv except UnboundLocalError as e: if "H_pinv" in str(e): pass @property def recon_mode(self) -> str: """Interpolation mode used for reconstruction.""" return self._recon_mode @property def H_pinv(self) -> torch.tensor: """Dynamic pseudo-inverse H_pinv. Equal to self.H_dyn_pinv.""" return self.H_dyn_pinv @H_pinv.deleter def H_pinv(self) -> None: del self.H_dyn_pinv @property def H_dyn_pinv(self) -> torch.tensor: """Dynamic pseudo-inverse H_pinv.""" try: return self._param_H_dyn_pinv.data except AttributeError as e: raise AttributeError( "The dynamic pseudo-inverse H_pinv has not been set yet. " + "Please call build_H_dyn_pinv() before accessing the attribute " + "H_dyn_pinv (or H_pinv)." ) from e @H_dyn_pinv.setter def H_dyn_pinv(self, value: torch.tensor) -> None: self._param_H_dyn_pinv = nn.Parameter( value.to(torch.float64), requires_grad=False ) @H_dyn_pinv.deleter def H_dyn_pinv(self) -> None: try: del self._param_H_dyn_pinv except UnboundLocalError: pass # @mprof.profile
[docs] def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: """Build the dynamic measurement matrix `H_dyn`. Compute and store the dynamic measurement matrix `H_dyn` from the static measurement matrix `H_static` and the deformation field `motion`. The output is stored in the attribute `self.H_dyn`. This is done using the physical version explained in [MaBP24]_. Args: :attr:`motion` (DeformationField): Deformation field representing the motion of the image. :attr:`mode` (str): Interpolation mode. Can only be 'bilinear' for now. Bicubic interpolation will be available in a future release. Defaults to 'bilinear'. Returns: None. The dynamic measurement matrix is stored in the attribute `self.H_dyn`. References: .. _MaBP24: [MaBP24] (MICCAI 2024 paper #883) Thomas Maitre, Elie Bretin, Romain Phan, Nicolas Ducros, Michaël Sdika. Dynamic Single-Pixel Imaging on an Extended Field of View without Warping the Patterns. 2024. hal-04533981 """ if self.device != motion.device: raise RuntimeError( "The device of the motion and the measurement operator must be the same." ) # store the mode in attribute self._recon_mode = mode try: del self._param_H_dyn del self._param_H_dyn_pinv warnings.warn( "The dynamic measurement matrix pseudo-inverse H_pinv has " + "been deleted. Please call self.build_H_dyn_pinv() to " + "recompute it.", UserWarning, ) except AttributeError: pass n_frames = motion.n_frames # get deformation field from motion # scale from [-1;1] x [-1;1] to [0;width-1] x [0;height-1] scale_factor = (torch.tensor(self.img_shape) - 1).to(self.device) def_field = (motion.field + 1) / 2 * scale_factor # drawings of the kernels for bilinear and bicubic 'interpolation' # 00 point 01 # +------+--------+ # | | | # | | | # +------+--------+ point # | | | # +------+--------+ # 10 11 # 00 01 point 02 03 # +-----------+-----+-----+-----------+ # | | | | # | | | | | # | 11 | | 12 | # 10 +-----------+-----+-----+-----------+ 13 # | | | | | # + - - - - - + - - + - - + - - - - - + point # | | | | | # 20 +-----------+-----+-----+-----------+ 23 # | 21 | | | 22 | # | | | | # | | | | | # +-----------+-----+-----+-----------+ # 30 31 32 33 kernel_size = self._spline(torch.tensor([0]), mode).shape[1] kernel_width = kernel_size - 1 kernel_n_pts = kernel_size**2 # PART 1: SEPARATE THE INTEGER AND DECIMAL PARTS OF THE FIELD # _________________________________________________________________ # crop def_field to keep only measured area # moveaxis because crop expects (h,w) as last dimensions def_field = spytorch.center_crop( def_field.moveaxis(-1, 0), self.meas_shape ).moveaxis( 0, -1 ) # shape (n_frames, meas_h, meas_w, 2) # coordinate of top-left closest corner def_field_floor = def_field.floor().to(torch.int64) # shape (n_frames, meas_h, meas_w, 2) # compute decimal part in x y direction dx, dy = torch.split((def_field - def_field_floor), [1, 1], dim=-1) del def_field dx, dy = dx.squeeze(-1), dy.squeeze(-1) # dx.shape = dy.shape = (n_frames, meas_h, meas_w) # evaluate the spline at the decimal part dxy = torch.einsum( "iajk,ibjk->iabjk", self._spline(dy, mode), self._spline(dx, mode) ).reshape(n_frames, kernel_n_pts, self.h * self.w) # shape (n_frames, kernel_n_pts, meas_h*meas_w) del dx, dy # PART 2: FLATTEN THE INDICES # _________________________________________________________________ # we consider an expanded grid (img_h+k)x(img_w+k), where k is # (kernel_width). This allows each part of the (kernel_size^2)- # point grid to contribute to the interpolation. # get coordinate of point _00 def_field_00 = def_field_floor - (kernel_size // 2 - 1) del def_field_floor # shift the grid for phantom rows/columns def_field_00 += kernel_width # create a mask indicating if either of the 2 indices is out of bounds # (w,h) because the def_field is in (x,y) coordinates maxs = torch.tensor( [self.img_w + kernel_width, self.img_h + kernel_width], device=self.device ) mask = torch.logical_or( (def_field_00 < 0).any(dim=-1), (def_field_00 >= maxs).any(dim=-1) ) # shape (n_frames, meas_h, meas_w) # trash index receives all the out-of-bounds indices trash = (maxs[0] * maxs[1]).to(torch.int64).to(self.device) # if the indices are out of bounds, we put the trash index # otherwise we put the flattened index (y*w + x) flattened_indices = torch.where( mask, trash, def_field_00[..., 0] + def_field_00[..., 1] * (self.img_w + kernel_width), ).reshape(n_frames, self.h * self.w) del def_field_00, mask # PART 3: WARP H MATRIX WITH FLATTENED INDICES # _________________________________________________________________ # Build 4 submatrices with 4 weights for bilinear interpolation if isinstance(self, DynamicLinearSplit): meas_pattern = self.P else: meas_pattern = self.H_static meas_dxy = ( meas_pattern.reshape(n_frames, 1, self.h * self.w).to(dxy.dtype) * dxy ) del dxy, meas_pattern # shape (n_frames, kernel_size^2, meas_h*meas_w) # Create a larger H_dyn that will be folded meas_dxy_sorted = torch.zeros( ( n_frames, kernel_n_pts, (self.img_h + kernel_width) * (self.img_w + kernel_width) + 1, # +1 for trash ), dtype=meas_dxy.dtype, device=self.device, ) # add at flattened_indices the values of meas_dxy (~warping) meas_dxy_sorted.scatter_add_( 2, flattened_indices.unsqueeze(1).expand_as(meas_dxy), meas_dxy ) del flattened_indices, meas_dxy # drop last column (trash) meas_dxy_sorted = meas_dxy_sorted[:, :, :-1] self.meas_dxy_sorted = meas_dxy_sorted # PART 4: FOLD THE MATRIX # _________________________________________________________________ # define operator fold = nn.Fold( output_size=(self.img_h, self.img_w), kernel_size=(kernel_size, kernel_size), padding=kernel_width, ) H_dyn = fold(meas_dxy_sorted).reshape(n_frames, self.img_h * self.img_w) # store in _param_H_dyn self._param_H_dyn = nn.Parameter(H_dyn, requires_grad=False).to(self.device)
[docs] def build_H_dyn_pinv(self, reg: str = "rcond", eta: float = 1e-3) -> None: """Computes the pseudo-inverse of the dynamic measurement matrix `H_dyn` and stores it in the attribute `H_dyn_pinv`. This method supposes that the dynamic measurement matrix `H_dyn` has already been set using the method `build_H_dyn()`. An error will be raised if `H_dyn` has not been set yet. Args: :attr:`reg` (str): Regularization method. Can be either 'rcond', 'L2' or 'H1'. Defaults to 'rcond'. :attr:`eta` (float): Regularization parameter. Defaults to 1e-6. Raises: AttributeError: If the dynamic measurement matrix `H_dyn` has not been set yet. """ # later do with regularization parameter try: H_dyn = self.H_dyn.to(torch.float64) except AttributeError as e: raise AttributeError( "The dynamic measurement matrix H has not been set yet. " + "Please call build_H_dyn() before computing the pseudo-inverse." ) from e self.H_dyn_pinv = self._build_pinv(H_dyn, reg, eta)
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture :math:`y = H \cdot x(t)`. The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is a batch of images. Args: :math:`x`: Batch of images of shape :math:`(*, t, c, h, w)`. `*` denotes any dimension (e.g. the batch size), `t` the number of frames, `c` the number of channels, and `h`, `w` the height and width of the images. Output: :math:`y`: Linear measurements of the input images. It has shape :math:`(*, c, M)` where * denotes any number of dimensions, `c` the number of channels, and `M` the number of measurements. .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `t = M`. Shape: :math:`x`: :math:`(*, t, c, h, w)`, where * denotes the batch size, `t` the number of frames, `c` the number of channels, and `h`, `w` the height and width of the images. :math:`output`: :math:`(*, c, M)`, where * denotes the batch size, `c` the number of channels, and `M` the number of measurements. Example: >>> x = torch.rand([10, 400, 3, 40, 40]) >>> H = torch.rand([400, 1600]) >>> meas_op = DynamicLinear(H) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 3, 400]) """ return self._dynamic_forward_with_op(x, self.H_static)
[docs] def forward_H_dyn(self, x: torch.tensor) -> torch.tensor: """Simulates the acquisition of measurements using the dynamic measurement matrix H_dyn. This supposes the dynamic measurement matrix H_dyn has been set using the method build_H_dyn(). An error will be raised if H_dyn has not been set yet. Args: x (torch.tensor): still image of shape (*, h, w). * denotes any dimension. h and w are the height and width of the image. If h and w are larger than the measurement pattern, the image is center-cropped to the measurement pattern size. Returns: torch.tensor: Measurement of the input image. It has shape (*, M). """ x = spytorch.center_crop(x, self.meas_shape) return self._static_forward_with_op(x, self.H_dyn)
def _set_Ord(self, Ord: torch.tensor) -> None: """Set the order matrix used to sort the rows of H.""" super()._set_Ord(Ord) # delete self.H (self._param_H_dyn) try: del self._param_H_dyn warnings.warn( "The dynamic measurement matrix H has been deleted. " + "Please call build_H_dyn() to recompute it." ) except AttributeError: pass # delete self.H_pinv (self._param_H_dyn_pinv) try: del self._param_H_dyn_pinv warnings.warn( "The dynamic pseudo-inverse H_pinv has been deleted. " + "Please call build_H_dyn_pinv() to recompute it." ) except AttributeError: pass def _spline(self, dx, mode): """ Returns a 2D row-like tensor containing the values of dx evaluated at each B-spline (2 values for bilinear, 4 for bicubic). dx must be between 0 and 1. Shapes dx: (n_frames, self.h, self.w) out: (n_frames, {2,4}, self.h, self.w) """ if mode == "bilinear": ans = torch.stack((1 - dx, dx), dim=1) elif mode == "bicubic": ans = torch.stack( ( (1 - dx) ** 3 / 6, 2 / 3 - dx**2 * (2 - dx) / 2, 2 / 3 - (1 - dx) ** 2 * (1 + dx) / 2, dx**3 / 6, ), dim=1, ) elif mode == "schaum": ans = torch.stack( ( dx / 6 * (dx - 1) * (2 - dx), (1 - dx / 2) * (1 - dx**2), (1 + (dx - 1) / 2) * (1 - (dx - 1) ** 2), 1 / 6 * (dx + 1) * dx * (dx - 1), ), dim=1, ) else: raise NotImplementedError( f"The mode {mode} is invalid, please choose bilinear, " + "bicubic or schaum." ) return ans.to(self.device)
# =============================================================================
[docs] class DynamicLinearSplit(DynamicLinear): # ========================================================================= r""" Simulates the measurement of a moving object using a splitted operator :math:`y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix} \cdot x(t)`. Computes linear measurements :math:`y` from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. The matrix :math:`P` contains only positive values and is obtained by splitting a measurement matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. The class is constructed from the :math:`M` by :math:`N` matrix :math:`H`, where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements. Therefore, the shape of :math:`P` is :math:`(2M, N)`. Args: :attr:`H` (torch.tensor): measurement matrix (linear operator) with shape :math:`(M, N)` where :math:`M` is the number of measurements and :math:`N` the number of pixels in the image. Only real values are supported. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. :attr:`meas_shape` (tuple, optional): Shape of the measurement patterns. Must be a tuple of two integers representing the height and width of the patterns. If not specified, the shape is suppposed to be a square image. If not, an error is raised. Defaults to None. :attr:`img_shape` (tuple, optional): Shape of the image. Must be a tuple of two integers representing the height and width of the image. If not specified, the shape is taken as equal to `meas_shape`. Setting this value is particularly useful when using an :ref:`extended field of view <_MICCAI24>`. Attributes: :attr:`H_static` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)` initialized as :math:`H`. :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, N)` such that `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`. :attr:`M` (int): Number of measurements performed by the linear operator. :attr:`N` (int): Number of pixels in the image. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. :attr:`meas_shape` (tuple): Shape of the measurement patterns (height, width). Is equal to `(self.h, self.w)`. :attr:`img_h` (int): Image height. :attr:`img_w` (int): Image width. :attr:`img_shape` (tuple): Shape of the image (height, width). Is equal to `(self.img_h, self.img_w)`. :attr:`H_dyn` (torch.tensor): Dynamic measurement matrix :math:`H`. Must be set using the method :meth:`build_H_dyn` before being accessed. :attr:`H` (torch.tensor): Alias for :attr:`H_dyn`. :attr:`H_dyn_pinv` (torch.tensor): Dynamic pseudo-inverse measurement matrix :math:`H_{dyn}^\dagger`. Must be set using the method :meth:`build_H_dyn_pinv` before being accessed. :attr:`H_pinv` (torch.tensor): Alias for :attr:`H_dyn_pinv`. .. warning:: For each call, there must be **exactly** as many images in :math:`x` as there are measurements in the linear operator :math:`P`. Example: >>> H = torch.rand([400,1600]) >>> meas_op = DynamicLinearSplit(H) >>> print(meas_op) DynamicLinearSplit( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_dyn): False (img_shape): (40, 40) (H_pinv): False (P.shape): torch.Size([800, 1600]) ) Reference: .. _MICCAI24: [MaBP24] (MICCAI 2024 paper #883) Thomas Maitre, Elie Bretin, Romain Phan, Nicolas Ducros, Michaël Sdika. Dynamic Single-Pixel Imaging on an Extended Field of View without Warping the Patterns. 2024. hal-04533981 """ def __init__( self, H: torch.tensor, Ord: torch.tensor = None, meas_shape: tuple = None, # (height, width) img_shape: tuple = None, # (height, width) ): # call constructor of DynamicLinear super().__init__(H, Ord, meas_shape, img_shape) self._set_P(self.H_static) @property # override _Base definition def operator(self) -> torch.tensor: return self.P
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture :math:`y = P \cdot x(t)`. The output :math:`y` is computed as :math:`y = Px`, where :math:`P` is the measurement matrix and :math:`x` is a batch of images. The matrix :math:`P` contains only positive values and is obtained by splitting a measurement matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. If you want to measure with the original matrix :math:`H`, use the method :meth:`forward_H`. Args: :attr:`x`: Batch of images of shape :math:`(*, t, c, h, w)` where * denotes any dimension (e.g. the batch size), :math:`t` the number of frames, :math:`c` the number of channels, and :math:`h`, :math:`w` the height and width of the images. Output: :math:`y`: Linear measurements of the input images. It has shape :math:`(*, c, 2M)` where * denotes any number of dimensions, :math:`c` the number of channels, and :math:`M` the number of measurements. .. important:: There must be as many images as there are measurements in the split linear operator, i.e. :math:`t = 2M`. Shape: :math:`x`: :math:`(*, t, c, h, w)` :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the number of measurements as defined by the first dimension of :math:`H` and :math:`N` is the number of pixels in the image. :math:`output`: :math:`(*, c, 2M)` or :math:`(*, c, t)` Example: >>> x = torch.rand([10, 800, 3, 40, 40]) >>> H = torch.rand([400, 1600]) >>> meas_op = DynamicLinearSplit(H) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 3, 800]) """ return self._dynamic_forward_with_op(x, self.P)
[docs] def forward_H(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture :math:`y = H \cdot x(t)`. The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is a batch of images. The matrix :math:`H` can contain positive and negative values and is given by the user at initialization. If you want to measure with the splitted matrix :math:`P`, use the method :meth:`forward`. Args: :attr:`x`: Batch of images of shape :math:`(*, t, c, h, w)` where * denotes any dimension (e.g. the batch size), :math:`t` the number of frames, :math:`c` the number of channels, and :math:`h`, :math:`w` the height and width of the images. Output: :math:`y`: Linear measurements of the input images. It has shape :math:`(*, c, M)` where * denotes any number of dimensions, :math:`c` the number of channels, and :math:`M` the number of measurements. .. important:: There must be as many images as there are measurements in the original linear operator, i.e. :math:`t = M`. Shape: :math:`x`: :math:`(*, t, c, h, w)` :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the number of measurements and :math:`N` is the number of pixels in the image. :math:`output`: :math:`(*, c, M)` Example: >>> x = torch.rand([10, 400, 3, 40, 40]) >>> H = torch.rand([400, 1600]) >>> meas_op = LinearDynamicSplit(H) >>> y = meas_op.forward_H(x) >>> print(y.shape) torch.Size([10, 3, 400]) """ return super().forward(x)
def _set_Ord(self, Ord: torch.tensor) -> None: """Set the order matrix used to sort the rows of H.""" super()._set_Ord(Ord) # update P self._set_P(self.H_static)
# =============================================================================
[docs] class DynamicHadamSplit(DynamicLinearSplit): # ========================================================================= r""" Simulates the measurement of a moving object using a splitted operator :math:`y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix} \cdot x(t)` with :math:`H` a Hadamard matrix. Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and :math:`x` is a batch of vectorized images representing a motion picture. The matrix :math:`P` contains only positive values and is obtained by splitting a Hadamard-based matrix :math:`H` such that :math:`P` has a shape of :math:`(2M, N)` and `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. :math:`H` is obtained by selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power of 2. Args: :attr:`M` (int): Number of measurements. If :math:`M < h^2`, the measurement matrix :math:`H` is cropped to :math:`M` rows. :attr:`h` (int): Measurement pattern height, must be a power of 2. The image is assumed to be square, so the number of pixels in the image is :math:`N = h^2`. :attr:`Ord` (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix :math:`H`. The first new row of :math:`H` will correspond to the highest value in :math:`Ord`. Must contain :math:`M` values. If some values repeat, the order is kept. Defaults to None. :attr:`img_shape` (tuple, optional): Shape of the image. Must be a tuple of two integers representing the height and width of the image. If not specified, the shape is taken as equal to `meas_shape`. Setting this value is particularly useful when using an :ref:`extended field of view <_MICCAI24>`. Attributes: :attr:`H_static` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)` initialized as :math:`H`. :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, N)` such that `P[0::2, :] = H_{+}` and `P[1::2, :] = H_{-}`. :attr:`M` (int): Number of measurements performed by the linear operator. :attr:`N` (int): Number of pixels in the image. :attr:`h` (int): Measurement pattern height. :attr:`w` (int): Measurement pattern width. :attr:`meas_shape` (tuple): Shape of the measurement patterns (height, width). Is equal to `(self.h, self.w)`. :attr:`img_h` (int): Image height. :attr:`img_w` (int): Image width. :attr:`img_shape` (tuple): Shape of the image (height, width). Is equal to `(self.img_h, self.img_w)`. :attr:`H_dyn` (torch.tensor): Dynamic measurement matrix :math:`H`. Must be set using the method :meth:`build_H_dyn` before being accessed. :attr:`H` (torch.tensor): Alias for :attr:`H_dyn`. :attr:`H_dyn_pinv` (torch.tensor): Dynamic pseudo-inverse measurement matrix :math:`H_{dyn}^\dagger`. Must be set using the method :meth:`build_H_dyn_pinv` before being accessed. :attr:`H_pinv` (torch.tensor): Alias for :attr:`H_dyn_pinv`. .. note:: The computation of a Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the computation of inverse Hadamard transforms. .. note:: :math:`H = H_{+} - H_{-}` Example: >>> Ord = torch.rand([32,32]) >>> meas_op = HadamSplitDynamic(400, 32, Ord) >>> print(meas_op) DynamicHadamSplit( (M): 400 (N): 1024 (H.shape): torch.Size([400, 1024]) (meas_shape): (32, 32) (H_dyn): False (img_shape): (32, 32) (H_pinv): False (P.shape): torch.Size([800, 1024]) ) Reference: .. _MICCAI24: [MaBP24] (MICCAI 2024 paper #883) Thomas Maitre, Elie Bretin, Romain Phan, Nicolas Ducros, Michaël Sdika. Dynamic Single-Pixel Imaging on an Extended Field of View without Warping the Patterns. 2024. hal-04533981 """ def __init__( self, M: int, h: int, Ord: torch.tensor = None, img_shape: tuple = None, # (height, width) ): F = spytorch.walsh2_matrix(h) # empty = torch.empty(h**2, h**2) # just to get the shape # we pass the whole F matrix to the constructor super().__init__(F, Ord, (h, h), img_shape) self._M = M def _set_Ord(self, Ord: torch.tensor) -> None: """Set the order matrix used to sort the rows of H.""" # get only the indices, as done in spyrit.core.torch.sort_by_significance self._indices = torch.argsort(-Ord.flatten(), stable=True).to(torch.int32) # update the Ord attribute self._param_Ord.data = Ord.to(self.device)