#include "../binary_c.h"

//#define eprint(...) fprintf(stdout,__VA_ARGS__);fflush(stdout);
#define eprint(...) /* do nothing */

/*
 * Wrapper to evolve the system in time, starting at time t, 
 * and do:
 * 
 * 1) stellar mass and angular momentum updates
 * 2) stellar evolution
 * 3) binary evolution
 *
 * leaving the stars at time t+dt.
 *
 * Note: events such as supernovae are handled after we converge the 
 *       evolution.
 */


#define Modulate_solver_timestep(F) {                   \
        stardata->model.dtm = dtm_in * (F);             \
        stardata->model.dt = dt_in * (F);               \
    }
#define Restore_solver_timestep {               \
        stardata->model.dtm = dtm_in;           \
        stardata->model.dt = dt_in;             \
    }

#define Take_evolution_step(__UPDATE_TIME)                              \
    ({                                                                  \
        int __retval;                                                   \
        if(stardata->evolving == TRUE &&                                \
           stardata->model.reject == REJECT_NONE)                       \
        {                                                               \
            __retval = evolution_step(stardata,                         \
                                      system_type,                      \
                                      __UPDATE_TIME);                   \
            eprint("retval from step %d\n",__retval);                   \
        }                                                               \
        else                                                            \
        {                                                               \
            __retval = retval;                                          \
            eprint("rejected : retval from prev %d\n",__retval);        \
        }                                                               \
        (__retval);                                                     \
    })

static int evolution_solver_loop(struct stardata_t * RESTRICT const stardata,
                                 Evolution_system_type system_type);
static int evolution_step(struct stardata_t * RESTRICT const stardata,
                          Evolution_system_type system_type,
                          Boolean update_time);

/************************************************************/

int evolution(struct stardata_t * RESTRICT const stardata,
              Evolution_system_type system_type)
{
    return evolution_solver_loop(stardata,
                                 system_type);
}

/************************************************************/

