Source code for torch_radon

import numpy as np
import torch
import scipy.stats
import abc
import torch.nn.functional as F
import warnings

try:
    import torch_radon_cuda
    from torch_radon_cuda import RaysCfg
except Exception as e:
    print("Importing exception")

from .differentiable_functions import RadonForward, RadonBackprojection
from .utils import normalize_shape
from .filtering import FourierFilters

__version__ = "1.0.0"


class BaseRadon(abc.ABC):
    def __init__(self, angles, rays_cfg):
        self.rays_cfg = rays_cfg

        if not isinstance(angles, torch.Tensor):
            angles = torch.FloatTensor(angles)

        # change sign to conform to Astra and Scikit
        self.angles = -angles

        # caches used to avoid reallocation of resources
        self.tex_cache = torch_radon_cuda.TextureCache(8)
        self.fourier_filters = FourierFilters()

        seed = np.random.get_state()[1][0]
        self.noise_generator = torch_radon_cuda.RadonNoiseGenerator(seed)

    def _move_parameters_to_device(self, device):
        if device != self.angles.device:
            self.angles = self.angles.to(device)

    def _check_input(self, x, square=False):
        if not x.is_contiguous():
            x = x.contiguous()

        if square:
            assert x.size(1) == x.size(2), f"Input images must be square, got shape ({x.size(1)}, {x.size(2)})."

        if x.dtype == torch.float16:
            assert x.size(
                0) % 4 == 0, f"Batch size must be multiple of 4 when using half precision. Got batch size {x.size(0)}"

        return x

    @normalize_shape(2)
    def forward(self, x):
        r"""Radon forward projection.

        :param x: PyTorch GPU tensor with shape :math:`(d_1, \dots, d_n, r, r)` where :math:`r` is the :attr:`resolution`
            given to the constructor of this class.
        :returns: PyTorch GPU tensor containing sinograms. Has shape :math:`(d_1, \dots, d_n, len(angles), det\_count)`.
        """
        x = self._check_input(x, square=True)
        self._move_parameters_to_device(x.device)

        return RadonForward.apply(x, self.angles, self.tex_cache, self.rays_cfg)

    @normalize_shape(2)
    def backprojection(self, sinogram):
        r"""Radon backward projection.

        :param sinogram: PyTorch GPU tensor containing sinograms with shape  :math:`(d_1, \dots, d_n, len(angles), det\_count)`.
        :returns: PyTorch GPU tensor with shape :math:`(d_1, \dots, d_n, r, r)` where :math:`r` is the :attr:`resolution`
            given to the constructor of this class.
        """
        sinogram = self._check_input(sinogram)
        self._move_parameters_to_device(sinogram.device)

        return RadonBackprojection.apply(sinogram, self.angles, self.tex_cache, self.rays_cfg)

    @normalize_shape(2)
    def filter_sinogram(self, sinogram, filter_name="ramp"):
        # if not self.clip_to_circle:
        #     warnings.warn("Filtered Backprojection with clip_to_circle=True will not produce optimal results."
        #                   "To avoid this specify clip_to_circle=False inside Radon constructor.")

        # Pad sinogram to improve accuracy
        size = sinogram.size(2)
        n_angles = sinogram.size(1)

        padded_size = max(64, int(2 ** np.ceil(np.log2(2 * size))))
        pad = padded_size - size

        padded_sinogram = F.pad(sinogram.float(), (0, pad, 0, 0))
        # TODO should be possible to use onesided=True saving memory and time
        sino_fft = torch.rfft(padded_sinogram, 1, normalized=True, onesided=False)

        # get filter and apply
        f = self.fourier_filters.get(padded_size, filter_name, sinogram.device)
        filtered_sino_fft = sino_fft * f

        # Inverse fft
        filtered_sinogram = torch.irfft(filtered_sino_fft, 1, normalized=True, onesided=False)

        # pad removal and rescaling
        filtered_sinogram = filtered_sinogram[:, :, :-pad] * (np.pi / (2 * n_angles))

        return filtered_sinogram.to(dtype=sinogram.dtype)

    def backward(self, sinogram):
        r"""Same as backprojection"""
        return self.backprojection(sinogram)

    @normalize_shape(2)
    def add_noise(self, x, signal, density_normalization=1.0, approximate=False):
        # print("WARN Radon.add_noise is deprecated")

        torch_radon_cuda.add_noise(x, self.noise_generator, signal, density_normalization, approximate)
        return x

    @normalize_shape(2)
    def emulate_readings(self, x, signal, density_normalization=1.0):
        return torch_radon_cuda.emulate_sensor_readings(x, self.noise_generator, signal, density_normalization)

    @normalize_shape(2)
    def emulate_readings_new(self, x, signal, normal_std, k, bins):
        return torch_radon_cuda.emulate_readings_new(x, self.noise_generator, signal, normal_std, k, bins)

    @normalize_shape(2)
    def readings_lookup(self, sensor_readings, lookup_table):
        return torch_radon_cuda.readings_lookup(sensor_readings, lookup_table)

    def set_seed(self, seed=-1):
        if seed < 0:
            seed = np.random.get_state()[1][0]

        self.noise_generator.set_seed(seed)

    def __del__(self):
        self.noise_generator.free()


