import math

import numpy as np
import matplotlib.pyplot as plt

from binarycpython.utils.distribution_functions import (
    three_part_powerlaw,
    Kroupa2001,
    Arenou2010_binary_fraction,
    raghavan2010_binary_fraction,
    imf_scalo1998,
    imf_scalo1986,
    imf_tinsley1980,
    imf_scalo1998,
    imf_chabrier2003,
    flatsections,
    duquennoy1991,
    sana12,
)
from binarycpython.utils.useful_funcs import calc_sep_from_period

################################################
# Example script to plot the available probability distributions.
# TODO:

################################################
# mass distribution plots
################################################
# mass_values = np.arange(0.11, 80, .1)

# kroupa_probability = [Kroupa2001(mass) for mass in mass_values]
# scalo1986 = [imf_scalo1986(mass) for mass in mass_values]
# tinsley1980 = [imf_tinsley1980(mass) for mass in mass_values]
# scalo1998 = [imf_scalo1998(mass) for mass in mass_values]
# chabrier2003 = [imf_chabrier2003(mass) for mass in mass_values]

# plt.plot(mass_values, kroupa_probability, label='Kroupa')
# plt.plot(mass_values, scalo1986, label='scalo1986')
# plt.plot(mass_values, tinsley1980, label='tinsley1980')
# plt.plot(mass_values, scalo1998, label='scalo1998')
# plt.plot(mass_values, chabrier2003, label='chabrier2003')

# plt.title('Probability distribution for mass of primary')
# plt.ylabel(r'Probability')
# plt.xlabel(r'Mass (M$_{\odot}$)')
# plt.yscale('log')
# plt.xscale('log')
# plt.grid()
# plt.legend()
# plt.show()

################################################
# Binary fraction distributions
################################################
# arenou_binary_distibution = [Arenou2010_binary_fraction(mass) for mass in mass_values]
# raghavan2010_binary_distribution = [raghavan2010_binary_fraction(mass) for mass in mass_values ]

# plt.plot(mass_values, arenou_binary_distibution, label='arenou 2010')
# plt.plot(mass_values, raghavan2010_binary_distribution, label='Raghavan 2010')
# plt.title('Binary fractions distributions')
# plt.ylabel(r'Binary fraction')
# plt.xlabel(r'Mass (M$_{\odot}$)')
# # plt.yscale('log')
# plt.xscale('log')
# plt.grid()
# plt.legend()
# plt.show()


################################################
# Mass ratio distributions
################################################

# mass_ratios = np.arange(0, 1, .01)
# example_mass = 2
# flat_dist = [flatsections(q, opts=[{'min':0.1/example_mass, 'max':0.8, 'height':1}, {'min': 0.8, 'max':1.0, 'height': 1.0}]) for q in mass_ratios]

# plt.plot(mass_ratios, flat_dist, label='Flat')
# plt.title('Mass ratio distributions')
# plt.ylabel(r'Probability')
# plt.xlabel(r'Mass ratio (q = $\frac{M1}{M2}$) ')
# plt.grid()
# plt.legend()
# plt.show()

################################################
# Period distributions
################################################
# TODO: fix this

logperiod_values = np.arange(-2, 12, 0.1)
duquennoy1991_distribution = [duquennoy1991(logper) for logper in logperiod_values]

# Sana12 distributions
m1 = 10
m2 = 5
period_min = 10 ** 0.15
period_max = 10 ** 5.5

sana12_distribution_q05 = [
    sana12(
        m1,
        m2,
        calc_sep_from_period(m1, m2, 10 ** logper),
        10 ** logper,
        calc_sep_from_period(m1, m2, period_min),
        calc_sep_from_period(m1, m2, period_max),
        math.log10(period_min),
        math.log10(period_max),
        -0.55,
    )
    for logper in logperiod_values
]

m1 = 10
m2 = 1
sana12_distribution_q01 = [
    sana12(
        m1,
        m2,
        calc_sep_from_period(m1, m2, 10 ** logper),
        10 ** logper,
        calc_sep_from_period(m1, m2, period_min),
        calc_sep_from_period(m1, m2, period_max),
        math.log10(period_min),
        math.log10(period_max),
        -0.55,
    )
    for logper in logperiod_values
]

m1 = 10
m2 = 10
sana12_distribution_q1 = [
    sana12(
        m1,
        m2,
        calc_sep_from_period(m1, m2, 10 ** logper),
        10 ** logper,
        calc_sep_from_period(m1, m2, period_min),
        calc_sep_from_period(m1, m2, period_max),
        math.log10(period_min),
        math.log10(period_max),
        -0.55,
    )
    for logper in logperiod_values
]


plt.plot(logperiod_values, duquennoy1991_distribution, label="Duquennoy & Mayor 1991")
plt.plot(logperiod_values, sana12_distribution_q05, label="Sana 12 (q=0.5)")
plt.plot(logperiod_values, sana12_distribution_q01, label="Sana 12 (q=0.1)")
plt.plot(logperiod_values, sana12_distribution_q1, label="Sana 12 (q=1)")
plt.title("Period distributions")
plt.ylabel(r"Probability")
plt.xlabel(r"Log10(orbital period)")
plt.grid()
plt.legend()
plt.show()


################################################
# Sampling part of distribution and calculating probability ratio
################################################

# TODO show the difference between sampling over the full range, or taking a smaller range initially and compensating for it.