import math
import binarycpython.utils.useful_funcs

###
# File containing probability distributions
# Mostly copied from the perl modules

# TODO: make some things globally present? rob does this in his module. i guess it saves calculations but not sure if im gonna do that now
# TODO: Add the stuff from the IMF file
# TODO: call all of these functions to check whether they work
# TODO: make global constants stuff
# TODO: make description of module submodule

log_ln_converter = 1.0 / math.log(10.0)


def flat(parameter):
    """
    Dummt distribution function that returns 1
    """

    return 1


def number(value):
    """
    Dummy distribution function that returns the input
    """
    return value


def powerlaw_constant(min_val, max_val, k):
    """
    Function that returns the constant to normalise a powerlaw
    """

    k1 = k + 1.0
    # print(
    #     "Powerlaw consts from {} to {}, k={} where k1={}".format(
    #         min_val, max_val, k, k1
    #     )
    # )

    powerlaw_const = k1 / (max_val ** k1 - min_val ** k1)
    return powerlaw_const


def powerlaw(min_val, max_val, k, x):
    """
    Single powerlaw with index k at x from min to max
    """

    # Handle faulty value
    if k == -1:
        print("wrong value for k")
        raise ValueError

    if (x < min_val) or (x > max_val):
        print("input value is out of bounds!")
        return 0

    else:
        const = powerlaw_constant(min_val, max_val, k)

        # powerlaw
        y = const * (x ** k)
        # print(
        #     "Power law from {} to {}: const = {}, y = {}".format(
        #         min_val, max_val, const, y
        #     )
        # )
        return y


def calculate_constants_three_part_powerlaw(m0, m1, m2, m_max, p1, p2, p3):
    # print("Initialising constants for the three-part powerlaw: m0={} m1={} m2={} m_max={} p1={} p2={} p3={}\n".format(m0, m1, m2, m_max, p1, p2, p3))

    array_constants_three_part_powerlaw = [0, 0, 0]

    array_constants_three_part_powerlaw[1] = (
        ((m1 ** p2) * (m1 ** (-p1)))
        * (1.0 / (1.0 + p1))
        * (m1 ** (1.0 + p1) - m0 ** (1.0 + p1))
    )
    array_constants_three_part_powerlaw[1] += (
        (m2 ** (1.0 + p2) - m1 ** (1.0 + p2))
    ) * (1.0 / (1.0 + p2))
    array_constants_three_part_powerlaw[1] += (
        ((m2 ** p2) * (m2 ** (-p3)))
        * (1.0 / (1.0 + p3))
        * (m_max ** (1.0 + p3) - m2 ** (1.0 + p3))
    )
    array_constants_three_part_powerlaw[1] = 1.0 / (
        array_constants_three_part_powerlaw[1] + 1e-50
    )

    array_constants_three_part_powerlaw[0] = array_constants_three_part_powerlaw[1] * (
        (m1 ** p2) * (m1 ** (-p1))
    )
    array_constants_three_part_powerlaw[2] = array_constants_three_part_powerlaw[1] * (
        (m2 ** p2) * (m2 ** (-p3))
    )

    return array_constants_three_part_powerlaw
    # $$array[1]=(($m1**$p2)*($m1**(-$p1)))*
    # (1.0/(1.0+$p1))*
    # ($m1**(1.0+$p1)-$m0**(1.0+$p1))+
    # (($m2**(1.0+$p2)-$m1**(1.0+$p2)))*
    # (1.0/(1.0+$p2))+
    # (($m2**$p2)*($m2**(-$p3)))*
    # (1.0/(1.0+$p3))*
    # ($mmax**(1.0+$p3)-$m2**(1.0+$p3));
    # $$array[1]=1.0/($$array[1]+1e-50);
    # $$array[0]=$$array[1]*$m1**$p2*$m1**(-$p1);
    # $$array[2]=$$array[1]*$m2**$p2*$m2**(-$p3);
    # #print "ARRAY SET @_ => @$array\n";
    # $threepart_powerlaw_consts{"@_"}=[@$array];


def three_part_powerlaw(M, M0, M1, M2, M_MAX, P1, P2, P3):
    """
    Generalized three-part power law, usually used for mass distributions
    """

    # TODO: add check on whether the values exist

    three_part_powerlaw_constants = calculate_constants_three_part_powerlaw(
        M0, M1, M2, M_MAX, P1, P2, P3
    )

    #
    if M < M0:
        p = 0  # Below 0
    elif M0 < M <= M1:
        p = three_part_powerlaw_constants[0] * (M ** P1)  # Between M0 and M1
    elif M1 < M <= M2:
        p = three_part_powerlaw_constants[1] * (M ** P2)  # Between M1 and M2
    elif M2 < M <= M_MAX:
        p = three_part_powerlaw_constants[2] * (M ** P3)  # Between M2 and M_MAX
    else:
        p = 0  # Above M_MAX

    return p