[docs]class Radon(BaseRadon): r""" | .. image:: https://raw.githubusercontent.com/matteo-ronchetti/torch-radon/ master/pictures/parallelbeam.svg?sanitize=true :align: center :width: 400px | Class that implements Radon projection for the Parallel Beam geometry. :param resolution: The resolution of the input images. :param angles: Array containing the list of measuring angles. Can be a Numpy array or a PyTorch tensor. :param det_count: Number of rays that will be projected. By default it is = :attr:`resolution` :param det_spacing: Distance between two contiguous rays. :param clip_to_circle: If True both forward and backward projection will be restricted to pixels inside the circle (highlighted in cyan). .. note:: Currently only support resolutions which are multiples of 16. """ def __init__(self, resolution: int, angles, det_count=-1, det_spacing=1.0, clip_to_circle=False): if det_count <= 0: det_count = resolution rays_cfg = RaysCfg(resolution, resolution, det_count, det_spacing, len(angles), clip_to_circle) super().__init__(angles, rays_cfg) self.det_count = det_count self.det_spacing = det_spacing
[docs]class RadonFanbeam(BaseRadon): r""" | .. image:: https://raw.githubusercontent.com/matteo-ronchetti/torch-radon/ master/pictures/fanbeam.svg?sanitize=true :align: center :width: 400px | Class that implements Radon projection for the Fanbeam geometry. :param resolution: The resolution of the input images. :param angles: Array containing the list of measuring angles. Can be a Numpy array or a PyTorch tensor. :param source_distance: Distance between the source of rays and the center of the image. :param det_distance: Distance between the detector plane and the center of the image. By default it is = :attr:`source_distance`. :param det_count: Number of rays that will be projected. By default it is = :attr:`resolution`. :param det_spacing: Distance between two contiguous rays. :param clip_to_circle: If True both forward and backward projection will be restricted to pixels inside the circle (highlighted in cyan). .. note:: Currently only support resolutions which are multiples of 16. """ def __init__(self, resolution: int, angles, source_distance: float, det_distance: float = -1, det_count: int = -1, det_spacing: float = -1, clip_to_circle=False): if det_count <= 0: det_count = resolution if det_distance < 0: det_distance = source_distance det_spacing = 2.0 if det_spacing < 0: det_spacing = (source_distance + det_distance) / source_distance rays_cfg = RaysCfg(resolution, resolution, det_count, det_spacing, len(angles), clip_to_circle, source_distance, det_distance) super().__init__(angles, rays_cfg) self.source_distance = source_distance self.det_distance = det_distance self.det_count = det_count self.det_spacing = det_spacing
def compute_lookup_table(sinogram, signal, normal_std, bins=4096, eps=0.01, eps_prob=0.99, eps_k=0.01, verbose=False): s = sinogram.view(-1) device = s.device eps = np.quantile(sinogram.cpu().numpy(), eps) + eps_k # Compute readings normalization value if verbose: print("Computing readings normalization value") k = 0 for i in range(1, 5000): a, b = torch_radon_cuda.compute_ab(s, signal, eps, bins * i) if verbose: print(a, b) if a >= (a + b) * eps_prob: k = bins * i break print("Readings normalization value = ", k // bins) # Compute weights for Gaussian error scale = k // bins weights = [] for i in range(0, 64): t = scipy.stats.norm.cdf((scale - i - 0.5) / normal_std) - scipy.stats.norm.cdf((- i - 0.5) / normal_std) if t < 0.005: break weights.append(t) weights = weights[scale:][::-1] + weights weights = np.array(weights) border_w = np.asarray([scipy.stats.norm.cdf((-x - 0.5) / normal_std) for x in range(scale)]) border_w = torch.FloatTensor(border_w).to(device) log_factorial = np.arange(k + len(weights)) log_factorial[0] = 1 log_factorial = np.cumsum(np.log(log_factorial).astype(np.float64)).astype(np.float32) log_factorial = torch.Tensor(log_factorial).to(device) weights = torch.FloatTensor(weights).to(device) lookup, lookup_var = torch_radon_cuda.compute_lookup_table(s, weights, signal, bins, scale, log_factorial, border_w) return lookup, lookup_var, scale class ReadingsLookup: def __init__(self, radon, bins=4096, mu=None, sigma=None, ks=None, signals=None, normal_stds=None): self.radon = radon self.bins = bins self.mu = [] if mu is None else mu self.sigma = [] if sigma is None else sigma self.ks = [] if ks is None else ks self.signals = [] if signals is None else signals self.normal_stds = [] if normal_stds is None else normal_stds self._mu = None self._sigma = None self._ks = None self._signals = None self._normal_stds = None self._need_repacking = True def repack(self, device): self._mu = torch.FloatTensor(self.mu).to(device) self._sigma = torch.FloatTensor(self.sigma).to(device) self._ks = torch.IntTensor(self.ks).to(device) self._signals = torch.FloatTensor(self.signals).to(device) self._normal_stds = torch.FloatTensor(self.normal_stds).to(device) @staticmethod def from_file(path, radon): obj = np.load(path) bins = int(obj["bins"]) return ReadingsLookup(radon, bins, list(obj["mu"]), list(obj["sigma"]), list(obj["ks"]), list(obj["signals"]), list(obj["normal_stds"])) def save(self, path): self.repack("cpu") np.savez(path, mu=self._mu, sigma=self._sigma, ks=self._ks, signals=self._signals, normal_stds=self._normal_stds, bins=self.bins) def add_lookup_table(self, sinogram, signal, normal_std, eps=0.01, eps_prob=0.99, eps_k=0.01, verbose=True): lookup, lookup_var, k = compute_lookup_table(sinogram, signal, normal_std, self.bins, eps, eps_prob, eps_k, verbose) self.mu.append(lookup.cpu().numpy()) self.sigma.append(lookup_var.cpu().numpy()) self.ks.append(k) self.signals.append(signal) self.normal_stds.append(normal_std) self._need_repacking = True @normalize_shape(2) def emulate_readings(self, sinogram, level): if self._need_repacking or self._mu.device != sinogram.device: self.repack(sinogram.device) if isinstance(level, torch.Tensor): return torch_radon_cuda.emulate_readings_multilevel(sinogram, self.radon.noise_generator, self._signals, self._normal_stds, self._ks, level, self.bins) else: return torch_radon_cuda.emulate_readings_new(sinogram, self.radon.noise_generator, self.signals[level], self.normal_stds[level], self.ks[level], self.bins) @normalize_shape(2) def lookup(self, readings, level): if self._need_repacking or self._mu.device != readings.device: self.repack(readings.device) if isinstance(level, torch.Tensor): mu = torch_radon_cuda.readings_lookup_multilevel(readings, self._mu, level) sigma = torch_radon_cuda.readings_lookup_multilevel(readings, self._sigma, level) else: mu = torch_radon_cuda.readings_lookup(readings, self._mu[level]) sigma = torch_radon_cuda.readings_lookup(readings, self._sigma[level]) return mu, sigma def random_levels(self, size, device): return torch.randint(0, len(self.mu), (size,), device=device, dtype=torch.int32)