From 733597dc8a96caff1c2ed2f3a0ae893e2af29498 Mon Sep 17 00:00:00 2001
From: dh00601 <dh00601@surrey.ac.uk>
Date: Sat, 8 Jan 2022 21:54:44 +0000
Subject: [PATCH] added the grid evolve tests

---
 badges/test_coverage.svg                      |    6 +-
 binarycpython/tests/main.py                   |   10 +-
 binarycpython/tests/test_c_bindings.py        |   10 +-
 binarycpython/tests/test_dicts.py             |  113 +-
 binarycpython/tests/test_grid.py              | 1035 +++++++++--------
 binarycpython/utils/dicts.py                  |   13 +-
 binarycpython/utils/grid.py                   |   14 +-
 .../utils/population_extensions/gridcode.py   |   27 +-
 8 files changed, 667 insertions(+), 561 deletions(-)

diff --git a/badges/test_coverage.svg b/badges/test_coverage.svg
index d1f89653c..2fad913a4 100644
--- a/badges/test_coverage.svg
+++ b/badges/test_coverage.svg
@@ -9,13 +9,13 @@
     </mask>
     <g mask="url(#a)">
         <path fill="#555" d="M0 0h63v20H0z"/>
-        <path fill="#fe7d37" d="M63 0h36v20H63z"/>
+        <path fill="#dfb317" d="M63 0h36v20H63z"/>
         <path fill="url(#b)" d="M0 0h99v20H0z"/>
     </g>
     <g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
         <text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
         <text x="31.5" y="14">coverage</text>
-        <text x="80" y="15" fill="#010101" fill-opacity=".3">53%</text>
-        <text x="80" y="14">53%</text>
+        <text x="80" y="15" fill="#010101" fill-opacity=".3">62%</text>
+        <text x="80" y="14">62%</text>
     </g>
 </svg>