def const(min_bound, max_bound, val=None):
    """
    a constant distribution function between min=$_[0] and max=$_[1]
    """

    if val:
        if not (min_bound < val <= max_bound):
            print("out of bounds")
            p = 0
            return p
    p = 1.0 / (min_bound - max_bound)
    return p


def set_opts(opts, newopts):
    """
    Function to take a default dict and override it with newer values.
    """

    # DONE: put in check to make sure that the newopts keys are contained in opts
    # TODO: change this to just a dict.update

    if newopts:
        for opt in newopts.keys():
            if opt in opts.keys():
                opts[key] = newopts[key]

    return opt


def gaussian(x, mean, sigma, gmin, gmax):
    """
    Gaussian distribution function. used for e..g Duquennoy + Mayor 1991
    
    Input: location, mean, sigma, min and max:
    """
    # # location (X value), mean and sigma, min and max range
    # my ($x,$mean,$sigma,$gmin,$gmax) = @_;

    if (x < gmin) or (x > gmax):
        p = 0
    else:
        # normalize over given range
        # TODO: add loading into global var
        normalisation = gaussian_normalizing_const(mean, sigma, gmin, gmax)
        p = normalisation * gaussian_func(x, mean, sigma)

    return p


def gaussian_normalizing_const(mean, sigma, gmin, gmax):
    """
    Function to calculate the normalisation constant for the gaussian
    """

    # First time; calculate multipllier for given mean and sigma
    ptot = 0
    resolution = 1000
    d = (gmax - gmin) / resolution

    for i in range(resolution):
        y = gmin + i * d
        ptot += d * gaussian_func(y, mean, sigma)

    # TODO: Set value in global
    return ptot


def gaussian_func(x, mean, sigma):
    """
    Function to evaluate a gaussian at a given point
    """
    GAUSSIAN_PREFACTOR = 1.0 / math.sqrt(2.0 * math.pi)

    r = 1.0 / (sigma)
    y = (x - mean) * r
    return GAUSSIAN_PREFACTOR * r * math.exp(-0.5 * y ** 2)


#####
# Mass distributions
#####


def Kroupa2001(m, newopts=None):
    """
    Probability distribution function for kroupa 2001 IMF 

    Input: Mass, (and optional: dict of new options. Input the  
        default = {'m0':0.1, 'm1':0.5, 'm2':1, 'mmax':100, 'p1':-1.3, 'p2':-2.3, 'p3':-2.3}
    """

    # Default params and override them
    default = {
        "m0": 0.1,
        "m1": 0.5,
        "m2": 1,
        "mmax": 100,
        "p1": -1.3,
        "p2": -2.3,
        "p3": -2.3,
    }
    value_dict = default.copy()
    if newopts:
        value_dict.update(newopts)
    return three_part_powerlaw(
        m,
        value_dict["m0"],
        value_dict["m1"],
        value_dict["m2"],
        value_dict["mmax"],
        value_dict["p1"],
        value_dict["p2"],
        value_dict["p3"],
    )


def ktg93(m, newopts):
    """
    Wrapper for mass distribution of KTG93
    """
    # TODO: ask rob what this means

    # if($m eq 'uncertainties')
    # {
    # # return (pointer to) the uncertainties hash
    # return {
    #     m0=>{default=>0.1,
    #      fixed=>1},
    #     m1=>{default=>0.5,
    #      fixed=>1},
    #     m2=>{default=>1.0,
    #      fixed=>1},
    #     mmax=>{default=>80.0,
    #        fixed=>1},
    #     p1=>{default=>-1.3,
    #      low=>-1.3,
    #      high=>-1.3},
    #     p2=>{default=>-2.2,
    #      low=>-2.2,
    #      high=>-2.2},
    #     p3=>{default=>-2.7,
    #      low=>-2.7,
    #      high=>-2.7}
    # };
    # }

    # set options
    # opts = set_opts({'m0':0.1, 'm1':0.5, 'm2':1.0, 'mmax':80, 'p1':-1.3, 'p2':-2.2, 'p3':-2.7}, newopts)
    defaults = {
        "m0": 0.1,
        "m1": 0.5,
        "m2": 1.0,
        "mmax": 80,
        "p1": -1.3,
        "p2": -2.2,
        "p3": -2.7,
    }
    value_dict = default.copy()
    if newopts:
        value_dict.update(newopts)
    return three_part_powerlaw(
        m,
        value_dict["m0"],
        value_dict["m0"],
        value_dict["m2"],
        value_dict["m0"],
        value_dict["m0"],
        value_dict["m0"],
        value_dict["m0"],
    )


# sub ktg93_lnspace
# {
#     # wrapper for KTG93 on a ln(m) grid
#     my $m=$_[0];
#     return ktg93(@_) * $m;
# }


def imf_tinsley1980(m):
    """
    From Tinsley 1980 (defined up until 80Msol)
    """

    return three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)


def imf_scalo1986(m):
    """
    From Scalo 1986 (defined up until 80Msol)
    """
    return three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)


def imf_scalo1998(m):
    """
    From scalo 1998
    """
    return three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)


