"""
Module containing the predefined distribution functions

The user can use any of these distribution functions to
generate probability distributions for sampling populations

There are distributions for the following parameters:
    - mass
    - period
    - mass ratio
    - binary fraction

Tasks:
    - TODO: make some things globally present? rob does this in his module..i guess it saves calculations but not sure if im gonna do that now
    - TODO: make global constants stuff
    - TODO: add eccentricity distribution: thermal
    - TODO: Add SFH distributions depending on redshift
    - TODO: Add metallicity distributions depending on redshift
    - TODO: Add initial rotational velocity distributions
    - TODO: make an n-part powerlaw thats general enough to fix the three part and the 4 part
"""

import math
import numpy as np
from typing import Optional, Union
from binarycpython.utils.useful_funcs import calc_period_from_sep, calc_sep_from_period

###
# File containing probability distributions
# Mostly copied from the perl modules
LOG_LN_CONVERTER = 1.0 / math.log(10.0)
distribution_constants = {}  # To store the constants in


def prepare_dict(global_dict: dict, list_of_sub_keys: list) -> None:
    """
    Function that makes sure that the global dict is prepared to have a value set there. This dictionary will store values and factors for the distribution functions, so that they dont have to be calculated each time.

    Args:
        global_dict: globablly acessible dictionary where factors are stored in
        list_of_sub_keys: List of keys that must become be(come) present in the global_dict
    """

    internal_dict_value = global_dict

    # This loop almost mimics a recursive loop into the dictionary.
    # It checks whether the first key of the list is present, if not; set it with an empty dict.
    # Then it overrides itself to be that (new) item, and goes on to do that again, until the list exhausted
    for k in list_of_sub_keys:
        # If the sub key doesnt exist then make an empty dict
        if not internal_dict_value.get(k, None):
            internal_dict_value[k] = {}
        internal_dict_value = internal_dict_value[k]


def set_opts(opts: dict, newopts: dict) -> dict:
    """
    Function to take a default dict and override it with newer values.

    # TODO: consider changing this to just a dict.update

    Args:
        opts: dictionary with default values
        newopts: dictionary with new values

    Returns:
        returns an updated dictionary
    """

    if newopts:
        for opt in newopts.keys():
            if opt in opts.keys():
                opts[opt] = newopts[opt]

    return opts


def flat() -> float:
    """
    Dummy distribution function that returns 1

    Returns:
        a flat uniform distribution: 1
    """

    return 1.0


def number(value: Union[int, float]) -> Union[int, float]:
    """
    Dummy distribution function that returns the input

    Args:
        value: the value that will be returned by this function.

    Returns:
        the value that was provided
    """

    return value


def const(
    min_bound: Union[int, float], max_bound: Union[int, float], val: float = None
) -> Union[int, float]:
    """
    a constant distribution function between min=min_bound and max=max_bound.

    Args:
        min_bound: lower bound of the range
        max_bound: upper bound of the range

    Returns:
            returns the value of 1/(max_bound-min_bound). If val is provided, it will check whether min_bound < val <= max_bound. if not: returns 0
    """

    if val:
        if not min_bound < val <= max_bound:
            print("out of bounds")
            prob = 0
            return prob
    prob = 1.0 / (max_bound - min_bound)
    return prob


def powerlaw_constant(
    min_val: Union[int, float], max_val: Union[int, float], k: Union[int, float]
) -> Union[int, float]:
    """
    Function that returns the constant to normalise a powerlaw

    TODO: what if k is -1?

    Args:
        min_val: lower bound of the range
        max_val: upper bound of the range
        k: powerlaw slope

    Returns:
        constant to normalize the given powerlaw between the min_val and max_val range
    """

    k1 = k + 1.0
    # print(
    #     "Powerlaw consts from {} to {}, k={} where k1={}".format(
    #         min_val, max_val, k, k1
    #     )
    # )

    powerlaw_const = k1 / (max_val ** k1 - min_val ** k1)
    return powerlaw_const


def powerlaw(
    min_val: Union[int, float],
    max_val: Union[int, float],
    k: Union[int, float],
    x: Union[int, float],
) -> Union[int, float]:
    """
    Single powerlaw with index k at x from min to max

    Args:
        min_val: lower bound of the powerlaw
        max_val: upper bound of the powerlaw
        k: slope of the power law
        x: position at which we want to evaluate

    Returns:
        `probability` at the given position(x)
    """

    # Handle faulty value
    if k == -1:
        print("wrong value for k")
        raise ValueError

    if (x < min_val) or (x > max_val):
        print("input value is out of bounds!")
        return 0

    powerlaw_const = powerlaw_constant(min_val, max_val, k)

    # powerlaw
    prob = powerlaw_const * (x ** k)
    # print(
    #     "Power law from {} to {}: const = {}, y = {}".format(
    #         min_val, max_val, const, y
    #     )
    # )
    return prob


def calculate_constants_three_part_powerlaw(
    m0: Union[int, float],
    m1: Union[int, float],
    m2: Union[int, float],
    m_max: Union[int, float],
    p1: Union[int, float],
    p2: Union[int, float],
    p3: Union[int, float],
) -> Union[int, float]:
    """
    Function to calculate the constants for a three-part powerlaw

    TODO: use the powerlaw_constant function to calculate all these values

    Args:
        m0: lower bound mass
        m1: second boundary, between the first slope and the second slope
        m2: third boundary, between the second slope and the third slope
        m_max: upper bound mass
        p1: first slope
        p2: second slope
        p3: third slope

    Returns:
        array of normalisation constants
    """

    # print("Initialising constants for the three-part powerlaw: m0={} m1={} m2={}\
    # m_max={} p1={} p2={} p3={}\n".format(m0, m1, m2, m_max, p1, p2, p3))

    array_constants_three_part_powerlaw = [0, 0, 0]

    array_constants_three_part_powerlaw[1] = (
        ((m1 ** p2) * (m1 ** (-p1)))
        * (1.0 / (1.0 + p1))
        * (m1 ** (1.0 + p1) - m0 ** (1.0 + p1))
    )
    array_constants_three_part_powerlaw[1] += (
        (m2 ** (1.0 + p2) - m1 ** (1.0 + p2))
    ) * (1.0 / (1.0 + p2))
    array_constants_three_part_powerlaw[1] += (
        ((m2 ** p2) * (m2 ** (-p3)))
        * (1.0 / (1.0 + p3))
        * (m_max ** (1.0 + p3) - m2 ** (1.0 + p3))
    )
    array_constants_three_part_powerlaw[1] = 1.0 / (
        array_constants_three_part_powerlaw[1] + 1e-50
    )

    array_constants_three_part_powerlaw[0] = array_constants_three_part_powerlaw[1] * (
        (m1 ** p2) * (m1 ** (-p1))
    )
    array_constants_three_part_powerlaw[2] = array_constants_three_part_powerlaw[1] * (
        (m2 ** p2) * (m2 ** (-p3))
    )

    return array_constants_three_part_powerlaw
    # $$array[1]=(($m1**$p2)*($m1**(-$p1)))*
    # (1.0/(1.0+$p1))*
    # ($m1**(1.0+$p1)-$m0**(1.0+$p1))+
    # (($m2**(1.0+$p2)-$m1**(1.0+$p2)))*
    # (1.0/(1.0+$p2))+
    # (($m2**$p2)*($m2**(-$p3)))*
    # (1.0/(1.0+$p3))*
    # ($mmax**(1.0+$p3)-$m2**(1.0+$p3));
    # $$array[1]=1.0/($$array[1]+1e-50);
    # $$array[0]=$$array[1]*$m1**$p2*$m1**(-$p1);
    # $$array[2]=$$array[1]*$m2**$p2*$m2**(-$p3);
    # #print "ARRAY SET @_ => @$array\n";
    # $threepart_powerlaw_consts{"@_"}=[@$array];


