function [dt_macro_av,dt_micro_av,err,num_macro_steps,num_micro_steps] = multirate_matlab6(TOL,PLOTTING,MR,CONSTANT_STEPS,QUAD_PLOT,EQN,dt_initial,delta1)

% January 2014, version 4: This version has the interpolation bug fixed! 
% May 2014, version 5: This version has constant stepsize mode included.
%
% March 2015, version 6: Cleaned up code. Made compatible with
% the code "generate_results4paper.m" 


global h N
global a
global gamma epsilon
global aa cc dd x

% TOL : macrointegrator has LTE < TOL at each macrostep
%
% PLOTTING :"1" means plot snapshots with pauses; "0" swiches this off
%
% MR: set MR=1 for multirate mode, MR=0 for single rate
%
% CONSTANT_STEPS: "1" means dt_{macro} and dt_{micro} are CONSTANT
% throughout the course of integration and m (number of macro steps
% to micro steps) is also fixed. "0" means the code chooses dt_{macro}
% and dt_{micro}. dt_initial determines dt_macro.
%
% QUAD_PLOT: % set to 1 to create a four panel summary of the evolution
             % and set t1, t2, t3, t4 below for the times to plot.
             % set to 0 otherwise.
% EQN (=1,2 or 3) determines equation to solve - either
% transport, reaction-diffusion or advection-diffusion


% other notes:
% ============
% delta1 : flag components with error > max(errors)*delta1. 
%          Smaller delta1 --> flag more aggressively, larger macro steps
% microintegrator takes microsteps to ensure sum of cumulative errors < TOL


N = 401;
a = 1; epsilon = 0.01; gamma = 100;
if EQN == 1 % Solve transport equation u_t + a*u_x = 0
    L = 20;
    x = linspace(-L,L,N);
    h = x(2)-x(1);
    y0 = exp(-x.^2);
    T = 7;
    t1 = 0.01; t2 = 1.5; t3 = 3.0; t4 = 6.0;
elseif EQN == 2 % solve fisher equation u_t = epsilon u_xx + gamma*u^2*(1-u)
    lambda = 0.5*sqrt(2*gamma/epsilon);
    L = 3.5;
    x = linspace(0,L,N);
    h = x(2)-x(1);
    y0 = 1./( 1+exp(lambda*(x-1)) );
    T = 0.5;
    t1 = 0.01; t2 = 0.2; t3 = 0.3; t4 = 0.4;
elseif EQN == 3 
    % solve advection-diffusion with forcing [4th order]
    % u_t + au_x = d u_xx - cu + g(x,t)
    % -L < x < L
    aa = 5; dd = 0.01; cc = 100; L = 1;
    x = linspace(-L,L,N);
    h = x(2)-x(1);
    y0 = zeros(1,N);
    T = 0.8;
    t1 = 0.01; t2 = 0.06; t3 = 0.3; t4 = 0.7;
end
y = y0;
dt = dt_initial;
% y(round(N/2)) = 1;

t = 0;
dt_max = 1.0;

if MR == 1
    P = 10;
else
    P = 0;
end
safety = 0.95;

flags = zeros(1,N);
err_vec = zeros(1,4000);
t_vec = zeros(1,4000);
num_micro_vec = zeros(1,4000);
num_micro_steps=0;
trynum=0;
numsteps=0;
% r=round(N/2)+1;s=round(N/2)-1;
num_macro_steps_counter = 1;

