classdef TV_min % Class contains functions needed for constrained L1 minimization using FISTA. % Check the original implementation at: http://www.math.tau.ac.il/~teboulle/papers/tlv.pdf % % The code uses X, Y and Z to store the current guess of the solution % at different stages of the calculation. For instance Y is the image % after L2 minimization step, Z is the image after L1 minimization and % X is the resulting solution. To know the details - check out the code!!! % % See the code for description of all parameters of this class. % % Example: % % a = phantom(256); % b = radon(a); % % f = TV_min(); % % % Define linear system A(X) = B with the initial guess X: % f.X = zeros(size(a)); % f.B = b; % f.A = @(x) radon(x); % f.A_trans = @(x) iradon(x, [], 'linear', 'none', 1, 256); % % % Parameters of iterative reconstruction: % f.L = 2*max(max(iradon(radon(ones(size(a))), [], 'linear', 'none'))); % f.lambda = 0.01; % f.bounds = [0, inf]; % f.X_true = a; % % [x, history] = f.FIST_IT_cpu(); % % figure(1); imshow(x, []); % figure(2); plot(history.err); properties (Access = public) % Function handle for A(x) A = @(x) radon(x); % Function handle for A_transposed(x) A_trans = @(b) iradon(b, [], 'linear', 'none'); % lipschitz constant L = 2*max( A_trans * A ) L = 2; % TV strength lambda lambda = 0; % Solution guess X X = []; % Observed data B B = []; % Number of iterations n_iter = 100; % Function handle for filter (can zero some pixels each iteration for instance) filter = []; % Lower and upper bounds for X bounds = [-inf, inf]; % Fancy runtime display display_it = 1; % Glorious history history = []; % True solution, if known, to calculate the error X_true = []; % Calculator (GPU/CPU) calc = CPU_mat; end properties (Access = protected) % Some variables needed during the reconstruction: tau = 1; Y = []; Z = []; TV_res_x = []; TV_res_y = []; err_norm = []; init_norm = []; end methods (Access = public) function this = TV_min() % create an empty history: this.history.L2 = []; this.history.TV = []; this.history.err = []; this.history.upd = []; end function [X, history] = run_cpu(this, X, n_iter) % Calculates X: min |A(X) - B|^2 + lambda * % Some initialization: if nargin == 3 this.n_iter = n_iter; end if nargin > 2 this.X = X; end % Initial norm of X, used to normalize smth: this.init_norm = this.calc.norm(this.X); % Initial error norm, used to normalize smth: if ~isempty(this.X_true) this.err_norm = this.calc.norm(this.X - this.X_true); end % Y is X with reducind L2 norm: this.Y = this.X; % tau is a number that determines step size in some complicated fashion: this.tau = 1; % Residual of TV norm, used for the next iteration of TV min. this.TV_res_x = []; this.TV_res_y = []; for i = 1:this.n_iter % Run one iteration: this = this.l2_iteration(); % Update the display: if this.display_it figure(101); imshow(this.X, []); title('FISTA'); xlabel(sprintf('%i%%', round(i / this.n_iter * 100))); refresh(101); end end X = this.X; history = this.history; end % end methods (Access = protected) function this = l2_iteration(this) % A single iteration of FISTA minimization: % Subproblem (1), minimization of L2 = |A(X) - B|^2: this.Y = this.Y - 2/this.L * this.A_trans( this.A(this.Y) - this.B ); % Subproblem (2), minimization of L1 = 2 * lambda * |TV(X)|: if this.lambda > 0 % l1_iteration mainly calculates the image with lower TV and stores it in this.Z: this = this.l1_iteration(); else % if l1_iteration is not necessary, then Z stores Y (the result of l2 minimization) this.Z = this.Y; end if ~isempty(this.filter) % Can apply any extra filter to the solution after each iteration: this.Z = this.filter(this.Z); end % Compute the total variation and L2-norm: this.history.L2(end+1) = norm(this.A(this.Z) - this.B,'fro')^2; this.history.TV(end+1) = 2 * this.lambda * this.TV(this.Z); % Compute the error: if ~isempty(this.X_true) this.history.err(end+1) = this.calc.norm(this.X - this.X_true) / this.err_norm; else this.history.err(end+1) = 0; end % X before TV minimization: X_ = this.X; % X after TV minimization this.X = this.Z; % Here the algorithm can be modified into 'monotonic' version. % Which was not tested yet. Something TODO in future... this.history.upd(end+1) = this.calc.norm(this.X - X_) / this.init_norm; % updating tau (needed for optimal step size) tau_ = this.tau; this.tau = (1 + sqrt(1+4*tau_^2))/2; % updating Y (needed for the next guess of X): this.Y = this.X + tau_ / this.tau * (this.Z - this.X) + (tau_ - 1) / this.tau * (this.X - X_); end function this = l1_iteration(this) % Calculates image with lower TV from the image this.Y. Stores the % results in the image this.Z. It uses this.TV_res_x and this.TV_res_y % from the last time it was called. % Initialize TV residual: if isempty(this.TV_res_x) this.TV_res_x = zeros(size(this.Y)); this.TV_res_y = zeros(size(this.Y)); end % Modified TV residual: final_TV_res_x = this.TV_res_x; final_TV_res_y = this.TV_res_y; % these are some internal variables for this function: tau = 1; stop_count = 0; i = 0; la = 2 * this.lambda / this.L; % End result: this.Z = this.Y .* 0; while ((i < 10) && (stop_count < 5)) i = i + 1; % old Z: Z_ = this.Z; % new Xout: this.Z = this.Y - la * (this.calc.dx_(final_TV_res_x) + this.calc.dy_(final_TV_res_y)); this.Z = this.calc.projection(this.Z, this.bounds(1), this.bounds(2)); % Taking a step towards minus of the gradient TV_res_x_ = this.TV_res_x; TV_res_y_ = this.TV_res_y; this.TV_res_x = final_TV_res_x - 1/(8*la) * this.calc.dx(this.Z); this.TV_res_y = final_TV_res_y - 1/(8*la) * this.calc.dy(this.Z); % this part can be changed to anisotropic, now it's L1 type: [this.TV_res_x, this.TV_res_y] = this.calc.normalize(this.TV_res_x, this.TV_res_y, 1); %Updating residual and tau: tau_ = tau; tau = (1 + sqrt(1 + 4*tau_^2)) / 2; final_TV_res_x = ... this.TV_res_x + (tau_ - 1) / (tau) * (this.TV_res_x - TV_res_x_); final_TV_res_y = ... this.TV_res_y + (tau_ - 1) / (tau) * (this.TV_res_y - TV_res_y_); % stop criterion: re = norm(this.Z - Z_, 'fro') / norm(this.Z, 'fro'); if (re < 1e-3) stop_count = stop_count + 1; else stop_count = 0; end end end end end