def three_part_powerlaw(
    m: Union[int, float],
    m0: Union[int, float],
    m1: Union[int, float],
    m2: Union[int, float],
    m_max: Union[int, float],
    p1: Union[int, float],
    p2: Union[int, float],
    p3: Union[int, float],
) -> Union[int, float]:
    """
    Generalized three-part power law, usually used for mass distributions

    Args:
        m: mass at which we want to evaluate the distribution.
        m0: lower bound mass
        m1: second boundary, between the first slope and the second slope
        m2: third boundary, between the second slope and the third slope
        m_max: upper bound mass
        p1: first slope
        p2: second slope
        p3: third slope

    Returns:
        'probability' at given mass m
    """

    # TODO: add check on whether the values exist

    three_part_powerlaw_constants = calculate_constants_three_part_powerlaw(
        m0, m1, m2, m_max, p1, p2, p3
    )

    #
    if m < m0:
        prob = 0  # Below lower bound TODO: make this clear.
    elif m0 < m <= m1:
        prob = three_part_powerlaw_constants[0] * (m ** p1)  # Between M0 and M1
    elif m1 < m <= m2:
        prob = three_part_powerlaw_constants[1] * (m ** p2)  # Between M1 and M2
    elif m2 < m <= m_max:
        prob = three_part_powerlaw_constants[2] * (m ** p3)  # Between M2 and M_MAX
    else:
        prob = 0  # Above M_MAX

    return prob


def gaussian_normalizing_const(
    mean: Union[int, float],
    sigma: Union[int, float],
    gmin: Union[int, float],
    gmax: Union[int, float],
) -> Union[int, float]:
    """
    Function to calculate the normalisation constant for the gaussian

    Args:
        mean: mean of the gaussian
        sigma: standard deviation of the gaussian
        gmin: lower bound of the range to calculate the probabilities in
        gmax: upper bound of the range to calculate the probabilities in

    Returns:
        normalisation constant for the gaussian distribution(mean, sigma) between gmin and gmax
    """

    # First time; calculate multipllier for given mean and sigma
    ptot = 0
    resolution = 1000
    d = (gmax - gmin) / resolution

    for i in range(resolution):
        y = gmin + i * d
        ptot += d * gaussian_func(y, mean, sigma)

    # TODO: Set value in global
    return ptot


def gaussian_func(
    x: Union[int, float], mean: Union[int, float], sigma: Union[int, float]
) -> Union[int, float]:
    """
    Function to evaluate a gaussian at a given point, but this time without any boundaries.

    Args:
        x: location at which to evaluate the distribution
        mean: mean of the gaussian
        sigma: standard deviation of the gaussian

    Returns:
        value of the gaussian at x
    """
    gaussian_prefactor = 1.0 / math.sqrt(2.0 * math.pi)

    r = 1.0 / (sigma)
    y = (x - mean) * r
    return gaussian_prefactor * r * math.exp(-0.5 * y ** 2)


def gaussian(
    x: Union[int, float],
    mean: Union[int, float],
    sigma: Union[int, float],
    gmin: Union[int, float],
    gmax: Union[int, float],
) -> Union[int, float]:
    """
    Gaussian distribution function. used for e..g Duquennoy + Mayor 1991

    Args:
        x: location at which to evaluate the distribution
        mean: mean of the gaussian
        sigma: standard deviation of the gaussian
        gmin: lower bound of the range to calculate the probabilities in
        gmax: upper bound of the range to calculate the probabilities in

    Returns:
        'probability' of the gaussian distribution between the boundaries, evaluated at x
    """

    # # location (X value), mean and sigma, min and max range
    # my ($x,$mean,$sigma,$gmin,$gmax) = @_;

    if (x < gmin) or (x > gmax):
        prob = 0
    else:
        # normalize over given range
        # TODO: add loading into global var
        normalisation = gaussian_normalizing_const(mean, sigma, gmin, gmax)
        prob = normalisation * gaussian_func(x, mean, sigma)

    return prob


#####
# Mass distributions
#####


def Kroupa2001(m: Union[int, float], newopts: dict = None) -> Union[int, float]:
    """
    Probability distribution function for kroupa 2001 IMF, where the default values to the
    three_part_powerlaw are:
        default = {"m0": 0.1, "m1": 0.5, "m2": 1, "mmax": 100, "p1": -1.3, "p2": -2.3,"p3": -2.3}

    Args:
        m: mass to evaluate the distribution at
        newopts: optional dict to override the default values.

    Returns:
        'probability' of distribution function evaluated at m
    """

    # Default params and override them
    default = {
        "m0": 0.1,
        "m1": 0.5,
        "m2": 1,
        "mmax": 100,
        "p1": -1.3,
        "p2": -2.3,
        "p3": -2.3,
    }

    value_dict = default.copy()

    if newopts:
        value_dict.update(newopts)

    return three_part_powerlaw(
        m,
        value_dict["m0"],
        value_dict["m1"],
        value_dict["m2"],
        value_dict["mmax"],
        value_dict["p1"],
        value_dict["p2"],
        value_dict["p3"],
    )


def ktg93(m: Union[int, float], newopts: dict = None) -> Union[int, float]:
    """
    Probability distribution function for KTG93 IMF, where the default values to the three_part_powerlaw are: default = {"m0": 0.1, "m1": 0.5, "m2": 1, "mmax": 80, "p1": -1.3, "p2": -2.2,"p3": -2.7}

    Args:
        m: mass to evaluate the distribution at
        newopts: optional dict to override the default values.

    Returns:
        'probability' of distribution function evaluated at m
    """
    # TODO: ask rob what this means

    # if($m eq 'uncertainties')
    # {
    # # return (pointer to) the uncertainties hash
    # return {
    #     m0=>{default=>0.1,
    #      fixed=>1},
    #     m1=>{default=>0.5,
    #      fixed=>1},
    #     m2=>{default=>1.0,
    #      fixed=>1},
    #     mmax=>{default=>80.0,
    #        fixed=>1},
    #     p1=>{default=>-1.3,
    #      low=>-1.3,
    #      high=>-1.3},
    #     p2=>{default=>-2.2,
    #      low=>-2.2,
    #      high=>-2.2},
    #     p3=>{default=>-2.7,
    #      low=>-2.7,
    #      high=>-2.7}
    # };
    # }

    # set options
    # opts = set_opts({'m0':0.1, 'm1':0.5, 'm2':1.0, 'mmax':80, 'p1':-1.3, 'p2':-2.2, 'p3':-2.7},
    # newopts)

    defaults = {
        "m0": 0.1,
        "m1": 0.5,
        "m2": 1.0,
        "mmax": 80,
        "p1": -1.3,
        "p2": -2.2,
        "p3": -2.7,
    }
    value_dict = defaults.copy()

    if newopts:
        value_dict.update(newopts)

    return three_part_powerlaw(
        m,
        value_dict["m0"],
        value_dict["m1"],
        value_dict["m2"],
        value_dict["mmax"],
        value_dict["p1"],
        value_dict["p2"],
        value_dict["p3"],
    )