diff --git a/binarycpython/tests/main.py b/binarycpython/tests/main.py
index 8954d32e0..89dad560c 100755
--- a/binarycpython/tests/main.py
+++ b/binarycpython/tests/main.py
@@ -35,7 +35,10 @@ from binarycpython.tests.test_dicts import (
     test_keys_to_floats,
     test_recursive_change_key_to_float,
     test_recursive_change_key_to_string,
-    test_multiply_float_values
+    test_multiply_float_values,
+    test_subtract_dicts,
+    test_update_dicts,
+    test__nested_get
 )
 from binarycpython.tests.test_ensemble import (
     test_binaryc_json_serializer,
@@ -56,7 +59,7 @@ from binarycpython.tests.test_functions import (
     test_get_help_super,
     test_make_build_text,
     test_write_binary_c_parameter_descriptions_to_rst_file,
-
+    test_bin_data
 )
 from binarycpython.tests.test_grid import (
     test__setup,
@@ -71,7 +74,8 @@ from binarycpython.tests.test_grid import (
     test__increment_probtot,
     test__increment_count,
     test__dict_from_line_source_file,
-    test_evolve_single
+    test_evolve_single,
+    test_grid_evolve
 )
 from binarycpython.tests.test_plot_functions import (
     test_color_by_index,
diff --git a/binarycpython/tests/test_c_bindings.py b/binarycpython/tests/test_c_bindings.py
index e51922448..692000670 100644
--- a/binarycpython/tests/test_c_bindings.py
+++ b/binarycpython/tests/test_c_bindings.py
@@ -437,14 +437,13 @@ class test_ensemble_functions(unittest.TestCase):
 
     #############
 
-    # def test_full_ensemble_output(self):
-    #     with Capturing() as output:
-    #         self._test_full_ensemble_output()
+    def test_full_ensemble_output(self):
+        with Capturing() as output:
+            self._test_full_ensemble_output()
 
     def _test_full_ensemble_output(self):
         """
         Function to just output the whole ensemble
-        TODO: put this one back
         """
         print(self.id())
 
@@ -461,13 +460,12 @@ class test_ensemble_functions(unittest.TestCase):
         #
         output_json_1 = extract_ensemble_json_from_string(output_1)
 
-        keys = json_1.keys()
+        keys = output_json_1.keys()
 
         # assert statements:
         self.assertIn("number_counts", keys)
         self.assertIn("HRD", keys)
         self.assertIn("HRD(t)", keys)
-        self.assertIn("Xyield", keys)
         self.assertIn("distributions", keys)
         self.assertIn("scalars", keys)
 
diff --git a/binarycpython/tests/test_dicts.py b/binarycpython/tests/test_dicts.py
index 1f4854d91..831b8bbad 100644
--- a/binarycpython/tests/test_dicts.py
+++ b/binarycpython/tests/test_dicts.py
@@ -2,8 +2,6 @@
 Unittests for dicts module
 
 TODO: _nested_set
-TODO: _nested_get
-TODO: update_dicts
 """
 
 import os
@@ -30,7 +28,10 @@ from binarycpython.utils.dicts import (
     recursive_change_key_to_float,
     recursive_change_key_to_string,
     multiply_float_values,
-    subtract_dicts
+    subtract_dicts,
+    update_dicts,
+    _nested_get,
+    _nested_set
 )
 
 TMP_DIR = temp_dir("tests", "test_dicts")
@@ -621,24 +622,24 @@ class test_subtract_dicts(unittest.TestCase):
         Test subtract_dicts resulting in a 0 value. which should be removed
         """
 
-        dict_1 = {"a": 4, 'b': 0}
-        dict_2 = {"a": 4, 'c': 0}
+        dict_1 = {"a": 4, 'b': 0, 'd': 1.0}
+        dict_2 = {"a": 4, 'c': 0, 'd': 1}
         output_dict = subtract_dicts(dict_1, dict_2)
 
         self.assertIsInstance(output_dict, dict)
         self.assertFalse(output_dict)
 
-    def test_lists(self):
+    def test_unsupported(self):
         with Capturing() as output:
-            self._test_lists()
+            self._test_unsupported()
 
-    def _test_lists(self):
+    def _test_unsupported(self):
         """
         Test merging dict with lists
         """
 
-        dict_1 = {"list": [1, 2]}
-        dict_2 = {"list": [3, 4]}
+        dict_1 = {"list": [1, 2], 'b': [1]}
+        dict_2 = {"list": [3, 4], 'c': [1]}
 
         self.assertRaises(ValueError, subtract_dicts, dict_1, dict_2)
 
@@ -660,5 +661,97 @@ class test_subtract_dicts(unittest.TestCase):
             output_dict["dict"], {"a": -1, "b": 1, "c": -2}
         )
 
+class test_update_dicts(unittest.TestCase):
+    """
+    Unittests for function update_dicts
+    """
+
+    def test_dicts(self):
+        with Capturing() as _:
+            self._test_dicts()
+
+    def _test_dicts(self):
+        """
+        Test update_dicts with dicts
+        """
+
+        dict_1 = {"dict": {"a": 1, "b": 1}}
+        dict_2 = {"dict": {"a": 2, "c": 2}}
+        output_dict = update_dicts(dict_1, dict_2)
+
+        self.assertTrue(isinstance(output_dict["dict"], dict))
+        self.assertEqual(
+            output_dict["dict"], {"a": 2, "b": 1, "c": 2}
+        )
+
+    def test_unsupported(self):
+        with Capturing() as output:
+            self._test_unsupported()
+
+    def _test_unsupported(self):
+        """
+        Test update_dicts with unsupported types
+        """
+
+        dict_1 = {"list": 2, 'b': [1]}
+        dict_2 = {"list": [3, 4], 'c': [1]}
+
+        self.assertRaises(ValueError, update_dicts, dict_1, dict_2)
+
+
+class test__nested_get(unittest.TestCase):
+    """
+    Unittests for function _nested_get
+    """
+
+    def test__nested_get(self):
+        with Capturing() as output:
+            self._test__nested_get()
+
+    def _test__nested_get(self):
+        """
+        Test _nested_get
+        """
+
+        input_1 = {'a': {'b': 2}}
+
+        output_1 = _nested_get(input_1, ['a'])
+        output_2 = _nested_get(input_1, ['a', 'b'])
+
+        self.assertEqual(output_1, {'b': 2})
+        self.assertEqual(output_2, 2)
+
+class test__nested_set(unittest.TestCase):
+    """
+    Unittests for function _nested_set
+    """
+
+    def test__nested_set(self):
+        with Capturing() as output:
+            self._test__nested_set()
+
+    def _test__nested_set(self):
+        """
+        Test _nested_set
+        """
+
+        #
+        input_1 = {'a': 0}
+        desired_output_1 = {'a': 2}
+        _nested_set(input_1, ['a'], 2)
+        self.assertEqual(input_1, desired_output_1)
+
+        #
+        input_2 = {'a': {'b': 0}}
+        desired_output_2 = {'a': {'b': 2}}
+        _nested_set(input_2, ['a', 'b'], 2)
+        self.assertEqual(input_2, desired_output_2)
+
+        #
+        input_3 = {'a': {'b': 0}}
+        desired_output_3 = {'a': {'b': 0, 'd': {'c': 10}}}
+        _nested_set(input_3, ['a', 'd', 'c'], 10)
+        self.assertEqual(input_3, desired_output_3)
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py
index 352f89e7a..efd5a2d52 100644
--- a/binarycpython/tests/test_grid.py
+++ b/binarycpython/tests/test_grid.py
@@ -29,10 +29,15 @@ import sys
 import json
 import gzip
 import unittest
+import numpy as np
 
 from binarycpython.utils.functions import (
     temp_dir,
     Capturing,
+    remove_file
+)
+from binarycpython.utils.dicts import (
+    merge_dicts,
 )
 
 from binarycpython.utils.grid import Population
@@ -601,11 +606,11 @@ class test_resultdict(unittest.TestCase):
             name="lnm1",
             longname="Primary mass",
             valuerange=[2, 150],
-            samplerfunc="const(math.log(2), math.log(150), {})".format(
+            samplerfunc="self.const_linear(math.log(2), math.log(150), {})".format(
                 resolution["M_1"]
             ),
             precode="M_1=math.exp(lnm1)",
-            probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1",
             dphasevol="dlnm1",
             parameter_name="M_1",
             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
@@ -663,519 +668,519 @@ class test_resultdict(unittest.TestCase):
 
 
 
-# class test_grid_evolve(unittest.TestCase):
-#     """
-#     Unittests for function Population.evolve()
-#     """
-
-#     def test_grid_evolve_1_thread(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_1_thread()
-
-#     def _test_grid_evolve_1_thread(self):
-#         """
-#         Unittests to see if 1 thread does all the systems
-#         """
-
-#         test_pop_evolve_1_thread = Population()
-#         test_pop_evolve_1_thread.set(
-#             num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop_evolve_1_thread.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop_evolve_1_thread.evolve()
-#         self.assertLess(
-#             np.abs(analytics["total_probability"] - 0.10820655287892997),
-#             1e-10,
-#             msg=analytics["total_probability"],
-#         )
-#         self.assertTrue(analytics["total_count"] == 10)
-
-#     def test_grid_evolve_2_threads(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_2_threads()
-
-#     def _test_grid_evolve_2_threads(self):
-#         """
-#         Unittests to see if multiple threads handle the all the systems correctly
-#         """
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop.evolve()
-#         self.assertLess(
-#             np.abs(analytics["total_probability"] - 0.10820655287892997),
-#             1e-10,
-#             msg=analytics["total_probability"],
-#         )  #
-#         self.assertTrue(analytics["total_count"] == 10)
-
-#     def test_grid_evolve_2_threads_with_custom_logging(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_2_threads_with_custom_logging()
-
-#     def _test_grid_evolve_2_threads_with_custom_logging(self):
-#         """
-#         Unittests to see if multiple threads do the custom logging correctly
-#         """
-
-#         data_dir_value = os.path.join(TMP_DIR, "grid_tests")
-#         num_cores_value = 2
-#         custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));'
-
-#         test_pop = Population()
-
-#         test_pop.set(
-#             num_cores=num_cores_value,
-#             verbosity=TEST_VERBOSITY,
-#             M_2=1,
-#             orbital_period=100000,
-#             data_dir=data_dir_value,
-#             C_logging_code=custom_logging_string,  # input it like this.
-#             parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging,
-#         )
-#         test_pop.set(ensemble=0)
-#         resolution = {"M_1": 2}
-
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop.evolve()
-#         output_names = [
-#             os.path.join(
-#                 data_dir_value,
-#                 "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(
-#                     analytics["population_name"], thread_id
-#                 ),
-#             )
-#             for thread_id in range(num_cores_value)
-#         ]
-
-#         for output_name in output_names:
-#             self.assertTrue(os.path.isfile(output_name))
-
-#             with open(output_name, "r") as f:
-#                 output_string = f.read()
-
-#             self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string)
-
-#             remove_file(output_name)
-
-#     def test_grid_evolve_with_condition_error(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_with_condition_error()
-
-#     def _test_grid_evolve_with_condition_error(self):
-#         """
-#         Unittests to see if the threads catch the errors correctly.
-#         """
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
-#         )
-
-#         # Set the amt of failed systems that each thread will log
-#         test_pop.set(failed_systems_threshold=4)
-
-#         CUSTOM_LOGGING_STRING_WITH_EXIT = """
-# Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry");
-# Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
-#     //
-#     stardata->model.time, // 1
-
-#     // masses
-#     stardata->common.zero_age.mass[0], //
-#     stardata->common.zero_age.mass[1], //
-
-#     stardata->star[0].mass,
-#     stardata->star[1].mass
-# );
-#         """
-
-#         test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)
-
-#         resolution = {"M_1": 10}
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop.evolve()
-#         self.assertLess(
-#             np.abs(analytics["total_probability"] - 0.10820655287892997),
-#             1e-10,
-#             msg=analytics["total_probability"],
-#         )  #
-#         self.assertEqual(analytics["failed_systems_error_codes"], [0])
-#         self.assertTrue(analytics["total_count"] == 10)
-#         self.assertTrue(analytics["failed_count"] == 10)
-#         self.assertTrue(analytics["errors_found"] == True)
-#         self.assertTrue(analytics["errors_exceeded"] == True)
-
-#         # test to see if 1 thread does all the systems
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
-#         )
-#         test_pop.set(failed_systems_threshold=4)
-#         test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)
-
-#         resolution = {"M_1": 10, "q": 2}
-
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         test_pop.add_grid_variable(
-#             name="q",
-#             longname="Mass ratio",
-#             valuerange=["0.1/M_1", 1],
-#             samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]),
-#             probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
-#             dphasevol="dq",
-#             precode="M_2 = q * M_1",
-#             parameter_name="M_2",
-#             # condition="M_1 in dir()",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#             condition="'random_var' in dir()",  # This will raise an error because random_var is not defined.
-#         )
-
-#         # TODO: why should it raise this error? It should probably raise a valueerror when the limit is exceeded right?
-#         # DEcided to turn it off for now because there is not raise VAlueError in that chain of functions.
-#         # NOTE: Found out why this test was here. It is to do with the condition random_var in dir(), but I changed the behaviour from raising an error to continue. This has to do with the moe&distefano code that will loop over several multiplicities
-#         # TODO: make sure the continue behaviour is what we actually want.
-
-#         # self.assertRaises(ValueError, test_pop.evolve)
-
-#     def test_grid_evolve_no_grid_variables(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_no_grid_variables()
-
-#     def _test_grid_evolve_no_grid_variables(self):
-#         """
-#         Unittests to see if errors are raised if there are no grid variables
-#         """
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
-#         )
-
-#         resolution = {"M_1": 10}
-#         self.assertRaises(ValueError, test_pop.evolve)
-
-#     def test_grid_evolve_2_threads_with_ensemble_direct_output(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_2_threads_with_ensemble_direct_output()
-
-#     def _test_grid_evolve_2_threads_with_ensemble_direct_output(self):
-#         """
-#         Unittests to see if multiple threads output the ensemble information to files correctly
-#         """
-
-#         data_dir_value = TMP_DIR
-#         num_cores_value = 2
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=num_cores_value,
-#             verbosity=TEST_VERBOSITY,
-#             M_2=1,
-#             orbital_period=100000,
-#             ensemble=1,
-#             ensemble_defer=1,
-#             ensemble_filters_off=1,
-#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
-#             ensemble_dt=1000,
-#         )
-#         test_pop.set(
-#             data_dir=TMP_DIR,
-#             ensemble_output_name="ensemble_output.json",
-#             combine_ensemble_with_thread_joining=False,
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop.evolve()
-#         output_names = [
-#             os.path.join(
-#                 data_dir_value,
-#                 "ensemble_output_{}_{}.json".format(
-#                     analytics["population_name"], thread_id
-#                 ),
-#             )
-#             for thread_id in range(num_cores_value)
-#         ]
-
-#         for output_name in output_names:
-#             self.assertTrue(os.path.isfile(output_name))
-
-#             with open(output_name, "r") as f:
-#                 file_content = f.read()
-
-#                 ensemble_json = json.loads(file_content)
-
-#                 self.assertTrue(isinstance(ensemble_json, dict))
-#                 self.assertNotEqual(ensemble_json, {})
-
-#                 self.assertIn("number_counts", ensemble_json)
-#                 self.assertNotEqual(ensemble_json["number_counts"], {})
-
-#     def test_grid_evolve_2_threads_with_ensemble_combining(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_2_threads_with_ensemble_combining()
-
-#     def _test_grid_evolve_2_threads_with_ensemble_combining(self):
-#         """
-#         Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid
-#         """
-
-#         data_dir_value = TMP_DIR
-#         num_cores_value = 2
-
-#         test_pop = Population()
-#         test_pop.set(
-#             num_cores=num_cores_value,
-#             verbosity=TEST_VERBOSITY,
-#             M_2=1,
-#             orbital_period=100000,
-#             ensemble=1,
-#             ensemble_defer=1,
-#             ensemble_filters_off=1,
-#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
-#             ensemble_dt=1000,
-#         )
-#         test_pop.set(
-#             data_dir=TMP_DIR,
-#             combine_ensemble_with_thread_joining=True,
-#             ensemble_output_name="ensemble_output.json",
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics = test_pop.evolve()
-
-#         self.assertTrue(isinstance(test_pop.grid_ensemble_results["ensemble"], dict))
-#         self.assertNotEqual(test_pop.grid_ensemble_results["ensemble"], {})
-
-#         self.assertIn("number_counts", test_pop.grid_ensemble_results["ensemble"])
-#         self.assertNotEqual(
-#             test_pop.grid_ensemble_results["ensemble"]["number_counts"], {}
-#         )
-
-#     def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
-#         with Capturing() as output:
-#             self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods()
-
-#     def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
-#         """
-#         Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same
-#         """
-
-#         data_dir_value = TMP_DIR
-#         num_cores_value = 2
-
-#         # First
-#         test_pop_1 = Population()
-#         test_pop_1.set(
-#             num_cores=num_cores_value,
-#             verbosity=TEST_VERBOSITY,
-#             M_2=1,
-#             orbital_period=100000,
-#             ensemble=1,
-#             ensemble_defer=1,
-#             ensemble_filters_off=1,
-#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
-#             ensemble_dt=1000,
-#         )
-#         test_pop_1.set(
-#             data_dir=TMP_DIR,
-#             combine_ensemble_with_thread_joining=True,
-#             ensemble_output_name="ensemble_output.json",
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop_1.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics_1 = test_pop_1.evolve()
-#         ensemble_output_1 = test_pop_1.grid_ensemble_results
-
-#         # second
-#         test_pop_2 = Population()
-#         test_pop_2.set(
-#             num_cores=num_cores_value,
-#             verbosity=TEST_VERBOSITY,
-#             M_2=1,
-#             orbital_period=100000,
-#             ensemble=1,
-#             ensemble_defer=1,
-#             ensemble_filters_off=1,
-#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
-#             ensemble_dt=1000,
-#         )
-#         test_pop_2.set(
-#             data_dir=TMP_DIR,
-#             ensemble_output_name="ensemble_output.json",
-#             combine_ensemble_with_thread_joining=False,
-#         )
-
-#         resolution = {"M_1": 10}
-
-#         test_pop_2.add_grid_variable(
-#             name="lnm1",
-#             longname="Primary mass",
-#             valuerange=[1, 100],
-#             samplerfunc="const(math.log(1), math.log(100), {})".format(
-#                 resolution["M_1"]
-#             ),
-#             precode="M_1=math.exp(lnm1)",
-#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
-#             dphasevol="dlnm1",
-#             parameter_name="M_1",
-#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
-#         )
-
-#         analytics_2 = test_pop_2.evolve()
-#         output_names_2 = [
-#             os.path.join(
-#                 data_dir_value,
-#                 "ensemble_output_{}_{}.json".format(
-#                     analytics_2["population_name"], thread_id
-#                 ),
-#             )
-#             for thread_id in range(num_cores_value)
-#         ]
-#         ensemble_output_2 = {}
-
-#         for output_name in output_names_2:
-#             self.assertTrue(os.path.isfile(output_name))
-
-#             with open(output_name, "r") as f:
-#                 file_content = f.read()
-
-#                 ensemble_json = json.loads(file_content)
-
-#                 ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json)
-
-#         for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]:
-#             self.assertIn(key, ensemble_output_2["number_counts"]["stellar_type"]["0"])
-
-#             # compare values
-#             self.assertLess(
-#                 np.abs(
-#                     ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"][
-#                         key
-#                     ]
-#                     - ensemble_output_2["number_counts"]["stellar_type"]["0"][key]
-#                 ),
-#                 1e-8,
-#             )
+class test_grid_evolve(unittest.TestCase):
+    """
+    Unittests for function Population.evolve()
+    """
+
+    def test_grid_evolve_1_thread(self):
+        with Capturing() as output:
+            self._test_grid_evolve_1_thread()
+
+    def _test_grid_evolve_1_thread(self):
+        """
+        Unittests to see if 1 thread does all the systems
+        """
+
+        test_pop_evolve_1_thread = Population()
+        test_pop_evolve_1_thread.set(
+            num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop_evolve_1_thread.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop_evolve_1_thread.evolve()
+        self.assertLess(
+            np.abs(analytics["total_probability"] - 0.10820655287892997),
+            1e-10,
+            msg=analytics["total_probability"],
+        )
+        self.assertTrue(analytics["total_count"] == 10)
+
+    def test_grid_evolve_2_threads(self):
+        with Capturing() as output:
+            self._test_grid_evolve_2_threads()
+
+    def _test_grid_evolve_2_threads(self):
+        """
+        Unittests to see if multiple threads handle the all the systems correctly
+        """
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop.evolve()
+        self.assertLess(
+            np.abs(analytics["total_probability"] - 0.10820655287892997),
+            1e-10,
+            msg=analytics["total_probability"],
+        )  #
+        self.assertTrue(analytics["total_count"] == 10)
+
+    def test_grid_evolve_2_threads_with_custom_logging(self):
+        with Capturing() as output:
+            self._test_grid_evolve_2_threads_with_custom_logging()
+
+    def _test_grid_evolve_2_threads_with_custom_logging(self):
+        """
+        Unittests to see if multiple threads do the custom logging correctly
+        """
+
+        data_dir_value = os.path.join(TMP_DIR, "grid_tests")
+        num_cores_value = 2
+        custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));'
+
+        test_pop = Population()
+
+        test_pop.set(
+            num_cores=num_cores_value,
+            verbosity=TEST_VERBOSITY,
+            M_2=1,
+            orbital_period=100000,
+            data_dir=data_dir_value,
+            C_logging_code=custom_logging_string,  # input it like this.
+            parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging,
+        )
+        test_pop.set(ensemble=0)
+        resolution = {"M_1": 2}
+
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop.evolve()
+        output_names = [
+            os.path.join(
+                data_dir_value,
+                "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(
+                    analytics["population_id"], thread_id
+                ),
+            )
+            for thread_id in range(num_cores_value)
+        ]
+
+        for output_name in output_names:
+            self.assertTrue(os.path.isfile(output_name))
+
+            with open(output_name, "r") as f:
+                output_string = f.read()
+
+            self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string)
+
+            remove_file(output_name)
+
+    def test_grid_evolve_with_condition_error(self):
+        with Capturing() as output:
+            self._test_grid_evolve_with_condition_error()
+
+    def _test_grid_evolve_with_condition_error(self):
+        """
+        Unittests to see if the threads catch the errors correctly.
+        """
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
+        )
+
+        # Set the amt of failed systems that each thread will log
+        test_pop.set(failed_systems_threshold=4)
+
+        CUSTOM_LOGGING_STRING_WITH_EXIT = """
+Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry");
+Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
+    //
+    stardata->model.time, // 1
+
+    // masses
+    stardata->common.zero_age.mass[0], //
+    stardata->common.zero_age.mass[1], //
+
+    stardata->star[0].mass,
+    stardata->star[1].mass
+);
+        """
+
+        test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)
+
+        resolution = {"M_1": 10}
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop.evolve()
+        self.assertLess(
+            np.abs(analytics["total_probability"] - 0.10820655287892997),
+            1e-10,
+            msg=analytics["total_probability"],
+        )  #
+        self.assertEqual(analytics["failed_systems_error_codes"], [0])
+        self.assertTrue(analytics["total_count"] == 10)
+        self.assertTrue(analytics["failed_count"] == 10)
+        self.assertTrue(analytics["errors_found"] == True)
+        self.assertTrue(analytics["errors_exceeded"] == True)
+
+        # test to see if 1 thread does all the systems
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
+        )
+        test_pop.set(failed_systems_threshold=4)
+        test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)
+
+        resolution = {"M_1": 10, "q": 2}
+
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        test_pop.add_grid_variable(
+            name="q",
+            longname="Mass ratio",
+            valuerange=["0.1/M_1", 1],
+            samplerfunc="self.const_linear(0.1/M_1, 1, {})".format(resolution["q"]),
+            probdist="self.flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
+            dphasevol="dq",
+            precode="M_2 = q * M_1",
+            parameter_name="M_2",
+            # condition="M_1 in dir()",  # Impose a condition on this grid variable. Mostly for a check for yourself
+            condition="'random_var' in dir()",  # This will raise an error because random_var is not defined.
+        )
+
+        # TODO: why should it raise this error? It should probably raise a valueerror when the limit is exceeded right?
+        # DEcided to turn it off for now because there is not raise VAlueError in that chain of functions.
+        # NOTE: Found out why this test was here. It is to do with the condition random_var in dir(), but I changed the behaviour from raising an error to continue. This has to do with the moe&distefano code that will loop over several multiplicities
+        # TODO: make sure the continue behaviour is what we actually want.
+
+        # self.assertRaises(ValueError, test_pop.evolve)
+
+    def test_grid_evolve_no_grid_variables(self):
+        with Capturing() as output:
+            self._test_grid_evolve_no_grid_variables()
+
+    def _test_grid_evolve_no_grid_variables(self):
+        """
+        Unittests to see if errors are raised if there are no grid variables
+        """
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
+        )
+
+        resolution = {"M_1": 10}
+        self.assertRaises(ValueError, test_pop.evolve)
+
+    def test_grid_evolve_2_threads_with_ensemble_direct_output(self):
+        with Capturing() as output:
+            self._test_grid_evolve_2_threads_with_ensemble_direct_output()
+
+    def _test_grid_evolve_2_threads_with_ensemble_direct_output(self):
+        """
+        Unittests to see if multiple threads output the ensemble information to files correctly
+        """
+
+        data_dir_value = TMP_DIR
+        num_cores_value = 2
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=num_cores_value,
+            verbosity=TEST_VERBOSITY,
+            M_2=1,
+            orbital_period=100000,
+            ensemble=1,
+            ensemble_defer=1,
+            ensemble_filters_off=1,
+            ensemble_filter_STELLAR_TYPE_COUNTS=1,
+            ensemble_dt=1000,
+        )
+        test_pop.set(
+            data_dir=TMP_DIR,
+            ensemble_output_name="ensemble_output.json",
+            combine_ensemble_with_thread_joining=False,
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop.evolve()
+        output_names = [
+            os.path.join(
+                data_dir_value,
+                "ensemble_output_{}_{}.json".format(
+                    analytics["population_id"], thread_id
+                ),
+            )
+            for thread_id in range(num_cores_value)
+        ]
+
+        for output_name in output_names:
+            self.assertTrue(os.path.isfile(output_name))
+
+            with open(output_name, "r") as f:
+                file_content = f.read()
+
+                ensemble_json = json.loads(file_content)
+
+                self.assertTrue(isinstance(ensemble_json, dict))
+                self.assertNotEqual(ensemble_json, {})
+
+                self.assertIn("number_counts", ensemble_json)
+                self.assertNotEqual(ensemble_json["number_counts"], {})
+
+    def test_grid_evolve_2_threads_with_ensemble_combining(self):
+        with Capturing() as output:
+            self._test_grid_evolve_2_threads_with_ensemble_combining()
+
+    def _test_grid_evolve_2_threads_with_ensemble_combining(self):
+        """
+        Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid
+        """
+
+        data_dir_value = TMP_DIR
+        num_cores_value = 2
+
+        test_pop = Population()
+        test_pop.set(
+            num_cores=num_cores_value,
+            verbosity=TEST_VERBOSITY,
+            M_2=1,
+            orbital_period=100000,
+            ensemble=1,
+            ensemble_defer=1,
+            ensemble_filters_off=1,
+            ensemble_filter_STELLAR_TYPE_COUNTS=1,
+            ensemble_dt=1000,
+        )
+        test_pop.set(
+            data_dir=TMP_DIR,
+            combine_ensemble_with_thread_joining=True,
+            ensemble_output_name="ensemble_output.json",
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics = test_pop.evolve()
+
+        self.assertTrue(isinstance(test_pop.grid_ensemble_results["ensemble"], dict))
+        self.assertNotEqual(test_pop.grid_ensemble_results["ensemble"], {})
+
+        self.assertIn("number_counts", test_pop.grid_ensemble_results["ensemble"])
+        self.assertNotEqual(
+            test_pop.grid_ensemble_results["ensemble"]["number_counts"], {}
+        )
+
+    def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
+        with Capturing() as output:
+            self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods()
+
+    def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
+        """
+        Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same
+        """
+
+        data_dir_value = TMP_DIR
+        num_cores_value = 2
+
+        # First
+        test_pop_1 = Population()
+        test_pop_1.set(
+            num_cores=num_cores_value,
+            verbosity=TEST_VERBOSITY,
+            M_2=1,
+            orbital_period=100000,
+            ensemble=1,
+            ensemble_defer=1,
+            ensemble_filters_off=1,
+            ensemble_filter_STELLAR_TYPE_COUNTS=1,
+            ensemble_dt=1000,
+        )
+        test_pop_1.set(
+            data_dir=TMP_DIR,
+            combine_ensemble_with_thread_joining=True,
+            ensemble_output_name="ensemble_output.json",
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop_1.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics_1 = test_pop_1.evolve()
+        ensemble_output_1 = test_pop_1.grid_ensemble_results
+
+        # second
+        test_pop_2 = Population()
+        test_pop_2.set(
+            num_cores=num_cores_value,
+            verbosity=TEST_VERBOSITY,
+            M_2=1,
+            orbital_period=100000,
+            ensemble=1,
+            ensemble_defer=1,
+            ensemble_filters_off=1,
+            ensemble_filter_STELLAR_TYPE_COUNTS=1,
+            ensemble_dt=1000,
+        )
+        test_pop_2.set(
+            data_dir=TMP_DIR,
+            ensemble_output_name="ensemble_output.json",
+            combine_ensemble_with_thread_joining=False,
+        )
+
+        resolution = {"M_1": 10}
+
+        test_pop_2.add_grid_variable(
+            name="lnm1",
+            longname="Primary mass",
+            valuerange=[1, 100],
+            samplerfunc="self.const_linear(math.log(1), math.log(100), {})".format(
+                resolution["M_1"]
+            ),
+            precode="M_1=math.exp(lnm1)",
+            probdist="self.three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
+            dphasevol="dlnm1",
+            parameter_name="M_1",
+            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
+        )
+
+        analytics_2 = test_pop_2.evolve()
+        output_names_2 = [
+            os.path.join(
+                data_dir_value,
+                "ensemble_output_{}_{}.json".format(
+                    analytics_2["population_id"], thread_id
+                ),
+            )
+            for thread_id in range(num_cores_value)
+        ]
+        ensemble_output_2 = {}
+
+        for output_name in output_names_2:
+            self.assertTrue(os.path.isfile(output_name))
+
+            with open(output_name, "r") as f:
+                file_content = f.read()
+
+                ensemble_json = json.loads(file_content)
+
+                ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json)
+
+        for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]:
+            self.assertIn(key, ensemble_output_2["number_counts"]["stellar_type"]["0"])
+
+            # compare values
+            self.assertLess(
+                np.abs(
+                    ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"][
+                        key
+                    ]
+                    - ensemble_output_2["number_counts"]["stellar_type"]["0"][key]
+                ),
+                1e-8,
+            )
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py
index eee25c39e..611af1606 100644
--- a/binarycpython/utils/dicts.py
+++ b/binarycpython/utils/dicts.py
@@ -9,8 +9,9 @@ import astropy.units as u
 import numpy as np
 
 # Define all numerical types
