import unittest
import numpy as np
from binarycpython.utils.plot_functions import *
import matplotlib.pyplot as plt

# class test_(unittest.TestCase):
#     """
#     Unittests for function
#     """

#     def test_1(self):
#         pass

class test_color_by_index(unittest.TestCase):
    """
    Unittests for function color_by_index
    """

    def test_1(self):
        colors = ['red', 'white', 'blue']

        color = color_by_index([1,2,3], 1, colors)
        self.assertTrue(color=='blue')


class test_plot_system(unittest.TestCase):
    """
    Unittests for function
    """

    def test_mass_evolution_plot(self):
        plot_type = 'mass_evolution'
        show_plot = False
        output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        fig, ax = plt.subplots(nrows=1)
        self.assertTrue(type(output_fig_1)==fig.__class__)

        # with stellar types
        # plot_type = 'mass_evolution'
        # show_plot = False
        # show_stellar_types = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        # fig, ax = plt.subplots(nrows=1)
        # self.assertTrue(type(output_fig_2)==fig.__class__)

        # # show plot
        # show_plot = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

    def test_orbit_evolution_plot(self):
        plot_type = 'orbit_evolution'
        show_plot = False
        output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        fig, ax = plt.subplots(nrows=1)
        self.assertTrue(type(output_fig_1)==fig.__class__)

        # with stellar types
        # plot_type = 'orbit_evolution'
        # show_plot = False
        # show_stellar_types = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        # fig, ax = plt.subplots(nrows=1)
        # self.assertTrue(type(output_fig_2)==fig.__class__)

        # # show plot
        # show_plot = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

    def test_hr_diagram_plot(self):
        plot_type = 'hr_diagram'
        show_plot = False
        output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        fig, ax = plt.subplots(nrows=1)
        self.assertTrue(type(output_fig_1)==fig.__class__)

        # with stellar types
        # plot_type = 'hr_diagram'
        # show_plot = False
        # show_stellar_types = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

        # fig, ax = plt.subplots(nrows=1)
        # self.assertTrue(type(output_fig_2)==fig.__class__)

        # # show plot
        # show_plot = True
        # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)

    def test_unknown_plottype(self):
        plot_type = 'random'
        self.assertRaises(ValueError, plot_system, plot_type)

if __name__ == "__main__":
    unittest.main()