# sub ktg93_lnspace
# {
#     # wrapper for KTG93 on a ln(m) grid
#     my $m=$_[0];
#     return ktg93(@_) * $m;
# }


def imf_tinsley1980(m: Union[int, float]) -> Union[int, float]:
    """
    Probability distribution function for tinsley 1980 IMF (defined up until 80Msol): three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)

    Args:
        m: mass to evaluate the distribution at

    Returns:
        'probability' of distribution function evaluated at m
    """

    return three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)


def imf_scalo1986(m: Union[int, float]) -> Union[int, float]:
    """
    Probability distribution function for Scalo 1986 IMF (defined up until 80Msol): three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)

    Args:
        m: mass to evaluate the distribution at

    Returns:
        'probability' of distribution function evaluated at m
    """
    return three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)


def imf_scalo1998(m: Union[int, float]) -> Union[int, float]:
    """
    From scalo 1998

    Probability distribution function for Scalo 1998 IMF (defined up until 80Msol): three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)

    Args:
        m: mass to evaluate the distribution at

    Returns:
        'probability' of distribution function evaluated at m
    """

    return three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)


def imf_chabrier2003(m: Union[int, float]) -> Union[int, float]:
    """
    Probability distribution function for IMF of Chabrier 2003 PASP 115:763-795

    Args:
        m: mass to evaluate the distribution at

    Returns:
        'probability' of distribution function evaluated at m
    """

    chabrier_logmc = math.log10(0.079)
    chabrier_sigma2 = 0.69 * 0.69
    chabrier_a1 = 0.158
    chabrier_a2 = 4.43e-2
    chabrier_x = -1.3
    if m <= 0:
        print("below bounds")
        raise ValueError
    if 0 < m < 1.0:
        A = 0.158
        dm = math.log10(m) - chabrier_logmc
        prob = chabrier_a1 * math.exp(-(dm ** 2) / (2.0 * chabrier_sigma2))
    else:
        prob = chabrier_a2 * (m ** chabrier_x)
    prob = prob / (0.1202462 * m * math.log(10))
    return prob


########################################################################
# Binary fractions
########################################################################


def Arenou2010_binary_fraction(m: Union[int, float]) -> Union[int, float]:
    """
    Arenou 2010 function for the binary fraction as f(M1)

    GAIA-C2-SP-OPM-FA-054
    www.rssd.esa.int/doc_fetch.php?id=2969346

    Args:
        m: mass to evaluate the distribution at

    Returns:
        binary fraction at m
    """

    return 0.8388 * math.tanh(0.688 * m + 0.079)


# print(Arenou2010_binary_fraction(0.4))


def raghavan2010_binary_fraction(m: Union[int, float]) -> Union[int, float]:
    """
    Fit to the Raghavan 2010 binary fraction as a function of
    spectral type (Fig 12). Valid for local stars (Z=Zsolar).

    The spectral type is converted  mass by use of the ZAMS
    effective temperatures from binary_c/BSE (at Z=0.02)
    and the new "long_spectral_type" function of binary_c
    (based on Jaschek+Jaschek's Teff-spectral type table).

    Rob then fitted the result

    Args:
        m: mass to evaluate the distribution at

    Returns:
        binary fraction at m
    """

    return min(
        1.0,
        max(
            (m ** 0.1) * (5.12310e-01) + (-1.02070e-01),
            (1.10450e00) * (m ** (4.93670e-01)) + (-6.95630e-01),
        ),
    )


# print(raghavan2010_binary_fraction(2))

########################################################################
# Period distributions
########################################################################


def duquennoy1991(logper: Union[int, float]) -> Union[int, float]:
    """
    Period distribution from Duquennoy + Mayor 1991. Evaluated the function gaussian(logper, 4.8, 2.3, -2, 12)

    Args:
        logper: logarithm of period to evaluate the distribution at

    Returns:
        'probability' at gaussian(logper, 4.8, 2.3, -2, 12)
    """
    return gaussian(logper, 4.8, 2.3, -2, 12)


def sana12(
    M1: Union[int, float],
    M2: Union[int, float],
    a: Union[int, float],
    P: Union[int, float],
    amin: Union[int, float],
    amax: Union[int, float],
    x0: Union[int, float],
    x1: Union[int, float],
    p: Union[int, float],
) -> Union[int, float]:
    """
    distribution of initial orbital periods as found by Sana et al. (2012)
    which is a flat distribution in ln(a) and ln(P) respectively for stars
    * less massive than 15Msun (no O-stars)
    * mass ratio q=M2/M1<0.1
    * log(P)<0.15=x0 and log(P)>3.5=x1
    and is be given by dp/dlogP ~ (logP)^p for all other binary configurations (default p=-0.55)

    arguments are M1, M2, a, Period P, amin, amax, x0=log P0, x1=log P1, p

    example args: 10, 5, sep(M1, M2, P), sep, ?, -2, 12, -0.55

    # TODO: Fix this function!

    Args:
        M1: Mass of primary
        M2: Mass of secondary
        a: separation of binary
        P: period of binary
        amin: minimum separation of the distribution (lower bound of the range)
        amax: maximum separation of the distribution (upper bound of the range)
        x0: log of minimum period of the distribution (lower bound of the range)
        x1: log of maximum period of the distribution (upper bound of the range)
        p: slope of the distributoon

    Returns:
        'probability' of orbital period P given the other parameters
    """

    res = 0
    if (M1 < 15) or (M2 / M1 < 0.1):
        res = 1.0 / (math.log(amax) - math.log(amin))
    else:
        p1 = 1.0 + p

        # For more details see the LyX document of binary_c for this distribution
        # where the variables and normalizations are given
        # we use the notation x=log(P), xmin=log(Pmin), x0=log(P0), ... to determine the
        x = LOG_LN_CONVERTER * math.log(P)
        xmin = LOG_LN_CONVERTER * math.log(calc_period_from_sep(M1, M2, amin))
        xmax = LOG_LN_CONVERTER * math.log(calc_period_from_sep(M1, M2, amax))

        # print("M1 M2 amin amax P x xmin xmax")
        # print(M1, M2, amin, amax, P, x, xmin, xmax)
        # my $x0 = 0.15;
        # my $x1 = 3.5;

        A1 = 1.0 / (
            x0 ** p * (x0 - xmin) + (x1 ** p1 - x0 ** p1) / p1 + x1 ** p * (xmax - x1)
        )
        A0 = A1 * x0 ** p
        A2 = A1 * x1 ** p

        if x < x0:
            res = 3.0 / 2.0 * LOG_LN_CONVERTER * A0
        elif x > x1:
            res = 3.0 / 2.0 * LOG_LN_CONVERTER * A2
        else:
            res = 3.0 / 2.0 * LOG_LN_CONVERTER * A1 * x ** p

    return res


# print(sana12(10, 2, 10, 100, 1, 1000, math.log(10), math.log(1000), 6))


def interpolate_in_mass_izzard2012(
    M: Union[int, float], high: Union[int, float], low: Union[int, float]
) -> Union[int, float]:
    """
    Function to interpolate in mass

    TODO: fix this function.
    TODO: describe the args
    high: at M=16.3
    low: at 1.15

    Args:
        M: mass
        high:
        low:

    Returns:

    """

    log_interpolation = False

    if log_interpolation:
        return (high - low) / (math.log10(16.3) - math.log10(1.15)) * (
            math.log10(M) - math.log10(1.15)
        ) + low
    else:
        return (high - low) / (16.3 - 1.15) * (M - 1.15) + low