static int evolution_solver_loop(struct stardata_t * RESTRICT const stardata,
                                 Evolution_system_type system_type)
{
    if(stardata->evolving == FALSE ||
       check_for_time_exhaustion(stardata,
                                 stardata->model.intpol) == TRUE)
    {
        eprint("evolution finished at start : return -STOP\n");
        return -STOP;
    }

    
    /*
     * Apply stellar mass and angular momentum changes,
     * e.g. wind loss, wind accretion, RLOF
     */

    /* save timesteps for later restore */
    double dtm_in = stardata->model.dtm;
    double dt_in = stardata->model.dt;
    stardata->model.true_dtm = dtm_in;

    /*
     * We're at the top of the evolution loop,
     * so cannot (yet) have rejected the timestep.
     */
    stardata->model.reject = REJECT_NONE;
    
    eprint("in evolution() model %d at t = %30.12e, dt = %g, dtm = %g : solver_step %d (intermediate? %d)\n",
           stardata->model.model_number,
           stardata->model.time,
           stardata->model.dt,
           stardata->model.dtm,
           stardata->model.solver_step,
           stardata->model.intermediate_step
        );

    int retval = 0;
    
    if(stardata->preferences->solver == SOLVER_FORWARD_EULER)
    {
        /*
         * Forward Euler has only one step 
         */
        stardata->model.solver_step = 0;
        stardata->model.intermediate_step = FALSE;
        retval = Take_evolution_step(UPDATE_TIME);
    }
    else if(stardata->preferences->solver == SOLVER_RK2)
    {
        /*
         * First step from t to t + dt/2
         */
        eprint("\n\nRK2 step 1 : dtm = %g\n\n",stardata->model.dtm);
        Boolean deny_was = stardata->model.deny_new_events;
        Set_event_denial(TRUE);
        Modulate_solver_timestep(0.5);
        stardata->model.solver_step = 0;
        stardata->model.intermediate_step = TRUE;
        retval = Take_evolution_step(UPDATE_TIME);

        /*
         * RK2 second step and apply derivatives at time t + dt/2
         * to obtain stars at t+dt
         */
        eprint("\n\nRK2 step 2 (reject %u)\n\n",stardata->model.reject);
        stardata->model.solver_step = 1;
        stardata->model.intermediate_step = FALSE;
        Set_event_denial(deny_was);
        retval = Take_evolution_step(UPDATE_TIME);

        eprint("return %d\n",retval);
                    
        Restore_solver_timestep;
    }
    else if(stardata->preferences->solver == SOLVER_RK4)
    {
        /*
         * The way the equations are written, we always use half the 
         * timestep, not the full timestep
         */
        Modulate_solver_timestep(0.5);
        Boolean deny_was = stardata->model.deny_new_events;
        Set_event_denial(TRUE);
        eprint("*** RK4 start ***\n");
        /*
         * Calculate k1' : 
         *    structure x = x(t)
         *    derivatives f at t 
         *    apply for dt/2
         *    x -> x + dt/2 * f = x + k1/2 = x + k1' = x1'
         *    t -> t + dt/2
         */
        stardata->model.solver_step = 0;
        stardata->model.intermediate_step = TRUE;
        eprint("RK4(0) call evolution\n");
        retval = Take_evolution_step(UPDATE_TIME);
        eprint("RK4(0) retval %d\n",retval);
                
        /*
         * Calculate k2:
         *    structure x = x(t) + k1' (set above)
         *    derivatives at t + dt/2
         *    apply for time 0
         *    x1' = x + k1' -> x + k2/2 = x + k2' = x2'
         *    t + dt/2 -> t + dt/2
         */
        double thalf = stardata->model.time;
        stardata->model.solver_step = 1;
        retval = Take_evolution_step(DO_NOT_UPDATE_TIME);
        stardata->model.time = thalf;
        eprint("RK4(1) retval %d\n",retval);
        
        /* 
         * Calculate k3:
         *    structre x = x(t) + k2' (set above)
         *    derivatives at t + dt/2 = thalf
         *    apply for time dt/2
         *    x2' = x + k2' -> x + k3 = x + 2*k3' = x3'
         *    t + dt/2 -> t + dt
         */
        stardata->model.solver_step = 2;
        retval = Take_evolution_step(UPDATE_TIME);
        eprint("RK4(2) retval %d\n",retval);

        /* 
         * Calculate k4:
         *    structure x = x(t) + k3' (see above)
         *    derivatives at t + dt
         *    apply for time 0
         *    x + k3 = x + 2*k3' -> x + k4 = x + 2*k3'
         *    t + dt -> t + dt
         */
        Set_event_denial(deny_was);
        stardata->model.solver_step = 3;
        stardata->model.intermediate_step = FALSE;
        retval = Take_evolution_step(DO_NOT_UPDATE_TIME);
        eprint("RK4(3) retval %d\n",retval);
        
        /*
         * Restore the full timestep
         */
        Restore_solver_timestep;
    }
    else if(stardata->preferences->solver == SOLVER_PREDICTOR_CORRECTOR)
    {
        /* 
         * PEC method
         * https://en.wikipedia.org/wiki/Predictor%E2%80%93corrector_method
         */
        retval = Take_evolution_step(UPDATE_TIME);
        Modulate_solver_timestep(0.5);

        int i;
        const int maxk = 10;
        for(i=1;i<maxk;i++)
        {
            stardata->model.solver_step = i;
            eprint("PEC%d\n",i);
            retval = Take_evolution_step(DO_NOT_UPDATE_TIME);
        }
        Restore_solver_timestep;
    }
    else
    {
        retval = 0;
        Exit_binary_c(BINARY_C_ALGORITHM_OUT_OF_RANGE,
                      "Unknown solver %d\n",stardata->preferences->solver);
    }
    
    /* go back to step 0 */
    stardata->model.solver_step = 0;

    eprint("DT RESTORE -> dtm = %g, dt = %g\n",
           stardata->model.dtm,
           stardata->model.dt
        );

    if(Is_zero(stardata->model.dt) && Is_zero(stardata->model.time))
    {
        /* reset in case of repeat */
        stardata->common.test = 0.0;
    }

    /*
     * Linear integration test
     */
#ifdef LINEAR_INTEGRATION_TEST

    double absdiff = Abs_diff(stardata->model.time*1e6,
                             stardata->common.test);

    eprint("TEST check t = %30.12e test = %30.12e : diff %g\n",
           stardata->model.time*1e6,
           stardata->common.test,
           absdiff
        );
    
    if(absdiff > 1e-3)
    {
        eprint("integration failed\n");
        fflush(NULL);
        _exit(0);
    }
#endif

    
    /*
     * Time-explicit algorithms
     */
    evolution_time_explicit(stardata);
    
    eprint("Devolution retval %d\n",retval);

    /*
     * rejection tests
     */
    evolution_rejection_tests(stardata);
    
    return retval;
}