def imf_chabrier2003(m):
    """
    # IMF of Chabrier 2003 PASP 115:763-795    
    """
    Chabrier_logmc = math.log10(0.079)
    Chabrier_sigma2 = 0.69 * 0.69
    Chabrier_A1 = 0.158
    Chabrier_A2 = 4.43e-2
    Chabrier_x = -1.3
    if m < 0:
        print("below bounds")
        raise ValueError
    if 0 < m < 1.0:
        A = 0.158
        dm = math.log10(m) - Chabrier_logmc
        p = Chabrier_A1 * math.exp(-(dm ** 2) / (2.0 * Chabrier_sigma2))
    else:
        p = Chabrier_A2 * (m ** Chabrier_x)
    p = p / (0.1202462 * m * math.log(10))
    return p


########################################################################
# Binary fractions
########################################################################
def Arenou2010_binary_fraction(m):
    # Arenou 2010 function for the binary fraction as f(M1)
    #
    # GAIA-C2-SP-OPM-FA-054
    # www.rssd.esa.int/doc_fetch.php?id=2969346

    return 0.8388 * math.tanh(0.688 * m + 0.079)


# print(Arenou2010_binary_fraction(0.4))


def raghavan2010_binary_fraction(m):
    """
    Fit to the Raghavan 2010 binary fraction as a function of   
    spectral type (Fig 12). Valid for local stars (Z=Zsolar).
    
    The spectral type is converted  mass by use of the ZAMS 
    effective temperatures from binary_c/BSE (at Z=0.02)
    and the new "long_spectral_type" function of binary_c
    (based on Jaschek+Jaschek's Teff-spectral type table).
    
    Rob then fitted the result
    """

    return min(
        1.0,
        max(
            (m ** 0.1) * (5.12310e-01) + (-1.02070e-01),
            (1.10450e00) * (m ** (4.93670e-01)) + (-6.95630e-01),
        ),
    )


# print(raghavan2010_binary_fraction(2))

########################################################################
# Period distributions
########################################################################


def duquennoy1991(x):
    """
    Period distribution from Duquennoy + Mayor 1991

    Input:
        x: logperiod
    """
    return gaussian(x, 4.8, 2.3, -2, 12)


def sana12(M1, M2, a, P, amin, amax, x0, x1, p):  # TODO: ? wtf. vague input
    """
    distribution of initial orbital periods as found by Sana et al. (2012)
    which is a flat distribution in ln(a) and ln(P) respectively for stars 
    * less massive than 15Msun (no O-stars)
    * mass ratio q=M2/M1<0.1
    * log(P)<0.15=x0 and log(P)>3.5=x1
    and is be given by dp/dlogP ~ (logP)^p for all other binary configurations (default p=-0.55)

    arguments are M1, M2, a, Period P, amin, amax, x0=log P0, x1=log P1, p

    example args: 10, 5, ?, P, ?, ?, -2, 12, -0.55
    """

    res = 0

    if (M1 < 15) or (M2 / M1 < 0.1):
        res = 1.0 / (math.log(amax) - math.log(amin))
    else:
        p1 = 1.0 + p

        # For more details see the LyX document of binary_c for this distribution where the variables and normalizations are given
        # we use the notation x=log(P), xmin=log(Pmin), x0=log(P0), ... to determine the
        x = log_ln_converter * math.log(p)
        xmin = log_ln_converter * math.log(calc_period_from_sep(m1, m2, amin))
        xmin = log_ln_converter * math.log(calc_period_from_sep(m1, m2, amax))

        # my $x0 = 0.15;
        # my $x1 = 3.5;

        A1 = 1.0 / (
            x0 ** p * (x0 - xmin) + (x1 ** p1 - x0 ** p1) / p1 + x1 ** p * (xmax - x1)
        )
        A0 = A1 * x0 ** p
        A2 = A1 * x1 ** p

        if x < x0:
            res = 3.0 / 2.0 * log_ln_converter * A0
        elif x > x1:
            res = 3.0 / 2.0 * log_ln_converter * A2
        else:
            res = 3.0 / 2.0 * log_ln_converter * A1 * x ** p

    return res


# print(sana12(10, 2, 10, 100, 1, 1000, math.log(10), math.log(1000), 6))

########################################################################
# Mass ratio distributions
########################################################################


def flatsections(x, opts):
    """
    
    opts = list of dicts with settings for the flat sections
    x: location to calculate the y value
    TODO: figure out why it has to be a list of dict. why not just 1
    """

    c = 0
    y = 0

    for opt in opts:
        dc = (opt["max"] - opt["min"]) * opt["height"]
        # print("added flatsection ({}-{})*{} = {}\n".format(opt['max'], opt['min'], opt['height'], dc))
        c += dc
        if opt["min"] <= x <= opt["max"]:
            y = opt["height"]
            # print("Use this\n")
    c = 1.0 / c
    y = y * c
    # print("flatsections gives C={}: y={}\n",c,y)
    return y


# print(flatsections(1, [{'min': 0, 'max': 2, 'height': 3}]))
