diff --git a/binarycpython/tests/test_c_bindings.py b/binarycpython/tests/test_c_bindings.py index 1f59f3a70d60799e41660694e2ced19d4cfd8cef..5796759832eedd7395647bcb726c4afb8a4c1bea 100644 --- a/binarycpython/tests/test_c_bindings.py +++ b/binarycpython/tests/test_c_bindings.py @@ -20,6 +20,8 @@ from binarycpython.utils.functions import ( handle_ensemble_string_to_json, verbose_print, extract_ensemble_json_from_string, + is_capsule, + Capturing ) # https://docs.python.org/3/library/unittest.html @@ -63,12 +65,6 @@ ensemble_filters_off {8} ensemble_filter_{9} 1 probability 0.1" return argstring - -def is_capsule(o): - t = type(o) - return t.__module__ == 'builtins' and t.__name__ == 'PyCapsule' - - ####################################################################################################################################################### ### General run_system test ####################################################################################################################################################### @@ -80,6 +76,10 @@ class test_run_system(unittest.TestCase): """ def test_output(self): + with Capturing() as output: + self._test_output() + + def _test_output(self): """ General test if run_system works """ @@ -123,10 +123,14 @@ class test_return_store_memaddr(unittest.TestCase): """ def test_return_store_memaddr(self): + with Capturing() as output: + self._test_return_store_memaddr() + + def _test_return_store_memaddr(self): """ Test to see if the memory adress is returned properly """ - print(self.id()) + output = _binary_c_bindings.return_store_memaddr() # print("function: test_return_store") @@ -159,6 +163,10 @@ class TestEnsemble(unittest.TestCase): super(TestEnsemble, self).__init__(*args, **kwargs) def test_return_persistent_data_memaddr(self): + with Capturing() as output: + self._test_return_persistent_data_memaddr() + + def _test_return_persistent_data_memaddr(self): """ Test case to check if the memory adress has been created succesfully """ @@ -172,6 +180,10 @@ class TestEnsemble(unittest.TestCase): # ) def test_minimal_ensemble_output(self): + with Capturing() as output: + self._test_minimal_ensemble_output() + + def _test_minimal_ensemble_output(self): """ test_case to check if the ensemble output is correctly output """ @@ -198,6 +210,10 @@ class TestEnsemble(unittest.TestCase): self.assertNotEqual(test_json["number_counts"], {}) def test_minimal_ensemble_output_defer(self): + with Capturing() as output: + self._test_minimal_ensemble_output_defer() + + def _test_minimal_ensemble_output_defer(self): """ test_case to check if the ensemble output is correctly output, by using defer command and freeing+outputting """ @@ -236,6 +252,10 @@ class TestEnsemble(unittest.TestCase): self.assertNotEqual(ensemble_json_output["number_counts"], {}) def test_add_ensembles_direct(self): + with Capturing() as output: + self._test_add_ensembles_direct() + + def _test_add_ensembles_direct(self): """ test_case to check if adding the ensemble outputs works. Many things should be caught by tests in the merge_dict test, but still good to test a bit here """ @@ -297,6 +317,10 @@ class TestEnsemble(unittest.TestCase): ) def test_compare_added_systems_with_double_deferred_systems(self): + with Capturing() as output: + self._test_compare_added_systems_with_double_deferred_systems() + + def _test_compare_added_systems_with_double_deferred_systems(self): """ test to run 2 systems without deferring, and merging them manually. Then run 2 systems with defer and then output them. """ @@ -384,6 +408,10 @@ class TestEnsemble(unittest.TestCase): ) def test_combine_with_empty_json(self): + with Capturing() as output: + self._test_combine_with_empty_json() + + def _test_combine_with_empty_json(self): """ Test for merging with an empty dict """ @@ -408,6 +436,11 @@ class TestEnsemble(unittest.TestCase): self.assertEqual(merge_dicts(output_json_1, {}), output_json_1, assert_message) ############# + + # def test_full_ensemble_output(self): + # with Capturing() as output: + # self._test_full_ensemble_output() + def _test_full_ensemble_output(self): """ Function to just output the whole ensemble diff --git a/binarycpython/tests/test_custom_logging.py b/binarycpython/tests/test_custom_logging.py index 4bab0a4c7ab179d4df2048cb1ca0e32398bf585c..05792732464a1672bd6d9304ae8d64084ab779d9 100644 --- a/binarycpython/tests/test_custom_logging.py +++ b/binarycpython/tests/test_custom_logging.py @@ -5,6 +5,7 @@ Unittests for the custom_logging module import unittest from binarycpython.utils.custom_logging_functions import * +from binarycpython.utils.functions import Capturing binary_c_temp_dir = temp_dir() @@ -15,6 +16,11 @@ class test_custom_logging(unittest.TestCase): """ def test_autogen_C_logging_code(self): + with Capturing() as output: + self._test_autogen_C_logging_code() + print("\n".join(output)) + + def _test_autogen_C_logging_code(self): """ Tests for the autogeneration of a print statement from a dictionary. and then checking if the output is correct """ @@ -43,6 +49,11 @@ class test_custom_logging(unittest.TestCase): self.assertEqual(output_3, None, msg="Output should be None") def test_binary_c_log_code(self): + with Capturing() as output: + self._test_binary_c_log_code() + print("\n".join(output)) + + def _test_binary_c_log_code(self): """ Test to see if passing a print statement to the function results in correct binary_c output """ @@ -61,6 +72,11 @@ class test_custom_logging(unittest.TestCase): ) def test_binary_c_write_log_code(self): + with Capturing() as output: + self._test_binary_c_write_log_code() + print("\n".join(output)) + + def _test_binary_c_write_log_code(self): """ Tests to see if writing the code to a file and reading that out again is the same """ @@ -85,6 +101,11 @@ class test_custom_logging(unittest.TestCase): self.assertEqual(repr(input_1), content_file, msg="Contents are not similar") def test_from_binary_c_config(self): + with Capturing() as output: + self._test_from_binary_c_config() + print("\n".join(output)) + + def _test_from_binary_c_config(self): """ Tests for interfacing with binary_c-config """ @@ -109,6 +130,11 @@ class test_custom_logging(unittest.TestCase): self.assertEqual(output_2, "2.1.7", msg="binary_c version doesnt match") def test_return_compilation_dict(self): + with Capturing() as output: + self._test_return_compilation_dict() + print("\n".join(output)) + + def _test_return_compilation_dict(self): """ Tests to see if the compilation dictionary contains the correct keys """ @@ -125,6 +151,11 @@ class test_custom_logging(unittest.TestCase): self.assertTrue("inc" in output) def test_create_and_load_logging_function(self): + with Capturing() as output: + self._test_create_and_load_logging_function() + print("\n".join(output)) + + def _test_create_and_load_logging_function(self): """ Tests checking the output of create_and_load_logging_function. Should return a valid memory int and a correct filename """ diff --git a/binarycpython/tests/test_distributions.py b/binarycpython/tests/test_distributions.py index 4c1204f09dec936dd18a570e096f9841f412a113..9a4bd344188e8176b37821d3c2d622d798c9b05a 100644 --- a/binarycpython/tests/test_distributions.py +++ b/binarycpython/tests/test_distributions.py @@ -6,7 +6,7 @@ import unittest from binarycpython.utils.distribution_functions import * from binarycpython.utils.useful_funcs import calc_sep_from_period - +from binarycpython.utils.functions import Capturing class TestDistributions(unittest.TestCase): """ @@ -29,6 +29,10 @@ class TestDistributions(unittest.TestCase): self.tolerance = 1e-5 def test_setopts(self): + with Capturing() as output: + self._test_setopts() + + def _test_setopts(self): """ Unittest for function set_opts """ @@ -45,6 +49,10 @@ class TestDistributions(unittest.TestCase): self.assertTrue(output_dict_2 == updated_dict) def test_flat(self): + with Capturing() as output: + self._test_flat() + + def _test_flat(self): """ Unittest for the function flat """ @@ -55,6 +63,10 @@ class TestDistributions(unittest.TestCase): self.assertEqual(output_1, 1.0) def test_number(self): + with Capturing() as output: + self._test_number() + + def _test_number(self): """ Unittest for function number """ @@ -65,6 +77,10 @@ class TestDistributions(unittest.TestCase): self.assertEqual(input_1, output_1) def test_const(self): + with Capturing() as output: + self._test_const() + + def _test_const(self): """ Unittest for function const """ @@ -80,6 +96,10 @@ class TestDistributions(unittest.TestCase): ) def test_powerlaw(self): + with Capturing() as output: + self._test_powerlaw() + + def _test_powerlaw(self): """ unittest for the powerlaw test """ @@ -108,6 +128,10 @@ class TestDistributions(unittest.TestCase): self.assertRaises(ValueError, powerlaw, 1, 100, -1, 10) def test_three_part_power_law(self): + with Capturing() as output: + self._test_three_part_power_law() + + def _test_three_part_power_law(self): """ unittest for three_part_power_law """ @@ -142,6 +166,10 @@ class TestDistributions(unittest.TestCase): ) def test_Kroupa2001(self): + with Capturing() as output: + self._test_Kroupa2001() + + def _test_Kroupa2001(self): """ unittest for three_part_power_law """ @@ -173,6 +201,10 @@ class TestDistributions(unittest.TestCase): ) def test_ktg93(self): + with Capturing() as output: + self._test_ktg93() + + def _test_ktg93(self): """ unittest for three_part_power_law """ @@ -204,6 +236,10 @@ class TestDistributions(unittest.TestCase): ) def test_imf_tinsley1980(self): + with Capturing() as output: + self._test_imf_tinsley1980() + + def _test_imf_tinsley1980(self): """ Unittest for function imf_tinsley1980 """ @@ -215,6 +251,10 @@ class TestDistributions(unittest.TestCase): ) def test_imf_scalo1986(self): + with Capturing() as output: + self._test_imf_scalo1986() + + def _test_imf_scalo1986(self): """ Unittest for function imf_scalo1986 """ @@ -226,6 +266,10 @@ class TestDistributions(unittest.TestCase): ) def test_imf_scalo1998(self): + with Capturing() as output: + self._test_imf_scalo1998() + + def _test_imf_scalo1998(self): """ Unittest for function imf_scalo1986 """ @@ -237,6 +281,10 @@ class TestDistributions(unittest.TestCase): ) def test_imf_chabrier2003(self): + with Capturing() as output: + self._test_imf_chabrier2003() + + def _test_imf_chabrier2003(self): """ Unittest for function imf_chabrier2003 """ @@ -261,6 +309,10 @@ class TestDistributions(unittest.TestCase): ) def test_duquennoy1991(self): + with Capturing() as output: + self._test_duquennoy1991() + + def _test_duquennoy1991(self): """ Unittest for function duquennoy1991 """ @@ -268,6 +320,10 @@ class TestDistributions(unittest.TestCase): self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) def test_gaussian(self): + with Capturing() as output: + self._test_gaussian() + + def _test_gaussian(self): """ unittest for three_part_power_law """ @@ -299,6 +355,10 @@ class TestDistributions(unittest.TestCase): ) def test_Arenou2010_binary_fraction(self): + with Capturing() as output: + self._test_Arenou2010_binary_fraction() + + def _test_Arenou2010_binary_fraction(self): """ unittest for three_part_power_law """ @@ -324,6 +384,10 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg) def test_raghavan2010_binary_fraction(self): + with Capturing() as output: + self._test_raghavan2010_binary_fraction() + + def _test_raghavan2010_binary_fraction(self): """ unittest for three_part_power_law """ @@ -342,6 +406,10 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg) def test_Izzard2012_period_distribution(self): + with Capturing() as output: + self._test_Izzard2012_period_distribution() + + def _test_Izzard2012_period_distribution(self): """ unittest for three_part_power_law """ @@ -399,6 +467,10 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg) def test_flatsections(self): + with Capturing() as output: + self._test_flatsections() + + def _test_flatsections(self): """ unittest for three_part_power_law """ @@ -426,6 +498,10 @@ class TestDistributions(unittest.TestCase): self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg) def test_sana12(self): + with Capturing() as output: + self._test_sana12() + + def _test_sana12(self): """ unittest for three_part_power_law """ diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py index a3ad3d9013f4cd289541e493a15f7698374bd4ab..5f105d00d1ed12f1e12e72c9169d13c5720e1c72 100644 --- a/binarycpython/tests/test_functions.py +++ b/binarycpython/tests/test_functions.py @@ -45,13 +45,22 @@ class test_verbose_print(unittest.TestCase): Unittests for verbose_print """ + def test_print(self): + with Capturing() as output: + self._test_print() + + def _test_print(self): """ Tests whether something gets printed """ verbose_print("test1", 1, 0) def test_not_print(self): + with Capturing() as output: + self._test_not_print() + + def _test_not_print(self): """ Tests whether nothing gets printed. """ @@ -65,6 +74,10 @@ class test_remove_file(unittest.TestCase): """ def test_remove_file(self): + with Capturing() as output: + self._test_remove_file() + + def _test_remove_file(self): """ Test to remove a file """ @@ -77,6 +90,10 @@ class test_remove_file(unittest.TestCase): remove_file(os.path.join(binary_c_temp_dir, "test_remove_file_file.txt")) def test_remove_nonexisting_file(self): + with Capturing() as output: + self._test_remove_nonexisting_file() + + def _test_remove_nonexisting_file(self): """ Test to try to remove a nonexistant file """ @@ -92,6 +109,10 @@ class test_temp_dir(unittest.TestCase): """ def test_create_temp_dir(self): + with Capturing() as output: + self._test_create_temp_dir() + + def _test_create_temp_dir(self): """ Test making a temp directory and comparing that to what it should be """ @@ -113,6 +134,10 @@ class test_create_hdf5(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ Test that creates files, packs them in a hdf5 file and checks the contents """ @@ -146,6 +171,10 @@ class test_return_binary_c_version_info(unittest.TestCase): """ def test_not_parsed(self): + with Capturing() as output: + self._test_not_parsed() + + def _test_not_parsed(self): """ Test for the raw version_info output """ @@ -158,6 +187,10 @@ class test_return_binary_c_version_info(unittest.TestCase): self.assertIn("SIGMA_THOMPSON", version_info) def test_parsed(self): + with Capturing() as output: + self._test_parsed() + + def _test_parsed(self): """ Test for the parssed version_info """ @@ -182,6 +215,10 @@ class test_parse_binary_c_version_info(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ Test for the parsed versio info, more detailed """ @@ -215,6 +252,10 @@ class test_output_lines(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ Test to check if the shape and contents of output_lines is correct """ @@ -234,6 +275,10 @@ class test_example_parse_output(unittest.TestCase): """ def test_normal_output(self): + with Capturing() as output: + self._test_normal_output() + + def _test_normal_output(self): """ Test checking if parsed output with a custom logging line works correctly """ @@ -263,6 +308,10 @@ class test_example_parse_output(unittest.TestCase): self.assertTrue(len(parsed_output["time"]) > 0) def test_mismatch_output(self): + with Capturing() as output: + self._test_mismatch_output() + + def _test_mismatch_output(self): """ Test checking if parsed output with a mismatching headerline doesnt have any contents """ @@ -294,6 +343,10 @@ class test_get_defaults(unittest.TestCase): """ def test_no_filter(self): + with Capturing() as output: + self._test_no_filter() + + def _test_no_filter(self): """ Test checking if the defaults without filtering contains non-filtered content """ @@ -307,6 +360,10 @@ class test_get_defaults(unittest.TestCase): self.assertIn("use_fixed_timestep_%d", output_1.keys()) def test_filter(self): + with Capturing() as output: + self._test_filter() + + def _test_filter(self): """ Test checking filtering works correctly """ @@ -326,7 +383,12 @@ class test_get_arg_keys(unittest.TestCase): Unittests for function get_arg_keys """ + def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ Test checking if some of the keys are indeed in the list """ @@ -346,6 +408,10 @@ class test_create_arg_string(unittest.TestCase): """ def test_default(self): + with Capturing() as output: + self._test_default() + + def _test_default(self): """ Test checking if the argstring is correct """ @@ -355,6 +421,10 @@ class test_create_arg_string(unittest.TestCase): self.assertEqual(argstring, "separation 40000 M_1 10") def test_sort(self): + with Capturing() as output: + self._test_sort() + + def _test_sort(self): """ Test checking if the argstring with a different ordered dict is also in a differnt order """ @@ -364,6 +434,10 @@ class test_create_arg_string(unittest.TestCase): self.assertEqual(argstring, "M_1 10 separation 40000") def test_filtered(self): + with Capturing() as output: + self._test_filtered() + + def _test_filtered(self): """ Test if filtering works """ @@ -379,6 +453,10 @@ class test_get_help(unittest.TestCase): """ def test_input_normal(self): + with Capturing() as output: + self._test_input_normal() + + def _test_input_normal(self): """ Function to test the get_help function """ @@ -390,6 +468,10 @@ class test_get_help(unittest.TestCase): ) def test_no_input(self): + with Capturing() as output: + self._test_no_input() + + def _test_no_input(self): """ Test if the result is None if called without input """ @@ -398,6 +480,10 @@ class test_get_help(unittest.TestCase): self.assertIsNone(output) def test_wrong_input(self): + with Capturing() as output: + self._test_wrong_input() + + def _test_wrong_input(self): """ Test if the result is None if called with an unknown input """ @@ -415,6 +501,10 @@ class test_get_help_all(unittest.TestCase): """ def test_all_output(self): + with Capturing() as output: + self._test_all_output() + + def _test_all_output(self): """ Function to test the get_help_all function """ @@ -441,6 +531,10 @@ class test_get_help_super(unittest.TestCase): """ def test_all_output(self): + with Capturing() as output: + self._test_all_output() + + def _test_all_output(self): """ Function to test the get_help_super function """ @@ -467,6 +561,10 @@ class test_make_build_text(unittest.TestCase): """ def test_output(self): + with Capturing() as output: + self._test_output() + + def _test_output(self): """ Test checking the contents of the build_text """ @@ -493,6 +591,10 @@ class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase): """ def test_bad_outputname(self): + with Capturing() as output: + self._test_bad_outputname() + + def _test_bad_outputname(self): """ Test checking if None is returned when a bad input name is provided """ @@ -505,6 +607,10 @@ class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase): self.assertIsNone(output_1) def test_checkfile(self): + with Capturing() as output: + self._test_checkfile() + + def _test_checkfile(self): """ Test checking if the file is created correctly """ @@ -523,6 +629,10 @@ class test_inspect_dict(unittest.TestCase): """ def test_compare_dict(self): + with Capturing() as output: + self._test_compare_dict() + + def _test_compare_dict(self): """ Test checking if inspect_dict returns the correct structure by comparing it to known value """ @@ -545,6 +655,10 @@ class test_inspect_dict(unittest.TestCase): self.assertTrue(compare_dict == output_dict) def test_compare_dict_with_print(self): + with Capturing() as output: + self._test_compare_dict_with_print() + + def _test_compare_dict_with_print(self): """ Test checking output is printed """ @@ -565,6 +679,10 @@ class test_merge_dicts(unittest.TestCase): """ def test_empty(self): + with Capturing() as output: + self._test_empty() + + def _test_empty(self): """ Test merging an empty dict """ @@ -581,6 +699,10 @@ class test_merge_dicts(unittest.TestCase): self.assertTrue(output_dict == input_dict) def test_unequal_types(self): + with Capturing() as output: + self._test_unequal_types() + + def _test_unequal_types(self): """ Test merging unequal types: should raise valueError """ @@ -591,6 +713,10 @@ class test_merge_dicts(unittest.TestCase): self.assertRaises(ValueError, merge_dicts, dict_1, dict_2) def test_bools(self): + with Capturing() as output: + self._test_bools() + + def _test_bools(self): """ Test merging dict with booleans """ @@ -603,6 +729,10 @@ class test_merge_dicts(unittest.TestCase): self.assertTrue(output_dict["bool"]) def test_ints(self): + with Capturing() as output: + self._test_ints() + + def _test_ints(self): """ Test merging dict with ints """ @@ -615,6 +745,10 @@ class test_merge_dicts(unittest.TestCase): self.assertEqual(output_dict["int"], 3) def test_floats(self): + with Capturing() as output: + self._test_floats() + + def _test_floats(self): """ Test merging dict with floats """ @@ -626,6 +760,10 @@ class test_merge_dicts(unittest.TestCase): self.assertTrue(isinstance(output_dict["float"], float)) self.assertEqual(output_dict["float"], 9.1) + def test_lists(self): + with Capturing() as output: + self._test_lists() + def test_lists(self): """ Test merging dict with lists @@ -639,6 +777,10 @@ class test_merge_dicts(unittest.TestCase): self.assertEqual(output_dict["list"], [1, 2, 3, 4]) def test_dicts(self): + with Capturing() as output: + self._test_dicts() + + def _test_dicts(self): """ Test merging dict with dicts """ @@ -653,6 +795,10 @@ class test_merge_dicts(unittest.TestCase): ) def test_unsupported(self): + with Capturing() as output: + self._test_unsupported() + + def _test_unsupported(self): """ Test merging dict with unsupported types. should raise ValueError """ @@ -670,6 +816,10 @@ class test_binaryc_json_serializer(unittest.TestCase): """ def test_not_function(self): + with Capturing() as output: + self._test_not_function() + + def _test_not_function(self): """ Test passing an object that doesnt get turned in to a string """ @@ -679,6 +829,10 @@ class test_binaryc_json_serializer(unittest.TestCase): self.assertTrue(stringo == output) def test_function(self): + with Capturing() as output: + self._test_function() + + def _test_function(self): """ Test passing an object that gets turned in to a string: a function """ @@ -694,6 +848,10 @@ class test_handle_ensemble_string_to_json(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ Test passing string representation of a dictionary. """ diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index 3eed2895cd321de603c20a25d598f182f9aa03ce..c846ccd7ecb8cb5c2b49774d4c06b558c29c6e42 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -19,6 +19,7 @@ from binarycpython.utils.functions import ( extract_ensemble_json_from_string, merge_dicts, remove_file, + Capturing, ) from binarycpython.utils.custom_logging_functions import binary_c_log_code @@ -72,6 +73,10 @@ class test_Population(unittest.TestCase): """ def test_setup(self): + with Capturing() as output: + self._test_setup() + + def _test_setup(self): """ Unittests for function _setup """ @@ -88,6 +93,10 @@ class test_Population(unittest.TestCase): self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) def test_set(self): + with Capturing() as output: + self._test_set() + + def _test_set(self): """ Unittests for function set """ @@ -109,6 +118,10 @@ class test_Population(unittest.TestCase): self.assertTrue(test_pop.grid_options["amt_cores"] == 2) def test_cmdline(self): + with Capturing() as output: + self._test_cmdline() + + def _test_cmdline(self): """ Unittests for function parse_cmdline """ @@ -146,6 +159,10 @@ class test_Population(unittest.TestCase): sys.argv = prev_sysargv.copy() def test__return_argline(self): + with Capturing() as output: + self._test__return_argline() + + def _test__return_argline(self): """ Unittests for the function _return_argline """ @@ -167,6 +184,10 @@ class test_Population(unittest.TestCase): ) def test_add_grid_variable(self): + with Capturing() as output: + self._test_add_grid_variable() + + def _test_add_grid_variable(self): """ Unittests for the function add_grid_variable @@ -210,6 +231,10 @@ class test_Population(unittest.TestCase): self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) def test_return_population_settings(self): + with Capturing() as output: + self._test_return_population_settings() + + def _test_return_population_settings(self): """ Unittests for the function return_population_settings """ @@ -233,6 +258,10 @@ class test_Population(unittest.TestCase): self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") def test__return_binary_c_version_info(self): + with Capturing() as output: + self._test__return_binary_c_version_info() + + def _test__return_binary_c_version_info(self): """ Unittests for the function _return_binary_c_version_info """ @@ -260,6 +289,10 @@ class test_Population(unittest.TestCase): self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) def test__return_binary_c_defaults(self): + with Capturing() as output: + self._test__return_binary_c_defaults() + + def _test__return_binary_c_defaults(self): """ Unittests for the function _return_binary_c_defaults """ @@ -271,6 +304,10 @@ class test_Population(unittest.TestCase): self.assertIn("metallicity", binary_c_defaults) def test_return_all_info(self): + with Capturing() as output: + self._test_return_all_info() + + def _test_return_all_info(self): """ Unittests for the function return_all_info Not going to do too much tests here, just check if they are not empty @@ -290,6 +327,10 @@ class test_Population(unittest.TestCase): self.assertNotEqual(all_info["binary_c_help_all"], {}) def test_export_all_info(self): + with Capturing() as output: + self._test_export_all_info() + + def _test_export_all_info(self): """ Unittests for the function export_all_info """ @@ -350,6 +391,10 @@ class test_Population(unittest.TestCase): ) def test__cleanup_defaults(self): + with Capturing() as output: + self._test__cleanup_defaults() + + def _test__cleanup_defaults(self): """ Unittests for the function _cleanup_defaults """ @@ -359,6 +404,10 @@ class test_Population(unittest.TestCase): self.assertNotIn("help_all", cleaned_up_defaults) def test__increment_probtot(self): + with Capturing() as output: + self._test__increment_probtot() + + def _test__increment_probtot(self): """ Unittests for the function _increment_probtot """ @@ -368,6 +417,10 @@ class test_Population(unittest.TestCase): self.assertEqual(test_pop.grid_options["_probtot"], 0.5) def test__increment_count(self): + with Capturing() as output: + self._test__increment_count() + + def _test__increment_count(self): """ Unittests for the function _increment_probtot """ @@ -377,6 +430,10 @@ class test_Population(unittest.TestCase): self.assertEqual(test_pop.grid_options["_count"], 1) def test__dict_from_line_source_file(self): + with Capturing() as output: + self._test__dict_from_line_source_file() + + def _test__dict_from_line_source_file(self): """ Unittests for the function _dict_from_line_source_file """ @@ -398,6 +455,10 @@ class test_Population(unittest.TestCase): self.assertTrue(argdict["metallicity"] == 0.02) def test_evolve_single(self): + with Capturing() as output: + self._test_evolve_single() + + def _test_evolve_single(self): """ Unittests for the function evolve_single """ @@ -459,6 +520,10 @@ class test_grid_evolve(unittest.TestCase): """ def test_grid_evolve_1_thread(self): + with Capturing() as output: + self._test_grid_evolve_1_thread() + + def _test_grid_evolve_1_thread(self): """ Unittests to see if 1 thread does all the systems """ @@ -492,6 +557,10 @@ class test_grid_evolve(unittest.TestCase): self.assertTrue(analytics["total_count"] == 10) def test_grid_evolve_2_threads(self): + with Capturing() as output: + self._test_grid_evolve_2_threads() + + def _test_grid_evolve_2_threads(self): """ Unittests to see if multiple threads handle the all the systems correctly """ @@ -523,6 +592,10 @@ class test_grid_evolve(unittest.TestCase): self.assertTrue(analytics["total_count"] == 10) def test_grid_evolve_2_threads_with_custom_logging(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_custom_logging() + + def _test_grid_evolve_2_threads_with_custom_logging(self): """ Unittests to see if multiple threads do the custom logging correctly """ @@ -582,6 +655,10 @@ class test_grid_evolve(unittest.TestCase): remove_file(output_name) def test_grid_evolve_with_condition_error(self): + with Capturing() as output: + self._test_grid_evolve_with_condition_error() + + def _test_grid_evolve_with_condition_error(self): """ Unittests to see if the threads catch the errors correctly. """ @@ -593,7 +670,7 @@ class test_grid_evolve(unittest.TestCase): test_pop.set(failed_systems_threshold=4) CUSTOM_LOGGING_STRING_WITH_EXIT = """ - Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits"); + Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry"); Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", // stardata->model.time, // 1 @@ -675,6 +752,10 @@ class test_grid_evolve(unittest.TestCase): self.assertRaises(ValueError, test_pop.evolve) def test_grid_evolve_no_grid_variables(self): + with Capturing() as output: + self._test_grid_evolve_no_grid_variables() + + def _test_grid_evolve_no_grid_variables(self): """ Unittests to see if errors are raised if there are no grid variables """ @@ -686,6 +767,10 @@ class test_grid_evolve(unittest.TestCase): self.assertRaises(ValueError, test_pop.evolve) def test_grid_evolve_2_threads_with_ensemble_direct_output(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_direct_output() + + def _test_grid_evolve_2_threads_with_ensemble_direct_output(self): """ Unittests to see if multiple threads output the ensemble information to files correctly """ @@ -755,6 +840,10 @@ class test_grid_evolve(unittest.TestCase): self.assertNotEqual(ensemble_json["number_counts"], {}) def test_grid_evolve_2_threads_with_ensemble_combining(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_combining() + + def _test_grid_evolve_2_threads_with_ensemble_combining(self): """ Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid """ @@ -807,6 +896,10 @@ class test_grid_evolve(unittest.TestCase): ) def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods() + + def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): """ Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same """ diff --git a/binarycpython/tests/test_grid_options_defaults.py b/binarycpython/tests/test_grid_options_defaults.py index d344e32163e296fa9dee4149903971b152f2f2cf..bc941176181ac4526d7bb59f10a0c02600ef1e0e 100644 --- a/binarycpython/tests/test_grid_options_defaults.py +++ b/binarycpython/tests/test_grid_options_defaults.py @@ -5,6 +5,7 @@ Unittests for grid_options_defaults module import unittest from binarycpython.utils.grid_options_defaults import * +from binarycpython.utils.functions import * binary_c_temp_dir = temp_dir() @@ -15,6 +16,10 @@ class test_grid_options_defaults(unittest.TestCase): """ def test_grid_options_help(self): + with Capturing() as output: + self._test_grid_options_help() + + def _test_grid_options_help(self): """ Unit tests for the grid_options_help function """ @@ -44,6 +49,10 @@ class test_grid_options_defaults(unittest.TestCase): # self.assertEqual(result_3[input_3], "", msg="description should be empty") def test_grid_options_description_checker(self): + with Capturing() as output: + self._test_grid_options_description_checker() + + def _test_grid_options_description_checker(self): """ Unit tests for the grid_options_description_checker function """ @@ -54,6 +63,10 @@ class test_grid_options_defaults(unittest.TestCase): self.assertTrue(output_1 > 0) def test_write_grid_options_to_rst_file(self): + with Capturing() as output: + self._test_write_grid_options_to_rst_file() + + def _test_write_grid_options_to_rst_file(self): """ Unit tests for the grid_options_description_checker function """ diff --git a/binarycpython/tests/test_hpc_functions.py b/binarycpython/tests/test_hpc_functions.py index ec173924927700601d45d3c9c88949a5679e9149..ddba58b7b5a925d0a2dcaef565a9608b37559fb9 100644 --- a/binarycpython/tests/test_hpc_functions.py +++ b/binarycpython/tests/test_hpc_functions.py @@ -3,3 +3,5 @@ Unittests for hpc_functions module """ from binarycpython.utils.hpc_functions import * + +# TODO: write tests for hpc functions \ No newline at end of file diff --git a/binarycpython/tests/test_plot_functions.py b/binarycpython/tests/test_plot_functions.py index 30b813a62332e118a84624ed94ef793b2d404b96..c8e9a1304432c543a7a0e79b3ebfa3250694d7e5 100644 --- a/binarycpython/tests/test_plot_functions.py +++ b/binarycpython/tests/test_plot_functions.py @@ -4,9 +4,11 @@ Unittests for plot_functions import unittest import numpy as np -from binarycpython.utils.plot_functions import * import matplotlib.pyplot as plt +from binarycpython.utils.plot_functions import * +from binarycpython.utils.functions import Capturing + # class test_(unittest.TestCase): # """ # Unittests for function @@ -15,13 +17,16 @@ import matplotlib.pyplot as plt # def test_1(self): # pass - class test_color_by_index(unittest.TestCase): """ Unittests for function color_by_index """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -38,6 +43,10 @@ class test_plot_system(unittest.TestCase): """ def test_mass_evolution_plot(self): + with Capturing() as output: + self._test_mass_evolution_plot() + + def _test_mass_evolution_plot(self): """ Test for setting plot_type = "mass_evolution" """ @@ -71,6 +80,10 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_orbit_evolution_plot(self): + with Capturing() as output: + self._test_orbit_evolution_plot() + + def _test_orbit_evolution_plot(self): """ Test for setting plot_type = "orbit_evolution" """ @@ -104,6 +117,10 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_hr_diagram_plot(self): + with Capturing() as output: + self._test_hr_diagram_plot() + + def _test_hr_diagram_plot(self): """ Test for setting plot_type = "hr_diagram" """ @@ -137,6 +154,10 @@ class test_plot_system(unittest.TestCase): # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) def test_unknown_plottype(self): + with Capturing() as output: + self._test_unknown_plottype() + + def _test_unknown_plottype(self): """ Test for non-existant setting plot_type = "hr_diagram" """ diff --git a/binarycpython/tests/test_run_system_wrapper.py b/binarycpython/tests/test_run_system_wrapper.py index 97558a9607b2bfa180da5445a96505b7ed92515a..87f0a5e0d54b59bacf821643ab06372113e9518b 100644 --- a/binarycpython/tests/test_run_system_wrapper.py +++ b/binarycpython/tests/test_run_system_wrapper.py @@ -3,3 +3,6 @@ Unittests for run_system_wrapper """ from binarycpython.utils.run_system_wrapper import * +from binarycpython.utils.functions import * + +# TODO: write tests for run_system_wrapper \ No newline at end of file diff --git a/binarycpython/tests/test_spacing_functions.py b/binarycpython/tests/test_spacing_functions.py index 72b6f3a63acfa5563c4057893e3ff5b82ad91adb..cd50b86a177458aa2384f772bf68dde0fdc6d7ee 100644 --- a/binarycpython/tests/test_spacing_functions.py +++ b/binarycpython/tests/test_spacing_functions.py @@ -6,7 +6,7 @@ Unittests for spacing_functions module import unittest import numpy as np from binarycpython.utils.spacing_functions import * - +from binarycpython.utils.functions import * class test_spacing_functions(unittest.TestCase): """ @@ -14,6 +14,10 @@ class test_spacing_functions(unittest.TestCase): """ def test_const(self): + with Capturing() as output: + self._test_const() + + def _test_const(self): """ Unittest for function const """ diff --git a/binarycpython/tests/test_stellar_types.py b/binarycpython/tests/test_stellar_types.py index 7091211b5fa19af97716a067d6f3bab6a9c8a09d..9ed515f7a726aecceda6c51621606d95994f90c0 100644 --- a/binarycpython/tests/test_stellar_types.py +++ b/binarycpython/tests/test_stellar_types.py @@ -5,3 +5,4 @@ Unittests for stellar_types module import unittest from binarycpython.utils.stellar_types import * +from binarycpython.utils.functions import * diff --git a/binarycpython/tests/test_useful_funcs.py b/binarycpython/tests/test_useful_funcs.py index d7f77d1d9d7a5bcea63ff31d3f8d0f657e7db51e..64531c4b3bdefb2a2ebc208dd6d3833e3d90a3ca 100644 --- a/binarycpython/tests/test_useful_funcs.py +++ b/binarycpython/tests/test_useful_funcs.py @@ -4,7 +4,9 @@ Unittests for useful_funcs module import unittest import numpy as np + from binarycpython.utils.useful_funcs import * +from binarycpython.utils.functions import * # class test_(unittest.TestCase): # """ @@ -23,6 +25,10 @@ class test_calc_period_from_sep(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -39,6 +45,10 @@ class test_calc_sep_from_period(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -53,6 +63,10 @@ class test_roche_lobe(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -72,6 +86,10 @@ class test_ragb(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -88,6 +106,10 @@ class test_rzams(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ @@ -117,6 +139,10 @@ class test_zams_collission(unittest.TestCase): """ def test_1(self): + with Capturing() as output: + self._test_1() + + def _test_1(self): """ First test """ diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py index c73091b424fa0b0bf91df2fd4ab53ee0d3b18e9b..adad1a08ea636b3edd296ef8fc093b05839bb623 100644 --- a/binarycpython/utils/functions.py +++ b/binarycpython/utils/functions.py @@ -14,15 +14,34 @@ import tempfile import copy import inspect import ast -from typing import Union, Any -from collections import defaultdict +import sys import h5py import numpy as np +from io import StringIO +from typing import Union, Any +from collections import defaultdict + from binarycpython import _binary_c_bindings +def is_capsule(o): + t = type(o) + return t.__module__ == 'builtins' and t.__name__ == 'PyCapsule' + +class Capturing(list): + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + return self + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + del self._stringio # free up some memory + sys.stdout = self._stdout + + + ######################################################## # utility functions ########################################################