Source code for torch_radon.solvers

import torch
from .utils import _normalize_shape, _unnormalize_shape


def normalize(x):
    size = x.size()
    x = x.view(size[0], -1)
    norm = torch.norm(x, dim=1)
    x /= norm.view(-1, 1)
    return x.view(*size), torch.max(norm).item()


[docs]class Landweber: """ Class that implements Landweber iteration to solve :math:`\\min_{x \\in C} \\|Ax-y\\|_2^2` (see `Wikipedia page <https://en.wikipedia.org/wiki/Landweber_iteration>`_). The iteration used is :math:`x_{n+1} = \\mathcal{P}_C(x - \\alpha A^T A x_n)` where :math:`\\mathcal{P}_C` is the projection onto :math:`C`. :param operator: Instance of a class that implements products :math:`A x` (``operator.forward(x)``) and :math:`A^T y` (``operator.backward(y)``). :param projection: Function that implements :math:`\\mathcal{P}_C(\\cdot)`, if not specified no projection is used. :param grad: If true gradient will be enabled, more memory will be used but it will be possible to backpropagate. """ def __init__(self, operator, projection=None, grad=False): self.operator = operator self.projection = lambda x: x if projection is None else projection self.grad = grad
[docs] def estimate_alpha(self, img_size, device, n_iter=50, batch_size=8): """ Use power iteration on :math:`A^T A` to estimate the maximum step size that still guarantees convergence. .. note:: Because this computation is not exact it is advised to use a value of alpha lower that the one estimated by this method (for example multiplying the estimate by 0.95). :param img_size: Size of the image :param device: GPU device that will be used for computation :param n_iter: Number of iterations :param batch_size: Number of vectors used in the power iteration. :return: Estimated value for alpha """ with torch.no_grad(): x = torch.randn((batch_size, img_size, img_size), device=device) x, _ = normalize(x) for i in range(n_iter): next_x = self.operator.backward(self.operator.forward(x)) x, v = normalize(next_x) return 2.0 / v
[docs] def run(self, x_zero, y, alpha, iterations=100, callback=None): """ Execute Landweber iterations. :param x_zero: Initial solution guess used as a starting point for the iteration :param y: Value of y in :math:`\\min_{x \\in C} \\|Ax-y\\|_2^2` :param alpha: Step size, can be estimated using :attr:`estimate_alpha` :param iterations: Number of iterations :param callback: Optional function that will be called at each iteration with :math:`x_n` as argument. Values returned by ``callback`` will be stored in a list and returned together with the computed solution :return: If ``callback`` is specified returns ``x, values`` where ``x`` is the solution computed by the Landweber iteration and ``values`` is the list of values returned by ``callback`` at each iteration. If ``callback`` is not specified returns only ``x`` """ res = [] with torch.set_grad_enabled(self.grad): x = x_zero.clone() for i in range(iterations): x = self.projection(x - alpha * self.operator.backward(self.operator.forward(x) - y)) if callback is not None: res.append(callback(x)) if callback is not None: return x, res else: return x
[docs]def cg(forward, x, y, callback=None, max_iter=500, tol=1e-5): """ Implements Conjugate Gradient algorithm for solving :math:`\\min_x \\|Ax-y\\|_2^2`. .. note:: For conjugate gradient to work the matrix :math:`A` must be symmetric positive definite. Otherwise use other solvers. :param forward: function that implements products :math:`A x` (``forward(x)``). :param x: Initial solution guess used as a starting point for the iteration :param y: Value of y in :math:`\\min_{x \\in C} \\|Ax-y\\|_2^2` :param callback: Optional function that will be called at each iteration with :math:`x_n` and the residual as arguments. Values returned by ``callback`` will be stored in a list and returned together with the computed solution. :param max_iter: Maximum number of iterations. :param tol: Algorithm is stopped when :math:`\\frac{\\| Ax_n - y \\|}{\\| y \\|} \\leq \\text{tol}` :return: If ``callback`` is specified returns ``x, values`` where ``x`` is the solution computed by the Landweber iteration and ``values`` is the list of values returned by ``callback`` at each iteration. If ``callback`` is not specified returns only ``x``. """ # normalize shape to be (batch, height, width) x, old_shape = _normalize_shape(x, 2) y, _ = _normalize_shape(y, 2) y_norm = torch.norm(y, dim=(1, 2)) r = y - forward(x) p = r.clone() r_n = torch.sum(r ** 2, dim=(1, 2)) values = [] for i in range(max_iter): Ap = forward(p) alpha = (r_n / torch.sum(p * Ap, dim=(1, 2)).clamp_min(1e-8)).view(-1, 1, 1) x += alpha * p r_next = r - alpha * Ap if callback is not None: values.append(callback(x, r_next)) r_next_n = torch.sum(r_next ** 2, dim=(1, 2)) rs = torch.min(torch.sqrt(r_next_n) / y_norm).item() if rs <= tol: break beta = (r_next_n / r_n).view(-1, 1, 1) r_n = r_next_n p = r_next + beta * p r = r_next.clone() x = _unnormalize_shape(x, old_shape) if callback is None: return x else: return x, values
[docs]def cgne(operator, x, y, callback=None, max_iter=5000, tol=1e-5): """ Implements Conjugate Gradient on the Normal Equations, an algorithm for solving :math:`\\min_x \\|Ax-y\\|_2^2`. :param operator: Instance of a class that implements products :math:`A x` (``operator.forward(x)``) and :math:`A^T y` (``operator.backward(y)``). :param x: Initial solution guess used as a starting point for the iteration :param y: Value of y in :math:`\\min_{x \\in C} \\|Ax-y\\|_2^2` :param callback: Optional function that will be called at each iteration with :math:`x_n` as argument. Values returned by ``callback`` will be stored in a list and returned together with the computed solution :param max_iter: Maximum number of iterations :param tol: Algorithm is stopped when :math:`\\frac{\\| s \\|}{\\| y \\|} \\leq \\text{tol}` :return: If ``callback`` is specified returns ``x, values`` where ``x`` is the solution computed by the Landweber iteration and ``values`` is the list of values returned by ``callback`` at each iteration. If ``callback`` is not specified returns only ``x`` """ # normalize shape to be (batch, height, width) x, old_shape = _normalize_shape(x, 2) y, _ = _normalize_shape(y, 2) y_norm = torch.norm(y, dim=(1, 2)) d = y - operator.forward(x) s = operator.backprojection(d) p = s.clone() s_norm = torch.sum(s ** 2, dim=(1, 2)) values = [] for i in range(max_iter): q = operator.forward(p) alpha_den = torch.sum(q ** 2, dim=(1, 2)).clamp_min(1e-8) alpha = (s_norm / alpha_den).view(-1, 1, 1) x = x + alpha * p d = d - alpha * q s_next = operator.backprojection(d) s_next_norm = torch.sum(s_next ** 2, dim=(1, 2)) if callback is not None: values.append(callback(x)) # if s_next_norm is small break sn = torch.max(torch.sqrt(s_next_norm) / y_norm).item() if sn <= tol: break beta = (s_next_norm / s_norm).view(-1, 1, 1) s = s_next s_norm = s_next_norm p = s + beta * p x = _unnormalize_shape(x, old_shape) if callback is None: return x else: return x, values