def Izzard2012_period_distribution(
    P: Union[int, float], M1: Union[int, float], log10Pmin: Union[int, float] = 1
) -> Union[int, float]:
    """
    period distribution which interpolates between
    Duquennoy and Mayor 1991 at low mass (G/K spectral type <~1.15Msun)
    and Sana et al 2012 at high mass (O spectral type >~16.3Msun)

    This gives dN/dlogP, i.e. DM/Raghavan's Gaussian in log10P at low mass
    and Sana's power law (as a function of logP) at high mass

    TODO: fix this function

    Args:
        P: period
        M1: Primary star mass
        log10Pmin: minimum period in base log10 (optional)

    Returns:
        'probability' of interpolated distribution function at P and M1

    """

    # Check if there is input and force it to be at least 1
    log10Pmin //= -1.0
    log10Pmin = max(-1.0, log10Pmin)

    # save mass input and limit mass used (M1 from now on) to fitted range
    Mwas = M1
    M1 = max(1.15, min(16.3, M1))
    # print("Izzard2012 called for M={} (trunc'd to {}), P={}\n".format(Mwas, M1, P))

    # Calculate the normalisations
    # need to normalize the distribution for this mass
    # (and perhaps secondary mass)
    prepare_dict(distribution_constants, ["Izzard2012", M1])
    if not distribution_constants["Izzard2012"][M1].get(log10Pmin):
        distribution_constants["Izzard2012"][M1][
            log10Pmin
        ] = 1  # To prevent this loop from going recursive
        N = 200.0  # Resolution for normalisation. I hope 1000 is enough
        dlP = (10.0 - log10Pmin) / N
        C = 0  # normalisation const.
        for lP in np.arange(log10Pmin, 10, dlP):
            C += dlP * Izzard2012_period_distribution(10 ** lP, M1, log10Pmin)

        distribution_constants["Izzard2012"][M1][log10Pmin] = 1.0 / C
    # print(
    #     "Normalization constant for Izzard2012 M={} (log10Pmin={}) is\
    #     {}\n".format(
    #         M1, log10Pmin, distribution_constants["Izzard2012"][M1][log10Pmin]
    #     )
    # )

    lP = math.log10(P)
    # log period

    # # fits
    mu = interpolate_in_mass_izzard2012(M1, -17.8, 5.03)
    sigma = interpolate_in_mass_izzard2012(M1, 9.18, 2.28)
    K = interpolate_in_mass_izzard2012(M1, 6.93e-2, 0.0)
    nu = interpolate_in_mass_izzard2012(M1, 0.3, -1)
    g = 1.0 / (1.0 + 1e-30 ** (lP - nu))

    lPmu = lP - mu
    # print(
    #     "M={} ({}) P={} : mu={} sigma={} K={} nu={} norm=%g\n".format(
    #         Mwas, M1, P, mu, sigma, K, nu
    #     )
    # )

    # print "FUNC $distdata{Izzard2012}{$M}{$log10Pmin} * (exp(- (x-$mu)**2/(2.0*$sigma*$sigma) ) + $K/MAX(0.1,$lP)) * $g;\n";

    if (lP < log10Pmin) or (lP > 10.0):
        return 0

    else:
        return (
            distribution_constants["Izzard2012"][M1][log10Pmin]
            * (math.exp(-lPmu * lPmu / (2.0 * sigma * sigma)) + K / max(0.1, lP))
            * g
        )


########################################################################
# Mass ratio distributions
########################################################################


def flatsections(x: float, opts: dict) -> Union[float, int]:
    """
    Function to generate flat distributions, possibly in multiple sections

    Args:
        x: mass ratio value
        opts: list containing the flat sections. Which are themselves dictionaries, with keys "max": upper bound, "min": lower bound and "height": value

    Returns:
        probability of that mass ratio.
    """

    c = 0
    y = 0

    for opt in opts:
        dc = (opt["max"] - opt["min"]) * opt["height"]
        # print("added flatsection ({}-{})*{} = {}\n".format(
        #   opt['max'], opt['min'], opt['height'], dc))
        c += dc
        if opt["min"] <= x <= opt["max"]:
            y = opt["height"]
            # print("Use this\n")

    c = 1.0 / c
    y = y * c

    # print("flatsections gives C={}: y={}\n",c,y)
    return y


# print(flatsections(1, [{'min': 0, 'max': 2, 'height': 3}]))

########################################################################
# Eccentricity distributions
########################################################################

########################################################################
# Star formation histories
########################################################################


def cosmic_SFH_madau_dickinson2014(z):
    """
    Cosmic star formation history distribution from Madau & Dickonson 2014 (https://arxiv.org/pdf/1403.0007.pdf)

    Args:
        z: redshift

    Returns:
        Cosmic star formation rate in Solarmass year^-1 megaparsec^-3
    """

    CSFH = 0.015 * ((1 + z) ** 2.7) / (1 + (((1 + z) / 2.9) ** 5.6))

    return CSFH


########################################################################
# Metallicity distributions
########################################################################


########################################################################
# Moe & DiStefano 2017 functions
########################################################################

import py_rinterpolate

# Tasks
# TODO: check all the raise ValueErrors to make them more appropriate
# TODO: Put the json checking stuff in a different function
# TODO: make function to normalize dictionary of 1 layer deep
# TODO: make use of the @cached_property decorators to make use of cached calls
# TODO: Fix the automatic freeing functions in py_rinterpolate
# TODO: check very well which functions actually need to be part of the population object

# Global dictionary to store values in
Moecache = {}


def poisson(lambda_val, n, nmax=None):
    """
    Function that calculates the poisson value and normalizes TODO: improve the description
    """

    cachekey = "{} {} {}".format(lambda_val, n, nmax)

    if distribution_constants.get("poisson_cache", None):
        if distribution_constants["poisson_cache"].get(cachekey, None):
            p_val = distribution_constants["poisson_cache"][cachekey]
            # print(
            #     "Found cached value: Poisson ({}, {}, {}) = {}\n".format(
            #         lambda_val, n, nmax, p_val
            #     )
            # )
            return p_val

    # Poisson distribution : note, n can be zero
    #
    # nmax is the truncation : if set, we normalize
    # correctly.
    p_val = _poisson(lambda_val, n)

    if nmax:
        I_poisson = 0
        for i in range(nmax + 1):
            I_poisson += _poisson(lambda_val, i)
        p_val = p_val / I_poisson

    # Add to cache
    if not distribution_constants.get("poisson_cache", None):
        distribution_constants["poisson_cache"] = {}
    distribution_constants["poisson_cache"][cachekey] = p_val

    # print("Poisson ({}, {}, {}) = {}\n".format(lambda_val, n, nmax, p_val))
    return p_val


def _poisson(lambda_val, n):
    """
    Function to return the poisson value
    """

    return (lambda_val ** n) * np.exp(-lambda_val) / (1.0 * math.factorial(n))


