r"""
02. Pseudoinverse solution from linear measurements
===================================================
.. _tuto_pseudoinverse_linear:

This tutorial shows how to simulate measurements and perform image reconstruction.
The measurement operator is chosen as a Hadamard matrix with positive coefficients.
Note that this matrix can be replaced by any desired matrix.

.. image:: ../fig/tuto2.png
   :width: 600
   :align: center
   :alt: Reconstruction architecture sketch

These tutorials load image samples from `/images/`.
"""

# %%
# Load a batch of images
# -----------------------------------------------------------------------------

###############################################################################
# Images :math:`x` for training expect values in [-1,1]. The images are normalized
# using the :func:`transform_gray_norm` function.

import os

import torch
import torchvision

import spyrit.core.torch as spytorch
from spyrit.misc.disp import imagesc
from spyrit.misc.statistics import transform_gray_norm

# sphinx_gallery_thumbnail_path = 'fig/tuto2.png'

h = 64  # image size hxh
i = 1  # Image index (modify to change the image)
spyritPath = os.getcwd()
imgs_path = os.path.join(spyritPath, "images/")


# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=h)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)

x, _ = next(iter(dataloader))
print(f"Shape of input images: {x.shape}")

# Select image
x = x[i : i + 1, :, :, :]
x = x.detach().clone()
print(f"Shape of selected image: {x.shape}")
b, c, h, w = x.shape

# plot
imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")

# %%
# Define a measurement operator
# -----------------------------------------------------------------------------
# .. _hadamard_positive:

###############################################################################
# We consider the case where the measurement matrix is the positive
# component of a Hadamard matrix, which is often used in single-pixel imaging.
# First, we compute a full Hadamard matrix that computes the 2D transform of an
# image of size :attr:`h` and takes its positive part.

F = spytorch.walsh2_matrix(h)
F = torch.max(F, torch.zeros_like(F))

###############################################################################
# .. _low_frequency:
#
# Next, we subsample the rows of the measurement matrix to simulate an
# accelerated acquisition. For this, we use the
# :func:`spyrit.core.torch.sort_by_significance` function
# that returns an input matrix whose rows are ordered in increasing order of
# significance according to a given array. The array is a sampling map that
# indicates the location of the most significant coefficients in the
# transformed domain.
#
# To keep the low-frequency Hadamard coefficients, we choose a sampling map
# with ones in the top left corner and zeros elsewhere.

import math

und = 4  # undersampling factor
M = h**2 // und  # number of measurements (undersampling factor = 4)

Sampling_map = torch.zeros(h, h)
M_xy = math.ceil(M**0.5)
Sampling_map[:M_xy, :M_xy] = 1

imagesc(Sampling_map, "low-frequency sampling map")

###############################################################################
# After permutation of the full Hadamard matrix, we keep only its first
# :attr:`M` rows

F = spytorch.sort_by_significance(F, Sampling_map, "rows", False)
H = F[:M, :]

print(f"Shape of the measurement matrix: {H.shape}")

###############################################################################
# Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator

from spyrit.core.meas import Linear

meas_op = Linear(H, pinv=True)

# %%
# Noiseless case
# -----------------------------------------------------------------------------

###############################################################################
# In the noiseless case, we consider the :class:`spyrit.core.noise.NoNoise` noise
# operator

from spyrit.core.noise import NoNoise

noise = NoNoise(meas_op)

# Simulate measurements
y = noise(x)
print(f"Shape of raw measurements: {y.shape}")

###############################################################################
# To display the subsampled measurement vector as an image in the transformed
# domain, we use the :func:`spyrit.core.torch.meas2img` function

# plot
y_plot = spytorch.meas2img(y, Sampling_map)

print(f"Shape of the raw measurement image: {y_plot.shape}")
imagesc(y_plot[0, 0, :, :], "Raw measurements (no noise)")


###############################################################################
# We now compute and plot the preprocessed measurements corresponding to an
# image in [-1,1]. For details in the preprocessing, see :ref:`Tutorial 1 <sphx_glr_gallery_tuto_01_acquisition_operators.py>`.
#
# .. note::
#
#       Using :class:`spyrit.core.prep.DirectPoisson` with :math:`\alpha = 1`
#       allows to compensate for the image normalisation achieved by
#       :class:`spyrit.core.noise.NoNoise`.

from spyrit.core.prep import DirectPoisson

prep = DirectPoisson(1.0, meas_op)  # "Undo" the NoNoise operator

m = prep(y)
print(f"Shape of the preprocessed measurements: {m.shape}")

# plot
m_plot = spytorch.meas2img(m, Sampling_map)

print(f"Shape of the preprocessed measurement image: {m_plot.shape}")
imagesc(m_plot[0, 0, :, :], "Preprocessed measurements (no noise)")

# %%
# Pseudo inverse
# -----------------------------------------------------------------------------

###############################################################################
# There are two ways to perform the pseudo inverse reconstruction from the
# measurements :attr:`y`. The first consists of explicitly computing the
# pseudo inverse of the measurement matrix :attr:`H` and applying it to the
# measurements. The second computes a least-squares solution using :func:`torch.linalg.lstsq`
# to compute the pseudo inverse solution.
# The choice is made automatically: if the measurement operator has a pseudo-inverse
# already computed, it is used; otherwise, the least-squares solution is used.
#
# .. note::
#  Generally, the second method is preferred because it is faster and more
#  numerically stable. However, if you will use the pseudo inverse multiple
#  times, it becomes more efficient to compute it explicitly.
#
# First way: explicit computation of the pseudo inverse
# We can use the :class:`spyrit.core.recon.PseudoInverse` class to perform the
# pseudo inverse reconstruction from the measurements :attr:`y`.