-ALLOWED_NUMERICAL_TYPES = Union[int, float, complex, np.number]
 
+ALLOWED_NUMERICAL_TYPES = (int, float, complex, np.number)
+UNION_ALLOWED_NUMERICAL_TYPES = Union[int, float, complex, np.number]
 
 def keys_to_floats(input_dict: dict) -> dict:
     """
@@ -483,8 +484,8 @@ def merge_dicts(dict_1: dict, dict_2: dict) -> dict:
         # If they keys are not the same, it depends on their type whether we still deal with them at all, or just raise an error
         if not type(dict_1[key]) is type(dict_2[key]):
             # Exceptions: numbers can be added
-            if isinstance(dict_1[key], (int, float, np.float64)) and isinstance(
-                dict_2[key], (int, float, np.float64)
+            if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES) and isinstance(
+                dict_2[key], ALLOWED_NUMERICAL_TYPES
             ):
                 new_dict[key] = dict_1[key] + dict_2[key]
 
@@ -636,9 +637,7 @@ def update_dicts(dict_1: dict, dict_2: dict) -> dict:
         # See whether the types are actually the same
         if not type(dict_1[key]) is type(dict_2[key]):
             # Exceptions:
-            if (type(dict_1[key]) in [int, float]) and (
-                type(dict_2[key]) in [int, float]
-            ):
+            if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES) and isinstance(dict_2[key], ALLOWED_NUMERICAL_TYPES):
                 new_dict[key] = dict_2[key]
 
             else:
@@ -667,7 +666,7 @@ def update_dicts(dict_1: dict, dict_2: dict) -> dict:
     return new_dict
 
 
-def multiply_values_dict(input_dict: dict, factor: ALLOWED_NUMERICAL_TYPES):
+def multiply_values_dict(input_dict: dict, factor: UNION_ALLOWED_NUMERICAL_TYPES):
     """
     Function that goes over dictionary recursively and multiplies the value if possible by a factor
 
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index 2f9adb08d..df3340849 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -1538,13 +1538,13 @@ class Population(
             ######################
             # Print status of runs
             # save the current time (used often)
-            now = time.time()
+            time_now = time.time()
 
             # update memory use stats every log_dt seconds (not every time, this is likely a bit expensive)
-            if now > next_mem_update_time:
+            if time_now > next_mem_update_time:
                 m = mem_use()
                 self.shared_memory["memory_use_per_thread"][ID] = m
-                next_mem_update_time = now + self.grid_options["log_dt"]
+                next_mem_update_time = time_now + self.grid_options["log_dt"]
                 if m > self.shared_memory["max_memory_use_per_thread"][ID]:
                     self.shared_memory["max_memory_use_per_thread"][ID] = m
 
@@ -1555,16 +1555,16 @@ class Population(
 
             # Check if we need to log info again
             # TODO: Check if we can put this functionality elsewhere
-            if now > next_log_time:
+            if time_now > next_log_time:
                 # we have exceeded the next log time : output and update timers
                 # Lock the threads. TODO: Do we need to release this?
                 lock = multiprocessing.Lock()
 
                 # Do the printing itself
-                self.vb1print(ID, now, system_number, system_dict)
+                self.vb1print(ID, time_now, system_number, system_dict)
 
                 # Set some values for next time
-                next_log_time = now + self.grid_options["log_dt"]
+                next_log_time = time_now + self.grid_options["log_dt"]
 
                 # print("PREV ",self.shared_memory["prev_log_time"])
                 # print("N LOG STATS",self.shared_memory["n_saved_log_stats"].value)
@@ -1582,7 +1582,7 @@ class Population(
                 ]
 
                 # set the current time and system number
-                self.shared_memory["prev_log_time"][0] = now
+                self.shared_memory["prev_log_time"][0] = time_now
                 self.shared_memory["prev_log_system_number"][0] = system_number
 
                 # increase the number of stats
diff --git a/binarycpython/utils/population_extensions/gridcode.py b/binarycpython/utils/population_extensions/gridcode.py
index 93f9cc8a4..cfbc301cd 100644
--- a/binarycpython/utils/population_extensions/gridcode.py
+++ b/binarycpython/utils/population_extensions/gridcode.py
@@ -107,8 +107,6 @@ class gridcode:
             "import math\n",
             "import numpy as np\n",
             "from collections import OrderedDict\n",
-            "from binarycpython.utils.distribution_functions import *\n",
-            "from binarycpython.utils.spacing_functions import *\n",
             "from binarycpython.utils.useful_funcs import *\n",
             "import numba" if _numba else "",
             "\n\n",
@@ -1002,7 +1000,7 @@ class gridcode:
                 This is evaluated as a parameter and you can use it throughout
                 the rest of the function
 
-                Examples:
+                Examples::
                     name = 'lnM_1'
 
             parameter_name:
@@ -1020,45 +1018,54 @@ class gridcode:
 
                 Examples:
                     longname = 'Primary mass'
+
             range:
                 Range of values to take. Does not get used really, the samplerfunc is used to
                 get the values from
 
-                Examples:
+                Examples::
                     range = [math.log(m_min), math.log(m_max)]
+
             samplerfunc:
                 Function returning a list or numpy array of samples spaced appropriately.
                 You can either use a real function, or a string representation of a function call.
 
-                Examples:
+                Examples::
                     samplerfunc = "self.const_linear(math.log(m_min), math.log(m_max), {})".format(resolution['M_1'])
 
             precode:
                 Extra room for some code. This code will be evaluated within the loop of the
                 sampling function (i.e. a value for lnM_1 is chosen already)
 
-                Examples:
+                Examples::
                     precode = 'M_1=math.exp(lnM_1);'
+
             postcode:
                 Code executed after the probability is calculated.
+
             probdist:
                 Function determining the probability that gets assigned to the sampled parameter
 
                 Examples:
-                    probdist = 'Kroupa2001(M_1)*M_1'
+                    probdist = 'self.Kroupa2001(M_1)*M_1'
+
             dphasevol:
                 part of the parameter space that the total probability is calculated with. Put to -1
                 if you want to ignore any dphasevol calculations and set the value to 1
-                Examples:
+
+                Examples::"
                     dphasevol = 'dlnM_1'
+
             condition:
                 condition that has to be met in order for the grid generation to continue
-                Examples:
+
+                Examples::
                     condition = 'self.grid_options['binary']==1'
+
             gridtype:
                 Method on how the value range is sampled. Can be either 'edge' (steps starting at
                 the lower edge of the value range) or 'centred'
-                (steps starting at lower edge + 0.5 * stepsize).
+                (steps starting at ``lower edge + 0.5 * stepsize``).
 
             dry_parallel:
                 If True, try to parallelize this variable in dry runs.
-- 
GitLab