def Moe_de_Stefano_2017_multiplicity_fractions(options):
    # Returns a list of multiplicity fractions for a given input list of masses
    global Moecache

    # TODO: make an extrapolation functionality in this. log10(1.6e1) is low, we can probably go a bit further

    result = {}

    # TODO: decide what to do with this.
    # if(0)
    # {
    #     # we have these in a table, so could interpolate, but...
    #     $Moecache{'rinterpolator_fractions_M1_table'} //=
    #         rinterpolate->new(
    #             table => $Moecache{'fractions_M1_table'},
    #             nparams => 1, # logM1
    #             ndata => 4, # single/binary/triple/quadruple fractions
    #         );

    #     $r = $Moecache{'rinterpolator_fractions_M1_table'}->interpolate([log10($opts->{'M1'})]);
    # }

    # ... it's better to interpolate the multiplicity and then
    # use a Poisson distribution to calculate the fractions
    # (this is more accurate)
    # Set up the multiplicity interpolator
    if not Moecache.get("rinterpolator_multiplicity", None):
        Moecache["rinterpolator_multiplicity"] = py_rinterpolate.Rinterpolate(
            table=Moecache["multiplicity_table"],  # Contains the table of data
            nparams=1,  # logM1
            ndata=4,  # The amount of datapoints (the parameters that we want to interpolate)
        )

    if options["multiplicity_model"] == "Poisson":
        multiplicity = Moecache["rinterpolator_multiplicity"].interpolate(
            [np.log10(options["M1"])]
        )[0]

        for n in range(4):
            result[n] = (
                options["multiplicity_modulator"][n] * poisson(multiplicity, n, 3)
                if options["multiplicity_modulator"][n] > 0
                else 0
            )

    elif options["multiplicity_model"] == "data":
        # use the fractions calculated from Moe's data directly
        #
        # note that in this case, there are no quadruples: these
        # are combined with triples

        for n in range(3):
            result[n] = (
                Moecache["rinterpolator_multiplicity"].interpolate(
                    [np.log10(options["M1"])]
                )[n + 1]
                * options["multiplicity_modulator"][n]
            )
        result[3] = 0.0  # no quadruples

    # Normalisation:
    if options["normalize_multiplicities"] == "raw":
        # do nothing : use raw Poisson predictions
        pass
    elif options["normalize_multiplicities"] == "norm":
        # simply normalize so all multiplicities add to 1
        sum_result = sum([result[key] for key in result.keys()])
        for key in result.keys():
            result[key] = result[key] / sum_result
    elif options["normalize_multiplicities"] == "merge":
        # if multiplicity == 1, merge binaries, triples and quads into singles
        #    (de facto this is the same as "norm")
        # if multiplicity == 2, merge triples and quads into binaries
        # if multiplicity == 3, merge quads into triples
        # if multiplicity == 4, do nothing, equivalent to 'raw'.

        # TODO: ask rob about this part. Not sure if i understand it.

        multiplicitty = 0
        for n in range(4):
            if options["multiplicity_modulator"][n] > 0:
                multiplicity = n + 1
        print("Multiplicity: {}: {}".format(multiplicity, str(result)))

        if multiplicity == 1:
            result[0] = 1.0
            for n in range(1, 4):
                result[0] += result[n]
                result[n] = 0

        elif multiplicity == 2:
            # we have only singles and binaries:
            # triples and quads are treated as binaries

            for n in [2, 3]:
                result[1] += result[n]
                result[n] = 0

        elif multiplicity == 3:
            # we have singles, binaries and triples:
            # quads are treated as triples
            result[2] += result[3]
            result[3] = 0

        elif multiplicity == 4:
            # we have singles, binaries, triples and quads,
            # so do nothing
            pass
        else:
            print(
                "Error: in the Moe distribution, we seem to want no stars at all (multiplicity == {}).\n".format(
                    multiplicity
                )
            )

        sum_result = sum([result[key] for key in result.keys()])
        for key in result.keys():
            result[key] = result[key] / sum_result

    # print("Multiplicity array: {}".format(str(result)))

    # return array reference
    return result

