diff --git a/binarycpython/tests/main.py b/binarycpython/tests/main.py index 5a00d810a0d58dc0a916d0efc53bb887e82e4b5b..81fff0a97b1d3ccbc3edf634608f6479511517bf 100644 --- a/binarycpython/tests/main.py +++ b/binarycpython/tests/main.py @@ -12,6 +12,7 @@ from binarycpython.tests.test_run_system_wrapper import * from binarycpython.tests.test_spacing_functions import * from binarycpython.tests.test_useful_funcs import * from binarycpython.tests.test_grid_options_defaults import * +from binarycpython.tests.test_stellar_types import * if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_distributions.py b/binarycpython/tests/test_distributions.py index bcb038cdf3595dd30e9c71fedbea8c4c18f9d383..0d993122b23aadd403f8ce44854f576b276d4da6 100644 --- a/binarycpython/tests/test_distributions.py +++ b/binarycpython/tests/test_distributions.py @@ -25,6 +25,57 @@ class TestDistributions(unittest.TestCase): self.tolerance = 1e-5 + def test_setopts(self): + """ + Unittest for function set_opts + """ + + default_dict = {'m1': 2, 'm2': 3} + output_dict_1 = set_opts(default_dict, {}) + self.assertTrue(output_dict_1==default_dict) + + + new_opts = {'m1': 10} + output_dict_2 = set_opts(default_dict, new_opts) + updated_dict = default_dict.copy() + updated_dict['m1'] = 10 + + self.assertTrue(output_dict_2==updated_dict) + + + def test_flat(self): + """ + Unittest for the function flat + """ + + output_1 = flat() + + self.assertTrue(isinstance(output_1, float)) + self.assertEqual(output_1, 1.0) + + def test_number(self): + """ + Unittest for function number + """ + + input_1 = 1.0 + output_1 = number(input_1) + + self.assertEqual(input_1, output_1) + + def test_const(self): + """ + Unittest for function const + """ + + output_1 = const(min_bound=0, max_bound=2) + self.assertEqual(output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1)) + + + output_2 = const(min_bound=0, max_bound=2, val=3) + self.assertEqual(output_2, 0, msg="Value should be 0, but is {}".format(output_2)) + + def test_powerlaw(self): """ unittest for the powerlaw test @@ -47,6 +98,9 @@ class TestDistributions(unittest.TestCase): for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) + # extra test for k = -1 + self.assertRaises(ValueError, powerlaw, 1, 100, -1, 10) + def test_three_part_power_law(self): """ unittest for three_part_power_law @@ -71,6 +125,11 @@ class TestDistributions(unittest.TestCase): for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) + # Extra test: + # M < M0 + self.assertTrue(three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3)==0, msg="Probability should be zero as M < M0") + + def test_Kroupa2001(self): """ unittest for three_part_power_law @@ -92,6 +151,9 @@ class TestDistributions(unittest.TestCase): # GO over the results and check whether they are equal (within tolerance) for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) + # Extra tests: + self.assertEqual(Kroupa2001(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3)) + def test_ktg93(self): """ @@ -114,6 +176,60 @@ class TestDistributions(unittest.TestCase): # GO over the results and check whether they are equal (within tolerance) for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) + # extra test: + self.assertEqual(ktg93(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7)) + + + def test_imf_tinsley1980(self): + """ + Unittest for function imf_tinsley1980 + """ + + m = 1.2 + self.assertEqual(imf_tinsley1980(m), three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)) + + def test_imf_scalo1986(self): + """ + Unittest for function imf_scalo1986 + """ + + m = 1.2 + self.assertEqual(imf_scalo1986(m), three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)) + + + def test_imf_scalo1998(self): + """ + Unittest for function imf_scalo1986 + """ + + m = 1.2 + self.assertEqual(imf_scalo1998(m), three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)) + + + def test_imf_chabrier2003(self): + """ + Unittest for function imf_chabrier2003 + """ + + input_1 = 0 + self.assertRaises(ValueError, imf_chabrier2003, input_1) + + # for m=0.5 + m = 0.5 + self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") + + # For m = 2 + m = 2 + self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") + + + def test_duquennoy1991(self): + """ + Unittest for function duquennoy1991 + """ + + self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) + def test_gaussian(self): """ @@ -137,6 +253,9 @@ class TestDistributions(unittest.TestCase): for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) + # Extra test: + self.assertTrue(gaussian(15, 4.8, 2.3, -2.0, 12.0)==0, msg="Probability should be 0 because the input period is out of bounds") + def test_Arenou2010_binary_fraction(self): """ unittest for three_part_power_law diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py index b83c2062156bde44a4d7872beef2b9c6a561cb92..a108a71ccb02baa2f993df9457c61d5fd615cf26 100644 --- a/binarycpython/tests/test_functions.py +++ b/binarycpython/tests/test_functions.py @@ -1,31 +1,297 @@ import unittest +import tempfile from binarycpython.utils.functions import * +from binarycpython.utils.custom_logging_functions import binary_c_log_code +from binarycpython.utils.run_system_wrapper import run_system +binary_c_temp_dir = temp_dir() ############################# # Script that contains unit tests for functions from the binarycpython.utils.functions file +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ -class test_get_help_super(unittest.TestCase): +# def test_1(self): +# pass + +class dummy(): + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + +class test_verbose_print(unittest.TestCase): """ - Unit test for get_help_super + Unittests for verbose_print """ - def test_all_output(self): + def test_print(self): + verbose_print('test1', 1, 0) + + def test_not_print(self): + verbose_print('test1', 0, 1) + +class test_remove_file(unittest.TestCase): + """ + Unittests for remove_file + """ + + def test_remove_file(self): + with open(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt'), 'w') as f: + f.write('test') + + remove_file(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt')) + + def test_remove_nonexisting_file(self): + file = os.path.join(binary_c_temp_dir, 'test_remove_nonexistingfile_file.txt') + + remove_file(file) + +class test_temp_dir(unittest.TestCase): + """ + Unittests for temp_dir + """ + + def test_create_temp_dir(self): + binary_c_temp_dir = temp_dir() + general_temp_dir = tempfile.gettempdir() + + self.assertTrue(os.path.isdir(os.path.join(general_temp_dir, 'binary_c_python'))) + self.assertTrue(os.path.join(general_temp_dir, 'binary_c_python'))==binary_c_temp_dir + +class test_create_hdf5(unittest.TestCase): + """ + Unittests for create_hdf5 + """ + + def test_1(self): + testdir = os.path.join(binary_c_temp_dir, 'test_create_hdf5') + os.makedirs(testdir, exist_ok=True) + + # Create dummy settings file: + settings_dict = {'settings_1': 1, 'settings_2': [1,2]} + + with open(os.path.join(testdir, 'example_settings.json'), 'w') as f: + f.write(json.dumps(settings_dict)) + + with open(os.path.join(testdir, 'data1.dat'), 'w') as f: + f.write("time mass\n") + f.write("1 10") + + create_hdf5(testdir, 'testhdf5.hdf5') + file = h5py.File(os.path.join(testdir, 'testhdf5.hdf5'), 'r') + + self.assertIn(b'time', file.get('data/data1_header')[()]) + self.assertIn(b'mass', file.get('data/data1_header')[()]) + + self.assertIn('settings_1', json.loads(file.get('settings/used_settings')[()])) + self.assertIn('settings_2', json.loads(file.get('settings/used_settings')[()])) + +class test_return_binary_c_version_info(unittest.TestCase): + """ + Unittests for return_binary_c_version_info + """ + + def test_not_parsed(self): + version_info = return_binary_c_version_info() + + self.assertTrue(isinstance(version_info, str)) + self.assertIn('Build', version_info) + self.assertIn('REIMERS_ETA_DEFAULT', version_info) + self.assertIn('SIGMA_THOMPSON', version_info) + + def test_parsed(self): + # also tests the parse_version_info indirectly + version_info_parsed = return_binary_c_version_info(parsed=True) + + self.assertTrue(isinstance(version_info_parsed, dict)) + self.assertIn('isotopes', version_info_parsed.keys()) + self.assertIn('argpairs', version_info_parsed.keys()) + self.assertIn('ensembles', version_info_parsed.keys()) + self.assertIn('macros', version_info_parsed.keys()) + self.assertIn('elements', version_info_parsed.keys()) + self.assertIn('dt_limits', version_info_parsed.keys()) + self.assertIn('nucleosynthesis_sources', version_info_parsed.keys()) + self.assertIn('miscellaneous', version_info_parsed.keys()) + +class test_parse_binary_c_version_info(unittest.TestCase): + """ + Unittests for function parse_binary_c_version_info + """ + + def test_1(self): + info = return_binary_c_version_info() + parsed_info = parse_binary_c_version_info(info) + + self.assertIn('isotopes', parsed_info.keys()) + self.assertNotEqual(parsed_info['isotopes'], {}) + self.assertIn('argpairs', parsed_info.keys()) + self.assertNotEqual(parsed_info['argpairs'], {}) + self.assertIn('ensembles', parsed_info.keys()) + self.assertNotEqual(parsed_info['ensembles'], {}) + self.assertIn('macros', parsed_info.keys()) + self.assertNotEqual(parsed_info['macros'], {}) + self.assertIn('elements', parsed_info.keys()) + self.assertNotEqual(parsed_info['elements'], {}) + self.assertIn('dt_limits', parsed_info.keys()) + self.assertNotEqual(parsed_info['dt_limits'], {}) + self.assertIn('nucleosynthesis_sources', parsed_info.keys()) + self.assertNotEqual(parsed_info['nucleosynthesis_sources'], {}) + self.assertIn('miscellaneous', parsed_info.keys()) + self.assertNotEqual(parsed_info['miscellaneous'], {}) + +class test_output_lines(unittest.TestCase): + """ + Unittests for function output_lines + """ + + def test_1(self): + example_text = "hallo\ntest\n123" + output_1 = output_lines(example_text) + + self.assertTrue(isinstance(output_1, list)) + self.assertIn('hallo', output_1) + self.assertIn('test', output_1) + self.assertIn('123', output_1) + +class test_example_parse_output(unittest.TestCase): + """ + Unittests for function example_parse_output + """ + + def test_normal_output(self): + # generate logging lines. Here you can choose whatever you want to have logged, and with what header + # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals. + logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)' + + # Generate entire shared lib code around logging lines + custom_logging_code = binary_c_log_code(logging_line) + + # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically. + output = run_system( + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + custom_logging_code=custom_logging_code, + ) + + parsed_output = example_parse_output(output, "MY_STELLAR_DATA") + + self.assertIn('time', parsed_output) + self.assertIn('mass', parsed_output) + self.assertTrue(isinstance(parsed_output['time'], list)) + self.assertTrue(len(parsed_output['time'])>0) + + def test_mismatch_output(self): + # generate logging lines. Here you can choose whatever you want to have logged, and with what header + # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals. + logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)' + + # Generate entire shared lib code around logging lines + custom_logging_code = binary_c_log_code(logging_line) + + # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically. + output = run_system( + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + custom_logging_code=custom_logging_code, + ) + + parsed_output = example_parse_output(output, "MY_STELLAR_DATA_MISMATCH") + self.assertIsNone(parsed_output) + +class test_get_defaults(unittest.TestCase): + """ + Unittests for function get_defaults + """ + + def test_no_filter(self): + output_1 = get_defaults() + + self.assertTrue(isinstance(output_1, dict)) + self.assertIn('colour_log', output_1.keys()) + self.assertIn('M_1', output_1.keys()) + self.assertIn('list_args', output_1.keys()) + self.assertIn('use_fixed_timestep_%d', output_1.keys()) + + def test_filter(self): + # Also tests the filter_arg_dict indirectly + output_1 = get_defaults(filter_values=True) + + self.assertTrue(isinstance(output_1, dict)) + self.assertIn('colour_log', output_1.keys()) + self.assertIn('M_1', output_1.keys()) + self.assertNotIn('list_args', output_1.keys()) + self.assertNotIn('use_fixed_timestep_%d', output_1.keys()) + +class test_get_arg_keys(unittest.TestCase): + """ + Unittests for function + """ + + def test_1(self): + output_1 = get_arg_keys() + + self.assertTrue(isinstance(output_1, list)) + self.assertIn('colour_log', output_1) + self.assertIn('M_1', output_1) + self.assertIn('list_args', output_1) + self.assertIn('use_fixed_timestep_%d', output_1) + +class test_create_arg_string(unittest.TestCase): + """ + Unittests for function create_arg_string + """ + + def test_default(self): + input_dict = {'separation': 40000, 'M_1': 10} + argstring = create_arg_string(input_dict) + self.assertEqual(argstring, 'separation 40000 M_1 10') + + def test_sort(self): + input_dict = {'M_1': 10, 'separation': 40000} + argstring = create_arg_string(input_dict, sort=True) + self.assertEqual(argstring, 'M_1 10 separation 40000') + + def test_sort(self): + input_dict = {'M_1': 10, 'separation': 40000, 'list_args': "NULL"} + argstring = create_arg_string(input_dict, filter_values=True) + self.assertEqual(argstring, 'M_1 10 separation 40000') + +class test_get_help(unittest.TestCase): + """ + Unit tests for function get_help + """ + + def test_input_normal(self): """ - Function to test the get_help_super function + Function to test the get_help function """ - get_help_super_output = get_help_super() - get_help_super_keys = get_help_super_output.keys() + self.assertEqual( + get_help("M_1", print_help=False)["parameter_name"], + "M_1", + msg="get_help('M_1') should return the correct parameter name", + ) - self.assertIn("stars", get_help_super_keys, "missing section") - self.assertIn("binary", get_help_super_keys, "missing section") - self.assertIn("nucsyn", get_help_super_keys, "missing section") - self.assertIn("output", get_help_super_keys, "missing section") - self.assertIn("i/o", get_help_super_keys, "missing section") - self.assertIn("algorithms", get_help_super_keys, "missing section") - self.assertIn("misc", get_help_super_keys, "missing section") + def test_no_input(self): + output = get_help() + self.assertIsNone(output) + + def test_wrong_input(self): + output = get_help("kaasblokjes") + self.assertIsNone(output) + # def test_print(self): + # output = get_help("M_1", print_help=True) class test_get_help_all(unittest.TestCase): """ @@ -49,24 +315,178 @@ class test_get_help_all(unittest.TestCase): self.assertIn("misc", get_help_all_keys, "missing section") -class test_get_help(unittest.TestCase): - def test_input(self): + # def test_print(self): + # # test if stuff is printed + # get_help_all(print_help=True) + +class test_get_help_super(unittest.TestCase): + """ + Unit test for get_help_super + """ + + def test_all_output(self): """ - Function to test the get_help function + Function to test the get_help_super function """ - self.assertEqual( - get_help("M_1", print_help=False)["parameter_name"], - "M_1", - msg="get_help('M_1') should return the correct parameter name", - ) + get_help_super_output = get_help_super() + get_help_super_keys = get_help_super_output.keys() + + self.assertIn("stars", get_help_super_keys, "missing section") + self.assertIn("binary", get_help_super_keys, "missing section") + self.assertIn("nucsyn", get_help_super_keys, "missing section") + self.assertIn("output", get_help_super_keys, "missing section") + self.assertIn("i/o", get_help_super_keys, "missing section") + self.assertIn("algorithms", get_help_super_keys, "missing section") + self.assertIn("misc", get_help_super_keys, "missing section") + + # def test_print(self): + # # test to see if stuff is printed. + # get_help_super(print_help=True) +class test_make_build_text(unittest.TestCase): + """ + Unittests for function + """ + + def test_output(self): + build_text = make_build_text() + + # Remove the things + build_text = build_text.replace("**binary_c git branch**:", ";") + build_text = build_text.replace("**binary_c git revision**:", ";") + build_text = build_text.replace("**Built on**:", ";") + + # Split up + split_text = build_text.split(";") + + # Check whether the contents are actually there + self.assertNotEqual(split_text[1].strip(), "second") + self.assertNotEqual(split_text[2].strip(), "second") + self.assertNotEqual(split_text[3].strip(), "second") + +class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase): + """ + Unittests for function write_binary_c_parameter_descriptions_to_rst_file + """ + + def test_bad_outputname(self): + output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt') + output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) + self.assertIsNone(output_1) + + def test_checkfile(self): + output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst') + output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) + self.assertTrue(os.path.isfile(output_name)) + +class test_inspect_dict(unittest.TestCase): + """ + Unittests for function + """ + + def test_compare_dict(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + output_dict = inspect_dict(input_dict) + compare_dict = {'int': int, 'float': float, 'list': list, 'function': os.path.isfile.__class__, 'dict': {'int': int, 'float': float}} + self.assertTrue(compare_dict==output_dict) + + def test_compare_dict(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + output_dict = inspect_dict(input_dict, print_structure=True) + +class test_merge_dicts(unittest.TestCase): + """ + Unittests for function + """ + + def test_empty(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + dict_2 = {} + output_dict = merge_dicts(input_dict, dict_2) + self.assertTrue(output_dict==input_dict) + + def test_unequal_types(self): + dict_1 = {"input": 10} + dict_2 = {"input": 'hello'} + + self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) + + def test_bools(self): + dict_1 = {'bool': True} + dict_2 = {'bool': False} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['bool'], bool)) + self.assertTrue(output_dict['bool']) + + def test_ints(self): + dict_1 = {'int': 2} + dict_2 = {'int': 1} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['int'], int)) + self.assertEqual(output_dict['int'], 3) + + def test_floats(self): + dict_1 = {'float': 4.5} + dict_2 = {'float': 4.6} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['float'], float)) + self.assertEqual(output_dict['float'], 9.1) + + def test_lists(self): + dict_1 = {'list': [1,2]} + dict_2 = {'list': [3,4]} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['list'], list)) + self.assertEqual(output_dict['list'], [1,2,3,4]) + + def test_dicts(self): + dict_1 = {'dict': {'same': 1, 'other_1': 2.0}} + dict_2 = {'dict': {'same': 2, 'other_2': [4.0]}} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['dict'], dict)) + self.assertEqual(output_dict['dict'], {'same': 3, 'other_1': 2.0, 'other_2': [4.0]}) + + def test_unsupported(self): + dict_1 = {'new': dummy('david')} + dict_2 = {'new': dummy('gio')} + + # output_dict = merge_dicts(dict_1, dict_2) + self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) + +class test_binaryc_json_serializer(unittest.TestCase): + """ + Unittests for function binaryc_json_serializer + """ + + def test_not_function(self): + stringo = "hello" + output = binaryc_json_serializer(stringo) + self.assertTrue(stringo==output) + + def test_function(self): + string_of_function = str(os.path.isfile) + output = binaryc_json_serializer(os.path.isfile) + self.assertTrue(string_of_function==output) + +class test_handle_ensemble_string_to_json(unittest.TestCase): + """ + Unittests for function handle_ensemble_string_to_json + """ -def all(): - test_get_help() - test_get_help_all() - test_get_help_super() + def test_1(self): + string_of_function = str(os.path.isfile) + input_string = '{"ding": 10, "list_example": [1,2,3]}' + output_dict = handle_ensemble_string_to_json(input_string) + self.assertTrue(isinstance(output_dict, dict)) + self.assertTrue(output_dict['ding']==10) + self.assertTrue(output_dict['list_example']==[1,2,3]) if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index 10ab09362f38e54242d1e2cf47c4b37b2263bfc3..2a397fc231688045bafc29bbaf947b5cf80dac33 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -1 +1,56 @@ +import sys +import unittest +import tempfile + from binarycpython.utils.grid import Population + +class test_Population(unittest.TestCase): + """ + Unittests for function + """ + + def test_setup(self): + test_pop = Population() + + self.assertTrue('orbital_period' in test_pop.defaults) + self.assertTrue('metallicity' in test_pop.defaults) + self.assertNotIn('help_all', test_pop.cleaned_up_defaults) + self.assertEqual(test_pop.bse_options, {}) + self.assertEqual(test_pop.custom_options, {}) + self.assertEqual(test_pop.argline_dict, {}) + self.assertEqual(test_pop.persistent_data_memory_dict, {}) + self.assertTrue(test_pop.grid_options['parse_function']==None) + self.assertTrue(isinstance(test_pop.grid_options['_main_pid'], int)) + + def test_set(self): + test_pop = Population() + test_pop.set(amt_cores=2) + test_pop.set(M_1=10) + test_pop.set(data_dir='/tmp/binary_c_python') + test_pop.set(ensemble_filter_SUPERNOVAE=1) + + self.assertIn('data_dir', test_pop.custom_options) + self.assertEqual(test_pop.custom_options['data_dir'], '/tmp/binary_c_python') + + # + self.assertTrue(test_pop.bse_options['M_1']==10) + self.assertTrue(test_pop.bse_options['ensemble_filter_SUPERNOVAE']==1) + + # + self.assertTrue(test_pop.grid_options['amt_cores']==2) + + + def test_cmdline(self): + cmdline_arg = '--cmdline \"metallicity=0.0002\"' + sys.argv = ['script', '--cmdline', "metallicity=0.0002"] + test_pop = Population() + test_pop.parse_cmdline() + print(test_pop.bse_options) + + self.assertTrue(test_pop.bse_options['metallicity']==0.0002) + + + + +if __name__ == "__main__": + unittest.main() diff --git a/binarycpython/tests/test_plot_functions.py b/binarycpython/tests/test_plot_functions.py index 05887a43bb6e298c532385bf507a1ac244e42bbc..41955cfe334ac2ed3baa0c74222de59921ca313e 100644 --- a/binarycpython/tests/test_plot_functions.py +++ b/binarycpython/tests/test_plot_functions.py @@ -1 +1,99 @@ +import unittest +import numpy as np from binarycpython.utils.plot_functions import * +import matplotlib.pyplot as plt + +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ + +# def test_1(self): +# pass + +class test_color_by_index(unittest.TestCase): + """ + Unittests for function color_by_index + """ + + def test_1(self): + colors = ['red', 'white', 'blue'] + + color = color_by_index([1,2,3], 1, colors) + self.assertTrue(color=='blue') + + +class test_plot_system(unittest.TestCase): + """ + Unittests for function + """ + + def test_mass_evolution_plot(self): + plot_type = 'mass_evolution' + show_plot = False + output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + fig, ax = plt.subplots(nrows=1) + self.assertTrue(type(output_fig_1)==fig.__class__) + + # with stellar types + # plot_type = 'mass_evolution' + # show_plot = False + # show_stellar_types = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + # fig, ax = plt.subplots(nrows=1) + # self.assertTrue(type(output_fig_2)==fig.__class__) + + # # show plot + # show_plot = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + def test_orbit_evolution_plot(self): + plot_type = 'orbit_evolution' + show_plot = False + output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + fig, ax = plt.subplots(nrows=1) + self.assertTrue(type(output_fig_1)==fig.__class__) + + # with stellar types + # plot_type = 'orbit_evolution' + # show_plot = False + # show_stellar_types = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + # fig, ax = plt.subplots(nrows=1) + # self.assertTrue(type(output_fig_2)==fig.__class__) + + # # show plot + # show_plot = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + def test_hr_diagram_plot(self): + plot_type = 'hr_diagram' + show_plot = False + output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + fig, ax = plt.subplots(nrows=1) + self.assertTrue(type(output_fig_1)==fig.__class__) + + # with stellar types + # plot_type = 'hr_diagram' + # show_plot = False + # show_stellar_types = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, show_stellar_types=show_stellar_types, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + # fig, ax = plt.subplots(nrows=1) + # self.assertTrue(type(output_fig_2)==fig.__class__) + + # # show plot + # show_plot = True + # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + + def test_unknown_plottype(self): + plot_type = 'random' + self.assertRaises(ValueError, plot_system, plot_type) + +if __name__ == "__main__": + unittest.main() diff --git a/binarycpython/tests/test_stellar_types.py b/binarycpython/tests/test_stellar_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc033c6a395e7cc979d4093795aa9744c0b2b68 --- /dev/null +++ b/binarycpython/tests/test_stellar_types.py @@ -0,0 +1,3 @@ +import unittest + +from binarycpython.utils.stellar_types import * \ No newline at end of file diff --git a/binarycpython/tests/test_useful_funcs.py b/binarycpython/tests/test_useful_funcs.py index 13fb1c5fc13bcfb029beb754baa68dda1e07a9d3..57b4224f97e79d3ea481ed2ae54dd75bfb7c44a7 100644 --- a/binarycpython/tests/test_useful_funcs.py +++ b/binarycpython/tests/test_useful_funcs.py @@ -1 +1,105 @@ +import unittest +import numpy as np from binarycpython.utils.useful_funcs import * + +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ + +# def test_1(self): +# pass + +class test_calc_period_from_sep(unittest.TestCase): + """ + Unittests for function calc_period_from_sep + + TODO: add tests comparing to .e.g astropy results + """ + + def test_1(self): + output_1 = calc_period_from_sep(1, 1, 1) + self.assertEqual(output_1, 0.08188845248066838) + +class test_calc_sep_from_period(unittest.TestCase): + """ + Unittests for function calc_sep_from_period + + TODO: add tests comparing to .e.g astropy results + """ + + def test_1(self): + + output_1 = calc_sep_from_period(1, 1, 1) + self.assertEqual(output_1, 5.302958446503317) + +class test_roche_lobe(unittest.TestCase): + """ + Unittests for function roche_lobe + """ + + def test_1(self): + mass_donor = 2 + mass_accretor = 1 + + output_1 = roche_lobe(mass_accretor/mass_donor) + print(output_1) + + self.assertLess(np.abs(output_1-0.3207881203346875), 1e-10) + +class test_ragb(unittest.TestCase): + """ + Unittests for function ragb + """ + + def test_1(self): + m = 20 + output = ragb(m, 0.02) + + self.assertEqual(output, 820) + +class test_rzams(unittest.TestCase): + """ + Unittests for function rzams + """ + + def test_1(self): + mass = 0.5 + metallicity = 0.02 + output_1 = rzams(mass, metallicity) + + self.assertLess(np.abs(output_1-0.458757762074762), 1e-7) + + mass = 12.5 + metallicity = 0.01241 + output_2 = rzams(mass, metallicity) + + self.assertLess(np.abs(output_2-4.20884329861741), 1e-7) + + mass = 149 + metallicity = 0.001241 + output_3 = rzams(mass, metallicity) + + self.assertLess(np.abs(output_3-12.8209978916491), 1e-7) + +class test_zams_collission(unittest.TestCase): + """ + Unittests for function zams_collission + """ + + def test_1(self): + mass1 = 1 + mass2 = 10 + sep = 10 + eccentricity = 0 + metallicity = 0.02 + + output_collision_1 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) + self.assertTrue(output_collision_1==0) + + sep = 1 + output_collision_2 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) + self.assertTrue(output_collision_2==1) + +if __name__ == "__main__": + unittest.main()