classdef TV_min
    % Class contains functions needed for constrained L1 minimization using FISTA.
    % Check the original implementation at: http://www.math.tau.ac.il/~teboulle/papers/tlv.pdf
    %
    % The code uses X, Y and Z to store the current guess of the solution
    % at different stages of the calculation. For instance Y is the image
    % after L2 minimization step, Z is the image after L1 minimization and
    % X is the resulting solution. To know the details - check out the code!!!
    %
    % See the code for description of all parameters of this class.
    %
    % Example:
    %
    % a = phantom(256);
    % b = radon(a);
    %
    % f = TV_min();
    %
    % % Define linear system A(X) = B with the initial guess X:
    % f.X = zeros(size(a));
    % f.B = b;
    % f.A = @(x) radon(x);
    % f.A_trans = @(x) iradon(x, [], 'linear', 'none', 1, 256);
    %
    % % Parameters of iterative reconstruction:
    % f.L = 2*max(max(iradon(radon(ones(size(a))), [], 'linear', 'none')));
    % f.lambda = 0.01;
    % f.bounds = [0, inf];
    % f.X_true = a;
    %
    % [x, history] = f.FIST_IT_cpu();
    %
    % figure(1); imshow(x, []);
    % figure(2); plot(history.err);

properties (Access = public)

    % Function handle for A(x)
    A = @(x) radon(x);

    % Function handle for A_transposed(x)
    A_trans = @(b) iradon(b, [], 'linear', 'none');

    % lipschitz constant L = 2*max( A_trans * A )
    L = 2;

    % TV strength lambda
    lambda = 0;

    % Solution guess X
    X = [];

    % Observed data B
    B = [];

    % Number of iterations
    n_iter = 100;

    % Function handle for filter (can zero some pixels each iteration for instance)
    filter = [];

    % Lower and upper bounds for X
    bounds = [-inf, inf];

    % Fancy runtime display
    display_it = 1;

    % Glorious history
    history = [];

    % True solution, if known, to calculate the error
    X_true = [];

    % Calculator (GPU/CPU)
    calc = CPU_mat;
end

properties (Access = protected)

    % Some variables needed during the reconstruction:
    tau = 1;

    Y = [];
    Z = [];
    TV_res_x = [];
    TV_res_y = [];

    err_norm = [];
    init_norm = [];
end

methods (Access = public)

    function this = TV_min()

        % create an empty history:
        this.history.L2 = [];
        this.history.TV = [];
        this.history.err = [];
        this.history.upd = [];
    end

    function [X, history] = run_cpu(this, X, n_iter)
        % Calculates X: min |A(X) - B|^2 + lambda *
        % Some initialization:

        if nargin == 3
            this.n_iter = n_iter;
        end

        if nargin > 2
            this.X = X;
        end

        % Initial norm of X, used to normalize smth:
        this.init_norm = this.calc.norm(this.X);

        % Initial error norm, used to normalize smth:
        if ~isempty(this.X_true)
            this.err_norm = this.calc.norm(this.X - this.X_true);
        end

        % Y is X with reducind L2 norm:
        this.Y = this.X;

        % tau is a number that determines step size in some complicated fashion:
        this.tau = 1;

        % Residual of TV norm, used for the next iteration of TV min.
        this.TV_res_x = [];
        this.TV_res_y = [];

        for i = 1:this.n_iter

            % Run one iteration:
            this = this.l2_iteration();

            % Update the display:
            if this.display_it

                figure(101);
                imshow(this.X, []); title('FISTA');
                xlabel(sprintf('%i%%', round(i / this.n_iter * 100)));
                refresh(101);
            end

        end

        X = this.X;
        history = this.history;
    end

    %
end