# @profile
def build_q_table(options, m, p):
    ############################################################
    #
    # Build an interpolation table for q, given a mass and
    # orbital period.
    #
    # $m and $p are labels which determine which system(s)
    # to look up from Moe's data:
    #
    # $m can be M1, M2, M3, M4, or if set M1+M2 etc.
    # $p can be P, P2, P3
    #
    # The actual values are in $opts:
    #
    # mass is in $opts->{$m}
    # period is  $opts->{$p}
    #
    ############################################################

    # Since the information from the table for M&S is independent of any choice we make,
    # we need to take into account that for example our choice of minimum mass leads to a minimum q_min that is not the same as in the table
    # We should ignore those parts of the table and renormalize. If we are below the lowest value of qmin in the table we need to extrapolate the data

    # We can check if we have a cached value for this already:
    # TODO: fix this cache check.
    incache = False
    if Moecache.get("rinterpolator_q_metadata", None):
        if (Moecache["rinterpolator_q_metadata"][m]) and (Moecache["rinterpolator_q_metadata"][p]):
            if (Moecache["rinterpolator_q_metadata"][m] == options[m]) and (Moecache["rinterpolator_q_metadata"][p] == options[p]):
                incache = True
        # print("INCACHE: {}".format(incache))

    #
    if not incache:
        # trim and/or expand the table to the range $qmin to $qmax.

        # qmin is set by the minimum stellar mass : below this
        # the companions are planets
        qmin = options["ranges"]["M"][
            0
        ]  # TODO: this lower range must not be lower than Mmin. However, since the q_min is used as a sample range for the M2, it should use this value I think. discuss

        # TODO: fix that this works. Should depend on metallicity too I think. iirc this function does not work now.
        # qmax = maximum_mass_ratio_for_RLOF(options[m], options[p])
        # TODO: change this to the above
        qmax = 1

        # qdata contains the table that we modify: we get
        # the original data by interpolating Moe's table
        qdata = {}
        can_renormalize = 1

        qeps = 1e-8  # small number but such that qeps+1 != 1
        if qeps + 1 == 1.0:
            print("qeps (= {}) +1 == 1. Make qeps larger".format(qeps))

        if qmin >= qmax:
            # there may be NO binaries in this part of the parameter space:
            # in which case, set up a table with lots of zero in it

            qdata = {0: 0, 1: 0}
            can_renormalize = 0

        else:
            # qmin and qmax mean we'll get something non-zero
            can_renormalize = 1

            # require extrapolation sets whether we need to extrapolate
            # at the low and high ends
            require_extrapolation = {}

            if qmin >= 0.15:
                # qmin is inside Moe's table : this is easy,
                # we just keep points from qmin at the low
                # end to qmax at the high end.
                require_extrapolation["low"] = 0
                require_extrapolation[
                    "high"
                ] = 1  # TODO: shouldnt the extrapolation need to happen if qmax > 0.95
                qdata[qmin] = Moecache["rinterpolator_q"].interpolate(
                    [np.log10(options[m]), np.log10(options[p])]
                )[0]

                for q in np.arange(0.15, 0.950001, 0.1):
                    if (q >= qmin) and (q <= qmax):
                        qdata[q] = Moecache["rinterpolator_q"].interpolate(
                            [np.log10(options[m]), np.log10(options[p]), q]
                        )[0]
            else:
                require_extrapolation["low"] = 1
                require_extrapolation["high"] = 1
                if qmax < 0.15:
                    # qmax < 0.15 which is off the edge
                    # of the table. In this case, choose
                    # two points at q=0.15 and 0.16 and interpolate
                    # at these in case we want to extrapolate.
                    for q in [0.15, 0.16]:
                        qdata[q] = Moecache["rinterpolator_q"].interpolate(
                            [np.log10(options[m]), np.log10(options[p]), q]
                        )[0]
                else:
                    # qmin < 0.15 and qmax > 0.15, so we
                    # have to generate Moe's table for
                    # q = 0.15 (i.e. 0.1 to 0.2) to 0.95 (0.9 to 1)
                    # as a function of M1 and orbital period,
                    # to obtain the q distribution data.

                    for q in np.arange(0.15, np.min([0.950001, qmax + 0.0001]), 0.1):
                        val = Moecache["rinterpolator_q"].interpolate(
                            [np.log10(options[m]), np.log10(options[p]), q]
                        )[0]
                        # print("val: ", val)
                        # print("q: ", q)
                        # print("qdata: ", qdata)
                        # print("type(qdata): ", type(qdata))
                        qdata[q] = val

                # just below qmin, if qmin>qeps, we want nothing
                if qmin - 0.15 > qeps:
                    q = qmin - qeps
                    qdata[q] = 0
                    require_extrapolation["low"] = 0

            # just above qmax, if qmax<1, we want nothing
            if qmax < 0.95:
                q = qmax + qeps
                qdata[q] = 0
                require_extrapolation["high"] = 0

            # sorted list of qs
            qs = sorted(qdata.keys())

            if len(qs) == 0:
                print("No qs found error")
                raise ValueError

            elif len(qs) == 1:
                # only one q value : pretend there are two
                # with a flat distribution up to 1.0.
                if qs[0] == 1.0:
                    qs[0] = 1.0 - 1e-6
                    qs[1] = 1
                    qdata[qs[0]] = 1
                    qdata[qs[1]] = 1
                else:
                    qs[1] = 1.0
                    qdata[qs[1]] = qs[0]

            else:
                for pre in ["low", "high"]:
                    if require_extrapolation[pre] == 0:
                        continue
                    else:
                        sign = -1 if pre == "low" else 1
                        end_index = 0 if pre == "low" else len(qs) - 1
                        indices = [0, 1] if pre == "low" else [len(qs) - 1, len(qs) - 2]
                        method = options.get("q_{}_extrapolation_method", None)
                        qlimit = qmin if pre == "log" else qmax

                        # print("Q: {} method: {}".format(pre, method))
                        # print("indices: {}".format(indices))
                        # print("End index: {}".format(end_index))
                        # print("QS: {}".format(str(indices)))

                        # truncate the distribution
                        qdata[max(0.0, min(1.0, qlimit + sign * qeps))] = 0

                        if method == None:
                            # no extrapolation : just interpolate between 0.10 and 0.95
                            continue
                        elif method == "flat":
                            # use the end value value and extrapolate it
                            # with zero slope
                            qdata[qlimit] = qdata[qs[end_index]]
                        elif method == "linear":
                            # linear extrapolation
                            # print("Linear 2 {}".format(pre))
                            dq = qs[indices[1]] - qs[indices[0]]

                            if dq == 0:
                                # No change
                                # print("dq = 0")
                                qdata[qlimit] = qs[end_index]
                            else:
                                slope = (
                                    qdata[qs[indices[1]]] - qdata[qs[indices[0]]]
                                ) / dq
                                intercept = (
                                    qdata[qs[indices[0]]] - slope * qs[indices[0]]
                                )
                                qdata[qlimit] = max(0.0, slope * qlimit + intercept)
                                # print(
                                #     "Slope: {} intercept: {} dn/dq({}) = {}".format(
                                #         slope, intercept, qlimit, qdata[qlimit]
                                #     )
                                # )
                        elif method == "plaw2":
                            newq = 0.05
                            # use a power-law extrapolation down to q=0.05, if possible
                            if (qdata[qs[indices[0]]] == 0) and (
                                qdata[qs[indices[1]]] == 0.0
                            ):
                                # not possible
                                qdata[newq] = 0
                            else:
                                slope = (
                                    np.log10(qdata[qs[indices[1]]])
                                    - np.log10(qdata[qs[indices[0]]])
                                ) / (
                                    np.log10(qs[indices[1]]) - np.log10(qs[indices[0]])
                                )
                                intercept = np.log10(
                                    qdata[qs[indices[0]]]
                                ) - slope * log10(qs[indices[0]])
                                qdata[newq] = slope * newq + intercept

                        elif method == "nolowq":
                            newq = 0.05
                            qdata[newq] = 0
                        else:
                            print("No other methods available")
                            raise ValueError

                        # TODO: consider implementing this
                        # elsif($method =~ /^(log)?poly(\d+)/)
                        # {
                        #     ############################################################
                        #     # NOT WORKING / TESTED
                        #     ############################################################

                        #     # fit a polynomial of degree n in q or log10(q)
                        #     my $dolog = defined $1 ? 1 : 0;
                        #     my $n = $2;

                        #     # make xdata : list of qs
                        #     my @xdata;
                        #     if($dolog)
                        #     {
                        #         @xdata = @qs; # linear
                        #     }
                        #     else
                        #     {
                        #         @xdata = map{log10(MAX(1e-20,$_))}@qs; # log
                        #     }

                        #     # make ydata : from JSON
                        #     my @ydata = map{$qdata->{$_}}@qs; # y data

                        #     # make parameters: these are 1/nn
                        #     my @parameters; # parameters
                        #     $#parameters = $n - 1; # set length of parameters array
                        #     {
                        #         my $nn = $n;
                        #         @parameters = map{1.0/$nn--}@parameters; # set all to 1.0/parameter number
                        #     }

                        #     # make polynomial with parameters
                        #     my $formula = Math::Polynomial->new(@parameters); # use a polynomial fit
                        #     my $variable ='x';
                        #     my $max_iter = 100;
                        #     my $square_residual = Algorithm::CurveFit->curve_fit(
                        #         $formula,
                        #         \@parameters,
                        #         $variable,
                        #         \@xdata,
                        #         \@ydata,
                        #         $max_iter);

                        #     # evaluate at q=0.05
                        #     my $newq = 0.05;
                        #     $qdata->{$newq} = $formula->evaluate($newq);

                        #     print Data::Dumper::Dumper(\@parameters);
                        #     exit;
                        # }

        # regenerate qs in new table
        tmp_table = []
        for q in sorted(qdata.keys()):
            tmp_table.append([q, qdata[q]])

        if Moecache.get("rinterpolator_q_given_{}_log10{}".format(m, p), None):
            print("Present interpolator: {}".format(Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)]))
            print("Destroying present interpolator:")
            interpolator = Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)]
            print(interpolator)
            print(type(interpolator))
            print(dir(interpolator))
            x=Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)].interpolate([0.5])
            print("Interpolated a value q=0.5: {}".format(x))
            Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)].destroy()
            print(interpolator)
            print(type(interpolator))
            print(dir(interpolator))
            print("Present interpolator: {}".format(Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)]))
            x=Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)].interpolate([0.5])
            print("Interpolated a value q=0.5: {}".format(x))
            # del Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)]

        print("CREATING A NEW TABLE Q table")
        # Make an interpolation table to contain our modified data
        q_interpolator = py_rinterpolate.Rinterpolate(
            table=tmp_table, nparams=1, ndata=1  # Contains the table of data  # q  #
        )
        print("CREATed A NEW TABLE Q table")
        # TODO: build a check in here to see if the interpolator build was successful

        # print("Can renormalize?: {}".format(can_renormalize))
        if can_renormalize:
            # now we integrate and renormalize (if the table is not all zero)
            #
            dq = 1e-3  # resolution of the integration/renormalization
            I = 0

            # integrate: note that the value of the integral is
            # meaningless to within a factor (which depends on $dq)
            for q in np.arange(0, 1 + 2e-6, dq):
                x = q_interpolator.interpolate([q])
                if len(x) == 0:
                    print("Q interpolator table interpolation failed")
                    print("tmp_table = {}".format(str(tmp_table)))
                    print("q_data = {}".format(str(qdata)))
                    raise ValueError
                else:
                    I += x[0] * dq
                    # print("dn/dq ({}) = {} I -> = {}".format(q, x[0], I))

            if I > 0:
                # normalize to 1.0 by dividing the data by 1.0/$I
                q_interpolator.multiply_table_column(1, 1.0 / I)

                # test this
                I = 0
                for q in np.arange(0, 1 + 2e-6, dq):
                    I += q_interpolator.interpolate([q])[0] * dq
                # print("Q integral: {}, {}".format(I, q_interpolator))

                # fail if error in integral > 1e-6 (should be ~ machine precision)
                if abs(1.0 - I) > 1e-6:
                    print("Error: > 1e-6 in q probability integral: {}".format(I))

        # set this new table in the cache
        print("STORING Q INTERPOLATOR AS {}".format("rinterpolator_q_given_{}_log10{}".format(m, p)))
        Moecache["rinterpolator_q_given_{}_log10{}".format(m, p)] = q_interpolator
        print("STORed Q INTERPOLATOR AS {}".format("rinterpolator_q_given_{}_log10{}".format(m, p)))

        if not Moecache.get("rinterpolator_q_metadata", None):
            Moecache["rinterpolator_q_metadata"] = {}
        Moecache["rinterpolator_q_metadata"][m] = options[m]
        Moecache["rinterpolator_q_metadata"][p] = options[p]


