Skip to content
Snippets Groups Projects
Commit 56aef52e authored by David Hendriks's avatar David Hendriks
Browse files

added tests for binarycpython grid

parent ba10bbc2
No related branches found
No related tags found
No related merge requests found
...@@ -14,11 +14,32 @@ import datetime ...@@ -14,11 +14,32 @@ import datetime
import numpy as np import numpy as np
from binarycpython.utils.grid import Population from binarycpython.utils.grid import Population
from binarycpython.utils.functions import temp_dir from binarycpython.utils.functions import temp_dir, extract_ensemble_json_from_string, merge_dicts, remove_file
from binarycpython.utils.custom_logging_functions import binary_c_log_code from binarycpython.utils.custom_logging_functions import binary_c_log_code
binary_c_temp_dir = temp_dir() binary_c_temp_dir = temp_dir()
def parse_function_test_grid_evolve_2_threads_with_custom_logging(self, output):
"""
Simple parse function that directly appends all the output to a file
"""
# Get some information from the
data_dir = self.custom_options['data_dir']
# make outputfilename
output_filename = os.path.join(data_dir, "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(self.grid_options['_population_id'], self.process_ID))
# Check directory, make if necessary
os.makedirs(data_dir, exist_ok=True)
if not os.path.exists(output_filename):
with open(output_filename, 'w') as first_f:
first_f.write(output+'\n')
else:
with open(output_filename, 'a') as first_f:
first_f.write(output+'\n')
# class test_(unittest.TestCase): # class test_(unittest.TestCase):
# """ # """
...@@ -34,561 +55,584 @@ binary_c_temp_dir = temp_dir() ...@@ -34,561 +55,584 @@ binary_c_temp_dir = temp_dir()
# """ # """
# class test_Population(unittest.TestCase): class test_Population(unittest.TestCase):
# """ """
# Unittests for function Unittests for function
# """ """
def test_setup(self):
test_pop = Population()
self.assertTrue("orbital_period" in test_pop.defaults)
self.assertTrue("metallicity" in test_pop.defaults)
self.assertNotIn("help_all", test_pop.cleaned_up_defaults)
self.assertEqual(test_pop.bse_options, {})
self.assertEqual(test_pop.custom_options, {})
self.assertEqual(test_pop.argline_dict, {})
self.assertEqual(test_pop.persistent_data_memory_dict, {})
self.assertTrue(test_pop.grid_options["parse_function"] == None)
self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int))
def test_set(self):
test_pop = Population()
test_pop.set(amt_cores=2)
test_pop.set(M_1=10)
test_pop.set(data_dir="/tmp/binary_c_python")
test_pop.set(ensemble_filter_SUPERNOVAE=1)
self.assertIn("data_dir", test_pop.custom_options)
self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python")
#
self.assertTrue(test_pop.bse_options["M_1"] == 10)
self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1)
#
self.assertTrue(test_pop.grid_options["amt_cores"] == 2)
def test_cmdline(self):
# copy old sys.argv values
prev_sysargv = sys.argv.copy()
# make a dummy cmdline arg input
sys.argv = [
"script",
"--cmdline",
"metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python",
]
# Set up population
test_pop = Population()
test_pop.set(data_dir="/tmp")
# parse arguments
test_pop.parse_cmdline()
# metallicity
self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str))
self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002")
# Amt cores
self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int))
self.assertTrue(test_pop.grid_options["amt_cores"] == 2)
# datadir
self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str))
self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python")
# put back the other args if they exist
sys.argv = prev_sysargv.copy()
def test__return_argline(self):
"""
Unittests for the function _return_argline
"""
# Set up population
test_pop = Population()
test_pop.set(metallicity=0.02)
test_pop.set(M_1=10)
argline = test_pop._return_argline()
self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02")
# custom dict
argline2 = test_pop._return_argline(
{"example_parameter1": 10, "example_parameter2": "hello"}
)
self.assertTrue(
argline2 == "binary_c example_parameter1 10 example_parameter2 hello"
)
def test_add_grid_variable(self):
"""
Unittests for the function add_grid_variable
TODO: Should I test more here?
"""
test_pop = Population()
resolution = {"M_1": 10, "q": 10}
test_pop.add_grid_variable(
name="lnm1",
longname="Primary mass",
valuerange=[1, 100],
resolution="{}".format(resolution["M_1"]),
spacingfunc="const(math.log(1), math.log(100), {})".format(
resolution["M_1"]
),
precode="M_1=math.exp(lnm1)",
probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
parameter_name="M_1",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
test_pop.add_grid_variable(
name="q",
longname="Mass ratio",
valuerange=["0.1/M_1", 1],
resolution="{}".format(resolution["q"]),
spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]),
probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
dphasevol="dq",
precode="M_2 = q * M_1",
parameter_name="M_2",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
self.assertIn("q", test_pop.grid_options["_grid_variables"])
self.assertIn("lnm1", test_pop.grid_options["_grid_variables"])
self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2)
def test_return_population_settings(self):
"""
Unittests for the function return_population_settings
"""
test_pop = Population()
test_pop.set(metallicity=0.02)
test_pop.set(M_1=10)
test_pop.set(amt_cores=2)
test_pop.set(data_dir="/tmp")
population_settings = test_pop.return_population_settings()
self.assertIn("bse_options", population_settings)
self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02)
self.assertTrue(population_settings["bse_options"]["M_1"] == 10)
self.assertIn("grid_options", population_settings)
self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2)
self.assertIn("custom_options", population_settings)
self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp")
def test__return_binary_c_version_info(self):
"""
Unittests for the function _return_binary_c_version_info
"""
test_pop = Population()
binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True)
self.assertTrue(isinstance(binary_c_version_info, dict))
self.assertIn("isotopes", binary_c_version_info)
self.assertIn("argpairs", binary_c_version_info)
self.assertIn("ensembles", binary_c_version_info)
self.assertIn("macros", binary_c_version_info)
self.assertIn("dt_limits", binary_c_version_info)
self.assertIn("nucleosynthesis_sources", binary_c_version_info)
self.assertIn("miscellaneous", binary_c_version_info)
self.assertIsNotNone(binary_c_version_info["argpairs"])
self.assertIsNotNone(binary_c_version_info["ensembles"])
self.assertIsNotNone(binary_c_version_info["macros"])
self.assertIsNotNone(binary_c_version_info["dt_limits"])
self.assertIsNotNone(binary_c_version_info["miscellaneous"])
if binary_c_version_info['miscellaneous']['NUCSYN'] == 'on':
self.assertIsNotNone(binary_c_version_info["isotopes"])
self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"])
def test__return_binary_c_defaults(self):
"""
Unittests for the function _return_binary_c_defaults
"""
test_pop = Population()
binary_c_defaults = test_pop._return_binary_c_defaults()
self.assertIn("probability", binary_c_defaults)
self.assertIn("phasevol", binary_c_defaults)
self.assertIn("metallicity", binary_c_defaults)
def test_return_all_info(self):
"""
Unittests for the function return_all_info
Not going to do too much tests here, just check if they are not empty
"""
test_pop = Population()
all_info = test_pop.return_all_info()
self.assertIn("population_settings", all_info)
self.assertIn("binary_c_defaults", all_info)
self.assertIn("binary_c_version_info", all_info)
self.assertIn("binary_c_help_all", all_info)
# def test_setup(self): self.assertNotEqual(all_info["population_settings"], {})
# test_pop = Population() self.assertNotEqual(all_info["binary_c_defaults"], {})
self.assertNotEqual(all_info["binary_c_version_info"], {})
# self.assertTrue("orbital_period" in test_pop.defaults) self.assertNotEqual(all_info["binary_c_help_all"], {})
# self.assertTrue("metallicity" in test_pop.defaults)
# self.assertNotIn("help_all", test_pop.cleaned_up_defaults) def test_export_all_info(self):
# self.assertEqual(test_pop.bse_options, {}) """
# self.assertEqual(test_pop.custom_options, {}) Unittests for the function export_all_info
# self.assertEqual(test_pop.argline_dict, {}) """
# self.assertEqual(test_pop.persistent_data_memory_dict, {})
# self.assertTrue(test_pop.grid_options["parse_function"] == None) test_pop = Population()
# self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int))
test_pop.set(metallicity=0.02)
# def test_set(self): test_pop.set(M_1=10)
# test_pop = Population() test_pop.set(amt_cores=2)
# test_pop.set(amt_cores=2) test_pop.set(data_dir=binary_c_temp_dir)
# test_pop.set(M_1=10)
# test_pop.set(data_dir="/tmp/binary_c_python") # datadir
# test_pop.set(ensemble_filter_SUPERNOVAE=1) settings_filename = test_pop.export_all_info(use_datadir=True)
self.assertTrue(os.path.isfile(settings_filename))
# self.assertIn("data_dir", test_pop.custom_options) with open(settings_filename, "r") as f:
# self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") all_info = json.loads(f.read())
# # #
# self.assertTrue(test_pop.bse_options["M_1"] == 10) self.assertIn("population_settings", all_info)
# self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) self.assertIn("binary_c_defaults", all_info)
self.assertIn("binary_c_version_info", all_info)
# # self.assertIn("binary_c_help_all", all_info)
# self.assertTrue(test_pop.grid_options["amt_cores"] == 2)
#
# def test_cmdline(self): self.assertNotEqual(all_info["population_settings"], {})
# # copy old sys.argv values self.assertNotEqual(all_info["binary_c_defaults"], {})
# prev_sysargv = sys.argv.copy() self.assertNotEqual(all_info["binary_c_version_info"], {})
self.assertNotEqual(all_info["binary_c_help_all"], {})
# # make a dummy cmdline arg input
# sys.argv = [ # custom name
# "script", # datadir
# "--cmdline", settings_filename = test_pop.export_all_info(
# "metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python", use_datadir=False,
# ] outfile=os.path.join(binary_c_temp_dir, "example_settings.json"),
)
# # Set up population self.assertTrue(os.path.isfile(settings_filename))
# test_pop = Population() with open(settings_filename, "r") as f:
# test_pop.set(data_dir="/tmp") all_info = json.loads(f.read())
# # parse arguments #
# test_pop.parse_cmdline() self.assertIn("population_settings", all_info)
self.assertIn("binary_c_defaults", all_info)
# # metallicity self.assertIn("binary_c_version_info", all_info)
# self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) self.assertIn("binary_c_help_all", all_info)
# self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002")
#
# # Amt cores self.assertNotEqual(all_info["population_settings"], {})
# self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int)) self.assertNotEqual(all_info["binary_c_defaults"], {})
# self.assertTrue(test_pop.grid_options["amt_cores"] == 2) self.assertNotEqual(all_info["binary_c_version_info"], {})
self.assertNotEqual(all_info["binary_c_help_all"], {})
# # datadir
# self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) # wrong filename
# self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") self.assertRaises(
ValueError,
# # put back the other args if they exist test_pop.export_all_info,
# sys.argv = prev_sysargv.copy() use_datadir=False,
outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"),
# def test__return_argline(self): )
# """
# Unittests for the function _return_argline def test__cleanup_defaults(self):
# """ """
Unittests for the function _cleanup_defaults
# # Set up population """
# test_pop = Population()
# test_pop.set(metallicity=0.02) test_pop = Population()
# test_pop.set(M_1=10) cleaned_up_defaults = test_pop._cleanup_defaults()
self.assertNotIn("help_all", cleaned_up_defaults)
# argline = test_pop._return_argline()
# self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") def test__increment_probtot(self):
"""
# # custom dict Unittests for the function _increment_probtot
# argline2 = test_pop._return_argline( """
# {"example_parameter1": 10, "example_parameter2": "hello"}
# ) test_pop = Population()
# self.assertTrue( test_pop._increment_probtot(0.5)
# argline2 == "binary_c example_parameter1 10 example_parameter2 hello" self.assertEqual(test_pop.grid_options["_probtot"], 0.5)
# )
def test__increment_count(self):
# def test_add_grid_variable(self): """
# """ Unittests for the function _increment_probtot
# Unittests for the function add_grid_variable """
# TODO: Should I test more here? test_pop = Population()
# """ test_pop._increment_count()
self.assertEqual(test_pop.grid_options["_count"], 1)
# test_pop = Population()
def test__dict_from_line_source_file(self):
# resolution = {"M_1": 10, "q": 10} """
Unittests for the function _dict_from_line_source_file
# test_pop.add_grid_variable( """
# name="lnm1",
# longname="Primary mass", source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt")
# valuerange=[1, 100],
# resolution="{}".format(resolution["M_1"]), # write
# spacingfunc="const(math.log(1), math.log(100), {})".format( with open(source_file, "w") as f:
# resolution["M_1"] f.write("binary_c M_1 10 metallicity 0.02\n")
# ),
# precode="M_1=math.exp(lnm1)", test_pop = Population()
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
# dphasevol="dlnm1", # readout
# parameter_name="M_1", with open(source_file, "r") as f:
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself for line in f.readlines():
# ) argdict = test_pop._dict_from_line_source_file(line)
# test_pop.add_grid_variable( self.assertTrue(argdict["M_1"] == 10)
# name="q", self.assertTrue(argdict["metallicity"] == 0.02)
# longname="Mass ratio",
# valuerange=["0.1/M_1", 1], def test_evolve_single(self):
# resolution="{}".format(resolution["q"]), """
# spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), Unittests for the function evolve_single
# probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", """
# dphasevol="dq",
# precode="M_2 = q * M_1", CUSTOM_LOGGING_STRING_MASSES = """
# parameter_name="M_2", Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself //
# ) stardata->model.time, // 1
# self.assertIn("q", test_pop.grid_options["_grid_variables"])
# self.assertIn("lnm1", test_pop.grid_options["_grid_variables"])
# self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2)
# def test_return_population_settings(self):
# """
# Unittests for the function return_population_settings
# """
# test_pop = Population()
# test_pop.set(metallicity=0.02)
# test_pop.set(M_1=10)
# test_pop.set(amt_cores=2)
# test_pop.set(data_dir="/tmp")
# population_settings = test_pop.return_population_settings()
# self.assertIn("bse_options", population_settings)
# self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02)
# self.assertTrue(population_settings["bse_options"]["M_1"] == 10)
# self.assertIn("grid_options", population_settings)
# self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2)
# self.assertIn("custom_options", population_settings)
# self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp")
# def test__return_binary_c_version_info(self):
# """
# Unittests for the function _return_binary_c_version_info
# """
# test_pop = Population()
# binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True)
# self.assertTrue(isinstance(binary_c_version_info, dict))
# self.assertIn("isotopes", binary_c_version_info)
# self.assertIn("argpairs", binary_c_version_info)
# self.assertIn("ensembles", binary_c_version_info)
# self.assertIn("macros", binary_c_version_info)
# self.assertIn("dt_limits", binary_c_version_info)
# self.assertIn("nucleosynthesis_sources", binary_c_version_info)
# self.assertIn("miscellaneous", binary_c_version_info)
# self.assertIsNotNone(binary_c_version_info["argpairs"])
# self.assertIsNotNone(binary_c_version_info["ensembles"])
# self.assertIsNotNone(binary_c_version_info["macros"])
# self.assertIsNotNone(binary_c_version_info["dt_limits"])
# self.assertIsNotNone(binary_c_version_info["miscellaneous"])
# if binary_c_version_info['miscellaneous']['NUCSYN'] == 'on':
# self.assertIsNotNone(binary_c_version_info["isotopes"])
# self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"])
# def test__return_binary_c_defaults(self):
# """
# Unittests for the function _return_binary_c_defaults
# """
# test_pop = Population()
# binary_c_defaults = test_pop._return_binary_c_defaults()
# self.assertIn("probability", binary_c_defaults)
# self.assertIn("phasevol", binary_c_defaults)
# self.assertIn("metallicity", binary_c_defaults)
# def test_return_all_info(self):
# """
# Unittests for the function return_all_info
# Not going to do too much tests here, just check if they are not empty
# """
# test_pop = Population()
# all_info = test_pop.return_all_info()
# self.assertIn("population_settings", all_info)
# self.assertIn("binary_c_defaults", all_info)
# self.assertIn("binary_c_version_info", all_info)
# self.assertIn("binary_c_help_all", all_info)
# self.assertNotEqual(all_info["population_settings"], {})
# self.assertNotEqual(all_info["binary_c_defaults"], {})
# self.assertNotEqual(all_info["binary_c_version_info"], {})
# self.assertNotEqual(all_info["binary_c_help_all"], {})
# def test_export_all_info(self):
# """
# Unittests for the function export_all_info
# """
# test_pop = Population()
# test_pop.set(metallicity=0.02)
# test_pop.set(M_1=10)
# test_pop.set(amt_cores=2)
# test_pop.set(data_dir=binary_c_temp_dir)
# # datadir
# settings_filename = test_pop.export_all_info(use_datadir=True)
# self.assertTrue(os.path.isfile(settings_filename))
# with open(settings_filename, "r") as f:
# all_info = json.loads(f.read())
# #
# self.assertIn("population_settings", all_info)
# self.assertIn("binary_c_defaults", all_info)
# self.assertIn("binary_c_version_info", all_info)
# self.assertIn("binary_c_help_all", all_info)
# #
# self.assertNotEqual(all_info["population_settings"], {})
# self.assertNotEqual(all_info["binary_c_defaults"], {})
# self.assertNotEqual(all_info["binary_c_version_info"], {})
# self.assertNotEqual(all_info["binary_c_help_all"], {})
# # custom name
# # datadir
# settings_filename = test_pop.export_all_info(
# use_datadir=False,
# outfile=os.path.join(binary_c_temp_dir, "example_settings.json"),
# )
# self.assertTrue(os.path.isfile(settings_filename))
# with open(settings_filename, "r") as f:
# all_info = json.loads(f.read())
# #
# self.assertIn("population_settings", all_info)
# self.assertIn("binary_c_defaults", all_info)
# self.assertIn("binary_c_version_info", all_info)
# self.assertIn("binary_c_help_all", all_info)
# #
# self.assertNotEqual(all_info["population_settings"], {})
# self.assertNotEqual(all_info["binary_c_defaults"], {})
# self.assertNotEqual(all_info["binary_c_version_info"], {})
# self.assertNotEqual(all_info["binary_c_help_all"], {})
# # wrong filename
# self.assertRaises(
# ValueError,
# test_pop.export_all_info,
# use_datadir=False,
# outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"),
# )
# def test__cleanup_defaults(self):
# """
# Unittests for the function _cleanup_defaults
# """
# test_pop = Population()
# cleaned_up_defaults = test_pop._cleanup_defaults()
# self.assertNotIn("help_all", cleaned_up_defaults)
# def test__increment_probtot(self):
# """
# Unittests for the function _increment_probtot
# """
# test_pop = Population()
# test_pop._increment_probtot(0.5)
# self.assertEqual(test_pop.grid_options["_probtot"], 0.5)
# def test__increment_count(self):
# """
# Unittests for the function _increment_probtot
# """
# test_pop = Population()
# test_pop._increment_count()
# self.assertEqual(test_pop.grid_options["_count"], 1)
# def test__dict_from_line_source_file(self):
# """
# Unittests for the function _dict_from_line_source_file
# """
# source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt")
# # write
# with open(source_file, "w") as f:
# f.write("binary_c M_1 10 metallicity 0.02\n")
# test_pop = Population()
# # readout
# with open(source_file, "r") as f:
# for line in f.readlines():
# argdict = test_pop._dict_from_line_source_file(line)
# self.assertTrue(argdict["M_1"] == 10)
# self.assertTrue(argdict["metallicity"] == 0.02)
# def test_evolve_single(self):
# """
# Unittests for the function evolve_single
# """
# CUSTOM_LOGGING_STRING_MASSES = """
# Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
# //
# stardata->model.time, // 1
# // masses // masses
# stardata->common.zero_age.mass[0], // stardata->common.zero_age.mass[0], //
# stardata->common.zero_age.mass[1], // stardata->common.zero_age.mass[1], //
# stardata->star[0].mass, stardata->star[0].mass,
# stardata->star[1].mass stardata->star[1].mass
# ); );
# """ """
# test_pop = Population() test_pop = Population()
# test_pop.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) test_pop.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000)
# test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES) test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES)
# output = test_pop.evolve_single() output = test_pop.evolve_single()
# # #
# self.assertTrue(len(output.splitlines())>1) self.assertTrue(len(output.splitlines())>1)
# self.assertIn('TEST_CUSTOM_LOGGING_1', output) self.assertIn('TEST_CUSTOM_LOGGING_1', output)
# # #
# custom_logging_dict = { custom_logging_dict = {
# 'TEST_CUSTOM_LOGGING_2': ['star[0].mass', 'model.time'] 'TEST_CUSTOM_LOGGING_2': ['star[0].mass', 'model.time']
# } }
# test_pop_2 = Population() test_pop_2 = Population()
# test_pop_2.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000) test_pop_2.set(M_1=10, M_2=5, orbital_period=100000, metallicty=0.02, max_evolution_time = 15000)
# test_pop_2.set(C_auto_logging=custom_logging_dict) test_pop_2.set(C_auto_logging=custom_logging_dict)
# output_2 = test_pop_2.evolve_single() output_2 = test_pop_2.evolve_single()
# # #
# self.assertTrue(len(output_2.splitlines())>1) self.assertTrue(len(output_2.splitlines())>1)
# self.assertIn('TEST_CUSTOM_LOGGING_2', output_2) self.assertIn('TEST_CUSTOM_LOGGING_2', output_2)
class test_grid_evolve(unittest.TestCase): class test_grid_evolve(unittest.TestCase):
""" """
Unittests for function Population.evolve() Unittests for function Population.evolve()
""" """
# def test_grid_evolve_1_thread(self): def test_grid_evolve_1_thread(self):
# # test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
# test_pop_evolve_1_thread = Population() test_pop_evolve_1_thread = Population()
# test_pop_evolve_1_thread.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) test_pop_evolve_1_thread.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000)
# resolution = {"M_1": 10} resolution = {"M_1": 10}
# test_pop_evolve_1_thread.add_grid_variable( test_pop_evolve_1_thread.add_grid_variable(
# name="lnm1", name="lnm1",
# longname="Primary mass", longname="Primary mass",
# valuerange=[1, 100], valuerange=[1, 100],
# resolution="{}".format(resolution["M_1"]), resolution="{}".format(resolution["M_1"]),
# spacingfunc="const(math.log(1), math.log(100), {})".format( spacingfunc="const(math.log(1), math.log(100), {})".format(
# resolution["M_1"] resolution["M_1"]
# ), ),
# precode="M_1=math.exp(lnm1)", precode="M_1=math.exp(lnm1)",
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
# dphasevol="dlnm1", dphasevol="dlnm1",
# parameter_name="M_1", parameter_name="M_1",
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
# ) )
# analytics = test_pop_evolve_1_thread.evolve() analytics = test_pop_evolve_1_thread.evolve()
# self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10)
# self.assertTrue(analytics['total_count']==10) self.assertTrue(analytics['total_count']==10)
# def test_grid_evolve_2_threads(self): def test_grid_evolve_2_threads(self):
# # test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
# test_pop = Population() test_pop = Population()
# test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000)
# resolution = {"M_1": 10} resolution = {"M_1": 10}
# test_pop.add_grid_variable( test_pop.add_grid_variable(
# name="lnm1", name="lnm1",
# longname="Primary mass", longname="Primary mass",
# valuerange=[1, 100], valuerange=[1, 100],
# resolution="{}".format(resolution["M_1"]), resolution="{}".format(resolution["M_1"]),
# spacingfunc="const(math.log(1), math.log(100), {})".format( spacingfunc="const(math.log(1), math.log(100), {})".format(
# resolution["M_1"] resolution["M_1"]
# ), ),
# precode="M_1=math.exp(lnm1)", precode="M_1=math.exp(lnm1)",
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
# dphasevol="dlnm1", dphasevol="dlnm1",
# parameter_name="M_1", parameter_name="M_1",
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
# ) )
# analytics = test_pop.evolve() analytics = test_pop.evolve()
# self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) #
# self.assertTrue(analytics['total_count']==10) self.assertTrue(analytics['total_count']==10)
# def test_grid_evolve_2_threads_with_custom_logging(self): def test_grid_evolve_1_threads_with_custom_logging(self):
# # test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
# test_pop = Population() data_dir_value = os.path.join(binary_c_temp_dir, 'grid_tests')
# test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) amt_cores_value = 2
custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));'
# resolution = {"M_1": 10}
test_pop = Population()
# test_pop.add_grid_variable(
# name="lnm1", test_pop.set(amt_cores=amt_cores_value,
# longname="Primary mass", verbosity=1,
# valuerange=[1, 100], M_2=1,
# resolution="{}".format(resolution["M_1"]), orbital_period=100000,
# spacingfunc="const(math.log(1), math.log(100), {})".format( data_dir=data_dir_value,
# resolution["M_1"] C_logging_code=custom_logging_string, # input it like this.
# ), parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging)
# precode="M_1=math.exp(lnm1)", test_pop.set(ensemble=0)
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", resolution = {"M_1": 2}
# dphasevol="dlnm1",
# parameter_name="M_1", test_pop.add_grid_variable(
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself name="lnm1",
# ) longname="Primary mass",
valuerange=[1, 100],
# analytics = test_pop.evolve() resolution="{}".format(resolution["M_1"]),
# self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # spacingfunc="const(math.log(1), math.log(100), {})".format(
# self.assertTrue(analytics['total_count']==10) resolution["M_1"]
),
# def test_grid_evolve_with_condition_error(self): precode="M_1=math.exp(lnm1)",
# # Test to see if we can catch the errors correctly. probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
# test_pop = Population() parameter_name="M_1",
# test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
# # Set the amt of failed systems that each thread will log
# test_pop.set(failed_systems_threshold=4) analytics = test_pop.evolve()
output_names = [os.path.join(data_dir_value, "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(analytics['population_name'], thread_id)) for thread_id in range(amt_cores_value)]
# CUSTOM_LOGGING_STRING_WITH_EXIT = """
# Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits"); for output_name in output_names:
# Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", self.assertTrue(os.path.isfile(output_name))
# //
# stardata->model.time, // 1 with open(output_name, 'r') as f:
output_string = f.read()
self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string)
remove_file(output_name)
def test_grid_evolve_with_condition_error(self):
# Test to see if we can catch the errors correctly.
test_pop = Population()
test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000)
# Set the amt of failed systems that each thread will log
test_pop.set(failed_systems_threshold=4)
CUSTOM_LOGGING_STRING_WITH_EXIT = """
Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits");
Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
//
stardata->model.time, // 1
# // masses // masses
# stardata->common.zero_age.mass[0], // stardata->common.zero_age.mass[0], //
# stardata->common.zero_age.mass[1], // stardata->common.zero_age.mass[1], //
# stardata->star[0].mass, stardata->star[0].mass,
# stardata->star[1].mass stardata->star[1].mass
# ); );
# """ """
# test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)
# resolution = {"M_1": 10} resolution = {"M_1": 10}
# test_pop.add_grid_variable( test_pop.add_grid_variable(
# name="lnm1", name="lnm1",
# longname="Primary mass", longname="Primary mass",
# valuerange=[1, 100], valuerange=[1, 100],
# resolution="{}".format(resolution["M_1"]), resolution="{}".format(resolution["M_1"]),
# spacingfunc="const(math.log(1), math.log(100), {})".format( spacingfunc="const(math.log(1), math.log(100), {})".format(
# resolution["M_1"] resolution["M_1"]
# ), ),
# precode="M_1=math.exp(lnm1)", precode="M_1=math.exp(lnm1)",
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
# dphasevol="dlnm1", dphasevol="dlnm1",
# parameter_name="M_1", parameter_name="M_1",
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
# ) )
# analytics = test_pop.evolve() analytics = test_pop.evolve()
# self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) #
# self.assertLess(np.abs(analytics['failed_prob']-0.1503788456014623), 1e-10) # self.assertLess(np.abs(analytics['failed_prob']-0.1503788456014623), 1e-10) #
# self.assertEqual(analytics['failed_systems_error_codes'], [0]) self.assertEqual(analytics['failed_systems_error_codes'], [0])
# self.assertTrue(analytics['total_count']==10) self.assertTrue(analytics['total_count']==10)
# self.assertTrue(analytics['failed_count']==10) self.assertTrue(analytics['failed_count']==10)
# self.assertTrue(analytics['errors_found']==True) self.assertTrue(analytics['errors_found']==True)
# self.assertTrue(analytics['errors_exceeded']==True) self.assertTrue(analytics['errors_exceeded']==True)
# # test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
# test_pop = Population() test_pop = Population()
# test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000) test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000)
# resolution = {"M_1": 10, "q": 2} resolution = {"M_1": 10, "q": 2}
# test_pop.add_grid_variable( test_pop.add_grid_variable(
# name="lnm1", name="lnm1",
# longname="Primary mass", longname="Primary mass",
# valuerange=[1, 100], valuerange=[1, 100],
# resolution="{}".format(resolution["M_1"]), resolution="{}".format(resolution["M_1"]),
# spacingfunc="const(math.log(1), math.log(100), {})".format( spacingfunc="const(math.log(1), math.log(100), {})".format(
# resolution["M_1"] resolution["M_1"]
# ), ),
# precode="M_1=math.exp(lnm1)", precode="M_1=math.exp(lnm1)",
# probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
# dphasevol="dlnm1", dphasevol="dlnm1",
# parameter_name="M_1", parameter_name="M_1",
# condition="", # Impose a condition on this grid variable. Mostly for a check for yourself condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
# ) )
# test_pop.add_grid_variable( test_pop.add_grid_variable(
# name="q", name="q",
# longname="Mass ratio", longname="Mass ratio",
# valuerange=["0.1/M_1", 1], valuerange=["0.1/M_1", 1],
# resolution="{}".format(resolution["q"]), resolution="{}".format(resolution["q"]),
# spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]),
# probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
# dphasevol="dq", dphasevol="dq",
# precode="M_2 = q * M_1", precode="M_2 = q * M_1",
# parameter_name="M_2", parameter_name="M_2",
# # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself
# condition="'random_var' in dir()", # This will raise an error because random_var is not defined. condition="'random_var' in dir()", # This will raise an error because random_var is not defined.
# ) )
# self.assertRaises(ValueError, test_pop.evolve) self.assertRaises(ValueError, test_pop.evolve)
# def test_grid_evolve_no_grid_variables(self): def test_grid_evolve_no_grid_variables(self):
# # test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
# test_pop = Population() test_pop = Population()
# test_pop.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000) test_pop.set(amt_cores=1, verbosity=1, M_2=1, orbital_period=100000)
# resolution = {"M_1": 10} resolution = {"M_1": 10}
# self.assertRaises(ValueError, test_pop.evolve) self.assertRaises(ValueError, test_pop.evolve)
def test_grid_evolve_2_threads_with_ensemble(self): def test_grid_evolve_2_threads_with_ensemble_direct_output(self):
# test to see if 1 thread does all the systems # test to see if 1 thread does all the systems
data_dir_value = binary_c_temp_dir
amt_cores_value = 2
test_pop = Population() test_pop = Population()
test_pop.set(amt_cores=2, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1) test_pop.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1)
test_pop.set(data_dir=binary_c_temp_dir, ensemble_output_name="ensemble_output.json", combine_ensemble_with_thread_joining=False) test_pop.set(data_dir=binary_c_temp_dir, ensemble_output_name="ensemble_output.json", combine_ensemble_with_thread_joining=False)
resolution = {"M_1": 10} resolution = {"M_1": 10}
...@@ -609,11 +653,133 @@ class test_grid_evolve(unittest.TestCase): ...@@ -609,11 +653,133 @@ class test_grid_evolve(unittest.TestCase):
) )
analytics = test_pop.evolve() analytics = test_pop.evolve()
output_names = [os.path.join(data_dir_value, "ensemble_output_{}_{}.json".format(analytics['population_name'], thread_id)) for thread_id in range(amt_cores_value)]
for output_name in output_names:
self.assertTrue(os.path.isfile(output_name))
with open(output_name, 'r') as f:
file_content = f.read()
self.assertTrue(file_content.startswith("ENSEMBLE_JSON"))
ensemble_json = extract_ensemble_json_from_string(file_content)
self.assertTrue(isinstance(ensemble_json, dict))
self.assertNotEqual(ensemble_json, {})
self.assertIn("number_counts", ensemble_json)
self.assertNotEqual(ensemble_json["number_counts"], {})
def test_grid_evolve_2_threads_with_ensemble_combining(self):
# test to see if 1 thread does all the systems
data_dir_value = binary_c_temp_dir
amt_cores_value = 2
test_pop = Population()
test_pop.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1)
test_pop.set(data_dir=binary_c_temp_dir, combine_ensemble_with_thread_joining=True, ensemble_output_name="ensemble_output.json")
resolution = {"M_1": 10}
test_pop.add_grid_variable(
name="lnm1",
longname="Primary mass",
valuerange=[1, 100],
resolution="{}".format(resolution["M_1"]),
spacingfunc="const(math.log(1), math.log(100), {})".format(
resolution["M_1"]
),
precode="M_1=math.exp(lnm1)",
probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
parameter_name="M_1",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
analytics = test_pop.evolve()
self.assertTrue(isinstance(test_pop.grid_options['ensemble_results'], dict))
self.assertNotEqual(test_pop.grid_options['ensemble_results'], {})
self.assertIn("number_counts", test_pop.grid_options['ensemble_results'])
self.assertNotEqual(test_pop.grid_options['ensemble_results']["number_counts"], {})
def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
# test to see if 1 thread does all the systems
data_dir_value = binary_c_temp_dir
amt_cores_value = 2
# First
test_pop_1 = Population()
test_pop_1.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1)
test_pop_1.set(data_dir=binary_c_temp_dir, combine_ensemble_with_thread_joining=True, ensemble_output_name="ensemble_output.json")
resolution = {"M_1": 10}
test_pop_1.add_grid_variable(
name="lnm1",
longname="Primary mass",
valuerange=[1, 100],
resolution="{}".format(resolution["M_1"]),
spacingfunc="const(math.log(1), math.log(100), {})".format(
resolution["M_1"]
),
precode="M_1=math.exp(lnm1)",
probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
parameter_name="M_1",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
analytics_1 = test_pop_1.evolve()
ensemble_output_1 = test_pop_1.grid_options['ensemble_results']
# second
test_pop_2 = Population()
test_pop_2.set(amt_cores=amt_cores_value, verbosity=1, M_2=1, orbital_period=100000, ensemble=1, ensemble_defer=1, ensemble_filters_off=1, ensemble_filter_STELLAR_TYPE_COUNTS=1)
test_pop_2.set(data_dir=binary_c_temp_dir, ensemble_output_name="ensemble_output.json", combine_ensemble_with_thread_joining=False)
resolution = {"M_1": 10}
test_pop_2.add_grid_variable(
name="lnm1",
longname="Primary mass",
valuerange=[1, 100],
resolution="{}".format(resolution["M_1"]),
spacingfunc="const(math.log(1), math.log(100), {})".format(
resolution["M_1"]
),
precode="M_1=math.exp(lnm1)",
probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
parameter_name="M_1",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
analytics_2 = test_pop_2.evolve()
output_names_2 = [os.path.join(data_dir_value, "ensemble_output_{}_{}.json".format(analytics_2['population_name'], thread_id)) for thread_id in range(amt_cores_value)]
ensemble_output_2 = {}
for output_name in output_names_2:
self.assertTrue(os.path.isfile(output_name))
with open(output_name, 'r') as f:
file_content = f.read()
self.assertTrue(file_content.startswith("ENSEMBLE_JSON"))
ensemble_json = extract_ensemble_json_from_string(file_content)
ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json)
# self.assertLess(np.abs(analytics['total_probability']-0.1503788456014623), 1e-10) # for key in ensemble_output_1['number_counts']['stellar_type']['0']:
# self.assertTrue(analytics['total_count']==10) self.assertIn(key, ensemble_output_2['number_counts']['stellar_type']['0'])
# compare values
self.assertLess(np.abs(ensemble_output_1['number_counts']['stellar_type']['0'][key]-ensemble_output_2['number_counts']['stellar_type']['0'][key]), 1e-8)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment