From 56aef52e041cd8f31a447dfd58bd54c1143351a7 Mon Sep 17 00:00:00 2001 From: David Hendriks <davidhendriks93@gmail.com> Date: Tue, 12 Jan 2021 15:17:56 +0000 Subject: [PATCH] added tests for binarycpython grid --- binarycpython/tests/test_grid.py | 1238 +++++++++++++++++------------- 1 file changed, 702 insertions(+), 536 deletions(-) diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index 848380e4b..04f9a8d56 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -14,11 +14,32 @@ import datetime import numpy as np from binarycpython.utils.grid import Population -from binarycpython.utils.functions import temp_dir +from binarycpython.utils.functions import temp_dir, extract_ensemble_json_from_string, merge_dicts, remove_file from binarycpython.utils.custom_logging_functions import binary_c_log_code binary_c_temp_dir = temp_dir() +def parse_function_test_grid_evolve_2_threads_with_custom_logging(self, output): + """ + Simple parse function that directly appends all the output to a file + """ + + # Get some information from the + data_dir = self.custom_options['data_dir'] + + # make outputfilename + output_filename = os.path.join(data_dir, "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(self.grid_options['_population_id'], self.process_ID)) + + # Check directory, make if necessary + os.makedirs(data_dir, exist_ok=True) + + if not os.path.exists(output_filename): + with open(output_filename, 'w') as first_f: + first_f.write(output+'\n') + else: + with open(output_filename, 'a') as first_f: + first_f.write(output+'\n') + # class test_(unittest.TestCase): # """ @@ -34,561 +55,584 @@ binary_c_temp_dir = temp_dir() # """ -# class test_Population(unittest.TestCase): -# """ -# Unittests for function -# """ +class test_Population(unittest.TestCase): + """ + Unittests for function + """ + + def test_setup(self): + test_pop = Population() + + self.assertTrue("orbital_period" in test_pop.defaults) + self.assertTrue("metallicity" in test_pop.defaults) + self.assertNotIn("help_all", test_pop.cleaned_up_defaults) + self.assertEqual(test_pop.bse_options, {}) + self.assertEqual(test_pop.custom_options, {}) + self.assertEqual(test_pop.argline_dict, {}) + self.assertEqual(test_pop.persistent_data_memory_dict, {}) + self.assertTrue(test_pop.grid_options["parse_function"] == None) + self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) + + def test_set(self): + test_pop = Population() + test_pop.set(amt_cores=2) + test_pop.set(M_1=10) + test_pop.set(data_dir="/tmp/binary_c_python") + test_pop.set(ensemble_filter_SUPERNOVAE=1) + + self.assertIn("data_dir", test_pop.custom_options) + self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") + + # + self.assertTrue(test_pop.bse_options["M_1"] == 10) + self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) + + # + self.assertTrue(test_pop.grid_options["amt_cores"] == 2) + + def test_cmdline(self): + # copy old sys.argv values + prev_sysargv = sys.argv.copy() + + # make a dummy cmdline arg input + sys.argv = [ + "script", + "--cmdline", + "metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python", + ] + + # Set up population + test_pop = Population() + test_pop.set(data_dir="/tmp") + + # parse arguments + test_pop.parse_cmdline() + + # metallicity + self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) + self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002") + + # Amt cores + self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int)) + self.assertTrue(test_pop.grid_options["amt_cores"] == 2) + + # datadir + self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) + self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") + + # put back the other args if they exist + sys.argv = prev_sysargv.copy() + + def test__return_argline(self): + """ + Unittests for the function _return_argline + """ + + # Set up population + test_pop = Population() + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + + argline = test_pop._return_argline() + self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") + + # custom dict + argline2 = test_pop._return_argline( + {"example_parameter1": 10, "example_parameter2": "hello"} + ) + self.assertTrue( + argline2 == "binary_c example_parameter1 10 example_parameter2 hello" + ) + + def test_add_grid_variable(self): + """ + Unittests for the function add_grid_variable + + TODO: Should I test more here? + """ + + test_pop = Population() + + resolution = {"M_1": 10, "q": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + test_pop.add_grid_variable( + name="q", + longname="Mass ratio", + valuerange=["0.1/M_1", 1], + resolution="{}".format(resolution["q"]), + spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), + probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", + dphasevol="dq", + precode="M_2 = q * M_1", + parameter_name="M_2", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + self.assertIn("q", test_pop.grid_options["_grid_variables"]) + self.assertIn("lnm1", test_pop.grid_options["_grid_variables"]) + self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) + + def test_return_population_settings(self): + """ + Unittests for the function return_population_settings + """ + + test_pop = Population() + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + test_pop.set(amt_cores=2) + test_pop.set(data_dir="/tmp") + + population_settings = test_pop.return_population_settings() + + self.assertIn("bse_options", population_settings) + self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02) + self.assertTrue(population_settings["bse_options"]["M_1"] == 10) + + self.assertIn("grid_options", population_settings) + self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2) + + self.assertIn("custom_options", population_settings) + self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") + + def test__return_binary_c_version_info(self): + """ + Unittests for the function _return_binary_c_version_info + """ + + test_pop = Population() + binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True) + + self.assertTrue(isinstance(binary_c_version_info, dict)) + self.assertIn("isotopes", binary_c_version_info) + self.assertIn("argpairs", binary_c_version_info) + self.assertIn("ensembles", binary_c_version_info) + self.assertIn("macros", binary_c_version_info) + self.assertIn("dt_limits", binary_c_version_info) + self.assertIn("nucleosynthesis_sources", binary_c_version_info) + self.assertIn("miscellaneous", binary_c_version_info) + + self.assertIsNotNone(binary_c_version_info["argpairs"]) + self.assertIsNotNone(binary_c_version_info["ensembles"]) + self.assertIsNotNone(binary_c_version_info["macros"]) + self.assertIsNotNone(binary_c_version_info["dt_limits"]) + self.assertIsNotNone(binary_c_version_info["miscellaneous"]) + + if binary_c_version_info['miscellaneous']['NUCSYN'] == 'on': + self.assertIsNotNone(binary_c_version_info["isotopes"]) + self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) + + def test__return_binary_c_defaults(self): + """ + Unittests for the function _return_binary_c_defaults + """ + + test_pop = Population() + binary_c_defaults = test_pop._return_binary_c_defaults() + self.assertIn("probability", binary_c_defaults) + self.assertIn("phasevol", binary_c_defaults) + self.assertIn("metallicity", binary_c_defaults) + + def test_return_all_info(self): + """ + Unittests for the function return_all_info + Not going to do too much tests here, just check if they are not empty + """ + + test_pop = Population() + all_info = test_pop.return_all_info() + + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) -# def test_setup(self): -# test_pop = Population() - -# self.assertTrue("orbital_period" in test_pop.defaults) -# self.assertTrue("metallicity" in test_pop.defaults) -# self.assertNotIn("help_all", test_pop.cleaned_up_defaults) -# self.assertEqual(test_pop.bse_options, {}) -# self.assertEqual(test_pop.custom_options, {}) -# self.assertEqual(test_pop.argline_dict, {}) -# self.assertEqual(test_pop.persistent_data_memory_dict, {}) -# self.assertTrue(test_pop.grid_options["parse_function"] == None) -# self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) - -# def test_set(self): -# test_pop = Population() -# test_pop.set(amt_cores=2) -# test_pop.set(M_1=10) -# test_pop.set(data_dir="/tmp/binary_c_python") -# test_pop.set(ensemble_filter_SUPERNOVAE=1) - -# self.assertIn("data_dir", test_pop.custom_options) -# self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") - -# # -# self.assertTrue(test_pop.bse_options["M_1"] == 10) -# self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) - -# # -# self.assertTrue(test_pop.grid_options["amt_cores"] == 2) - -# def test_cmdline(self): -# # copy old sys.argv values -# prev_sysargv = sys.argv.copy() - -# # make a dummy cmdline arg input -# sys.argv = [ -# "script", -# "--cmdline", -# "metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python", -# ] - -# # Set up population -# test_pop = Population() -# test_pop.set(data_dir="/tmp") - -# # parse arguments -# test_pop.parse_cmdline() - -# # metallicity -# self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) -# self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002") - -# # Amt cores -# self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int)) -# self.assertTrue(test_pop.grid_options["amt_cores"] == 2) - -# # datadir -# self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) -# self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") - -# # put back the other args if they exist -# sys.argv = prev_sysargv.copy() - -# def test__return_argline(self): -# """ -# Unittests for the function _return_argline -# """ - -# # Set up population -# test_pop = Population() -# test_pop.set(metallicity=0.02) -# test_pop.set(M_1=10) - -# argline = test_pop._return_argline() -# self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") - -# # custom dict -# argline2 = test_pop._return_argline( -# {"example_parameter1": 10, "example_parameter2": "hello"} -# ) -# self.assertTrue( -# argline2 == "binary_c example_parameter1 10 example_parameter2 hello" -# ) - -# def test_add_grid_variable(self): -# """ -# Unittests for the function add_grid_variable - -# TODO: Should I test more here? -# """ - -# test_pop = Population() - -# resolution = {"M_1": 10, "q": 10} - -# test_pop.add_grid_variable( -# name="lnm1", -# longname="Primary mass", -# valuerange=[1, 100], -# resolution="{}".format(resolution["M_1"]), -# spacingfunc="const(math.log(1), math.log(100), {})".format( -# resolution["M_1"] -# ), -# precode="M_1=math.exp(lnm1)", -# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", -# dphasevol="dlnm1", -# parameter_name="M_1", -# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself -# ) - -# test_pop.add_grid_variable( -# name="q", -# longname="Mass ratio", -# valuerange=["0.1/M_1", 1], -# resolution="{}".format(resolution["q"]), -# spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), -# probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", -# dphasevol="dq", -# precode="M_2 = q * M_1", -# parameter_name="M_2", -# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself -# ) - -# self.assertIn("q", test_pop.grid_options["_grid_variables"]) -# self.assertIn("lnm1", test_pop.grid_options["_grid_variables"]) -# self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) - -# def test_return_population_settings(self): -# """ -# Unittests for the function return_population_settings -# """ - -# test_pop = Population() -# test_pop.set(metallicity=0.02) -# test_pop.set(M_1=10) -# test_pop.set(amt_cores=2) -# test_pop.set(data_dir="/tmp") - -# population_settings = test_pop.return_population_settings() - -# self.assertIn("bse_options", population_settings) -# self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02) -# self.assertTrue(population_settings["bse_options"]["M_1"] == 10) - -# self.assertIn("grid_options", population_settings) -# self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2) - -# self.assertIn("custom_options", population_settings) -# self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") - -# def test__return_binary_c_version_info(self): -# """ -# Unittests for the function _return_binary_c_version_info -# """ - -# test_pop = Population() -# binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True) - -# self.assertTrue(isinstance(binary_c_version_info, dict)) -# self.assertIn("isotopes", binary_c_version_info) -# self.assertIn("argpairs", binary_c_version_info) -# self.assertIn("ensembles", binary_c_version_info) -# self.assertIn("macros", binary_c_version_info) -# self.assertIn("dt_limits", binary_c_version_info) -# self.assertIn("nucleosynthesis_sources", binary_c_version_info) -# self.assertIn("miscellaneous", binary_c_version_info) - -# self.assertIsNotNone(binary_c_version_info["argpairs"]) -# self.assertIsNotNone(binary_c_version_info["ensembles"]) -# self.assertIsNotNone(binary_c_version_info["macros"]) -# self.assertIsNotNone(binary_c_version_info["dt_limits"]) -# self.assertIsNotNone(binary_c_version_info["miscellaneous"]) - -# if binary_c_version_info['miscellaneous']['NUCSYN'] == 'on': -# self.assertIsNotNone(binary_c_version_info["isotopes"]) -# self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) - -# def test__return_binary_c_defaults(self): -# """ -# Unittests for the function _return_binary_c_defaults -# """ - -# test_pop = Population() -# binary_c_defaults = test_pop._return_binary_c_defaults() -# self.assertIn("probability", binary_c_defaults) -# self.assertIn("phasevol", binary_c_defaults) -# self.assertIn("metallicity", binary_c_defaults) - -# def test_return_all_info(self): -# """ -# Unittests for the function return_all_info -# Not going to do too much tests here, just check if they are not empty -# """ - -# test_pop = Population() -# all_info = test_pop.return_all_info() - -# self.assertIn("population_settings", all_info) -# self.assertIn("binary_c_defaults", all_info) -# self.assertIn("binary_c_version_info", all_info) -# self.assertIn("binary_c_help_all", all_info) - -# self.assertNotEqual(all_info["population_settings"], {}) -# self.assertNotEqual(all_info["binary_c_defaults"], {}) -# self.assertNotEqual(all_info["binary_c_version_info"], {}) -# self.assertNotEqual(all_info["binary_c_help_all"], {}) - -# def test_export_all_info(self): -# """ -# Unittests for the function export_all_info -# """ - -# test_pop = Population() - -# test_pop.set(metallicity=0.02) -# test_pop.set(M_1=10) -# test_pop.set(amt_cores=2) -# test_pop.set(data_dir=binary_c_temp_dir) - -# # datadir -# settings_filename = test_pop.export_all_info(use_datadir=True) -# self.assertTrue(os.path.isfile(settings_filename)) -# with open(settings_filename, "r") as f: -# all_info = json.loads(f.read()) - -# # -# self.assertIn("population_settings", all_info) -# self.assertIn("binary_c_defaults", all_info) -# self.assertIn("binary_c_version_info", all_info) -# self.assertIn("binary_c_help_all", all_info) - -# # -# self.assertNotEqual(all_info["population_settings"], {}) -# self.assertNotEqual(all_info["binary_c_defaults"], {}) -# self.assertNotEqual(all_info["binary_c_version_info"], {}) -# self.assertNotEqual(all_info["binary_c_help_all"], {}) - -# # custom name -# # datadir -# settings_filename = test_pop.export_all_info( -# use_datadir=False, -# outfile=os.path.join(binary_c_temp_dir, "example_settings.json"), -# ) -# self.assertTrue(os.path.isfile(settings_filename)) -# with open(settings_filename, "r") as f: -# all_info = json.loads(f.read()) - -# # -# self.assertIn("population_settings", all_info) -# self.assertIn("binary_c_defaults", all_info) -# self.assertIn("binary_c_version_info", all_info) -# self.assertIn("binary_c_help_all", all_info) - -# # -# self.assertNotEqual(all_info["population_settings"], {}) -# self.assertNotEqual(all_info["binary_c_defaults"], {}) -# self.assertNotEqual(all_info["binary_c_version_info"], {}) -# self.assertNotEqual(all_info["binary_c_help_all"], {}) - -# # wrong filename -# self.assertRaises( -# ValueError, -# test_pop.export_all_info, -# use_datadir=False, -# outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"), -# ) - -# def test__cleanup_defaults(self): -# """ -# Unittests for the function _cleanup_defaults -# """ - -# test_pop = Population() -# cleaned_up_defaults = test_pop._cleanup_defaults() -# self.assertNotIn("help_all", cleaned_up_defaults) - -# def test__increment_probtot(self): -# """ -# Unittests for the function _increment_probtot -# """ - -# test_pop = Population() -# test_pop._increment_probtot(0.5) -# self.assertEqual(test_pop.grid_options["_probtot"], 0.5) - -# def test__increment_count(self): -# """ -# Unittests for the function _increment_probtot -# """ - -# test_pop = Population() -# test_pop._increment_count() -# self.assertEqual(test_pop.grid_options["_count"], 1) - -# def test__dict_from_line_source_file(self): -# """ -# Unittests for the function _dict_from_line_source_file -# """ - -# source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt") - -# # write -# with open(source_file, "w") as f: -# f.write("binary_c M_1 10 metallicity 0.02\n") - -# test_pop = Population() - -# # readout -# with open(source_file, "r") as f: -# for line in f.readlines(): -# argdict = test_pop._dict_from_line_source_file(line) - -# self.assertTrue(argdict["M_1"] == 10) -# self.assertTrue(argdict["metallicity"] == 0.02) - -# def test_evolve_single(self): -# """ -# Unittests for the function evolve_single -# """ - -# CUSTOM_LOGGING_STRING_MASSES = """ -# Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", -# // -# stardata->model.time, // 1 + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + def test_export_all_info(self): + """ + Unittests for the function export_all_info + """ + + test_pop = Population() + + test_pop.set(metallicity=0.02) + test_pop.set(M_1=10) + test_pop.set(amt_cores=2) + test_pop.set(data_dir=binary_c_temp_dir) + + # datadir + settings_filename = test_pop.export_all_info(use_datadir=True) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # custom name + # datadir + settings_filename = test_pop.export_all_info( + use_datadir=False, + outfile=os.path.join(binary_c_temp_dir, "example_settings.json"), + ) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # wrong filename + self.assertRaises( + ValueError, + test_pop.export_all_info, + use_datadir=False, + outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"), + ) + + def test__cleanup_defaults(self): + """ + Unittests for the function _cleanup_defaults + """ + + test_pop = Population() + cleaned_up_defaults = test_pop._cleanup_defaults() + self.assertNotIn("help_all", cleaned_up_defaults) + + def test__increment_probtot(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_probtot(0.5) + self.assertEqual(test_pop.grid_options["_probtot"], 0.5) + + def test__increment_count(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_count() + self.assertEqual(test_pop.grid_options["_count"], 1) + + def test__dict_from_line_source_file(self): + """ + Unittests for the function _dict_from_line_source_file + """ + + source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt") + + # write + with open(source_file, "w") as f: + f.write("binary_c M_1 10 metallicity 0.02\n") + + test_pop = Population() + + # readout + with open(source_file, "r") as f: + for line in f.readlines(): + argdict = test_pop._dict_from_line_source_file(line) + + self.assertTrue(argdict["M_1"] == 10) + self.assertTrue(argdict["metallicity"] == 0.02) + + def test_evolve_single(self): + """ + Unittests for the function evolve_single + """ + + CUSTOM_LOGGING_STRING_MASSES = """ + Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", + // + stardata->model.time, // 1 -# // masses -# stardata->common.zero_age.mass[0], // -# stardata->common.zero_age.mass[1], // + // masses + stardata->common.zero_age.mass[0], // + stardata->common.zero_age.mass[1], // -# stardata->star[0].mass, -# stardata->star[1].mass -# ); -# """ + stardata->star[0].mass, + stardata->star[1].mass + ); + """ -# test_pop = Population() -# test_pop.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) + test_pop = Population() + test_pop.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) -# test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES) + test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES) -# output = test_pop.evolve_single() + output = test_pop.evolve_single() -# # -# self.assertTrue(len(output.splitlines())>1) -# self.assertIn('TEST_CUSTOM_LOGGING_1', output) + # + self.assertTrue(len(output.splitlines())>1) + self.assertIn('TEST_CUSTOM_LOGGING_1', output) -# # -# custom_logging_dict = { -# 'TEST_CUSTOM_LOGGING_2': ['star[0].mass', 'model.time'] -# } -# test_pop_2 = Population() -# test_pop_2.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) + # + custom_logging_dict = { + 'TEST_CUSTOM_LOGGING_2': ['star[0].mass', 'model.time'] + } + test_pop_2 = Population() + test_pop_2.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) -# test_pop_2.set(C_auto_logging=custom_logging_dict) + test_pop_2.set(C_auto_logging=custom_logging_dict) -# output_2 = test_pop_2.evolve_single() + output_2 = test_pop_2.evolve_single() -# # -# self.assertTrue(len(output_2.splitlines())>1) -# self.assertIn('TEST_CUSTOM_LOGGING_2', output_2) + # + self.assertTrue(len(output_2.splitlines())>1) + self.assertIn('TEST_CUSTOM_LOGGING_2', output_2) class test_grid_evolve(unittest.TestCase): """ Unittests for function Population.evolve() """ - # def test_grid_evolve_1_thread(self): - # # test to see if 1 thread does all the systems - - # test_pop_evolve_1_thread = Population() - # test_pop_evolve_1_thread.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) - - # resolution = {"M_1": 10} - - # test_pop_evolve_1_thread.add_grid_variable( - # name="lnm1", - # longname="Primary mass", - # valuerange=[1, 100], - # resolution="{}".format(resolution["M_1"]), - # spacingfunc="const(math.log(1), math.log(100), {})".format( - # resolution["M_1"] - # ), - # precode="M_1=math.exp(lnm1)", - # probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - # dphasevol="dlnm1", - # parameter_name="M_1", - # condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - # ) - - # analytics = test_pop_evolve_1_thread.evolve() - # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) - # self.assertTrue(analytics['total_count']==10) - - # def test_grid_evolve_2_threads(self): - # # test to see if 1 thread does all the systems - - # test_pop = Population() - # test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) - - # resolution = {"M_1": 10} - - # test_pop.add_grid_variable( - # name="lnm1", - # longname="Primary mass", - # valuerange=[1, 100], - # resolution="{}".format(resolution["M_1"]), - # spacingfunc="const(math.log(1), math.log(100), {})".format( - # resolution["M_1"] - # ), - # precode="M_1=math.exp(lnm1)", - # probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - # dphasevol="dlnm1", - # parameter_name="M_1", - # condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - # ) - - # analytics = test_pop.evolve() - # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # - # self.assertTrue(analytics['total_count']==10) - - # def test_grid_evolve_2_threads_with_custom_logging(self): - # # test to see if 1 thread does all the systems - - # test_pop = Population() - # test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) - - # resolution = {"M_1": 10} - - # test_pop.add_grid_variable( - # name="lnm1", - # longname="Primary mass", - # valuerange=[1, 100], - # resolution="{}".format(resolution["M_1"]), - # spacingfunc="const(math.log(1), math.log(100), {})".format( - # resolution["M_1"] - # ), - # precode="M_1=math.exp(lnm1)", - # probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - # dphasevol="dlnm1", - # parameter_name="M_1", - # condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - # ) - - # analytics = test_pop.evolve() - # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # - # self.assertTrue(analytics['total_count']==10) - - # def test_grid_evolve_with_condition_error(self): - # # Test to see if we can catch the errors correctly. - - # test_pop = Population() - # test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) - - # # Set the amt of failed systems that each thread will log - # test_pop.set(failed_systems_threshold=4) - - # CUSTOM_LOGGING_STRING_WITH_EXIT = """ - # Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits"); - # Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", - # // - # stardata->model.time, // 1 + def test_grid_evolve_1_thread(self): + # test to see if 1 thread does all the systems + + test_pop_evolve_1_thread = Population() + test_pop_evolve_1_thread.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) + + resolution = {"M_1": 10} + + test_pop_evolve_1_thread.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop_evolve_1_thread.evolve() + self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) + self.assertTrue(analytics['total_count']==10) + + def test_grid_evolve_2_threads(self): + # test to see if 1 thread does all the systems + + test_pop = Population() + test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) + + resolution = {"M_1": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # + self.assertTrue(analytics['total_count']==10) + + def test_grid_evolve_1_threads_with_custom_logging(self): + # test to see if 1 thread does all the systems + + data_dir_value = os.path.join(binary_c_temp_dir, 'grid_tests') + amt_cores_value = 2 + custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));' + + test_pop = Population() + + test_pop.set(amt_cores=amt_cores_value, + verbosity=1, + M_2=1, + orbital_period=100000, + data_dir=data_dir_value, + C_logging_code=custom_logging_string, # input it like this. + parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging) + test_pop.set(ensemble=0) + resolution = {"M_1": 2} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + output_names = [os.path.join(data_dir_value, "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(analytics['population_name'], thread_id)) for thread_id in range(amt_cores_value)] + + for output_name in output_names: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, 'r') as f: + output_string = f.read() + + self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string) + + remove_file(output_name) + + def test_grid_evolve_with_condition_error(self): + # Test to see if we can catch the errors correctly. + + test_pop = Population() + test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) + + # Set the amt of failed systems that each thread will log + test_pop.set(failed_systems_threshold=4) + + CUSTOM_LOGGING_STRING_WITH_EXIT = """ + Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits"); + Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", + // + stardata->model.time, // 1 - # // masses - # stardata->common.zero_age.mass[0], // - # stardata->common.zero_age.mass[1], // - - # stardata->star[0].mass, - # stardata->star[1].mass - # ); - # """ - - # test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) - - # resolution = {"M_1": 10} - # test_pop.add_grid_variable( - # name="lnm1", - # longname="Primary mass", - # valuerange=[1, 100], - # resolution="{}".format(resolution["M_1"]), - # spacingfunc="const(math.log(1), math.log(100), {})".format( - # resolution["M_1"] - # ), - # precode="M_1=math.exp(lnm1)", - # probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - # dphasevol="dlnm1", - # parameter_name="M_1", - # condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - # ) - - # analytics = test_pop.evolve() - # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # - # self.assertLess(np.abs(analytics['failed_prob']-0.1503788456014623), 1e-10) # - # self.assertEqual(analytics['failed_systems_error_codes'], [0]) - # self.assertTrue(analytics['total_count']==10) - # self.assertTrue(analytics['failed_count']==10) - # self.assertTrue(analytics['errors_found']==True) - # self.assertTrue(analytics['errors_exceeded']==True) - - # # test to see if 1 thread does all the systems - - # test_pop = Population() - # test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) - - # resolution = {"M_1": 10, "q": 2} - - # test_pop.add_grid_variable( - # name="lnm1", - # longname="Primary mass", - # valuerange=[1, 100], - # resolution="{}".format(resolution["M_1"]), - # spacingfunc="const(math.log(1), math.log(100), {})".format( - # resolution["M_1"] - # ), - # precode="M_1=math.exp(lnm1)", - # probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - # dphasevol="dlnm1", - # parameter_name="M_1", - # condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - # ) - - # test_pop.add_grid_variable( - # name="q", - # longname="Mass ratio", - # valuerange=["0.1/M_1", 1], - # resolution="{}".format(resolution["q"]), - # spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), - # probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", - # dphasevol="dq", - # precode="M_2 = q * M_1", - # parameter_name="M_2", - # # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself - # condition="'random_var' in dir()", # This will raise an error because random_var is not defined. - # ) - - # self.assertRaises(ValueError, test_pop.evolve) - - # def test_grid_evolve_no_grid_variables(self): - # # test to see if 1 thread does all the systems - - # test_pop = Population() - # test_pop.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) - - # resolution = {"M_1": 10} - # self.assertRaises(ValueError, test_pop.evolve) - - def test_grid_evolve_2_threads_with_ensemble(self): + // masses + stardata->common.zero_age.mass[0], // + stardata->common.zero_age.mass[1], // + + stardata->star[0].mass, + stardata->star[1].mass + ); + """ + + test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) + + resolution = {"M_1": 10} + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # + self.assertLess(np.abs(analytics['failed_prob']-0.1503788456014623), 1e-10) # + self.assertEqual(analytics['failed_systems_error_codes'], [0]) + self.assertTrue(analytics['total_count']==10) + self.assertTrue(analytics['failed_count']==10) + self.assertTrue(analytics['errors_found']==True) + self.assertTrue(analytics['errors_exceeded']==True) + + # test to see if 1 thread does all the systems + + test_pop = Population() + test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) + + resolution = {"M_1": 10, "q": 2} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + test_pop.add_grid_variable( + name="q", + longname="Mass ratio", + valuerange=["0.1/M_1", 1], + resolution="{}".format(resolution["q"]), + spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), + probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", + dphasevol="dq", + precode="M_2 = q * M_1", + parameter_name="M_2", + # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself + condition="'random_var' in dir()", # This will raise an error because random_var is not defined. + ) + + self.assertRaises(ValueError, test_pop.evolve) + + def test_grid_evolve_no_grid_variables(self): + # test to see if 1 thread does all the systems + + test_pop = Population() + test_pop.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) + + resolution = {"M_1": 10} + self.assertRaises(ValueError, test_pop.evolve) + + def test_grid_evolve_2_threads_with_ensemble_direct_output(self): # test to see if 1 thread does all the systems + data_dir_value = binary_c_temp_dir + amt_cores_value = 2 + test_pop = Population() - test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) + test_pop.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) test_pop.set(data_dir=binary_c_temp_dir, ensemble_output_name="ensemble_output.json", combine_ensemble_with_thread_joining=False) resolution = {"M_1": 10} @@ -609,11 +653,133 @@ class test_grid_evolve(unittest.TestCase): ) analytics = test_pop.evolve() + output_names = [os.path.join(data_dir_value, "ensemble_output_{}_{}.json".format(analytics['population_name'], thread_id)) for thread_id in range(amt_cores_value)] + + for output_name in output_names: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, 'r') as f: + file_content = f.read() + + self.assertTrue(file_content.startswith("ENSEMBLE_JSON")) + + ensemble_json = extract_ensemble_json_from_string(file_content) + + self.assertTrue(isinstance(ensemble_json, dict)) + self.assertNotEqual(ensemble_json, {}) + + self.assertIn("number_counts", ensemble_json) + self.assertNotEqual(ensemble_json["number_counts"], {}) + + def test_grid_evolve_2_threads_with_ensemble_combining(self): + # test to see if 1 thread does all the systems + + data_dir_value = binary_c_temp_dir + amt_cores_value = 2 + + test_pop = Population() + test_pop.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) + test_pop.set(data_dir=binary_c_temp_dir, combine_ensemble_with_thread_joining=True, ensemble_output_name="ensemble_output.json") + + resolution = {"M_1": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + + self.assertTrue(isinstance(test_pop.grid_options['ensemble_results'], dict)) + self.assertNotEqual(test_pop.grid_options['ensemble_results'], {}) + + self.assertIn("number_counts", test_pop.grid_options['ensemble_results']) + self.assertNotEqual(test_pop.grid_options['ensemble_results']["number_counts"], {}) + + def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): + # test to see if 1 thread does all the systems + + data_dir_value = binary_c_temp_dir + amt_cores_value = 2 + + # First + test_pop_1 = Population() + test_pop_1.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) + test_pop_1.set(data_dir=binary_c_temp_dir, combine_ensemble_with_thread_joining=True, ensemble_output_name="ensemble_output.json") + + resolution = {"M_1": 10} + + test_pop_1.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics_1 = test_pop_1.evolve() + ensemble_output_1 = test_pop_1.grid_options['ensemble_results'] + + # second + test_pop_2 = Population() + test_pop_2.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) + test_pop_2.set(data_dir=binary_c_temp_dir, ensemble_output_name="ensemble_output.json", combine_ensemble_with_thread_joining=False) + + resolution = {"M_1": 10} + + test_pop_2.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics_2 = test_pop_2.evolve() + output_names_2 = [os.path.join(data_dir_value, "ensemble_output_{}_{}.json".format(analytics_2['population_name'], thread_id)) for thread_id in range(amt_cores_value)] + ensemble_output_2 = {} + + for output_name in output_names_2: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, 'r') as f: + file_content = f.read() + + self.assertTrue(file_content.startswith("ENSEMBLE_JSON")) + + ensemble_json = extract_ensemble_json_from_string(file_content) + ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json) - # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # - # self.assertTrue(analytics['total_count']==10) + for key in ensemble_output_1['number_counts']['stellar_type']['0']: + self.assertIn(key, ensemble_output_2['number_counts']['stellar_type']['0']) + # compare values + self.assertLess(np.abs(ensemble_output_1['number_counts']['stellar_type']['0'][key]-ensemble_output_2['number_counts']['stellar_type']['0'][key]), 1e-8) if __name__ == "__main__": unittest.main() -- GitLab