static int evolution_step(struct stardata_t * RESTRICT const stardata,
                          Evolution_system_type system_type,
                          Boolean update_time)
{
    /*
     * Perform an evolutionary step:
     *
     * 1) calculate derivatives of mass and orbital elements
     * 2) apply these to take the system from t to t+dt
     * 3) update stellar structure to t+dt
     */

    eprint("evolution step\n");
    /*
     * Calculate mass, angular momentum and evolutionary changes
     * over the timestep. This sets and applies the derivatives.
     */
    int stop =  mass_angmom_and_evolutionary_changes(stardata,system_type);
    int status = 0;

    if(stop == 0)
    {
        /* 
         * When should we rejuv, when should we not?
         * He-star fail suggests we should, but
         * other systems suggest we should not.
         */
        struct model_t * model = &stardata->model;

        if(system_type == GENERIC_SYSTEM_CALL)
        {
            /*
             * Code for a generic system
             */
            save_detached_stellar_types(stardata);
            model->tphys0 = model->time;
            model->dtm0 = model->dtm;

#ifdef BSE
            if(stardata->model.in_RLOF)
            {
                rejuvenate_MS_secondary_and_age_primary(stardata);
            }
#endif//BSE
        }
        else if(system_type==DETACHED_SYSTEM_CALL)
        {
            /* 
             * Detached system
             */
            save_detached_stellar_types(stardata);
            if(model->intpol==0)
            {
                model->tphys0 = model->time;
                model->dtm0 = model->dtm;
            }
        }
        else
        {
            /*
             * RLOFing system.
             *
             * Rejuvenate the secondary and 
             * age the primary if they are on
             * the main sequence.
             */
#ifdef BSE
            rejuvenate_MS_secondary_and_age_primary(stardata);
#endif
        }

        /*
         * Set the time to t+dt
         */
        if(update_time == TRUE)
        {
            update_the_time(stardata);
        
            Dprint("Advance the time: intpol=%d time was %30.22e (dt=%30.22e dtm=%30.22e) now %30.22e (tphys00 %30.22e, maxt=%g)\n",
                   model->intpol,
                   model->time-model->dtm,
                   model->dt,
                   model->dtm,
                   model->time,
                   model->tphys0,
                   model->max_evolution_time);
        }
        else
        {
            Dprint("Not not advance the time (update_time == FALSE)\n");
        }
        
        Nancheck(model->dt);
        Nancheck(model->dtm);

        /*
         * Do stellar evolution over the timestep
         */
        status = stellar_evolution(stardata,system_type);
    }
    
    /*
     * Do binary stellar evolution over the timestep,
     * also checks for supernovae and updates the orbit 
     * if required.
     */
    binary_star_evolution(stardata,&status);

    int retval;
    if(stop!=0)
    {
        /*
         * Return without updating the time: 
         * there has probably been a common envelope
         * or a merger or something like that.
         */
        retval = -stop;
        eprint("retval forced to stop %d\n",retval);
    }
    else
    {
        /*
         * Update stellar spins (by conserving angular momenta)
         */
        calculate_spins(stardata);

        if(system_type==GENERIC_SYSTEM_CALL)
        {
            if(stardata->model.sgl==FALSE)
            {
                /* in binaries, update the radius derivative */
                adjust_radius_derivative(stardata);
            }
            else
            {
                /* in single stars, make Roche lobes huge */
                make_roche_lobe_radii_huge(stardata);
            }
        }
        else if(system_type==DETACHED_SYSTEM_CALL)
        {
            if(stardata->model.sgl==FALSE)
            {
                /* in binaries, update the radius derivative */
                adjust_radius_derivative(stardata);
            }
            else
            {
                /* in single stars, make Roche lobes huge */
                make_roche_lobe_radii_huge(stardata);
            }
            if(stardata->model.supernova==TRUE) stardata->model.dtm = 0.0;
        }

        update_phase_start_times(stardata);
        
        Boolean exhausted = check_for_time_exhaustion(stardata,
                                                      stardata->model.intpol);

        retval = exhausted == TRUE ? -STOP : status;
        eprint("exhausted? (evolution) %d -> retval %d\n",
               exhausted,
            retval);
    }

    
    return retval;
}


