spyrit.misc.walsh_hadamard.walsh_torch
- spyrit.misc.walsh_hadamard.walsh_torch(x, H=None)[source]
Return 1D Walsh-ordered Hadamard transform of a signal
- Args:
x(torch.tensor): Input signals with shape (*, n).H(torch.tensor, optional): 1D Walsh-ordered Hadamard matrix with shape (*, m).- Returns:
torch.tensor: Hadamard transformed signals with shape (*, m).- Note:
Providing the input argument
Hleads to much faster computation when multiple Hadamard transforms are repeated (see Example 2).- Example 1:
Sequency-ordered (i.e., Walsh) Hadamard transform
>>> import torch >>> import spyrit.misc.walsh_hadamard as wh >>> x = torch.tensor([1.0, 3.0, 0.0, -1.0, 7.0, 5.0, 1.0, -2.0]) >>> y = wh.fwht_torch(x) >>> print(y) >>> y = wh.walsh_torch(x) >>> print(y)
- Example 2:
Fast vs regular: Computation times for 5 batches of 512 signals of length 2**10
>>> import timeit >>> import torch >>> import numpy as np >>> import spyrit.misc.walsh_hadamard as wh >>> x = torch.rand(5, 1, 512, 2**10) >>> t = timeit.timeit(lambda: wh.fwht_torch(x), number=200) >>> print(f"Fast Hadamard transform (200x): {t:.4f} seconds") >>> t = timeit.timeit(lambda: torch.from_numpy(walsh_matrix(x.shape[-1]).astype('float32')), number=200) >>> print(f"Construction of Hadamard matrix (200x): {t:.4f} seconds") >>> H = torch.from_numpy(walsh_matrix(x.shape[-1]).astype('float32')) >>> t = timeit.timeit(lambda: wh.walsh_torch(x, H), number=200) >>> print(f"Matrix-vector products (200x): {t:.4f} seconds")
- Example 3:
CPU vs GPU: Computation times for 5 batches of 512 signals of length 2**10
>>> import timeit >>> import torch >>> import spyrit.misc.walsh_hadamard as wh >>> x = torch.rand(5, 1, 512, 2**10) >>> H = torch.tensor(walsh_matrix(x.shape[-1]), dtype=torch.float32) >>> t = timeit.timeit(lambda: wh.walsh_torch(x, H), number=200) >>> print(f"Fast Hadamard transform pytorch CPU (200x): {t:.4f} seconds") >>> x = x.to(torch.device('cuda:0')) >>> H = H.to(torch.device('cuda:0')) >>> t = timeit.timeit(lambda: wh.walsh_torch(x, H), number=200) >>> print(f"Fast Hadamard transform pytorch GPU (200x): {t:.4f} seconds")