Source code for torch_radon.shearlet

from alpha_transform import AlphaShearletTransform
from alpha_transform.fourier_util import my_ifft_shift

from .utils import normalize_shape
import numpy as np
import torch
import os


[docs]class ShearletTransform: """ Implementation of Alpha-Shearlet transform based on https://github.com/dedale-fet/alpha-transform/tree/master/alpha_transform. Once the shearlet spectrograms are computed all the computations are done on the GPU. :param width: Width of the images :param height: Height of the images :param alphas: List of alpha coefficients that will be used to generate shearlets :param cache: If specified it should be a path to a directory that will be used to cache shearlet coefficients in order to avoid recomputing them at each instantiation of this class. .. note:: Support both float and double precision. """ def __init__(self, width, height, alphas, cache=None): cache_name = f"{width}_{height}_{alphas}.npy" if cache is not None: if not os.path.exists(cache): os.makedirs(cache) cache_file = os.path.join(cache, cache_name) if os.path.exists(cache_file): shifted_spectrograms = np.load(cache_file) else: alpha_shearlet = AlphaShearletTransform(width, height, alphas, real=True, parseval=True) shifted_spectrograms = np.asarray([my_ifft_shift(spec) for spec in alpha_shearlet.spectrograms]) np.save(cache_file, shifted_spectrograms) else: alpha_shearlet = AlphaShearletTransform(width, height, alphas, real=True, parseval=True) scales = [0] + [x[0] for x in alpha_shearlet.indices[1:]] self.scales = np.asarray(scales) shifted_spectrograms = np.asarray([my_ifft_shift(spec) for spec in alpha_shearlet.spectrograms]) self.scales = torch.FloatTensor(self.scales) self.shifted_spectrograms = torch.FloatTensor(shifted_spectrograms) self.shifted_spectrograms_d = torch.DoubleTensor(shifted_spectrograms) def _move_parameters_to_device(self, device): if device != self.shifted_spectrograms.device: self.shifted_spectrograms = self.shifted_spectrograms.to(device) self.shifted_spectrograms_d = self.shifted_spectrograms_d.to(device) @normalize_shape(2) def forward(self, x): """ Do shearlet transform of a batch of images. :param x: PyTorch GPU tensor with shape :math:`(d_1, \\dots, d_n, h, w)`. :returns: PyTorch GPU tensor containing shearlet coefficients. Has shape :math:`(d_1, \\dots, d_n, \\text{n_shearlets}, h, w)`. """ self._move_parameters_to_device(x.device) c = torch.rfft(x, 2, normalized=True, onesided=False) if x.dtype == torch.float64: cs = torch.einsum("fij,bijc->bfijc", self.shifted_spectrograms_d, c) else: cs = torch.einsum("fij,bijc->bfijc", self.shifted_spectrograms, c) return torch.irfft(cs, 2, normalized=True, onesided=False) @normalize_shape(3) def backward(self, cs): """ Do inverse shearlet transform. :param cs: PyTorch GPU tensor containing shearlet coefficients, with shape :math:`(d_1, \\dots, d_n, \\text{n_shearlets}, h, w)`. :returns: PyTorch GPU tensor containing reconstructed images. Has shape :math:`(d_1, \\dots, d_n, h, w)`. """ cs_fft = torch.rfft(cs, 2, normalized=True, onesided=False) if cs.dtype == torch.float64: res = torch.einsum("fij,bfijc->bijc", self.shifted_spectrograms_d, cs_fft) else: res = torch.einsum("fij,bfijc->bijc", self.shifted_spectrograms, cs_fft) return torch.irfft(res, 2, normalized=True, onesided=False)