""" Unittests for plot_functions """ import unittest import numpy as np import matplotlib.pyplot as plt from binarycpython.utils.plot_functions import * from binarycpython.utils.functions import Capturing # class test_(unittest.TestCase): # """ # Unittests for function # """ # def test_1(self): # pass class test_color_by_index(unittest.TestCase): """ Unittests for function color_by_index """ def test_1(self): with Capturing() as output: self._test_1() def _test_1(self): """ First test """ colors = ["red", "white", "blue"] color = color_by_index([1, 2, 3], 1, colors) self.assertTrue(color == "blue") class test_plot_system(unittest.TestCase): """ Unittests for function """ def test_mass_evolution_plot(self): with Capturing() as output: self._test_mass_evolution_plot() def _test_mass_evolution_plot(self): """ Test for setting plot_type = "mass_evolution" """ plot_type = "mass_evolution" show_plot = False output_fig_1 = plot_system( plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000, ) fig, ax = plt.subplots(nrows=1) self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'mass_evolution' # show_plot = False # show_stellar_types = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # fig, ax = plt.subplots(nrows=1) # self.assertTrue(type(output_fig_2)==fig.__class__) # # show plot # show_plot = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_orbit_evolution_plot(self): with Capturing() as output: self._test_orbit_evolution_plot() def _test_orbit_evolution_plot(self): """ Test for setting plot_type = "orbit_evolution" """ plot_type = "orbit_evolution" show_plot = False output_fig_1 = plot_system( plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000, ) fig, ax = plt.subplots(nrows=1) self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'orbit_evolution' # show_plot = False # show_stellar_types = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # fig, ax = plt.subplots(nrows=1) # self.assertTrue(type(output_fig_2)==fig.__class__) # # show plot # show_plot = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_hr_diagram_plot(self): with Capturing() as output: self._test_hr_diagram_plot() def _test_hr_diagram_plot(self): """ Test for setting plot_type = "hr_diagram" """ plot_type = "hr_diagram" show_plot = False output_fig_1 = plot_system( plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000, ) fig, ax = plt.subplots(nrows=1) self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'hr_diagram' # show_plot = False # show_stellar_types = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # fig, ax = plt.subplots(nrows=1) # self.assertTrue(type(output_fig_2)==fig.__class__) # # show plot # show_plot = True # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_unknown_plottype(self): with Capturing() as output: self._test_unknown_plottype() def _test_unknown_plottype(self): """ Test for non-existant setting plot_type = "hr_diagram" """ plot_type = "random" self.assertRaises(ValueError, plot_system, plot_type) if __name__ == "__main__": unittest.main()