from spyrit.core.recon import PseudoInverse

# Pseudo-inverse reconstruction operator
recon_op = PseudoInverse()

# Reconstruction
x_rec1 = recon_op(m, meas_op)  # equivalent to: meas_op.pinv(y)
print("Shape of the explicit pseudo-inverse reconstructed image:", x_rec1.shape)

###############################################################################
# Second way: calling pinv method from the Linear operator
# The code is very similar to the previous case, but we need to make sure the
# measurement operator has no pseudo-inverse computed. We can also specify
# regularization parameters for the least-squares solution when calling
# `recon_op`. In our case, the pseudo-inverse was computed at initialization
# of the meas_op object.

print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}")
temp = meas_op.H_pinv  # save the pseudo-inverse
del meas_op.H_pinv  # delete the pseudo-inverse
print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}")

# Reconstruction
x_rec2 = recon_op(m, meas_op, reg="rcond", eta=1e-6)
print("Shape of the least-squares reconstructed image:", x_rec2.shape)

# restore the pseudo-inverse
meas_op.H_pinv = temp

##############################################################################
# .. note::
#   This choice is also offered for dynamic measurement operators which are
#   explained in :ref:`Tutorial 9 <sphx_glr_gallery_tuto_09_dynamic.py>`.

# plot side by side
import matplotlib.pyplot as plt
from spyrit.misc.disp import add_colorbar

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

im1 = ax1.imshow(x_rec1[0, 0, :, :], cmap="gray")
ax1.set_title("Explicit pseudo-inverse reconstruction")
add_colorbar(im1, "right", size="20%")

im2 = ax2.imshow(x_rec2[0, 0, :, :], cmap="gray")
ax2.set_title("Least-squares pseudo-inverse reconstruction")
add_colorbar(im2, "right", size="20%")


# %%
# PinvNet Network
# -----------------------------------------------------------------------------

###############################################################################
# Alternatively, we can consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an
# image by computing the pseudoinverse solution, which is fed to a neural
# networker denoiser. To compute the pseudoinverse solution only, the denoiser
# can be set to the identity operator

###############################################################################
# .. image:: ../fig/pinvnet.png
#    :width: 400
#    :align: center
#    :alt: Sketch of the PinvNet architecture

from spyrit.core.recon import PinvNet

pinv_net = PinvNet(noise, prep, denoi=torch.nn.Identity())

###############################################################################
# or equivalently
pinv_net = PinvNet(noise, prep)

###############################################################################
# Then, we reconstruct the image from the measurement vector :attr:`y` using the
# :func:`~spyrit.core.recon.PinvNet.reconstruct` method.

x_rec = pinv_net.reconstruct(y)
print("Shape of the PinvNet reconstructed image:", x_rec.shape)

# plot
imagesc(x_rec[0, 0, :, :], "PinvNet reconstruction (no noise)", title_fontsize=20)

###############################################################################
# Alternatively, the measurement vector can be simulated using the
# :func:`~spyrit.core.recon.PinvNet.acquire` method

y = pinv_net.acquire(x)
x_rec = pinv_net.reconstruct(y)

# plot
imagesc(x_rec[0, 0, :, :], "Another pseudoinverse reconstruction (no noise)")

###############################################################################
# Note that the full module :attr:`pinv_net` both simulates noisy measurements
# and reconstruct them

x_rec = pinv_net(x)
print(f"Ground-truth image x: {x.shape}")
print(f"Reconstructed x_rec: {x_rec.shape}")

# plot
imagesc(x_rec[0, 0, :, :], "One more pseudoinverse reconstruction (no noise)")

# %%
# Poisson-corrupted measurement
# -----------------------------------------------------------------------------

###############################################################################
# Here, we consider the :class:`spyrit.core.noise.Poisson` class
# together with a :class:`spyrit.core.prep.DirectPoisson`
# preprocessing operator (see :ref:`Tutorial 1 <sphx_glr_gallery_tuto_01_acquisition_operators.py>`).

alpha = 10  # maximum number of photons in the image

from spyrit.core.noise import Poisson
from spyrit.misc.disp import imagecomp

noise = Poisson(meas_op, alpha)
prep = DirectPoisson(alpha, meas_op)  # To undo the "Poisson" operator
pinv_net = PinvNet(noise, prep)

x_rec_1 = pinv_net(x)
x_rec_2 = pinv_net(x)
print(f"Ground-truth image x: {x.shape}")
print(f"Reconstructed x_rec: {x_rec.shape}")

# plot
x_plot_1 = x_rec_1[0, 0, :, :]
x_plot_1[:2, :2] = 0.0  # hide the top left "crazy pixel" that collects noise
x_plot_2 = x_rec_2[0, 0, :, :]
x_plot_2[:2, :2] = 0.0  # hide the top left "crazy pixel" that collects noise
imagecomp(x_plot_1, x_plot_2, "Pseudoinverse reconstruction", "Noise #1", "Noise #2")

###############################################################################
# As shown in the next tutorial, a denoising neural network can be trained to
# postprocess the pseudo inverse solution.