methods (Access = protected)

    function this = l2_iteration(this)
        % A single iteration of FISTA minimization:

        % Subproblem (1), minimization of L2 = |A(X) - B|^2:
        this.Y = this.Y - 2/this.L * this.A_trans( this.A(this.Y) - this.B );

        % Subproblem (2), minimization of L1 = 2 * lambda * |TV(X)|:
        if this.lambda > 0

            % l1_iteration mainly calculates the image with lower TV and stores it in this.Z:
            this = this.l1_iteration();
        else

            % if l1_iteration is not necessary, then Z stores Y (the result of l2 minimization)
            this.Z = this.Y;
        end

        if ~isempty(this.filter)

            % Can apply any extra filter to the solution after each iteration:
            this.Z = this.filter(this.Z);
        end

        % Compute the total variation and L2-norm:
        this.history.L2(end+1) = norm(this.A(this.Z) - this.B,'fro')^2;
        this.history.TV(end+1) = 2 * this.lambda * this.TV(this.Z);

        % Compute the error:
        if ~isempty(this.X_true)
            this.history.err(end+1) = this.calc.norm(this.X - this.X_true) / this.err_norm;
        else
            this.history.err(end+1) = 0;
        end

        % X before TV minimization:
        X_ = this.X;

        % X after TV minimization
        this.X = this.Z;

        % Here the algorithm can be modified into 'monotonic' version.
        % Which was not tested yet. Something TODO in future...

        this.history.upd(end+1) = this.calc.norm(this.X - X_) / this.init_norm;

        % updating tau (needed for optimal step size)
        tau_ = this.tau;
        this.tau = (1 + sqrt(1+4*tau_^2))/2;

        % updating Y (needed for the next guess of X):
        this.Y = this.X + tau_ / this.tau * (this.Z - this.X) + (tau_ - 1) / this.tau * (this.X - X_);
    end

    function this = l1_iteration(this)
        % Calculates image with lower TV from the image this.Y. Stores the
        % results in the image this.Z. It uses this.TV_res_x and this.TV_res_y
        % from the last time it was called.

        % Initialize TV residual:
        if isempty(this.TV_res_x)
            this.TV_res_x = zeros(size(this.Y));
            this.TV_res_y = zeros(size(this.Y));
        end

        % Modified TV residual:
        final_TV_res_x = this.TV_res_x;
        final_TV_res_y = this.TV_res_y;

        % these are some internal variables for this function:
        tau = 1;
        stop_count = 0;
        i = 0;
        la = 2 * this.lambda / this.L;

        % End result:
        this.Z = this.Y .* 0;

        while ((i < 10) && (stop_count < 5))
            i = i + 1;

            % old Z:
            Z_ = this.Z;

            % new Xout:
            this.Z = this.Y - la * (this.calc.dx_(final_TV_res_x) + this.calc.dy_(final_TV_res_y));
            this.Z = this.calc.projection(this.Z, this.bounds(1), this.bounds(2));

            % Taking a step towards minus of the gradient
            TV_res_x_ = this.TV_res_x;
            TV_res_y_ = this.TV_res_y;

            this.TV_res_x = final_TV_res_x - 1/(8*la) * this.calc.dx(this.Z);
            this.TV_res_y = final_TV_res_y - 1/(8*la) * this.calc.dy(this.Z);

            % this part can be changed to anisotropic, now it's L1 type:
            [this.TV_res_x, this.TV_res_y] = this.calc.normalize(this.TV_res_x, this.TV_res_y, 1);

            %Updating residual and tau:
            tau_ = tau;
            tau = (1 + sqrt(1 + 4*tau_^2)) / 2;

            final_TV_res_x = ...
                this.TV_res_x + (tau_ - 1) / (tau) * (this.TV_res_x - TV_res_x_);

            final_TV_res_y = ...
                this.TV_res_y + (tau_ - 1) / (tau) * (this.TV_res_y - TV_res_y_);

            % stop criterion:
            re = norm(this.Z - Z_, 'fro') / norm(this.Z, 'fro');

            if (re < 1e-3)
                stop_count = stop_count + 1;
            else
                stop_count = 0;
            end
        end
    end
end

end