def Moe_de_Stefano_2017_pdf(options):
    # Moe and de Stefano probability density function
    #
    # takes a dctionary as input (in options) with options:
    #
    # M1, M2, M3, M4 => masses (Msun) [M1 required, rest optional]
    # P, P2, P3 => periods (days) [number: none=binary, 2=triple, 3=quadruple]
    # ecc, ecc2, ecc3 => eccentricities [numbering as for P above]
    #
    # mmin => minimum allowed stellar mass (default 0.07)
    # mmax => maximum allowed stellar mass (default 80.0)
    #

    prob = []  # Value that we will return

    multiplicity = options["multiplicity"]
    if not options.get("multiplicity", None):
        multiplicity = 1
        for n in range(2, 5):
            multiplicity += 1 if options.get("M{}".format(n), None) else 0
    else:
        multiplicity = options["multiplicity"]

    # immediately return 0 if the multiplicity modulator is 0
    if options["multiplicity_modulator"][multiplicity - 1] == 0:
        print("_pdf ret 0 because of mult mod\n")
        return 0

    ############################################################
    # multiplicity fraction
    prob.append(Moe_de_Stefano_2017_multiplicity_fractions(options)[multiplicity - 1])

    ############################################################
    # always require an IMF for the primary star
    #
    # NB multiply by M1 to convert dN/dM to dN/dlnM
    # (dlnM = dM/M, so 1/dlnM = M/dM)

    # TODO: Create an n-part-powerlaw method that can have breakpoints and slopes. I'm using a three-part powerlaw now.
    prob.append(Kroupa2001(options["M1"]) * options["M1"])

    if multiplicity >= 2:

        # Separation of the inner binary
        options["sep"] = calc_sep_from_period(
            options["M1"], options["M2"], options["P"]
        )

        # Total mass inner binary:
        options["M1+M2"] = options["M1"] + options["M2"]

        # binary, triple or quadruple system
        if not Moecache.get("rinterpolator_log10P", None):
            Moecache["rinterpolator_log10P"] = py_rinterpolate.Rinterpolate(
                table=Moecache["period_distributions"],  # Contains the table of data
                nparams=2,  # log10M, log10P
                ndata=2,  # binary, triple
            )

        # TODO: fix this function
        # test the period distribution integrates to 1.0
        # if(0){
        #     if(!defined $Moecache{'P_integrals'}->{$opts->{'M1'}})
        #     {
        #         my $I = 0.0;
        #         my $dlogP = 1e-3;
        #         for(my $logP = 0.0; $logP < 10.0; $logP += $dlogP)
        #         {
        #             my $dp_dlogP = # dp / dlogP
        #                 $Moecache{'rinterpolator_log10P'}->interpolate(
        #                     [
        #                      log10($opts->{'M1'}),
        #                      $logP
        #                     ])->[0];
        #             $I += $dp_dlogP * $dlogP;
        #             #printf "logM1 = %g, logP = %g -> dp/logP = %g\n",
        #             #    log10($opts->{'M1'}),
        #             #    $logP,
        #             #    $dp_dlogP;
        #         }
        #         $Moecache{'P_integrals'}->{$opts->{'M1'}} = $I;

        #         #printf "M1=%g : P integral %g\n",
        #         #    $opts->{'M1'},
        #         #    $Moecache{'P_integrals'}->{$opts->{'M1'}};
        #     }
        # }

        prob.append(
            Moecache["rinterpolator_log10P"].interpolate(
                [np.log10(options["M1"]), np.log10(options["P"])]
            )[0]
        )

        ############################################################
        # mass ratio (0 < q = M2/M1 < qmax)
        #
        # we need to construct the q table for the given M1
        # subject to qmin = Mmin/M1

        # Make a table storing Moe's data for q distributions

        if not Moecache.get("rinterpolator_q", None):
            Moecache["rinterpolator_q"] = py_rinterpolate.Rinterpolate(
                table=Moecache["q_distributions"],  # Contains the table of data
                nparams=3,  # log10M, log10P, q
                ndata=1,  #
            )

        # Build the table for q
        primary_mass = options["M1"]
        secondary_mass = options["M2"]
        m_label = "M1"
        p_label = "P"

        build_q_table(options, m_label, p_label)
        prob.append(
            Moecache[
                "rinterpolator_q_given_{}_log10{}".format(m_label, p_label)
            ].interpolate([secondary_mass / primary_mass])[0]
        )

        # TODO add eccentricity calculations

        if multiplicity >= 3:
            # triple or quadruple system

            ############################################################
            # orbital period 2 =
            #     orbital period of star 3 (multiplicity==3) or
            #     the star3+star4 binary (multiplicity==4)
            #
            # we assume the same period distribution for star 3
            # (or stars 3 and 4) but with a separation that is >10*a*(1+e)
            # where 10*a*(1+e) is the maximum apastron separation of
            # stars 1 and 2

            # TODO: Is this a correct assumption?
            max_sep = (
                10.0 * options["sep"] * (1.0 + options["ecc"])
            )  # TODO: isnt this the minimal separation?
            min_P2 = calc_period_from_sep(options["M1+M2"], options["mmin"], max_sep)

            if options["P2"] < min_P2:
                # period is too short : system is not hierarchical
                prob.append(0)
            else:
                # period is long enough that the system is hierarchical

                # hence the separation between the outer star
                # and inner binary
                options["sep2"] = calc_sep_from_period(
                    options["M3"], options["M1+M2"], options["P2"]
                )

                if not Moecache.get("rinterpolator_log10P2", None):
                    Moecache["rinterpolator_log10P2"] = py_rinterpolate.Rinterpolate(
                        table=Moecache[
                            "period_distributions"
                        ],  # Contains the table of data
                        nparams=2,  # log10(M1+M2), log10P # TODO: Really? Is the first column of that table M1+M2?
                        ndata=2,  # binary, triple
                    )

                if not Moecache.get("P2_integrals", None):
                    Moecache["P2_integrals"] = {}

                if not Moecache["P2_integrals"].get(options["M1+M2"], None):
                    # normalize because $min_per > 0 so not all the periods
                    # should be included
                    I_p2 = 0
                    dlogP2 = 1e-3
                    for logP2 in np.arange(np.log10(min_P2), 10, dlogP2):
                        # dp_dlogP2 = dp / dlogP
                        dp_dlogP2 = Moecache["rinterpolator_log10P2"].interpolate(
                            [np.log10(options["M1"]), logP2]
                        )[0]
                        I_p2 += dp_dlogP2 * dlogP2
                    Moecache["P2_integrals"][options["M1+M2"]] = I_p2

                p_val = Moecache["rinterpolator_log10P"].interpolate(
                    [np.log10(options["M1+M2"]), np.log10(options["P2"])]
                )[0]
                p_val = p_val / Moecache["P2_integrals"]
                prob.append(p_val)

                ############################################################
                # mass ratio 2 = q2 = M3 / (M1+M2)
                #
                # we need to construct the q table for the given M1
                # subject to qmin = Mmin/(M1+M2)
                #
                # Make a table storing Moe's data for q distributions
                # TODO: Check if this is correct.
                if not Moecache.get("rinterpolator_q", None):
                    Moecache["rinterpolator_q"] = py_rinterpolate.Rinterpolate(
                        table=Moecache["q_distributions"],  # Contains the table of data
                        nparams=3,  # log10(M1+M2), log10P2, q
                        ndata=1,  #
                    )

                # Build the table for q2
                primary_mass = options["M1+M2"]
                secondary_mass = options["M3"]
                m_label = "M1+M2"
                p_label = "P2"

                build_q_table(options, m_label, p_label)
                prob.append(
                    Moecache[
                        "rinterpolator_q_given_{}_log10{}".format(m_label, p_label)
                    ].interpolate([secondary_mass / primary_mass])[0]
                )

                # TODO: ecc2

                if multiplicity == 4:
                    # quadruple system.
                    # TODO: Ask Rob about the strructure of the quadruple. Is htis only double binary quadrupples?

                    ############################################################
                    # orbital period 3
                    #
                    # we assume the same period distribution for star 4
                    # as for any other stars but Pmax must be such that
                    # sep3 < sep2 * 0.2

                    max_sep3 = 0.2 * options["sep2"] * (1.0 + opts["ecc2"])
                    max_per3 = calc_period_from_sep(
                        options["M1+M2"], options["mmin"], max_sep3
                    )

                    if not Moecache.get("rinterpolator_log10P2", None):
                        Moecache[
                            "rinterpolator_log10P2"
                        ] = py_rinterpolate.Rinterpolate(
                            table=Moecache[
                                "period_distributions"
                            ],  # Contains the table of data
                            nparams=2,  # log10(M1+M2), log10P3
                            ndata=2,  # binary, triple
                        )

                    if not Moecache.get("P2_integrals", None):
                        Moecache["P2_integrals"] = {}

                    if not Moecache["P2_integrals"].get(options["M1+M2"], None):
                        # normalize because $min_per > 0 so not all the periods
                        # should be included
                        I_p2 = 0
                        dlogP2 = 1e-3

                        for logP2 in np.arange(np.log10(min_P2), 10, dlogP2):
                            # dp_dlogP2 = dp / dlogP
                            dp_dlogP2 = Moecache["rinterpolator_log10P2"].interpolate(
                                [np.log10(options["M1"]), logP2]
                            )[0]
                            I_p2 += dp_dlogP2 * dlogP2
                        Moecache["P2_integrals"][options["M1+M2"]] = I_p2

                    # TODO: should this really be the log10P interpolator?
                    p_val = Moecache["rinterpolator_log10P"].interpolate(
                        [np.log10(options["M1+M2"]), np.log10(options["P2"])]
                    )[0]

                    p_val = p_val / Moecache["P2_integrals"]
                    prob.append(p_val)

                    ############################################################
                    # mass ratio 2
                    #
                    # we need to construct the q table for the given M1
                    # subject to qmin = Mmin/(M1+M2)
                    # Make a table storing Moe's data for q distributions
                    # TODO: Check if this is correct.
                    if not Moecache.get("rinterpolator_q", None):
                        Moecache["rinterpolator_q"] = py_rinterpolate.Rinterpolate(
                            table=Moecache[
                                "q_distributions"
                            ],  # Contains the table of data
                            nparams=3,  # log10(M1+M2), log10P2, q
                            ndata=1,  #
                        )

                    # Build the table for q2
                    primary_mass = options["M1+M2"]
                    secondary_mass = options["M3"]
                    m_label = "M1+M2"
                    p_label = "P2"

                    build_q_table(options, m_label, p_label)
                    prob.append(
                        Moecache[
                            "rinterpolator_q_given_{}_log10{}".format(m_label, p_label)
                        ].interpolate([secondary_mass / primary_mass])[0]
                    )

                    # todo ecc 3

    elif multiplicity not in range(1, 5):
        print("Unknown multiplicity {}\n".format(multiplicity))

    # TODO: Translate this to get some better output
    # printf "PDFPROBS (M1=%g, logP=%g q=%g) : @prob\n",
    #    $opts->{'M1'},
    #    $multiplicity == 2 ? ($opts->{'M2'}/$opts->{'M1'}) : -999,
    #    $multiplicity == 2 ? log10($opts->{'P'}) : -999;

    # the final probability density is the product of all the
    # probability density functions
    # TODO: Where do we take into account the stepsize? probdens is not a probability if we dont take a stepsize
    prob_dens = 1 * np.prod(prob)

    print_info = 1
    if print_info:
        # print("Probability density")

        if multiplicity == 1:
            print(
                "M1={} q=N/A log10P=N/A ({}): {} -> {}\n".format(
                    options["M1"], len(prob), str(prob), prob_dens
                )
            )
        elif multiplicity == 2:
            print(
                "M1={} q={} log10P={} ({}): {} -> {}\n".format(
                    options["M1"],
                    options["M2"] / options["M1"],
                    np.log10(options["P"]),
                    len(prob),
                    str(prob),
                    prob_dens,
                )
            )
        elif multiplicity == 3:
            print(
                "M1={} q={} log10P={} ({}): M3={}} P2={} ecc2={} : {} - {}".format(
                    options["M1"],
                    options["M2"] / options["M1"],
                    np.log10(options["P"]),
                    len(prob),
                    options["M3"],
                    np.log10("P2"),
                    np.log10("ecc2"),
                    str(prob),
                    prob_dens,
                )
            )
        elif multiplicity == 4:
            print(
                "M1={} q={} log10P={} ({}) : M3={} P2={}} ecc2={} : M4={} P3={} ecc3={} : {} -> {}".format(
                    options["M1"],
                    options["M2"] / options["M1"],
                    np.log10(options["P"]),
                    len(prob),
                    options["M3"],
                    np.log10("P2"),
                    np.log10(options["ecc2"]),
                    options["M4"],
                    np.log10("P3"),
                    np.log10(options["ecc3"]),
                    str(prob),
                    prob_dens,
                )
            )
    return prob_dens