while t<T

    [ynew,error,K] = rk_onestep(t,y,dt,[1 N],EQN);
        %%%%% start of multirate %%%
    if MR == 1
        flags = error > delta1*max(error);
        for i=1:length(flags)
            if flags(i) == 1
                r = i;
                break;
            end
        end
        for j=length(flags):-1:1
            if flags(j) == 1
                s = j;
                break;
            end
        end

        if CONSTANT_STEPS == 0
            r = max(r-P,3);
            s = min(s+P,N-2);
        elseif CONSTANT_STEPS == 1
            if EQN == 1
                r = 186; s = 216; % fixed [r,s]; transport equation when N=401
            elseif EQN == 2
                r = 93; s = 139; % fixed [r,s]; reaction-diffusion equation when N=401
            elseif EQN == 3
                r = 201; s = 241; % fixed [r,s]; advection-diffusion equation when N=401
            end
        end

        [ynew,num_micro_steps] = micro_integrate(t,t+dt,y,ynew,[r s],K,TOL,CONSTANT_STEPS,EQN);
    %%% end of multirate %%%
        sprintf('t = %f',t)
    end
    
    if MR == 1
        e_star = max( max(error(1:r-1)), max(error(s+1:end)) );
    else
        e_star = max(error); r = 1; s = N;
    end
    

    
    if CONSTANT_STEPS == 1
        t = t+dt; y = ynew; trynum = 0; numsteps = numsteps+1;
        num_micro_vec(num_macro_steps_counter) = num_micro_steps;
        %num_micro_vec = [num_micro_vec num_micro_steps];
        if numsteps == round(T/dt)
            t = T; % force exit of while loop after T/dt steps
        end
    elseif CONSTANT_STEPS == 0    
        % R = ( desired_error/max_error )^(1/5);
        R = (e_star/TOL)^(1/5);
        if R>1 % step failed
            sprintf('step failed');
            dt = dt*safety*max(0.2,1/R); 
            trynum = trynum+1;
            if trynum>10
                sprintf('10 failed attempts!')
                return
            end
        elseif R<1 % step succeeded
            sprintf('step succeeded');    
            t = t+dt; y = ynew; trynum = 0; numsteps = numsteps+1;
            num_micro_vec(num_macro_steps_counter) = num_micro_steps;
            %num_micro_vec = [num_micro_vec num_micro_steps];
        end
    end 
                
    if PLOTTING == 1 && t>0
        figure(2); subplot(2,1,1);
        if EQN == 1
            y_exact = advection_exact(t,N,a,L,x,h,y0');
        elseif EQN == 2
            y_exact = fisher_exact(x,t,epsilon,gamma,N);
        elseif EQN == 3
            y_exact = parabolic_exact(t,L,aa,cc,dd,N);
        end
        plot(x,y,'b-',x,y_exact,'r-');
        legend('numerical','exact');
        err = norm(y'-y_exact,inf);
        % err_vec = [err_vec err];
        err_vec(num_macro_steps_counter) = err;
        % 
        %t_vec = [t_vec t];
        t_vec(num_macro_steps_counter) = t;
        tit = sprintf('t = %f, r = %d, s = %d, error = %2.4e, numsteps=%d',t,r,s,err,numsteps); 
        title(tit);
        subplot(2,1,2);
        spy(flags)
        sprintf('[r s] = [%d %d], s-r = %d, dt = %f   error = %2.4e   num_micro_steps = %d',r,s,s-r,dt,err,num_micro_steps)
        pause(0.1);
    end
    
    if t>t1 && QUAD_PLOT == 1
        if EQN == 1
            y_exact = advection_exact(t,N,a,L,x,h,y0');
        elseif EQN == 2
            y_exact = fisher_exact(x,t,epsilon,gamma,N);
        elseif EQN == 3
            y_exact = parabolic_exact(t,L,aa,cc,dd,N);
        end
        err = norm(y'-y_exact,inf);
        %
        subplot(4,2,1)
        plot(x,ynew,'LineWidth',2);
        xlabel('$x$','Interpreter','Latex');
        ylabel('$y(x,t)$','Interpreter','Latex');
        tit = sprintf('t = %.2f, r = %d, s = %d, error = %.2e, numsteps=%d',t,r,s,err,numsteps); 
        title(tit);
        % axis([-L/2 L/2 0 1]);
        subplot(4,2,3)
        spy(flags,20)
        title('Flagged components');
        QUAD_PLOT = 2;
    elseif t>t2 && QUAD_PLOT == 2
        if EQN == 1
            y_exact = advection_exact(t,N,a,L,x,h,y0');
        elseif EQN == 2
            y_exact = fisher_exact(x,t,epsilon,gamma,N);
        elseif EQN == 3
            y_exact = parabolic_exact(t,L,aa,cc,dd,N);
        end
        err = norm(y'-y_exact,inf);
        %
        subplot(4,2,2)
        plot(x,ynew,'LineWidth',2);
        xlabel('$x$','Interpreter','Latex');
        ylabel('$y(x,t)$','Interpreter','Latex');
        tit = sprintf('t = %.2f, r = %d, s = %d, error = %.2e, numsteps=%d',t,r,s,err,numsteps); 
        title(tit);
        % axis([-L/2 L/2 0 1]);
        subplot(4,2,4)
        spy(flags,20)
        title('Flagged components');
        QUAD_PLOT=3;
    elseif t>t3 && QUAD_PLOT == 3
        if EQN == 1
            y_exact = advection_exact(t,N,a,L,x,h,y0');
        elseif EQN == 2
            y_exact = fisher_exact(x,t,epsilon,gamma,N);
        elseif EQN == 3
            y_exact = parabolic_exact(t,L,aa,cc,dd,N);
        end
        err = norm(y'-y_exact,inf);
        %
        subplot(4,2,5)
        plot(x,ynew,'LineWidth',2);
        xlabel('$x$','Interpreter','Latex');
        ylabel('$y(x,t)$','Interpreter','Latex');
        tit = sprintf('t = %.2f, r = %d, s = %d, error = %.2e, numsteps=%d',t,r,s,err,numsteps); 
        title(tit);
        % axis([-L/2 L/2 0 1]);
        subplot(4,2,7)
        spy(flags,20)
        title('Flagged components');
        QUAD_PLOT=4;
    elseif t>t4 && QUAD_PLOT == 4
        if EQN == 1
            y_exact = advection_exact(t,N,a,L,x,h,y0');
        elseif EQN == 2
            y_exact = fisher_exact(x,t,epsilon,gamma,N);
        elseif EQN == 3
            y_exact = parabolic_exact(t,L,aa,cc,dd,N);
        end
        err = norm(y'-y_exact,inf);
        %
        subplot(4,2,6)
        plot(x,ynew,'LineWidth',2);
        xlabel('$x$','Interpreter','Latex');
        ylabel('$y(x,t)$','Interpreter','Latex');
        tit = sprintf('t = %.2f, r = %d, s = %d, error = %.2e, numsteps=%d',t,r,s,err,numsteps); 
        title(tit);
        % axis([-L/2 L/2 0 1]);
        subplot(4,2,8)
        spy(flags,20)
        title('Flagged components')
    end
    
    if CONSTANT_STEPS == 0            
         dt = dt*safety*min(5,1/R);
         if t+dt>T
             dt = (T-t);
         end
         if dt>dt_max
             dt = dt_max;
         end        
    end
   num_macro_steps_counter = num_macro_steps_counter+1;
end



% figure(2); plot(t_vec,err_vec,'kd-')
% xlabel('t');
% ylabel('error');

% figure(3); loglog(num_micro_vec,err_vec,'kd'); 
% xlabel('number of microsteps');
% ylabel('error');

if PLOTTING == 1
    figure(1);
    plot(x,y,'bd',x,y_exact,'r-')
    legend('numerical','exact');
    % axis([6 15 -0.2 1.2]);
end

if EQN == 1
    y_exact = advection_exact(T,N,a,L,x,h,y0');
elseif EQN == 2
    y_exact = fisher_exact(x,T,epsilon,gamma,N);
elseif EQN == 3
    y_exact = parabolic_exact(T,L,aa,cc,dd,N);
end
if CONSTANT_STEPS == 1
    err = norm(y(r:s)'-y_exact(r:s),Inf);
else
    err = norm(y'-y_exact,Inf);
end

num_macro_steps = numsteps;
num_micro_steps = sum(num_micro_vec);
dt_macro_av = T/num_macro_steps;
dt_micro_av = T/num_micro_steps;
if PLOTTING == 1
    tit = sprintf('error = %2.4e, num macro steps=%d, num micro steps=%d',err,num_macro_steps,num_micro_steps); title(tit);
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [y2,numsteps] = micro_integrate(t1,t2,y1,y2,cpt,K,desired_error,CONSTANT_STEPS,EQN)

% y(r) ... y(s) require integration
% y(r-2), y(r-1), y(s+1), y(s+2) are bcs.
% y2 requires updating

% desired error is the accumulated desired error of microstep errors over
% the whole macrostep

r = cpt(1); s = cpt(2);
m0 = 10;
dt = (t2-t1)/m0;
safety = 0.9;

y_current = y1;
t = t1;
trynum = 0;
numsteps = 0;

% figure; hold on;
% plot(t1,y1(r-2),'kd',t2,y2(r-2),'kd');
% plot(t1,y1(r-1),'rd',t2,y2(r-1),'rd');
% plot(t1,y1(s+1),'bd',t2,y2(s+1),'bd');
% plot(t1,y1(s+2),'md',t2,y2(s+2),'md');
% tt = linspace(t1,t2,100);
% plot(tt,Interpolate(tt,y1(r-2),y2(r-2),t1,t2,K(r-2,:)),'k-')
% plot(tt,Interpolate(tt,y1(r-1),y2(r-1),t1,t2,K(r-1,:)),'r-')
% plot(tt,Interpolate(tt,y1(s+1),y2(s+1),t1,t2,K(s+1,:)),'b-')
% plot(tt,Interpolate(tt,y1(s+2),y2(s+2),t1,t2,K(s+2,:)),'m-')

while t<t2
    [y_new,error,~] = rk_onestep_w_interp(t,y_current,dt,[r,s],y1,y2,t1,t2,K,EQN);
    % R = max(error(r:s))/desired_error/0.01;
    
    if CONSTANT_STEPS == 1
        t = t+dt; numsteps = numsteps+1; y_current = y_new;
        if numsteps == m0
            t = t2; % exit while loop after m0 steps
        end
    elseif CONSTANT_STEPS == 0   
        R = max(error(r:s))/(dt*desired_error);
        if R>1 % step failed
            dt = safety*dt*max(0.1,R^-0.25); trynum = trynum+1;
            if trynum>10
                sprintf('10 failed attempts!')
                return
            end
        elseif R<1 % step succeeded
            t = t+dt; numsteps = numsteps+1; y_current = y_new;
            trynum = 0;
            dt = safety*dt*min(5,R^-0.2);          
            if t+dt>t2
                dt = (t2-t);
            end 
        end
    end
end
y2(r:s) = y_current(r:s);


function [ynew,error,K] = rk_onestep(t,y,dt,cpt,EQN)
% take a single rk45 with components cpt and step of size dt and output the error


a2 = 1/5; a3 = 3/10; a4 = 3/5; a5 = 1; a6 = 7/8;
b21 = 1/5;
b31 = 3/40; b32 = 9/40;
b41 = 3/10; b42 = -9/10; b43 = 6/5;
b51 = -11/54; b52 = 5/2; b53 = -70/27; b54 = 35/27;
b61 = 1631/55296; b62 = 175/512; b63 = 575/13824; b64 = 44275/110592; b65 = 253/4096;

c1 = 37/378;
c2 = 0;
c3 = 250/621;
c4 = 125/594;
c5 = 0;
c6 = 512/1771;

c1s = 2825/27648;
c2s = 0;
c3s = 18575/48384;
c4s = 13525/55296;
c5s = 277/14336;
c6s = 1/4;


k1 = dt*f(t,y,cpt,EQN);
k2 = dt*f(t+a2*dt,y+b21*k1,cpt,EQN);
k3 = dt*f(t+a3*dt,y+b31*k1+b32*k2,cpt,EQN);
k4 = dt*f(t+a4*dt,y+b41*k1+b42*k2+b43*k3,cpt,EQN);
k5 = dt*f(t+a5*dt,y+b51*k1+b52*k2+b53*k3+b54*k4,cpt,EQN);
k6 = dt*f(t+a6*dt,y+b61*k1+b62*k2+b63*k3+b64*k4+b65*k5,cpt,EQN);
K = [k1' k2' k3' k4' k5' k6'];

ynew_p = y + c1*k1 + c2*k2 + c3*k3 + c4*k4 + c5*k5 + c6*k6; % 5th order
ynew = y + c1s*k1 + c2s*k2 + c3s*k3 + c4s*k4 + c5s*k5 + c6s*k6; % 4th order

error = abs(ynew - ynew_p);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [ynew,error,K] = rk_onestep_w_interp(t,y,dt,cpt,y1,y2,t1,t2,K,EQN)
% takes a single rk45 with components cpt(3:end-2), constructs 
% interpolants with components cpt(1), cpt(2), cpt(end-1), cpt(end),
% uses steps of size dt and outputs the error


a2 = 1/5; a3 = 3/10; a4 = 3/5; a5 = 1; a6 = 7/8;
b21 = 1/5;
b31 = 3/40; b32 = 9/40;
b41 = 3/10; b42 = -9/10; b43 = 6/5;
b51 = -11/54; b52 = 5/2; b53 = -70/27; b54 = 35/27;
b61 = 1631/55296; b62 = 175/512; b63 = 575/13824; b64 = 44275/110592; b65 = 253/4096;

c1 = 37/378;
c2 = 0;
c3 = 250/621;
c4 = 125/594;
c5 = 0;
c6 = 512/1771;

c1s = 2825/27648;
c2s = 0;
c3s = 18575/48384;
c4s = 13525/55296;
c5s = 277/14336;
c6s = 1/4;

% f_interp(t,y,components,y1,y2,t1,t2,K)
k1 = dt*f_interp(t,y,cpt,y1,y2,t1,t2,K,EQN);
k2 = dt*f_interp(t+a2*dt,y+b21*k1,cpt,y1,y2,t1,t2,K,EQN);
k3 = dt*f_interp(t+a3*dt,y+b31*k1+b32*k2,cpt,y1,y2,t1,t2,K,EQN);
k4 = dt*f_interp(t+a4*dt,y+b41*k1+b42*k2+b43*k3,cpt,y1,y2,t1,t2,K,EQN);
k5 = dt*f_interp(t+a5*dt,y+b51*k1+b52*k2+b53*k3+b54*k4,cpt,y1,y2,t1,t2,K,EQN);
k6 = dt*f_interp(t+a6*dt,y+b61*k1+b62*k2+b63*k3+b64*k4+b65*k5,cpt,y1,y2,t1,t2,K,EQN);
K = [k1' k2' k3' k4' k5' k6'];

ynew_p = y + c1*k1 + c2*k2 + c3*k3 + c4*k4 + c5*k5 + c6*k6; % 5th order
ynew = y + c1s*k1 + c2s*k2 + c3s*k3 + c4s*k4 + c5s*k5 + c6s*k6; % 4th order

error = abs(ynew - ynew_p);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function yinterp = Interpolate(t,y1,y2,t1,t2,k)


a = (t-t1)/(t2-t1);

yinterp = y1 + a*k(1) + 0.5*a.^2*(-8/3 * k(1) + 25/6 * k(4) - 3/2 * k(5)) + ...
            a.^3/6*(10/3*k(1) - 25/3 * k(4) + 5*k(5));



function ydot = f(t,y,components,EQN)
% evaluates RHS of ODE
% components = [r s] where 1 <= r < s <= N
% N = total number of ODEs

global h N 
global a 
global epsilon gamma 
global aa cc dd x

r = components(1); s = components(2);

ydot = zeros(1,N);

if EQN == 1
    ydot(r+1:s) = -(a/h)*(y(r+1:s) - y(r:s-1));
    % ydot(r+1:s-1) = -(a/h)*(y(r+1:s-1) - y(r:s-2)) - (0.5/h)*(y(r+2:s)-2*y(r+1:s-1)+y(r:s-2));
elseif EQN == 2
    ydot(r+1:s-1) = (epsilon/h^2)*( y(r+2:s)-2*y(r+1:s-1)+y(r:s-2) )+ gamma*y(r+1:s-1).^2.*(1-y(r+1:s-1));
elseif EQN == 3
    
    ydot(r+2:s-2) = dd/12/h^2 * ( -y(r:s-4) + 16*y(r+1:s-3) - 30*y(r+2:s-2) + 16*y(r+3:s-1) - y(r+4:s) ) + ...
    -aa/12/h * ( y(r:s-4) - 8*y(r+1:s-3) + 8*y(r+3:s-1) - y(r+4:s) ) + ...
    - cc*y(r+2:s-2) + g(x(r+2:s-2),t)';
    
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function ydot = f_interp(t,y,components,y1,y2,t1,t2,K,EQN)
% evaluates RHS of ODE
% components = [r s] where 1 <= r < s <= N and interpolants
% are constructed for components **outside** of [r,s]
% N = total number of ODEs

global h N 
global a 
global epsilon gamma 
global aa cc dd x


r = components(1); s = components(2);
ydot = zeros(1,N);

yr_minus_2 = Interpolate(t,y1(r-2),y2(r-2),t1,t2,K(r-2,:));
yr_minus_1 = Interpolate(t,y1(r-1),y2(r-1),t1,t2,K(r-1,:));
ys_plus_1 = Interpolate(t,y1(s+1),y2(s+1),t1,t2,K(s+1,:));
ys_plus_2 = Interpolate(t,y1(s+2),y2(s+2),t1,t2,K(s+2,:));

if EQN == 1
    ydot(r) = -(a/h)*(y(r) - yr_minus_1); % interpolate y(r-1)
    ydot(r+1:s) = -(a/h)*(y(r+1:s) - y(r:s-1));
    % ydot(r+1:s-1) = -(a/h)*(y(r+1:s-1) - y(r:s-2)) - (0.5/h)*(y(r+2:s)-2*y(r+1:s-1)+y(r:s-2));
elseif EQN == 2     
    ydot(r) = (epsilon/h^2)*( y(r+1) - 2*y(r)+ yr_minus_1 )+ gamma*y(r).^2.*(1-y(r));
    ydot(r+1:s-1) = (epsilon/h^2)*( y(r+2:s)-2*y(r+1:s-1)+y(r:s-2) )+ gamma*y(r+1:s-1).^2.*(1-y(r+1:s-1));
    ydot(s) = (epsilon/h^2)*( ys_plus_1-2*y(s)+y(s-1) )+ gamma*y(s).^2.*(1-y(s));
elseif EQN == 3
    
    
    ydot(r) = dd/12/h^2 * ( -yr_minus_2 + 16*yr_minus_1 - 30*y(r) + 16*y(r+1) - y(r+2) ) + ...
        -aa/12/h * ( yr_minus_2 - 8*yr_minus_1 + 8*y(r+1) - y(r+2) ) + ...
        - cc*y(r) + g(x(r),t)';
    
    ydot(r+1) = dd/12/h^2 * ( -yr_minus_1 + 16*y(r) - 30*y(r+1) + 16*y(r+2) - y(r+3) ) + ...
        -aa/12/h * ( yr_minus_1 - 8*y(r) + 8*y(r+2) - y(r+3) ) + ...
        - cc*y(r+1) + g(x(r+1),t)';
    
    ydot(r+2:s-2) = dd/12/h^2 * ( -y(r:s-4) + 16*y(r+1:s-3) - 30*y(r+2:s-2) + 16*y(r+3:s-1) - y(r+4:s) ) + ...
        -aa/12/h * ( y(r:s-4) - 8*y(r+1:s-3) + 8*y(r+3:s-1) - y(r+4:s) ) + ...
        - cc*y(r+2:s-2) + g(x(r+2:s-2),t)';
    
    ydot(s-1) = dd/12/h^2 * ( -y(s-3) + 16*y(s-2) - 30*y(s-1) + 16*y(s) - ys_plus_1 ) + ...
        -aa/12/h * ( y(s-3) - 8*y(s-2) + 8*y(s) - ys_plus_1 ) + ...
        - cc*y(s-1) + g(x(s-1),t)';
    
    ydot(s) = dd/12/h^2 * ( -y(s-2) + 16*y(s-1) - 30*y(s) + 16*ys_plus_1 - ys_plus_2 ) + ...
        -aa/12/h * ( y(s-2) - 8*y(s-1) + 8*ys_plus_1 - ys_plus_2 ) + ...
        - cc*y(s) + g(x(s),t)';
    
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
