""" Unittests for the c-bindings """ import os import sys import time import json import textwrap import unittest import numpy as np from binarycpython import _binary_c_bindings from binarycpython.utils.functions import ( binarycDecoder, temp_dir, inspect_dict, merge_dicts, handle_ensemble_string_to_json, verbose_print, extract_ensemble_json_from_string, is_capsule, Capturing, ) # https://docs.python.org/3/library/unittest.html TMP_DIR = temp_dir() os.makedirs(os.path.join(TMP_DIR, "test"), exist_ok=True) #### some useful functions def return_argstring( m1=15.0, m2=14.0, separation=0, orbital_period=453000000000, eccentricity=0.0, metallicity=0.02, max_evolution_time=15000, defer_ensemble=0, ensemble_filters_off=1, ensemble_filter="SUPERNOVAE", ): """ Function to make a argstring that we can use in these tests """ # Make the argstrings argstring_template = "binary_c M_1 {0:g} M_2 {1:g} separation {2:g} orbital_period {3:g} \ eccentricity {4:g} metallicity {5:g} max_evolution_time {6:g} ensemble 1 ensemble_defer {7} \ ensemble_filters_off {8} ensemble_filter_{9} 1 probability 0.1" argstring = argstring_template.format( m1, m2, separation, orbital_period, eccentricity, metallicity, max_evolution_time, defer_ensemble, ensemble_filters_off, ensemble_filter, ) return argstring ####################################################################################################################################################### ### General run_system test ####################################################################################################################################################### class test_run_system(unittest.TestCase): """ Unit test for run_system """ def test_output(self): with Capturing() as output: self._test_output() def _test_output(self): """ General test if run_system works """ print(self.id()) m1 = 15.0 # Msun m2 = 14.0 # Msun separation = 0 # 0 = ignored, use period orbital_period = 4530.0 # days eccentricity = 0.0 metallicity = 0.02 max_evolution_time = 15000 argstring = "binary_c M_1 {0:g} M_2 {1:g} separation {2:g} orbital_period {3:g} eccentricity {4:g} metallicity {5:g} max_evolution_time {6:g} ".format( m1, m2, separation, orbital_period, eccentricity, metallicity, max_evolution_time, ) output = _binary_c_bindings.run_system(argstring=argstring) self.assertIn( "SINGLE_STAR_LIFETIME", output, msg="Output didn't contain SINGLE_STAR_LIFETIME", ) ####################################################################################################################################################### ### memaddr test ####################################################################################################################################################### # TODO: Make some assertion tests in c class test_return_store_memaddr(unittest.TestCase): """ Unit test for return_store_memaddr """ def test_return_store_memaddr(self): with Capturing() as output: self._test_return_store_memaddr() def _test_return_store_memaddr(self): """ Test to see if the memory adress is returned properly """ output = _binary_c_bindings.return_store_memaddr() # print("function: test_return_store") # print("store memory adress:") # print(textwrap.indent(str(output), "\t")) self.assertTrue(is_capsule(output)) # self.assertNotEqual( # output, 0, "memory adress seems not to have a correct value" # ) # TODO: check if we can built in some signal for how successful this was. _ = _binary_c_bindings.free_store_memaddr(output) ####################################################################################################################################################### ### ensemble tests ####################################################################################################################################################### class TestEnsemble(unittest.TestCase): """ Unittests for handling the ensemble outputs and adding those """ def __init__(self, *args, **kwargs): """ init function """ super(TestEnsemble, self).__init__(*args, **kwargs) def test_return_persistent_data_memaddr(self): with Capturing() as output: self._test_return_persistent_data_memaddr() def _test_return_persistent_data_memaddr(self): """ Test case to check if the memory adress has been created succesfully """ print(self.id()) output = _binary_c_bindings.return_persistent_data_memaddr() self.assertTrue(is_capsule(output), msg="Object must be an integer") # self.assertNotEqual( # output, 0, "memory adress seems not to have a correct value" # ) def test_minimal_ensemble_output(self): with Capturing() as output: self._test_minimal_ensemble_output() def _test_minimal_ensemble_output(self): """ test_case to check if the ensemble output is correctly output """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun # Direct output commands argstring_1 = return_argstring( m1=m1, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=0, # no defer to memory location. just output it ) output_1 = _binary_c_bindings.run_system(argstring=argstring_1) # Check if the ENSEMBLE_JSON is uberhaubt in the output self.assertIn("ENSEMBLE_JSON", output_1) test_json = extract_ensemble_json_from_string(output_1) self.assertIn("number_counts", test_json) self.assertNotEqual(test_json["number_counts"], {}) def test_minimal_ensemble_output_defer(self): with Capturing() as output: self._test_minimal_ensemble_output_defer() def _test_minimal_ensemble_output_defer(self): """ test_case to check if the ensemble output is correctly output, by using defer command and freeing+outputting """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun persistent_data_memaddr = _binary_c_bindings.return_persistent_data_memaddr() # Direct output commands argstring_1 = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1, # no defer to memory location. just output it ) output_1 = _binary_c_bindings.run_system( argstring=argstring_1, persistent_data_memaddr=persistent_data_memaddr ) # self.assertNotIn("ENSEMBLE_JSON", output_1) # free memory and output the stuff. raw_json_output = ( _binary_c_bindings.free_persistent_data_memaddr_and_return_json_output( persistent_data_memaddr ) ) ensemble_json_output = extract_ensemble_json_from_string(raw_json_output) self.assertIn("number_counts", ensemble_json_output) self.assertNotEqual(ensemble_json_output["number_counts"], {}) def test_add_ensembles_direct(self): with Capturing() as output: self._test_add_ensembles_direct() def _test_add_ensembles_direct(self): """ test_case to check if adding the ensemble outputs works. Many things should be caught by tests in the merge_dict test, but still good to test a bit here """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun # Direct output commands argstring_1 = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", # no defer to memory location. just output it ) argstring_2 = return_argstring( m1=10, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", # no defer to memory location. just output it ) # output_1 = _binary_c_bindings.run_system(argstring=argstring_1) output_2 = _binary_c_bindings.run_system(argstring=argstring_2) # output_json_1 = extract_ensemble_json_from_string(output_1) output_json_2 = extract_ensemble_json_from_string(output_2) # merged_dict = merge_dicts(output_json_1, output_json_2) self.assertIn("number_counts", merged_dict) self.assertIn("stellar_type", merged_dict["number_counts"]) for key in output_json_1["number_counts"]["stellar_type"]["0"]: self.assertIn(key, merged_dict["number_counts"]["stellar_type"]["0"]) for key in output_json_2["number_counts"]["stellar_type"]["0"]: self.assertIn(key, merged_dict["number_counts"]["stellar_type"]["0"]) # compare stuff: self.assertLess( np.abs( output_json_1["number_counts"]["stellar_type"]["0"]["CHeB"] + output_json_2["number_counts"]["stellar_type"]["0"]["CHeB"] - merged_dict["number_counts"]["stellar_type"]["0"]["CHeB"] ), 1e-10, ) self.assertLess( np.abs( output_json_1["number_counts"]["stellar_type"]["0"]["MS"] + output_json_2["number_counts"]["stellar_type"]["0"]["MS"] - merged_dict["number_counts"]["stellar_type"]["0"]["MS"] ), 1e-10, ) def test_compare_added_systems_with_double_deferred_systems(self): with Capturing() as output: self._test_compare_added_systems_with_double_deferred_systems() def _test_compare_added_systems_with_double_deferred_systems(self): """ test to run 2 systems without deferring, and merging them manually. Then run 2 systems with defer and then output them. """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun # Direct output commands argstring_1 = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", # no defer to memory location. just output it ) argstring_2 = return_argstring( m1=10, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", # no defer to memory location. just output it ) # output_1 = _binary_c_bindings.run_system(argstring=argstring_1) output_2 = _binary_c_bindings.run_system(argstring=argstring_2) # output_json_1 = extract_ensemble_json_from_string(output_1) output_json_2 = extract_ensemble_json_from_string(output_2) # merged_dict = merge_dicts(output_json_1, output_json_2) ############################### # Deferred setup persistent_data_memaddr = _binary_c_bindings.return_persistent_data_memaddr() argstring_1_deferred = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1, # no defer to memory location. just output it ) argstring_2_deferred = return_argstring( m1=10, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1, # no defer to memory location. just output it ) # run _ = _binary_c_bindings.run_system( argstring=argstring_1_deferred, persistent_data_memaddr=persistent_data_memaddr, ) _ = _binary_c_bindings.run_system( argstring=argstring_2_deferred, persistent_data_memaddr=persistent_data_memaddr, ) # output raw_json_output = ( _binary_c_bindings.free_persistent_data_memaddr_and_return_json_output( persistent_data_memaddr ) ) ensemble_json_output = extract_ensemble_json_from_string(raw_json_output) # CHeck all keys are present for key in merged_dict["number_counts"]["stellar_type"]["0"]: self.assertIn( key, ensemble_json_output["number_counts"]["stellar_type"]["0"] ) # Check if they are of the same value for key in merged_dict["number_counts"]["stellar_type"]["0"]: self.assertLess( np.abs( merged_dict["number_counts"]["stellar_type"]["0"][key] - ensemble_json_output["number_counts"]["stellar_type"]["0"][key] ), 1e-10, ) def test_combine_with_empty_json(self): with Capturing() as output: self._test_combine_with_empty_json() def _test_combine_with_empty_json(self): """ Test for merging with an empty dict """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun argstring_1 = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter="STELLAR_TYPE_COUNTS", # no defer to memory location. just output it ) output_1 = _binary_c_bindings.run_system(argstring=argstring_1) output_json_1 = extract_ensemble_json_from_string(output_1) assert_message = "combining output json with empty dict should give same result as initial json" self.assertEqual(merge_dicts(output_json_1, {}), output_json_1, assert_message) ############# # def test_full_ensemble_output(self): # with Capturing() as output: # self._test_full_ensemble_output() def _test_full_ensemble_output(self): """ Function to just output the whole ensemble TODO: put this one back """ print(self.id()) m1 = 2 # Msun m2 = 0.1 # Msun argstring_1 = return_argstring( m1=m1, m2=m2, orbital_period=1000000000, ensemble_filter=0 ) argstring_1 = return_argstring(defer_ensemble=0, ensemble_filters_off=0) output_1 = _binary_c_bindings.run_system(argstring=argstring_1) # output_json_1 = extract_ensemble_json_from_string(output_1) keys = json_1.keys() # assert statements: self.assertIn("number_counts", keys) self.assertIn("HRD", keys) self.assertIn("HRD(t)", keys) self.assertIn("Xyield", keys) self.assertIn("distributions", keys) self.assertIn("scalars", keys) ####################################################################################################################################################### ### ensemble tests ####################################################################################################################################################### if __name__ == "__main__": unittest.main()