diff --git a/binarycpython/tests/test_distributions.py b/binarycpython/tests/test_distributions.py index c06fa2e11e2850f3b3f6b2a2ec2482fde0330c55..ac6e6ac9b2622c224f00dc8a29d0cb14e1e86eb6 100644 --- a/binarycpython/tests/test_distributions.py +++ b/binarycpython/tests/test_distributions.py @@ -30,18 +30,16 @@ class TestDistributions(unittest.TestCase): Unittest for function set_opts """ - default_dict = {'m1': 2, 'm2': 3} + default_dict = {"m1": 2, "m2": 3} output_dict_1 = set_opts(default_dict, {}) - self.assertTrue(output_dict_1==default_dict) + self.assertTrue(output_dict_1 == default_dict) - - new_opts = {'m1': 10} + new_opts = {"m1": 10} output_dict_2 = set_opts(default_dict, new_opts) updated_dict = default_dict.copy() - updated_dict['m1'] = 10 - - self.assertTrue(output_dict_2==updated_dict) + updated_dict["m1"] = 10 + self.assertTrue(output_dict_2 == updated_dict) def test_flat(self): """ @@ -69,12 +67,14 @@ class TestDistributions(unittest.TestCase): """ output_1 = const(min_bound=0, max_bound=2) - self.assertEqual(output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1)) - + self.assertEqual( + output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1) + ) output_2 = const(min_bound=0, max_bound=2, val=3) - self.assertEqual(output_2, 0, msg="Value should be 0, but is {}".format(output_2)) - + self.assertEqual( + output_2, 0, msg="Value should be 0, but is {}".format(output_2) + ) def test_powerlaw(self): """ @@ -127,8 +127,10 @@ class TestDistributions(unittest.TestCase): # Extra test: # M < M0 - self.assertTrue(three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3)==0, msg="Probability should be zero as M < M0") - + self.assertTrue( + three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) == 0, + msg="Probability should be zero as M < M0", + ) def test_Kroupa2001(self): """ @@ -152,8 +154,10 @@ class TestDistributions(unittest.TestCase): for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) # Extra tests: - self.assertEqual(Kroupa2001(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3)) - + self.assertEqual( + Kroupa2001(10, newopts={"mmax": 300}), + three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3), + ) def test_ktg93(self): """ @@ -177,8 +181,10 @@ class TestDistributions(unittest.TestCase): for i in range(len(python_results)): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) # extra test: - self.assertEqual(ktg93(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7)) - + self.assertEqual( + ktg93(10, newopts={"mmax": 300}), + three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7), + ) def test_imf_tinsley1980(self): """ @@ -186,7 +192,10 @@ class TestDistributions(unittest.TestCase): """ m = 1.2 - self.assertEqual(imf_tinsley1980(m), three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)) + self.assertEqual( + imf_tinsley1980(m), + three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3), + ) def test_imf_scalo1986(self): """ @@ -194,8 +203,10 @@ class TestDistributions(unittest.TestCase): """ m = 1.2 - self.assertEqual(imf_scalo1986(m), three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)) - + self.assertEqual( + imf_scalo1986(m), + three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70), + ) def test_imf_scalo1998(self): """ @@ -203,8 +214,10 @@ class TestDistributions(unittest.TestCase): """ m = 1.2 - self.assertEqual(imf_scalo1998(m), three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)) - + self.assertEqual( + imf_scalo1998(m), + three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3), + ) def test_imf_chabrier2003(self): """ @@ -216,11 +229,19 @@ class TestDistributions(unittest.TestCase): # for m=0.5 m = 0.5 - self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") + self.assertLess( + np.abs(imf_chabrier2003(m) - 0.581457346702825), + self.tolerance, + msg="Difference is bigger than the tolerance", + ) # For m = 2 m = 2 - self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") + self.assertLess( + np.abs(imf_chabrier2003(m) - 0.581457346702825), + self.tolerance, + msg="Difference is bigger than the tolerance", + ) def test_duquennoy1991(self): """ @@ -229,7 +250,6 @@ class TestDistributions(unittest.TestCase): self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) - def test_gaussian(self): """ unittest for three_part_power_law @@ -253,7 +273,10 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) # Extra test: - self.assertTrue(gaussian(15, 4.8, 2.3, -2.0, 12.0)==0, msg="Probability should be 0 because the input period is out of bounds") + self.assertTrue( + gaussian(15, 4.8, 2.3, -2.0, 12.0) == 0, + msg="Probability should be 0 because the input period is out of bounds", + ) def test_Arenou2010_binary_fraction(self): """ diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py index a108a71ccb02baa2f993df9457c61d5fd615cf26..83a5dd7e1384cabde3499bfc833c3c076460c0e2 100644 --- a/binarycpython/tests/test_functions.py +++ b/binarycpython/tests/test_functions.py @@ -3,6 +3,7 @@ import tempfile from binarycpython.utils.functions import * from binarycpython.utils.custom_logging_functions import binary_c_log_code from binarycpython.utils.run_system_wrapper import run_system + binary_c_temp_dir = temp_dir() ############################# @@ -16,23 +17,26 @@ binary_c_temp_dir = temp_dir() # def test_1(self): # pass -class dummy(): + +class dummy: def __init__(self, name): self.name = name def __str__(self): return self.name + class test_verbose_print(unittest.TestCase): """ Unittests for verbose_print """ def test_print(self): - verbose_print('test1', 1, 0) + verbose_print("test1", 1, 0) def test_not_print(self): - verbose_print('test1', 0, 1) + verbose_print("test1", 0, 1) + class test_remove_file(unittest.TestCase): """ @@ -40,16 +44,19 @@ class test_remove_file(unittest.TestCase): """ def test_remove_file(self): - with open(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt'), 'w') as f: - f.write('test') + with open( + os.path.join(binary_c_temp_dir, "test_remove_file_file.txt"), "w" + ) as f: + f.write("test") - remove_file(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt')) + remove_file(os.path.join(binary_c_temp_dir, "test_remove_file_file.txt")) def test_remove_nonexisting_file(self): - file = os.path.join(binary_c_temp_dir, 'test_remove_nonexistingfile_file.txt') + file = os.path.join(binary_c_temp_dir, "test_remove_nonexistingfile_file.txt") remove_file(file) + class test_temp_dir(unittest.TestCase): """ Unittests for temp_dir @@ -59,8 +66,13 @@ class test_temp_dir(unittest.TestCase): binary_c_temp_dir = temp_dir() general_temp_dir = tempfile.gettempdir() - self.assertTrue(os.path.isdir(os.path.join(general_temp_dir, 'binary_c_python'))) - self.assertTrue(os.path.join(general_temp_dir, 'binary_c_python'))==binary_c_temp_dir + self.assertTrue( + os.path.isdir(os.path.join(general_temp_dir, "binary_c_python")) + ) + self.assertTrue( + os.path.join(general_temp_dir, "binary_c_python") + ) == binary_c_temp_dir + class test_create_hdf5(unittest.TestCase): """ @@ -68,27 +80,28 @@ class test_create_hdf5(unittest.TestCase): """ def test_1(self): - testdir = os.path.join(binary_c_temp_dir, 'test_create_hdf5') + testdir = os.path.join(binary_c_temp_dir, "test_create_hdf5") os.makedirs(testdir, exist_ok=True) # Create dummy settings file: - settings_dict = {'settings_1': 1, 'settings_2': [1,2]} + settings_dict = {"settings_1": 1, "settings_2": [1, 2]} - with open(os.path.join(testdir, 'example_settings.json'), 'w') as f: + with open(os.path.join(testdir, "example_settings.json"), "w") as f: f.write(json.dumps(settings_dict)) - with open(os.path.join(testdir, 'data1.dat'), 'w') as f: + with open(os.path.join(testdir, "data1.dat"), "w") as f: f.write("time mass\n") f.write("1 10") - create_hdf5(testdir, 'testhdf5.hdf5') - file = h5py.File(os.path.join(testdir, 'testhdf5.hdf5'), 'r') + create_hdf5(testdir, "testhdf5.hdf5") + file = h5py.File(os.path.join(testdir, "testhdf5.hdf5"), "r") + + self.assertIn(b"time", file.get("data/data1_header")[()]) + self.assertIn(b"mass", file.get("data/data1_header")[()]) - self.assertIn(b'time', file.get('data/data1_header')[()]) - self.assertIn(b'mass', file.get('data/data1_header')[()]) + self.assertIn("settings_1", json.loads(file.get("settings/used_settings")[()])) + self.assertIn("settings_2", json.loads(file.get("settings/used_settings")[()])) - self.assertIn('settings_1', json.loads(file.get('settings/used_settings')[()])) - self.assertIn('settings_2', json.loads(file.get('settings/used_settings')[()])) class test_return_binary_c_version_info(unittest.TestCase): """ @@ -99,23 +112,24 @@ class test_return_binary_c_version_info(unittest.TestCase): version_info = return_binary_c_version_info() self.assertTrue(isinstance(version_info, str)) - self.assertIn('Build', version_info) - self.assertIn('REIMERS_ETA_DEFAULT', version_info) - self.assertIn('SIGMA_THOMPSON', version_info) + self.assertIn("Build", version_info) + self.assertIn("REIMERS_ETA_DEFAULT", version_info) + self.assertIn("SIGMA_THOMPSON", version_info) def test_parsed(self): # also tests the parse_version_info indirectly version_info_parsed = return_binary_c_version_info(parsed=True) self.assertTrue(isinstance(version_info_parsed, dict)) - self.assertIn('isotopes', version_info_parsed.keys()) - self.assertIn('argpairs', version_info_parsed.keys()) - self.assertIn('ensembles', version_info_parsed.keys()) - self.assertIn('macros', version_info_parsed.keys()) - self.assertIn('elements', version_info_parsed.keys()) - self.assertIn('dt_limits', version_info_parsed.keys()) - self.assertIn('nucleosynthesis_sources', version_info_parsed.keys()) - self.assertIn('miscellaneous', version_info_parsed.keys()) + self.assertIn("isotopes", version_info_parsed.keys()) + self.assertIn("argpairs", version_info_parsed.keys()) + self.assertIn("ensembles", version_info_parsed.keys()) + self.assertIn("macros", version_info_parsed.keys()) + self.assertIn("elements", version_info_parsed.keys()) + self.assertIn("dt_limits", version_info_parsed.keys()) + self.assertIn("nucleosynthesis_sources", version_info_parsed.keys()) + self.assertIn("miscellaneous", version_info_parsed.keys()) + class test_parse_binary_c_version_info(unittest.TestCase): """ @@ -126,22 +140,23 @@ class test_parse_binary_c_version_info(unittest.TestCase): info = return_binary_c_version_info() parsed_info = parse_binary_c_version_info(info) - self.assertIn('isotopes', parsed_info.keys()) - self.assertNotEqual(parsed_info['isotopes'], {}) - self.assertIn('argpairs', parsed_info.keys()) - self.assertNotEqual(parsed_info['argpairs'], {}) - self.assertIn('ensembles', parsed_info.keys()) - self.assertNotEqual(parsed_info['ensembles'], {}) - self.assertIn('macros', parsed_info.keys()) - self.assertNotEqual(parsed_info['macros'], {}) - self.assertIn('elements', parsed_info.keys()) - self.assertNotEqual(parsed_info['elements'], {}) - self.assertIn('dt_limits', parsed_info.keys()) - self.assertNotEqual(parsed_info['dt_limits'], {}) - self.assertIn('nucleosynthesis_sources', parsed_info.keys()) - self.assertNotEqual(parsed_info['nucleosynthesis_sources'], {}) - self.assertIn('miscellaneous', parsed_info.keys()) - self.assertNotEqual(parsed_info['miscellaneous'], {}) + self.assertIn("isotopes", parsed_info.keys()) + self.assertNotEqual(parsed_info["isotopes"], {}) + self.assertIn("argpairs", parsed_info.keys()) + self.assertNotEqual(parsed_info["argpairs"], {}) + self.assertIn("ensembles", parsed_info.keys()) + self.assertNotEqual(parsed_info["ensembles"], {}) + self.assertIn("macros", parsed_info.keys()) + self.assertNotEqual(parsed_info["macros"], {}) + self.assertIn("elements", parsed_info.keys()) + self.assertNotEqual(parsed_info["elements"], {}) + self.assertIn("dt_limits", parsed_info.keys()) + self.assertNotEqual(parsed_info["dt_limits"], {}) + self.assertIn("nucleosynthesis_sources", parsed_info.keys()) + self.assertNotEqual(parsed_info["nucleosynthesis_sources"], {}) + self.assertIn("miscellaneous", parsed_info.keys()) + self.assertNotEqual(parsed_info["miscellaneous"], {}) + class test_output_lines(unittest.TestCase): """ @@ -153,9 +168,10 @@ class test_output_lines(unittest.TestCase): output_1 = output_lines(example_text) self.assertTrue(isinstance(output_1, list)) - self.assertIn('hallo', output_1) - self.assertIn('test', output_1) - self.assertIn('123', output_1) + self.assertIn("hallo", output_1) + self.assertIn("test", output_1) + self.assertIn("123", output_1) + class test_example_parse_output(unittest.TestCase): """ @@ -182,10 +198,10 @@ class test_example_parse_output(unittest.TestCase): parsed_output = example_parse_output(output, "MY_STELLAR_DATA") - self.assertIn('time', parsed_output) - self.assertIn('mass', parsed_output) - self.assertTrue(isinstance(parsed_output['time'], list)) - self.assertTrue(len(parsed_output['time'])>0) + self.assertIn("time", parsed_output) + self.assertIn("mass", parsed_output) + self.assertTrue(isinstance(parsed_output["time"], list)) + self.assertTrue(len(parsed_output["time"]) > 0) def test_mismatch_output(self): # generate logging lines. Here you can choose whatever you want to have logged, and with what header @@ -208,6 +224,7 @@ class test_example_parse_output(unittest.TestCase): parsed_output = example_parse_output(output, "MY_STELLAR_DATA_MISMATCH") self.assertIsNone(parsed_output) + class test_get_defaults(unittest.TestCase): """ Unittests for function get_defaults @@ -217,20 +234,21 @@ class test_get_defaults(unittest.TestCase): output_1 = get_defaults() self.assertTrue(isinstance(output_1, dict)) - self.assertIn('colour_log', output_1.keys()) - self.assertIn('M_1', output_1.keys()) - self.assertIn('list_args', output_1.keys()) - self.assertIn('use_fixed_timestep_%d', output_1.keys()) + self.assertIn("colour_log", output_1.keys()) + self.assertIn("M_1", output_1.keys()) + self.assertIn("list_args", output_1.keys()) + self.assertIn("use_fixed_timestep_%d", output_1.keys()) def test_filter(self): # Also tests the filter_arg_dict indirectly output_1 = get_defaults(filter_values=True) self.assertTrue(isinstance(output_1, dict)) - self.assertIn('colour_log', output_1.keys()) - self.assertIn('M_1', output_1.keys()) - self.assertNotIn('list_args', output_1.keys()) - self.assertNotIn('use_fixed_timestep_%d', output_1.keys()) + self.assertIn("colour_log", output_1.keys()) + self.assertIn("M_1", output_1.keys()) + self.assertNotIn("list_args", output_1.keys()) + self.assertNotIn("use_fixed_timestep_%d", output_1.keys()) + class test_get_arg_keys(unittest.TestCase): """ @@ -241,10 +259,11 @@ class test_get_arg_keys(unittest.TestCase): output_1 = get_arg_keys() self.assertTrue(isinstance(output_1, list)) - self.assertIn('colour_log', output_1) - self.assertIn('M_1', output_1) - self.assertIn('list_args', output_1) - self.assertIn('use_fixed_timestep_%d', output_1) + self.assertIn("colour_log", output_1) + self.assertIn("M_1", output_1) + self.assertIn("list_args", output_1) + self.assertIn("use_fixed_timestep_%d", output_1) + class test_create_arg_string(unittest.TestCase): """ @@ -252,19 +271,20 @@ class test_create_arg_string(unittest.TestCase): """ def test_default(self): - input_dict = {'separation': 40000, 'M_1': 10} + input_dict = {"separation": 40000, "M_1": 10} argstring = create_arg_string(input_dict) - self.assertEqual(argstring, 'separation 40000 M_1 10') + self.assertEqual(argstring, "separation 40000 M_1 10") def test_sort(self): - input_dict = {'M_1': 10, 'separation': 40000} + input_dict = {"M_1": 10, "separation": 40000} argstring = create_arg_string(input_dict, sort=True) - self.assertEqual(argstring, 'M_1 10 separation 40000') + self.assertEqual(argstring, "M_1 10 separation 40000") def test_sort(self): - input_dict = {'M_1': 10, 'separation': 40000, 'list_args': "NULL"} + input_dict = {"M_1": 10, "separation": 40000, "list_args": "NULL"} argstring = create_arg_string(input_dict, filter_values=True) - self.assertEqual(argstring, 'M_1 10 separation 40000') + self.assertEqual(argstring, "M_1 10 separation 40000") + class test_get_help(unittest.TestCase): """ @@ -293,6 +313,7 @@ class test_get_help(unittest.TestCase): # def test_print(self): # output = get_help("M_1", print_help=True) + class test_get_help_all(unittest.TestCase): """ Unit test for get_help_all @@ -314,11 +335,11 @@ class test_get_help_all(unittest.TestCase): self.assertIn("algorithms", get_help_all_keys, "missing section") self.assertIn("misc", get_help_all_keys, "missing section") - # def test_print(self): # # test if stuff is printed # get_help_all(print_help=True) + class test_get_help_super(unittest.TestCase): """ Unit test for get_help_super @@ -341,9 +362,10 @@ class test_get_help_super(unittest.TestCase): self.assertIn("misc", get_help_super_keys, "missing section") # def test_print(self): - # # test to see if stuff is printed. + # # test to see if stuff is printed. # get_help_super(print_help=True) + class test_make_build_text(unittest.TestCase): """ Unittests for function @@ -365,100 +387,136 @@ class test_make_build_text(unittest.TestCase): self.assertNotEqual(split_text[2].strip(), "second") self.assertNotEqual(split_text[3].strip(), "second") + class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase): """ Unittests for function write_binary_c_parameter_descriptions_to_rst_file """ def test_bad_outputname(self): - output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt') + output_name = os.path.join( + binary_c_temp_dir, + "test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt", + ) output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) self.assertIsNone(output_1) def test_checkfile(self): - output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst') + output_name = os.path.join( + binary_c_temp_dir, + "test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst", + ) output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) self.assertTrue(os.path.isfile(output_name)) + class test_inspect_dict(unittest.TestCase): """ Unittests for function """ def test_compare_dict(self): - input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + input_dict = { + "int": 1, + "float": 1.2, + "list": [1, 2, 3], + "function": os.path.isfile, + "dict": {"int": 1, "float": 1.2}, + } output_dict = inspect_dict(input_dict) - compare_dict = {'int': int, 'float': float, 'list': list, 'function': os.path.isfile.__class__, 'dict': {'int': int, 'float': float}} - self.assertTrue(compare_dict==output_dict) + compare_dict = { + "int": int, + "float": float, + "list": list, + "function": os.path.isfile.__class__, + "dict": {"int": int, "float": float}, + } + self.assertTrue(compare_dict == output_dict) def test_compare_dict(self): - input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + input_dict = { + "int": 1, + "float": 1.2, + "list": [1, 2, 3], + "function": os.path.isfile, + "dict": {"int": 1, "float": 1.2}, + } output_dict = inspect_dict(input_dict, print_structure=True) + class test_merge_dicts(unittest.TestCase): """ Unittests for function """ def test_empty(self): - input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + input_dict = { + "int": 1, + "float": 1.2, + "list": [1, 2, 3], + "function": os.path.isfile, + "dict": {"int": 1, "float": 1.2}, + } dict_2 = {} output_dict = merge_dicts(input_dict, dict_2) - self.assertTrue(output_dict==input_dict) + self.assertTrue(output_dict == input_dict) def test_unequal_types(self): dict_1 = {"input": 10} - dict_2 = {"input": 'hello'} + dict_2 = {"input": "hello"} self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) def test_bools(self): - dict_1 = {'bool': True} - dict_2 = {'bool': False} + dict_1 = {"bool": True} + dict_2 = {"bool": False} output_dict = merge_dicts(dict_1, dict_2) - self.assertTrue(isinstance(output_dict['bool'], bool)) - self.assertTrue(output_dict['bool']) + self.assertTrue(isinstance(output_dict["bool"], bool)) + self.assertTrue(output_dict["bool"]) def test_ints(self): - dict_1 = {'int': 2} - dict_2 = {'int': 1} + dict_1 = {"int": 2} + dict_2 = {"int": 1} output_dict = merge_dicts(dict_1, dict_2) - self.assertTrue(isinstance(output_dict['int'], int)) - self.assertEqual(output_dict['int'], 3) + self.assertTrue(isinstance(output_dict["int"], int)) + self.assertEqual(output_dict["int"], 3) def test_floats(self): - dict_1 = {'float': 4.5} - dict_2 = {'float': 4.6} + dict_1 = {"float": 4.5} + dict_2 = {"float": 4.6} output_dict = merge_dicts(dict_1, dict_2) - self.assertTrue(isinstance(output_dict['float'], float)) - self.assertEqual(output_dict['float'], 9.1) + self.assertTrue(isinstance(output_dict["float"], float)) + self.assertEqual(output_dict["float"], 9.1) def test_lists(self): - dict_1 = {'list': [1,2]} - dict_2 = {'list': [3,4]} + dict_1 = {"list": [1, 2]} + dict_2 = {"list": [3, 4]} output_dict = merge_dicts(dict_1, dict_2) - self.assertTrue(isinstance(output_dict['list'], list)) - self.assertEqual(output_dict['list'], [1,2,3,4]) + self.assertTrue(isinstance(output_dict["list"], list)) + self.assertEqual(output_dict["list"], [1, 2, 3, 4]) def test_dicts(self): - dict_1 = {'dict': {'same': 1, 'other_1': 2.0}} - dict_2 = {'dict': {'same': 2, 'other_2': [4.0]}} + dict_1 = {"dict": {"same": 1, "other_1": 2.0}} + dict_2 = {"dict": {"same": 2, "other_2": [4.0]}} output_dict = merge_dicts(dict_1, dict_2) - self.assertTrue(isinstance(output_dict['dict'], dict)) - self.assertEqual(output_dict['dict'], {'same': 3, 'other_1': 2.0, 'other_2': [4.0]}) + self.assertTrue(isinstance(output_dict["dict"], dict)) + self.assertEqual( + output_dict["dict"], {"same": 3, "other_1": 2.0, "other_2": [4.0]} + ) def test_unsupported(self): - dict_1 = {'new': dummy('david')} - dict_2 = {'new': dummy('gio')} + dict_1 = {"new": dummy("david")} + dict_2 = {"new": dummy("gio")} # output_dict = merge_dicts(dict_1, dict_2) self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) + class test_binaryc_json_serializer(unittest.TestCase): """ Unittests for function binaryc_json_serializer @@ -467,12 +525,13 @@ class test_binaryc_json_serializer(unittest.TestCase): def test_not_function(self): stringo = "hello" output = binaryc_json_serializer(stringo) - self.assertTrue(stringo==output) + self.assertTrue(stringo == output) def test_function(self): string_of_function = str(os.path.isfile) output = binaryc_json_serializer(os.path.isfile) - self.assertTrue(string_of_function==output) + self.assertTrue(string_of_function == output) + class test_handle_ensemble_string_to_json(unittest.TestCase): """ @@ -485,8 +544,9 @@ class test_handle_ensemble_string_to_json(unittest.TestCase): output_dict = handle_ensemble_string_to_json(input_string) self.assertTrue(isinstance(output_dict, dict)) - self.assertTrue(output_dict['ding']==10) - self.assertTrue(output_dict['list_example']==[1,2,3]) + self.assertTrue(output_dict["ding"] == 10) + self.assertTrue(output_dict["list_example"] == [1, 2, 3]) + if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index 2a397fc231688045bafc29bbaf947b5cf80dac33..ff44027ae77aebe60f9bb38a656fceb4e3103171 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -1,8 +1,36 @@ +""" +Test cases for the grid + +Tasks: + TODO: write tests for load_from_sourcefile +""" + +import os import sys +import json import unittest import tempfile +import datetime from binarycpython.utils.grid import Population +from binarycpython.utils.functions import temp_dir + +binary_c_temp_dir = temp_dir() + + +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ + +# def test_1(self): +# pass + +# def test_(self): +# """ +# Unittests for the function +# """ + class test_Population(unittest.TestCase): """ @@ -12,44 +40,315 @@ class test_Population(unittest.TestCase): def test_setup(self): test_pop = Population() - self.assertTrue('orbital_period' in test_pop.defaults) - self.assertTrue('metallicity' in test_pop.defaults) - self.assertNotIn('help_all', test_pop.cleaned_up_defaults) + self.assertTrue("orbital_period" in test_pop.defaults) + self.assertTrue("metallicity" in test_pop.defaults) + self.assertNotIn("help_all", test_pop.cleaned_up_defaults) self.assertEqual(test_pop.bse_options, {}) self.assertEqual(test_pop.custom_options, {}) self.assertEqual(test_pop.argline_dict, {}) self.assertEqual(test_pop.persistent_data_memory_dict, {}) - self.assertTrue(test_pop.grid_options['parse_function']==None) - self.assertTrue(isinstance(test_pop.grid_options['_main_pid'], int)) + self.assertTrue(test_pop.grid_options["parse_function"] == None) + self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) def test_set(self): test_pop = Population() test_pop.set(amt_cores=2) test_pop.set(M_1=10) - test_pop.set(data_dir='/tmp/binary_c_python') + test_pop.set(data_dir="/tmp/binary_c_python") test_pop.set(ensemble_filter_SUPERNOVAE=1) - self.assertIn('data_dir', test_pop.custom_options) - self.assertEqual(test_pop.custom_options['data_dir'], '/tmp/binary_c_python') - - # - self.assertTrue(test_pop.bse_options['M_1']==10) - self.assertTrue(test_pop.bse_options['ensemble_filter_SUPERNOVAE']==1) + self.assertIn("data_dir", test_pop.custom_options) + self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") # - self.assertTrue(test_pop.grid_options['amt_cores']==2) + self.assertTrue(test_pop.bse_options["M_1"] == 10) + self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) + # + self.assertTrue(test_pop.grid_options["amt_cores"] == 2) def test_cmdline(self): - cmdline_arg = '--cmdline \"metallicity=0.0002\"' - sys.argv = ['script', '--cmdline', "metallicity=0.0002"] + # copy old sys.argv values + prev_sysargv = sys.argv.copy() + + # make a dummy cmdline arg input + sys.argv = [ + "script", + "--cmdline", + "metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python", + ] + + # Set up population test_pop = Population() + test_pop.set(data_dir="/tmp") + + # parse arguments test_pop.parse_cmdline() - print(test_pop.bse_options) - self.assertTrue(test_pop.bse_options['metallicity']==0.0002) + # metallicity + self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) + self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002") + + # Amt cores + self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int)) + self.assertTrue(test_pop.grid_options["amt_cores"] == 2) + + # datadir + self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) + self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") + + # put back the other args if they exist + sys.argv = prev_sysargv.copy() + + def test__return_argline(self): + """ + Unittests for the function _return_argline + """ + + # Set up population + test_pop = Population() + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + + argline = test_pop._return_argline() + self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") + + # custom dict + argline2 = test_pop._return_argline( + {"example_parameter1": 10, "example_parameter2": "hello"} + ) + self.assertTrue( + argline2 == "binary_c example_parameter1 10 example_parameter2 hello" + ) + + def test_add_grid_variable(self): + """ + Unittests for the function add_grid_variable + + TODO: Should I test more here? + """ + + test_pop = Population() + + resolution = {"M_1": 10, "q": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + test_pop.add_grid_variable( + name="q", + longname="Mass ratio", + valuerange=["0.1/M_1", 1], + resolution="{}".format(resolution["q"]), + spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), + probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", + dphasevol="dq", + precode="M_2 = q * M_1", + parameter_name="M_2", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + self.assertIn("q", test_pop.grid_options["_grid_variables"]) + self.assertIn("lnm1", test_pop.grid_options["_grid_variables"]) + self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) + + def test_return_population_settings(self): + """ + Unittests for the function return_population_settings + """ + + test_pop = Population() + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + test_pop.set(amt_cores=2) + test_pop.set(data_dir="/tmp") + + population_settings = test_pop.return_population_settings() + + self.assertIn("bse_options", population_settings) + self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02) + self.assertTrue(population_settings["bse_options"]["M_1"] == 10) + + self.assertIn("grid_options", population_settings) + self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2) + + self.assertIn("custom_options", population_settings) + self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") + + def test__return_binary_c_version_info(self): + """ + Unittests for the function _return_binary_c_version_info + """ + + test_pop = Population() + binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True) + + self.assertTrue(isinstance(binary_c_version_info, dict)) + self.assertIn("isotopes", binary_c_version_info) + self.assertIn("argpairs", binary_c_version_info) + self.assertIn("ensembles", binary_c_version_info) + self.assertIn("macros", binary_c_version_info) + self.assertIn("dt_limits", binary_c_version_info) + self.assertIn("nucleosynthesis_sources", binary_c_version_info) + self.assertIn("miscellaneous", binary_c_version_info) + + self.assertIsNotNone(binary_c_version_info["isotopes"]) + self.assertIsNotNone(binary_c_version_info["argpairs"]) + self.assertIsNotNone(binary_c_version_info["ensembles"]) + self.assertIsNotNone(binary_c_version_info["macros"]) + self.assertIsNotNone(binary_c_version_info["dt_limits"]) + self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) + self.assertIsNotNone(binary_c_version_info["miscellaneous"]) + + def test__return_binary_c_defaults(self): + """ + Unittests for the function _return_binary_c_defaults + """ + + test_pop = Population() + binary_c_defaults = test_pop._return_binary_c_defaults() + self.assertIn("probability", binary_c_defaults) + self.assertIn("phasevol", binary_c_defaults) + self.assertIn("metallicity", binary_c_defaults) + + def test_return_all_info(self): + """ + Unittests for the function return_all_info + Not going to do too much tests here, just check if they are not empty + """ + + test_pop = Population() + all_info = test_pop.return_all_info() + + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + def test_export_all_info(self): + """ + Unittests for the function export_all_info + """ + + test_pop = Population() + + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + test_pop.set(amt_cores=2) + test_pop.set(data_dir=binary_c_temp_dir) + + # datadir + settings_filename = test_pop.export_all_info(use_datadir=True) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # custom name + # datadir + settings_filename = test_pop.export_all_info( + use_datadir=False, + outfile=os.path.join(binary_c_temp_dir, "example_settings.json"), + ) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # wrong filename + self.assertRaises( + ValueError, + test_pop.export_all_info, + use_datadir=False, + outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"), + ) + + def test__cleanup_defaults(self): + """ + Unittests for the function _cleanup_defaults + """ + + test_pop = Population() + cleaned_up_defaults = test_pop._cleanup_defaults() + self.assertNotIn("help_all", cleaned_up_defaults) + + def test__increment_probtot(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_probtot(0.5) + self.assertEqual(test_pop.grid_options["_probtot"], 0.5) + + def test__increment_count(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_count() + self.assertEqual(test_pop.grid_options["_count"], 1) + + def test__dict_from_line_source_file(self): + """ + Unittests for the function _dict_from_line_source_file + """ + + source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt") + + # write + with open(source_file, "w") as f: + f.write("binary_c M_1 10 metallicity 0.02\n") + + test_pop = Population() + # readout + with open(source_file, "r") as f: + for line in f.readlines(): + argdict = test_pop._dict_from_line_source_file(line) + self.assertTrue(argdict["M_1"] == 10) + self.assertTrue(argdict["metallicity"] == 0.02) if __name__ == "__main__": diff --git a/binarycpython/tests/test_grid_options_defaults.py b/binarycpython/tests/test_grid_options_defaults.py index 9692f0da43f9e27e54ef56ad434fc7280662536b..53723636c32132b07892a733fbdbd2ebb3d40de1 100644 --- a/binarycpython/tests/test_grid_options_defaults.py +++ b/binarycpython/tests/test_grid_options_defaults.py @@ -47,27 +47,27 @@ class test_grid_options_defaults(unittest.TestCase): output_1 = grid_options_description_checker(print_info=True) self.assertTrue(isinstance(output_1, int)) - self.assertTrue(output_1>0) - - + self.assertTrue(output_1 > 0) def test_write_grid_options_to_rst_file(self): """ Unit tests for the grid_options_description_checker function """ - input_1 = os.path.join(binary_c_temp_dir, "test_write_grid_options_to_rst_file_1.txt") + input_1 = os.path.join( + binary_c_temp_dir, "test_write_grid_options_to_rst_file_1.txt" + ) output_1 = write_grid_options_to_rst_file(input_1) self.assertIsNone(output_1) - - input_2 = os.path.join(binary_c_temp_dir, "test_write_grid_options_to_rst_file_2.rst") + input_2 = os.path.join( + binary_c_temp_dir, "test_write_grid_options_to_rst_file_2.rst" + ) output_2 = write_grid_options_to_rst_file(input_2) self.assertTrue(os.path.isfile(input_2)) - write_grid_options_to_rst_file if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_plot_functions.py b/binarycpython/tests/test_plot_functions.py index 41955cfe334ac2ed3baa0c74222de59921ca313e..4b01c1b448f1a6819c5fd25a3850fbed94088017 100644 --- a/binarycpython/tests/test_plot_functions.py +++ b/binarycpython/tests/test_plot_functions.py @@ -11,16 +11,17 @@ import matplotlib.pyplot as plt # def test_1(self): # pass + class test_color_by_index(unittest.TestCase): """ Unittests for function color_by_index """ def test_1(self): - colors = ['red', 'white', 'blue'] + colors = ["red", "white", "blue"] - color = color_by_index([1,2,3], 1, colors) - self.assertTrue(color=='blue') + color = color_by_index([1, 2, 3], 1, colors) + self.assertTrue(color == "blue") class test_plot_system(unittest.TestCase): @@ -29,12 +30,20 @@ class test_plot_system(unittest.TestCase): """ def test_mass_evolution_plot(self): - plot_type = 'mass_evolution' + plot_type = "mass_evolution" show_plot = False - output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + output_fig_1 = plot_system( + plot_type, + show_plot=show_plot, + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + ) fig, ax = plt.subplots(nrows=1) - self.assertTrue(type(output_fig_1)==fig.__class__) + self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'mass_evolution' @@ -50,12 +59,20 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_orbit_evolution_plot(self): - plot_type = 'orbit_evolution' + plot_type = "orbit_evolution" show_plot = False - output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + output_fig_1 = plot_system( + plot_type, + show_plot=show_plot, + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + ) fig, ax = plt.subplots(nrows=1) - self.assertTrue(type(output_fig_1)==fig.__class__) + self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'orbit_evolution' @@ -71,12 +88,20 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_hr_diagram_plot(self): - plot_type = 'hr_diagram' + plot_type = "hr_diagram" show_plot = False - output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) + output_fig_1 = plot_system( + plot_type, + show_plot=show_plot, + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + ) fig, ax = plt.subplots(nrows=1) - self.assertTrue(type(output_fig_1)==fig.__class__) + self.assertTrue(type(output_fig_1) == fig.__class__) # with stellar types # plot_type = 'hr_diagram' @@ -92,8 +117,9 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_unknown_plottype(self): - plot_type = 'random' + plot_type = "random" self.assertRaises(ValueError, plot_system, plot_type) + if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_stellar_types.py b/binarycpython/tests/test_stellar_types.py index 6bc033c6a395e7cc979d4093795aa9744c0b2b68..0b86a5a70e10e463e443f278bd5e938328ec8aca 100644 --- a/binarycpython/tests/test_stellar_types.py +++ b/binarycpython/tests/test_stellar_types.py @@ -1,3 +1,3 @@ import unittest -from binarycpython.utils.stellar_types import * \ No newline at end of file +from binarycpython.utils.stellar_types import * diff --git a/binarycpython/tests/test_useful_funcs.py b/binarycpython/tests/test_useful_funcs.py index 57b4224f97e79d3ea481ed2ae54dd75bfb7c44a7..b861f595ce8c05122996b186f9453dc0efbfc9df 100644 --- a/binarycpython/tests/test_useful_funcs.py +++ b/binarycpython/tests/test_useful_funcs.py @@ -10,6 +10,7 @@ from binarycpython.utils.useful_funcs import * # def test_1(self): # pass + class test_calc_period_from_sep(unittest.TestCase): """ Unittests for function calc_period_from_sep @@ -21,6 +22,7 @@ class test_calc_period_from_sep(unittest.TestCase): output_1 = calc_period_from_sep(1, 1, 1) self.assertEqual(output_1, 0.08188845248066838) + class test_calc_sep_from_period(unittest.TestCase): """ Unittests for function calc_sep_from_period @@ -33,6 +35,7 @@ class test_calc_sep_from_period(unittest.TestCase): output_1 = calc_sep_from_period(1, 1, 1) self.assertEqual(output_1, 5.302958446503317) + class test_roche_lobe(unittest.TestCase): """ Unittests for function roche_lobe @@ -42,10 +45,11 @@ class test_roche_lobe(unittest.TestCase): mass_donor = 2 mass_accretor = 1 - output_1 = roche_lobe(mass_accretor/mass_donor) + output_1 = roche_lobe(mass_accretor / mass_donor) print(output_1) - self.assertLess(np.abs(output_1-0.3207881203346875), 1e-10) + self.assertLess(np.abs(output_1 - 0.3207881203346875), 1e-10) + class test_ragb(unittest.TestCase): """ @@ -58,6 +62,7 @@ class test_ragb(unittest.TestCase): self.assertEqual(output, 820) + class test_rzams(unittest.TestCase): """ Unittests for function rzams @@ -68,19 +73,20 @@ class test_rzams(unittest.TestCase): metallicity = 0.02 output_1 = rzams(mass, metallicity) - self.assertLess(np.abs(output_1-0.458757762074762), 1e-7) + self.assertLess(np.abs(output_1 - 0.458757762074762), 1e-7) mass = 12.5 metallicity = 0.01241 output_2 = rzams(mass, metallicity) - self.assertLess(np.abs(output_2-4.20884329861741), 1e-7) + self.assertLess(np.abs(output_2 - 4.20884329861741), 1e-7) mass = 149 metallicity = 0.001241 output_3 = rzams(mass, metallicity) - self.assertLess(np.abs(output_3-12.8209978916491), 1e-7) + self.assertLess(np.abs(output_3 - 12.8209978916491), 1e-7) + class test_zams_collission(unittest.TestCase): """ @@ -94,12 +100,17 @@ class test_zams_collission(unittest.TestCase): eccentricity = 0 metallicity = 0.02 - output_collision_1 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) - self.assertTrue(output_collision_1==0) + output_collision_1 = zams_collision( + mass1, mass2, sep, eccentricity, metallicity + ) + self.assertTrue(output_collision_1 == 0) sep = 1 - output_collision_2 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) - self.assertTrue(output_collision_2==1) + output_collision_2 = zams_collision( + mass1, mass2, sep, eccentricity, metallicity + ) + self.assertTrue(output_collision_2 == 1) + if __name__ == "__main__": unittest.main() diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py index 2d4bb8fb20c8c06c7440230639b83d1e2a706c2c..8e10644b21dc24638665b248416e7d412ee54081 100644 --- a/binarycpython/utils/functions.py +++ b/binarycpython/utils/functions.py @@ -73,6 +73,7 @@ def remove_file(file: str, verbosity: int = 0) -> None: else: verbose_print("File/directory {} doesn't exist. Can't remove it.", verbosity, 1) + def temp_dir() -> str: """ Function to return the path the custom logging library shared object @@ -190,6 +191,8 @@ def parse_binary_c_version_info(version_info_string: str) -> dict: """ Function that parses the binary_c version info. Long function with a lot of branches + TODO: fix this function. stuff is missing: isotopes, macros, nucleosynthesis_sources + Args: version_info_string: raw output of version_info call to binary_c diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py index 0f0e55abe3ded62f979d291f81524678f5ae16b0..e21aed4e93d6c27acd4f33a1ed9f1c4f20ee5506 100644 --- a/binarycpython/utils/grid.py +++ b/binarycpython/utils/grid.py @@ -17,6 +17,7 @@ Tasks: - TODO: add functionality to return the ensemble_list - TODO: consider spreading the functions over more files. - TODO: type the private functions + - TODO: fix the correct object types for the default values of the bse_options """ import os @@ -189,9 +190,16 @@ class Population: Function to handle settings values via the command line. Best to be called after all the .set(..) lines, and just before the .evolve() is called + If you input any known parameter (i.e. contained in grid_options, defaults/bse_options or custom_options), + this function will attempt to convert the input from string (because everything is string) to the type of + the value that option had before. + + The values of the bse_options are initially all strings, but after user input they can change to ints. + + The value of any new parameter (which will go to custom_options) will be a string. + Tasks: - TODO: remove the need for --cmdline - - TODO: fix that the input is converted to the correct type (i.e. type of the default value) """ parser = argparse.ArgumentParser() @@ -243,12 +251,24 @@ class Population: # (attempt to) convert if old_value_found: try: - verbose_print("Converting type of {} from {} to {}".format(parameter, type(value), type(old_value)), self.grid_options['verbosity'], 1) + verbose_print( + "Converting type of {} from {} to {}".format( + parameter, type(value), type(old_value) + ), + self.grid_options["verbosity"], + 1, + ) value = type(old_value)(value) - verbose_print("Success!", self.grid_options['verbosity'], 1) + verbose_print("Success!", self.grid_options["verbosity"], 1) except ValueError as e: - verbose_print("Tried to convert the given parameter {}/value {} to its correct type {} (from old value {}). But that wasn't possible.".format(parameter, value, type(old_value), old_value), self.grid_options['verbosity'], 0) + verbose_print( + "Tried to convert the given parameter {}/value {} to its correct type {} (from old value {}). But that wasn't possible.".format( + parameter, value, type(old_value), old_value + ), + self.grid_options["verbosity"], + 0, + ) # Add to dict cmdline_dict[parameter] = value @@ -440,7 +460,7 @@ class Population: include_binary_c_defaults: bool = True, include_binary_c_version_info: bool = True, include_binary_c_help_all: bool = True, - ) -> None: + ) -> Union[str, None]: """ Function that exports the all_info to a json file @@ -502,6 +522,7 @@ class Population: default=binaryc_json_serializer, ) ) + return settings_fullname else: verbose_print( "Writing settings to {}".format(outfile), @@ -515,12 +536,14 @@ class Population: 0, ) raise ValueError + with open(outfile, "w") as file: file.write( json.dumps( all_info_cleaned, indent=4, default=binaryc_json_serializer ) ) + return outfile def _set_custom_logging(self): """ @@ -713,10 +736,12 @@ class Population: # Log and print some information verbose_print( "Population-{} finished! It took a total of {}s to run {} systems on {} cores".format( - self.grid_options["_population_id"], - self.grid_options['_end_time_evolution']-self.grid_options['_start_time_evolution'], - self.grid_options['_total_starcount'], - self.grid_options['amt_cores']), + self.grid_options["_population_id"], + self.grid_options["_end_time_evolution"] + - self.grid_options["_start_time_evolution"], + self.grid_options["_total_starcount"], + self.grid_options["amt_cores"], + ), self.grid_options["verbosity"], 0, ) @@ -744,7 +769,11 @@ class Population: 0, ) else: - verbose_print("There were no errors found in this run.", self.grid_options["verbosity"], 0) + verbose_print( + "There were no errors found in this run.", + self.grid_options["verbosity"], + 0, + ) ## # Clean up code: remove files, unset values. @@ -752,7 +781,8 @@ class Population: def _process_run_population(self, ID): """ - Function that loops over the whole generator, but only runs systems that fit to: if (localcounter+ID) % self.grid_options["amt_cores"] == 0 + Function that loops over the whole generator, but only runs + systems that fit to: if (localcounter+ID) % self.grid_options["amt_cores"] == 0 That way with 4 processes, process 1 runs sytem 0, 4, 8... process 2 runs system 1, 5, 9..., etc @@ -771,9 +801,15 @@ class Population: # Set up local variables running = True - localcounter = 0 # global counter for the whole loop. (need to be ticked every loop) - probability_of_systems_run = 0 # counter for the probability of the actual systems this tread ran - number_of_systems_run = 0 # counter for the actual amt of systems this thread ran + localcounter = ( + 0 # global counter for the whole loop. (need to be ticked every loop) + ) + probability_of_systems_run = ( + 0 # counter for the probability of the actual systems this tread ran + ) + number_of_systems_run = ( + 0 # counter for the actual amt of systems this thread ran + ) verbose_print( "Process {} started".format(ID), self.grid_options["verbosity"], 0 @@ -803,9 +839,9 @@ class Population: self._evolve_system_mp(full_system_dict) # TODO: fix the 'repeat' and 'weight' tracking here - # Keep track of systems: - probability_of_systems_run += full_system_dict['probability'] - number_of_systems_run += 1 + # Keep track of systems: + probability_of_systems_run += full_system_dict["probability"] + number_of_systems_run += 1 except StopIteration: running = False @@ -813,7 +849,6 @@ class Population: # Has to be here because this one is used for the (localcounter+ID) % (self..) localcounter += 1 - # Return a set of results and errors output_dict = { "results": self.grid_options["results"], @@ -824,10 +859,18 @@ class Population: ], "_errors_exceeded": self.grid_options["_errors_exceeded"], "_errors_found": self.grid_options["_errors_found"], + "_probtot": self.grid_options["_probtot"], + "_count": self.grid_options["_count"], } verbose_print( - "Process {}: generator done. Ran {} systems with a total probability of {}. This thread had {} failing systems with a total probability of {}".format(ID, number_of_systems_run, probability_of_systems_run, self.grid_options["_failed_count"], self.grid_options["_failed_prob"]), + "Process {}: generator done. Ran {} systems with a total probability of {}. This thread had {} failing systems with a total probability of {}".format( + ID, + number_of_systems_run, + probability_of_systems_run, + self.grid_options["_failed_count"], + self.grid_options["_failed_prob"], + ), self.grid_options["verbosity"], 0, ) @@ -885,9 +928,9 @@ class Population: for output_dict in result: combined_output_dict = merge_dicts(combined_output_dict, output_dict) - # Put the values back as object properties print(combined_output_dict) + # Put the values back as object properties self.grid_options["results"] = combined_output_dict["results"] self.grid_options["_failed_count"] = combined_output_dict["_failed_count"] self.grid_options["_failed_prob"] = combined_output_dict["_failed_prob"] @@ -896,6 +939,8 @@ class Population: ) self.grid_options["_errors_exceeded"] = combined_output_dict["_errors_exceeded"] self.grid_options["_errors_found"] = combined_output_dict["_errors_found"] + self.grid_options["_probtot"] = combined_output_dict["_probtot"] + self.grid_options["_count"] = combined_output_dict["_count"] def _evolve_population_lin(self): """ @@ -1001,14 +1046,12 @@ class Population: Tasks: TODO: Make other kinds of populations possible. i.e, read out type of grid, and set up accordingly - TODO: make this function more general. Have it explicitly set the system_generator function """ if not self.grid_options["parse_function"]: - print("Error: No parse function set. Aborting run") - raise ValueError + print("Warning: No parse function set. Make sure you intended to do this.") ####################### ### Custom logging code: @@ -1132,6 +1175,7 @@ class Population: self.grid_options["_failed_prob"] = 0 self.grid_options["_errors_found"] = False self.grid_options["_errors_exceeded"] = False + self.grid_options["_failed_systems_error_codes"] = [] # Remove files # TODO: remove files @@ -1165,7 +1209,7 @@ class Population: # TODO: add sensible description to this function. # TODO: Check whether all the probability and phasevol values are correct. # TODO: import only the necessary packages/functions - + Results in a generated file that contains a system_generator function. """ @@ -1531,10 +1575,7 @@ class Population: # # code_string += indent * (depth + 1) + "\n" code_string += indent * (depth + 1) + "#" * 40 + "\n" - code_string += ( - indent * (depth + 1) - + "if print_results:\n" - ) + code_string += indent * (depth + 1) + "if print_results:\n" code_string += ( indent * (depth + 2) + "print('Grid has handled {} stars'.format(_total_starcount))\n" @@ -1719,7 +1760,7 @@ class Population: verbose_print("Source file loaded", self.grid_options["verbosity"], 1) - def _dict_from_line_source_file(self): + def _dict_from_line_source_file(self, line): """ Function that creates a dict from a binary_c argline """ diff --git a/binarycpython/utils/useful_funcs.py b/binarycpython/utils/useful_funcs.py index 134c5219f95516ae284dd4f8737b885017765bf0..2b3b5001246f2fae36774d189df158aa3971bf1a 100644 --- a/binarycpython/utils/useful_funcs.py +++ b/binarycpython/utils/useful_funcs.py @@ -66,7 +66,7 @@ def roche_lobe(q: Union[int, float]) -> Union[int, float]: # TODO: check whether the logs are correct Args: - q: mass ratio of the binary (secondary/primary). If you input: q = mass_accretor/mass_donor, you will get the rochelobe radius of the accretor. And vice versa for the donor. + q: mass ratio of the binary (secondary/primary). If you input: q = mass_accretor/mass_donor, you will get the rochelobe radius of the accretor. And vice versa for the donor. Returns: Roche lobe radius in units of the separation diff --git a/docs/source/conf.py b/docs/source/conf.py index 118f1a9be674429d2dd8e962f376888a81acc54d..8f5768736b981db76340efd880517048c0373585 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -50,13 +50,15 @@ extensions = [ "hawkmoth", "m2r2", "sphinx_rtd_theme", - "sphinx_autodoc_typehints", # https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html + "sphinx_autodoc_typehints", # https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html ] # Napoleon settings -napoleon_google_docstring = True # https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html -napoleon_numpy_docstring = False +napoleon_google_docstring = ( + True # https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html +) +napoleon_numpy_docstring = False napoleon_include_init_with_doc = False napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True @@ -97,6 +99,7 @@ import m2r2 current_m2r2_setup = m2r2.setup + def patched_m2r2_setup(app): try: return current_m2r2_setup(app) @@ -104,13 +107,18 @@ def patched_m2r2_setup(app): app.add_source_suffix(".md", "markdown") app.add_source_parser(m2r2.M2RParser) return dict( - version=m2r2.__version__, parallel_read_safe=True, parallel_write_safe=True, + version=m2r2.__version__, + parallel_read_safe=True, + parallel_write_safe=True, ) + m2r2.setup = patched_m2r2_setup # Generate some custom documentations for this version of binarycpython and binary_c -from binarycpython.utils.functions import write_binary_c_parameter_descriptions_to_rst_file +from binarycpython.utils.functions import ( + write_binary_c_parameter_descriptions_to_rst_file, +) from binarycpython.utils.grid_options_defaults import write_grid_options_to_rst_file print("Generating binary_c_parameters.rst") diff --git a/examples/examples.py b/examples/examples.py index 4ad9b15cdae1c8c8fecc368cd137075c65cdf0a0..21585612b79d6b7e3e57bacd74600d6804700ae5 100644 --- a/examples/examples.py +++ b/examples/examples.py @@ -74,8 +74,12 @@ def run_example_binary_with_run_system(): # print(output) # Catch results that start with a given header. (Mind that binary_c has to be configured to print them if your not using a custom logging function) - result_example_header_1 = example_parse_output(output, selected_header="example_header_1") - result_example_header_2 = example_parse_output(output, selected_header="example_header_2") + result_example_header_1 = example_parse_output( + output, selected_header="example_header_1" + ) + result_example_header_2 = example_parse_output( + output, selected_header="example_header_2" + ) # print(result_example_header_1) diff --git a/setup.py b/setup.py index 6d6fe080d6562bd8a23bcb127eaed1311c318861..aff51674f32e8e2eb6ac6931b0ef9ed9946a8ffe 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ def check_version(installed_binary_c_version, required_binary_c_versions): ) assert installed_binary_c_version in required_binary_c_versions, message + def execute_make(): """ Function to execute the makefile. @@ -147,14 +148,11 @@ API_h = os.path.join(BINARY_C_DIR, "src", "API", "binary_c_API.h") ############################################################ # Setting all directories and LIBRARIES to their final values ############################################################ -INCLUDE_DIRS = ( - [ - os.path.join(BINARY_C_DIR, "src"), - os.path.join(BINARY_C_DIR, "src", "API"), - "include", - ] - + BINARY_C_INCDIRS -) +INCLUDE_DIRS = [ + os.path.join(BINARY_C_DIR, "src"), + os.path.join(BINARY_C_DIR, "src", "API"), + "include", +] + BINARY_C_INCDIRS if GSL_DIR: INCLUDE_DIRS += [os.path.join(GSL_DIR, "include")] @@ -226,6 +224,7 @@ class CustomBuildCommand(distutils.command.build.build): # Run the original build command distutils.command.build.build.run(self) + setup( name="binarycpython", version="0.2.8", @@ -237,7 +236,7 @@ setup( author_email="davidhendriks93@gmail.com", long_description=readme(), # long_description="hello", - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", url="https://gitlab.eps.surrey.ac.uk/ri0005/binary_c-python", license="gpl", keywords=[ @@ -252,7 +251,15 @@ setup( "binarycpython.core", "binarycpython.tests", ], - install_requires=["numpy", "pytest", "h5py", "pathos", "pandas", "astropy", "matplotlib"], + install_requires=[ + "numpy", + "pytest", + "h5py", + "pathos", + "pandas", + "astropy", + "matplotlib", + ], include_package_data=True, ext_modules=[BINARY_C_PYTHON_API_MODULE], # binary_c must be loaded classifiers=[