diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py index 4c15060dc6682aa76dd0b9f78930347ac6deb87b..eb02619d559fa990d44a29fe557fc81f5c798975 100644 --- a/binarycpython/tests/test_functions.py +++ b/binarycpython/tests/test_functions.py @@ -2,12 +2,15 @@ Unittests for the functions module """ +import os import unittest import tempfile -from binarycpython.utils.functions import * + from binarycpython.utils.custom_logging_functions import binary_c_log_code from binarycpython.utils.run_system_wrapper import run_system +from binarycpython.utils.functions import * + TMP_DIR = temp_dir("tests", "test_functions") class dummy: @@ -852,6 +855,67 @@ class test_handle_ensemble_string_to_json(unittest.TestCase): self.assertTrue(output_dict["ding"] == 10) self.assertTrue(output_dict["list_example"] == [1, 2, 3]) +class test_AutoVivicationDict(unittest.TestCase): + """ + Unittests for AutoVivicationDict + """ + + def test_add(self): + """ + Tests to see if the adding is done correctly + """ + + result_dict = AutoVivificationDict() + + result_dict['a']['b']['c'] += 10 + + self.assertEqual(result_dict['a']['b']['c'], 10) + result_dict['a']['b']['c'] += 10 + self.assertEqual(result_dict['a']['b']['c'], 20) + +class test_bin_data(unittest.TestCase): + """ + Unittests for bin_data + """ + + def test_positive_bin(self): + """ + Tests to see if the binning is done correctly for positive values + """ + + value = 0.6 + binwidth = 1 + + binned_value = bin_data(value, binwidth) + + self.assertEqual(binned_value, 0.5) + + def test_negative_bin(self): + """ + Tests to see if the binning is done correctly for negative values + """ + + value = -0.6 + binwidth = 1 + + binned_value = bin_data(value, binwidth) + + self.assertEqual(binned_value, -0.5) + + def test_zero_bin(self): + """ + Tests to see if the binning is done correctly + TODO: when the value is 0 then its now binned in the negative located bin. Decide whether we want that + """ + + value = 0 + binwidth = 1 + + binned_value = bin_data(value, binwidth) + + self.assertEqual(binned_value, -0.5) + + if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index ce1bda06b48be6b565a12f2ca9391a7971171578..7c248f803d3e50b5ef82d8c6f83ae22dfe1b3076 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -19,6 +19,7 @@ from binarycpython.utils.functions import ( merge_dicts, remove_file, Capturing, + bin_data ) from binarycpython.utils.custom_logging_functions import binary_c_log_code @@ -1030,6 +1031,132 @@ Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", 1e-8, ) +def parse_function_adding_results(self, output): + """ + Example parse function + """ + + seperator = " " + + parameters = ["time", "mass", "zams_mass", "probability", "stellar_type"] + + self.grid_results['example']['count'] += 1 + + # Go over the output. + for line in output.splitlines(): + headerline = line.split()[0] + + # CHeck the header and act accordingly + if headerline == "EXAMPLE_OUTPUT": + values = line.split()[1:] + + # Bin the mass probability + self.grid_results['example']['mass'][bin_data(float(values[2]), binwidth=0.5)] += float(values[3]) + + # + if not len(parameters) == len(values): + print("Amount of column names isnt equal to amount of columns") + raise ValueError + + # record the probability of this line (Beware, this is meant to only be run once for each system. its a controls quantity) + self.grid_results['example']['probability'] += float(values[3]) + +class test_resultdict(unittest.TestCase): + """ + Unittests for bin_data + """ + + def test_adding_results(self): + """ + Function to test whether the results are properly added and combined + """ + + # Create custom logging statement + custom_logging_statement = """ + if (stardata->model.time < stardata->model.max_evolution_time) + { + Printf("EXAMPLE_OUTPUT %30.16e %g %g %30.12e %d\\n", + // + stardata->model.time, // 1 + stardata->star[0].mass, // 2 + stardata->common.zero_age.mass[0], // 3 + stardata->model.probability, // 4 + stardata->star[0].stellar_type // 5 + ); + }; + /* Kill the simulation to save time */ + stardata->model.max_evolution_time = stardata->model.time - stardata->model.dtm; + """ + + example_pop = Population() + example_pop.set(verbosity=0) + example_pop.set( + max_evolution_time=15000, # bse_options + # grid_options + amt_cores=2, + tmp_dir=TMP_DIR, + + # Custom options + data_dir=os.path.join( + TMP_DIR, "test_resultdict" + ), # custom_options + C_logging_code=custom_logging_statement, + parse_function=parse_function_adding_results, + ) + + # Add grid variables + resolution = {"M_1": 10} + + # Mass + example_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[2, 150], + resolution="{}".format(resolution["M_1"]), + spacingfunc="const(math.log(2), math.log(150), {})".format(resolution["M_1"]), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + ## Executing a population + ## This uses the values generated by the grid_variables + analytics = example_pop.evolve() + + # + grid_prob = analytics['total_probability'] + result_dict_prob = example_pop.grid_results['example']['probability'] + + # amt systems + grid_count = analytics['total_count'] + result_dict_count = example_pop.grid_results['example']['count'] + + # Check if the total probability matches + self.assertAlmostEqual( + grid_prob, + result_dict_prob, + places=12, + msg="Total probability from grid {} and from result dict {} are not equal".format(grid_prob, result_dict_prob) + ) + + # Check if the total count matches + self.assertAlmostEqual( + grid_count, + result_dict_count, + places=12, + msg="Total count from grid {} and from result dict {} are not equal".format(grid_count, result_dict_count) + ) + + # Check if the structure is what we expect. Note: this depends on the probability calculation. if that changes we need to recalibrate this + test_case_dict = {2.25: 0.01895481306515, 3.75: 0.01081338190204, 5.75: 0.006168841009268, 9.25: 0.003519213484031, 13.75: 0.002007648361756, 21.25: 0.001145327489437, 33.25: 0.0006533888518775, 50.75: 0.0003727466560393, 78.25: 0.000212645301782, 120.75: 0.0001213103421247} + + self.assertEqual(test_case_dict, dict(example_pop.grid_results['example']['mass'])) + + + + if __name__ == "__main__": unittest.main() diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py index 32acf5c0dad2d7745f218116068096674d297ddf..88e994cface147d9154ab4719a00c3d1e46f85ad 100644 --- a/binarycpython/utils/functions.py +++ b/binarycpython/utils/functions.py @@ -39,6 +39,35 @@ import py_rinterpolate ######################################################## # Unsorted ######################################################## +class AutoVivificationDict(dict): + """ + Implementation of perl's autovivification feature. + """ + + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError: + value = self[item] = type(self)() + return value + + def __iadd__(self, other): + # if a value does not exist, assume it is 0.0 + try: + self += other + except: + self = other + return self + +def bin_data(value, binwidth): + """ + Function that bins the data + + Uses the absolute value of binwidth + """ + + return ((0.5 if value > 0.0 else -0.5) + int(value/abs(binwidth))) * abs(binwidth) + def convert_bytes(size): """ Function to return the size + a magnitude string diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py index 85264ad4109465dbb8f8c1174e86edd53f0bc2bd..b33540a047016ed9519f505a5dd8ed7a177107e0 100644 --- a/binarycpython/utils/grid.py +++ b/binarycpython/utils/grid.py @@ -73,9 +73,9 @@ from binarycpython.utils.functions import ( recursive_change_key_to_string, multiply_values_dict, format_ensemble_results, + AutoVivificationDict, ) - # from binarycpython.utils.hpc_functions import ( # get_condor_version, # get_slurm_version, @@ -149,7 +149,8 @@ class Population: self.process_ID = 0 # Create location to store results. Users should write to this dictionary. - self.grid_results = {} + # The AutoVivificationDict allows for perls method of accessing possibly non-existant subdicts + self.grid_results = AutoVivificationDict() # Create location where ensemble results are written to self.grid_ensemble_results = {}