diff --git a/binarycpython/tests/test_distributions.py b/binarycpython/tests/test_distributions.py index a78d8e70db1c9ff8608f71280ac1022bdbd848cb..779e872a77408b37a1c37fa942cdbf406523fac8 100644 --- a/binarycpython/tests/test_distributions.py +++ b/binarycpython/tests/test_distributions.py @@ -25,6 +25,24 @@ class TestDistributions(unittest.TestCase): self.tolerance = 1e-5 + def test_setopts(self): + """ + Unittest for function set_opts + """ + + default_dict = {'m1': 2, 'm2': 3} + output_dict_1 = set_opts(default_dict, {}) + self.assertTrue(output_dict_1==default_dict) + + + new_opts = {'m1': 10} + output_dict_2 = set_opts(default_dict, new_opts) + updated_dict = default_dict.copy() + updated_dict['m1'] = 10 + + self.assertTrue(output_dict_2==updated_dict) + + def test_flat(self): """ Unittest for the function flat @@ -205,7 +223,6 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(imf_chabrier(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") - def test_duquennoy1991(self): """ Unittest for function duquennoy1991 diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py index b83c2062156bde44a4d7872beef2b9c6a561cb92..a108a71ccb02baa2f993df9457c61d5fd615cf26 100644 --- a/binarycpython/tests/test_functions.py +++ b/binarycpython/tests/test_functions.py @@ -1,31 +1,297 @@ import unittest +import tempfile from binarycpython.utils.functions import * +from binarycpython.utils.custom_logging_functions import binary_c_log_code +from binarycpython.utils.run_system_wrapper import run_system +binary_c_temp_dir = temp_dir() ############################# # Script that contains unit tests for functions from the binarycpython.utils.functions file +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ -class test_get_help_super(unittest.TestCase): +# def test_1(self): +# pass + +class dummy(): + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + +class test_verbose_print(unittest.TestCase): """ - Unit test for get_help_super + Unittests for verbose_print """ - def test_all_output(self): + def test_print(self): + verbose_print('test1', 1, 0) + + def test_not_print(self): + verbose_print('test1', 0, 1) + +class test_remove_file(unittest.TestCase): + """ + Unittests for remove_file + """ + + def test_remove_file(self): + with open(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt'), 'w') as f: + f.write('test') + + remove_file(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt')) + + def test_remove_nonexisting_file(self): + file = os.path.join(binary_c_temp_dir, 'test_remove_nonexistingfile_file.txt') + + remove_file(file) + +class test_temp_dir(unittest.TestCase): + """ + Unittests for temp_dir + """ + + def test_create_temp_dir(self): + binary_c_temp_dir = temp_dir() + general_temp_dir = tempfile.gettempdir() + + self.assertTrue(os.path.isdir(os.path.join(general_temp_dir, 'binary_c_python'))) + self.assertTrue(os.path.join(general_temp_dir, 'binary_c_python'))==binary_c_temp_dir + +class test_create_hdf5(unittest.TestCase): + """ + Unittests for create_hdf5 + """ + + def test_1(self): + testdir = os.path.join(binary_c_temp_dir, 'test_create_hdf5') + os.makedirs(testdir, exist_ok=True) + + # Create dummy settings file: + settings_dict = {'settings_1': 1, 'settings_2': [1,2]} + + with open(os.path.join(testdir, 'example_settings.json'), 'w') as f: + f.write(json.dumps(settings_dict)) + + with open(os.path.join(testdir, 'data1.dat'), 'w') as f: + f.write("time mass\n") + f.write("1 10") + + create_hdf5(testdir, 'testhdf5.hdf5') + file = h5py.File(os.path.join(testdir, 'testhdf5.hdf5'), 'r') + + self.assertIn(b'time', file.get('data/data1_header')[()]) + self.assertIn(b'mass', file.get('data/data1_header')[()]) + + self.assertIn('settings_1', json.loads(file.get('settings/used_settings')[()])) + self.assertIn('settings_2', json.loads(file.get('settings/used_settings')[()])) + +class test_return_binary_c_version_info(unittest.TestCase): + """ + Unittests for return_binary_c_version_info + """ + + def test_not_parsed(self): + version_info = return_binary_c_version_info() + + self.assertTrue(isinstance(version_info, str)) + self.assertIn('Build', version_info) + self.assertIn('REIMERS_ETA_DEFAULT', version_info) + self.assertIn('SIGMA_THOMPSON', version_info) + + def test_parsed(self): + # also tests the parse_version_info indirectly + version_info_parsed = return_binary_c_version_info(parsed=True) + + self.assertTrue(isinstance(version_info_parsed, dict)) + self.assertIn('isotopes', version_info_parsed.keys()) + self.assertIn('argpairs', version_info_parsed.keys()) + self.assertIn('ensembles', version_info_parsed.keys()) + self.assertIn('macros', version_info_parsed.keys()) + self.assertIn('elements', version_info_parsed.keys()) + self.assertIn('dt_limits', version_info_parsed.keys()) + self.assertIn('nucleosynthesis_sources', version_info_parsed.keys()) + self.assertIn('miscellaneous', version_info_parsed.keys()) + +class test_parse_binary_c_version_info(unittest.TestCase): + """ + Unittests for function parse_binary_c_version_info + """ + + def test_1(self): + info = return_binary_c_version_info() + parsed_info = parse_binary_c_version_info(info) + + self.assertIn('isotopes', parsed_info.keys()) + self.assertNotEqual(parsed_info['isotopes'], {}) + self.assertIn('argpairs', parsed_info.keys()) + self.assertNotEqual(parsed_info['argpairs'], {}) + self.assertIn('ensembles', parsed_info.keys()) + self.assertNotEqual(parsed_info['ensembles'], {}) + self.assertIn('macros', parsed_info.keys()) + self.assertNotEqual(parsed_info['macros'], {}) + self.assertIn('elements', parsed_info.keys()) + self.assertNotEqual(parsed_info['elements'], {}) + self.assertIn('dt_limits', parsed_info.keys()) + self.assertNotEqual(parsed_info['dt_limits'], {}) + self.assertIn('nucleosynthesis_sources', parsed_info.keys()) + self.assertNotEqual(parsed_info['nucleosynthesis_sources'], {}) + self.assertIn('miscellaneous', parsed_info.keys()) + self.assertNotEqual(parsed_info['miscellaneous'], {}) + +class test_output_lines(unittest.TestCase): + """ + Unittests for function output_lines + """ + + def test_1(self): + example_text = "hallo\ntest\n123" + output_1 = output_lines(example_text) + + self.assertTrue(isinstance(output_1, list)) + self.assertIn('hallo', output_1) + self.assertIn('test', output_1) + self.assertIn('123', output_1) + +class test_example_parse_output(unittest.TestCase): + """ + Unittests for function example_parse_output + """ + + def test_normal_output(self): + # generate logging lines. Here you can choose whatever you want to have logged, and with what header + # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals. + logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)' + + # Generate entire shared lib code around logging lines + custom_logging_code = binary_c_log_code(logging_line) + + # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically. + output = run_system( + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + custom_logging_code=custom_logging_code, + ) + + parsed_output = example_parse_output(output, "MY_STELLAR_DATA") + + self.assertIn('time', parsed_output) + self.assertIn('mass', parsed_output) + self.assertTrue(isinstance(parsed_output['time'], list)) + self.assertTrue(len(parsed_output['time'])>0) + + def test_mismatch_output(self): + # generate logging lines. Here you can choose whatever you want to have logged, and with what header + # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals. + logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)' + + # Generate entire shared lib code around logging lines + custom_logging_code = binary_c_log_code(logging_line) + + # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically. + output = run_system( + M_1=1, + metallicity=0.002, + M_2=0.1, + separation=0, + orbital_period=100000000000, + custom_logging_code=custom_logging_code, + ) + + parsed_output = example_parse_output(output, "MY_STELLAR_DATA_MISMATCH") + self.assertIsNone(parsed_output) + +class test_get_defaults(unittest.TestCase): + """ + Unittests for function get_defaults + """ + + def test_no_filter(self): + output_1 = get_defaults() + + self.assertTrue(isinstance(output_1, dict)) + self.assertIn('colour_log', output_1.keys()) + self.assertIn('M_1', output_1.keys()) + self.assertIn('list_args', output_1.keys()) + self.assertIn('use_fixed_timestep_%d', output_1.keys()) + + def test_filter(self): + # Also tests the filter_arg_dict indirectly + output_1 = get_defaults(filter_values=True) + + self.assertTrue(isinstance(output_1, dict)) + self.assertIn('colour_log', output_1.keys()) + self.assertIn('M_1', output_1.keys()) + self.assertNotIn('list_args', output_1.keys()) + self.assertNotIn('use_fixed_timestep_%d', output_1.keys()) + +class test_get_arg_keys(unittest.TestCase): + """ + Unittests for function + """ + + def test_1(self): + output_1 = get_arg_keys() + + self.assertTrue(isinstance(output_1, list)) + self.assertIn('colour_log', output_1) + self.assertIn('M_1', output_1) + self.assertIn('list_args', output_1) + self.assertIn('use_fixed_timestep_%d', output_1) + +class test_create_arg_string(unittest.TestCase): + """ + Unittests for function create_arg_string + """ + + def test_default(self): + input_dict = {'separation': 40000, 'M_1': 10} + argstring = create_arg_string(input_dict) + self.assertEqual(argstring, 'separation 40000 M_1 10') + + def test_sort(self): + input_dict = {'M_1': 10, 'separation': 40000} + argstring = create_arg_string(input_dict, sort=True) + self.assertEqual(argstring, 'M_1 10 separation 40000') + + def test_sort(self): + input_dict = {'M_1': 10, 'separation': 40000, 'list_args': "NULL"} + argstring = create_arg_string(input_dict, filter_values=True) + self.assertEqual(argstring, 'M_1 10 separation 40000') + +class test_get_help(unittest.TestCase): + """ + Unit tests for function get_help + """ + + def test_input_normal(self): """ - Function to test the get_help_super function + Function to test the get_help function """ - get_help_super_output = get_help_super() - get_help_super_keys = get_help_super_output.keys() + self.assertEqual( + get_help("M_1", print_help=False)["parameter_name"], + "M_1", + msg="get_help('M_1') should return the correct parameter name", + ) - self.assertIn("stars", get_help_super_keys, "missing section") - self.assertIn("binary", get_help_super_keys, "missing section") - self.assertIn("nucsyn", get_help_super_keys, "missing section") - self.assertIn("output", get_help_super_keys, "missing section") - self.assertIn("i/o", get_help_super_keys, "missing section") - self.assertIn("algorithms", get_help_super_keys, "missing section") - self.assertIn("misc", get_help_super_keys, "missing section") + def test_no_input(self): + output = get_help() + self.assertIsNone(output) + + def test_wrong_input(self): + output = get_help("kaasblokjes") + self.assertIsNone(output) + # def test_print(self): + # output = get_help("M_1", print_help=True) class test_get_help_all(unittest.TestCase): """ @@ -49,24 +315,178 @@ class test_get_help_all(unittest.TestCase): self.assertIn("misc", get_help_all_keys, "missing section") -class test_get_help(unittest.TestCase): - def test_input(self): + # def test_print(self): + # # test if stuff is printed + # get_help_all(print_help=True) + +class test_get_help_super(unittest.TestCase): + """ + Unit test for get_help_super + """ + + def test_all_output(self): """ - Function to test the get_help function + Function to test the get_help_super function """ - self.assertEqual( - get_help("M_1", print_help=False)["parameter_name"], - "M_1", - msg="get_help('M_1') should return the correct parameter name", - ) + get_help_super_output = get_help_super() + get_help_super_keys = get_help_super_output.keys() + + self.assertIn("stars", get_help_super_keys, "missing section") + self.assertIn("binary", get_help_super_keys, "missing section") + self.assertIn("nucsyn", get_help_super_keys, "missing section") + self.assertIn("output", get_help_super_keys, "missing section") + self.assertIn("i/o", get_help_super_keys, "missing section") + self.assertIn("algorithms", get_help_super_keys, "missing section") + self.assertIn("misc", get_help_super_keys, "missing section") + + # def test_print(self): + # # test to see if stuff is printed. + # get_help_super(print_help=True) +class test_make_build_text(unittest.TestCase): + """ + Unittests for function + """ + + def test_output(self): + build_text = make_build_text() + + # Remove the things + build_text = build_text.replace("**binary_c git branch**:", ";") + build_text = build_text.replace("**binary_c git revision**:", ";") + build_text = build_text.replace("**Built on**:", ";") + + # Split up + split_text = build_text.split(";") + + # Check whether the contents are actually there + self.assertNotEqual(split_text[1].strip(), "second") + self.assertNotEqual(split_text[2].strip(), "second") + self.assertNotEqual(split_text[3].strip(), "second") + +class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase): + """ + Unittests for function write_binary_c_parameter_descriptions_to_rst_file + """ + + def test_bad_outputname(self): + output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt') + output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) + self.assertIsNone(output_1) + + def test_checkfile(self): + output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst') + output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name) + self.assertTrue(os.path.isfile(output_name)) + +class test_inspect_dict(unittest.TestCase): + """ + Unittests for function + """ + + def test_compare_dict(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + output_dict = inspect_dict(input_dict) + compare_dict = {'int': int, 'float': float, 'list': list, 'function': os.path.isfile.__class__, 'dict': {'int': int, 'float': float}} + self.assertTrue(compare_dict==output_dict) + + def test_compare_dict(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + output_dict = inspect_dict(input_dict, print_structure=True) + +class test_merge_dicts(unittest.TestCase): + """ + Unittests for function + """ + + def test_empty(self): + input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}} + dict_2 = {} + output_dict = merge_dicts(input_dict, dict_2) + self.assertTrue(output_dict==input_dict) + + def test_unequal_types(self): + dict_1 = {"input": 10} + dict_2 = {"input": 'hello'} + + self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) + + def test_bools(self): + dict_1 = {'bool': True} + dict_2 = {'bool': False} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['bool'], bool)) + self.assertTrue(output_dict['bool']) + + def test_ints(self): + dict_1 = {'int': 2} + dict_2 = {'int': 1} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['int'], int)) + self.assertEqual(output_dict['int'], 3) + + def test_floats(self): + dict_1 = {'float': 4.5} + dict_2 = {'float': 4.6} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['float'], float)) + self.assertEqual(output_dict['float'], 9.1) + + def test_lists(self): + dict_1 = {'list': [1,2]} + dict_2 = {'list': [3,4]} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['list'], list)) + self.assertEqual(output_dict['list'], [1,2,3,4]) + + def test_dicts(self): + dict_1 = {'dict': {'same': 1, 'other_1': 2.0}} + dict_2 = {'dict': {'same': 2, 'other_2': [4.0]}} + output_dict = merge_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict['dict'], dict)) + self.assertEqual(output_dict['dict'], {'same': 3, 'other_1': 2.0, 'other_2': [4.0]}) + + def test_unsupported(self): + dict_1 = {'new': dummy('david')} + dict_2 = {'new': dummy('gio')} + + # output_dict = merge_dicts(dict_1, dict_2) + self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) + +class test_binaryc_json_serializer(unittest.TestCase): + """ + Unittests for function binaryc_json_serializer + """ + + def test_not_function(self): + stringo = "hello" + output = binaryc_json_serializer(stringo) + self.assertTrue(stringo==output) + + def test_function(self): + string_of_function = str(os.path.isfile) + output = binaryc_json_serializer(os.path.isfile) + self.assertTrue(string_of_function==output) + +class test_handle_ensemble_string_to_json(unittest.TestCase): + """ + Unittests for function handle_ensemble_string_to_json + """ -def all(): - test_get_help() - test_get_help_all() - test_get_help_super() + def test_1(self): + string_of_function = str(os.path.isfile) + input_string = '{"ding": 10, "list_example": [1,2,3]}' + output_dict = handle_ensemble_string_to_json(input_string) + self.assertTrue(isinstance(output_dict, dict)) + self.assertTrue(output_dict['ding']==10) + self.assertTrue(output_dict['list_example']==[1,2,3]) if __name__ == "__main__": unittest.main()