"""
Unit tests for the grid module

TODO: jobID
TODO: exit
TODO: _set_nprocesses
TODO: _pre_run_setup
TODO: clean
TODO: evolve
TODO: _evolve_population
TODO: _system_queue_filler
TODO: _evolve_population_grid
TODO: _evolve_system_mp
TODO: _parent_signal_handler
TODO: _child_signal_handler
TODO: _process_run_population_grid
TODO: _cleanup
TODO: _dry_run
TODO: _dry_run_source_file
TODO: _load_source_file
TODO: was_killed
TODO: _check_binary_c_error

TODO: Before running the non-unit tests to cover functions like evolve, we need to run the unit tests
"""

import os
import sys
import json
import unittest

from binarycpython.utils.functions import (
    temp_dir,
    Capturing,
)

from binarycpython.utils.grid import Population

TMP_DIR = temp_dir("tests", "test_grid")
TEST_VERBOSITY = 1


def parse_function_test_grid_evolve_2_threads_with_custom_logging(self, output):
    """
    Simple parse function that directly appends all the output to a file
    """

    # Get some information from the
    data_dir = self.custom_options["data_dir"]

    # make outputfilename
    output_filename = os.path.join(
        data_dir,
        "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(
            self.grid_options["_population_id"], self.process_ID
        ),
    )

    # Check directory, make if necessary
    os.makedirs(data_dir, exist_ok=True)

    if not os.path.exists(output_filename):
        with open(output_filename, "w") as first_f:
            first_f.write(output + "\n")
    else:
        with open(output_filename, "a") as first_f:
            first_f.write(output + "\n")


def parse_function_adding_results(self, output):
    """
    Example parse function
    """

    seperator = " "

    parameters = ["time", "mass", "zams_mass", "probability", "stellar_type"]

    self.grid_results["example"]["count"] += 1

    # Go over the output.
    for line in output.splitlines():
        headerline = line.split()[0]

        # CHeck the header and act accordingly
        if headerline == "EXAMPLE_OUTPUT":
            values = line.split()[1:]

            # Bin the mass probability
            self.grid_results["example"]["mass"][
                bin_data(float(values[2]), binwidth=0.5)
            ] += float(values[3])

            #
            if not len(parameters) == len(values):
                print("Number of column names isnt equal to number of columns")
                raise ValueError

    # record the probability of this line (Beware, this is meant to only be run once for each system. its a controls quantity)
    self.grid_results["example"]["probability"] += float(values[3])


class test__setup(unittest.TestCase):
    """
    Unittests for _setup function 
    """

    def test_setup(self):
        with Capturing() as output:
            self._test_setup()

    def _test_setup(self):
        """
        Unittests for function _setup
        """
        test_pop = Population()

        self.assertTrue("orbital_period" in test_pop.defaults)
        self.assertTrue("metallicity" in test_pop.defaults)
        self.assertNotIn("help_all", test_pop.cleaned_up_defaults)
        self.assertEqual(test_pop.bse_options, {})
        self.assertEqual(test_pop.custom_options, {})
        self.assertEqual(test_pop.argline_dict, {})
        self.assertEqual(test_pop.persistent_data_memory_dict, {})
        self.assertTrue(test_pop.grid_options["parse_function"] == None)
        self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int))


class test_set(unittest.TestCase):
    """
    Unittests for _setup function 
    """

    def test_set(self):
        with Capturing() as output:
            self._test_set()

    def _test_set(self):
        """
        Unittests for function set
        """

        test_pop = Population()
        test_pop.set(num_cores=2, verbosity=TEST_VERBOSITY)
        test_pop.set(M_1=10)
        test_pop.set(data_dir="/tmp/binary_c_python")
        test_pop.set(ensemble_filter_SUPERNOVAE=1, ensemble_dt=1000)

        self.assertIn("data_dir", test_pop.custom_options)
        self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python")

        #
        self.assertTrue(test_pop.bse_options["M_1"] == 10)
        self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1)

        #
        self.assertTrue(test_pop.grid_options["num_cores"] == 2)


class test_cmdline(unittest.TestCase):
    """
    Unittests for cmdline function 
    """

    def test_cmdline(self):
        with Capturing() as output:
            self._test_cmdline()

    def _test_cmdline(self):
        """
        Unittests for function parse_cmdline
        """

        # copy old sys.argv values
        prev_sysargv = sys.argv.copy()

        # make a dummy cmdline arg input
        sys.argv = [
            "script",
            "metallicity=0.0002",
            "num_cores=2",
            "data_dir=/tmp/binary_c_python",
        ]

        # Set up population
        test_pop = Population()
        test_pop.set(data_dir="/tmp", verbosity=TEST_VERBOSITY)

        # parse arguments
        test_pop.parse_cmdline()

        # metallicity
        self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str))
        self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002")

        # Amt cores
        self.assertTrue(isinstance(test_pop.grid_options["num_cores"], int))
        self.assertTrue(test_pop.grid_options["num_cores"] == 2)

        # datadir
        self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str))
        self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python")

        # put back the other args if they exist
        sys.argv = prev_sysargv.copy()


class test__return_argline(unittest.TestCase):
    """
    Unittests for _return_argline function 
    """

    def test__return_argline(self):
        with Capturing() as output:
            self._test__return_argline()

    def _test__return_argline(self):
        """
        Unittests for the function _return_argline
        """

        # Set up population
        test_pop = Population()
        test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY)
        test_pop.set(M_1=10)

        argline = test_pop._return_argline()
        self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02")

        # custom dict
        argline2 = test_pop._return_argline(
            {"example_parameter1": 10, "example_parameter2": "hello"}
        )
        self.assertTrue(
            argline2 == "binary_c example_parameter1 10 example_parameter2 hello"
        )


class test_return_population_settings(unittest.TestCase):
    """
    Unittests for return_population_settings function 
    """

    def test_return_population_settings(self):
        with Capturing() as output:
            self._test_return_population_settings()

    def _test_return_population_settings(self):
        """
        Unittests for the function return_population_settings
        """

        test_pop = Population()
        test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY)
        test_pop.set(M_1=10)
        test_pop.set(num_cores=2)
        test_pop.set(data_dir="/tmp")

        population_settings = test_pop.return_population_settings()

        self.assertIn("bse_options", population_settings)
        self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02)
        self.assertTrue(population_settings["bse_options"]["M_1"] == 10)

        self.assertIn("grid_options", population_settings)
        self.assertTrue(population_settings["grid_options"]["num_cores"] == 2)

        self.assertIn("custom_options", population_settings)
        self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp")


class test_return_binary_c_defaults(unittest.TestCase):
    """
    Unittests for return_binary_c_defaults function 
    """

    def test_return_binary_c_defaults(self):
        with Capturing() as output:
            self._test_return_binary_c_defaults()

    def _test_return_binary_c_defaults(self):
        """
        Unittests for the function return_binary_c_defaults
        """

        test_pop = Population()
        binary_c_defaults = test_pop.return_binary_c_defaults()
        self.assertIn("probability", binary_c_defaults)
        self.assertIn("phasevol", binary_c_defaults)
        self.assertIn("metallicity", binary_c_defaults)


class test_return_all_info(unittest.TestCase):
    """
    Unittests for return_all_info function 
    """

    def test_return_all_info(self):
        with Capturing() as output:
            self._test_return_all_info()

    def _test_return_all_info(self):
        """
        Unittests for the function return_all_info
        Not going to do too much tests here, just check if they are not empty
        """

        test_pop = Population()
        all_info = test_pop.return_all_info()

        self.assertIn("population_settings", all_info)
        self.assertIn("binary_c_defaults", all_info)
        self.assertIn("binary_c_version_info", all_info)
        self.assertIn("binary_c_help_all", all_info)

        self.assertNotEqual(all_info["population_settings"], {})
        self.assertNotEqual(all_info["binary_c_defaults"], {})
        self.assertNotEqual(all_info["binary_c_version_info"], {})
        self.assertNotEqual(all_info["binary_c_help_all"], {})


class test_export_all_info(unittest.TestCase):
    """
    Unittests for export_all_info function 
    """

    def test_export_all_info(self):
        with Capturing() as output:
            self._test_export_all_info()

    def _test_export_all_info(self):
        """
        Unittests for the function export_all_info
        """

        test_pop = Population()

        test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY)
        test_pop.set(M_1=10)
        test_pop.set(num_cores=2)
        test_pop.set(data_dir=TMP_DIR)

        # datadir
        settings_filename = test_pop.export_all_info(use_datadir=True)
        self.assertTrue(os.path.isfile(settings_filename))
        with open(settings_filename, "r") as f:
            all_info = json.loads(f.read())

        #
        self.assertIn("population_settings", all_info)
        self.assertIn("binary_c_defaults", all_info)
        self.assertIn("binary_c_version_info", all_info)
        self.assertIn("binary_c_help_all", all_info)

        #
        self.assertNotEqual(all_info["population_settings"], {})
        self.assertNotEqual(all_info["binary_c_defaults"], {})
        self.assertNotEqual(all_info["binary_c_version_info"], {})
        self.assertNotEqual(all_info["binary_c_help_all"], {})

        # custom name
        # datadir
        settings_filename = test_pop.export_all_info(
            use_datadir=False,
            outfile=os.path.join(TMP_DIR, "example_settings.json"),
        )
        self.assertTrue(os.path.isfile(settings_filename))
        with open(settings_filename, "r") as f:
            all_info = json.loads(f.read())

        #
        self.assertIn("population_settings", all_info)
        self.assertIn("binary_c_defaults", all_info)
        self.assertIn("binary_c_version_info", all_info)
        self.assertIn("binary_c_help_all", all_info)

        #
        self.assertNotEqual(all_info["population_settings"], {})
        self.assertNotEqual(all_info["binary_c_defaults"], {})
        self.assertNotEqual(all_info["binary_c_version_info"], {})
        self.assertNotEqual(all_info["binary_c_help_all"], {})

        # wrong filename
        self.assertRaises(
            ValueError,
            test_pop.export_all_info,
            use_datadir=False,
            outfile=os.path.join(TMP_DIR, "example_settings.txt"),
        )


class test__cleanup_defaults(unittest.TestCase):
    """
    Unittests for _cleanup_defaults function 
    """

    def test__cleanup_defaults(self):
        with Capturing() as output:
            self._test__cleanup_defaults()

    def _test__cleanup_defaults(self):
        """
        Unittests for the function _cleanup_defaults
        """

        test_pop = Population()
        cleaned_up_defaults = test_pop._cleanup_defaults()
        self.assertNotIn("help_all", cleaned_up_defaults)


class test__increment_probtot(unittest.TestCase):
    """
    Unittests for _increment_probtot function 
    """

    def test__increment_probtot(self):
        with Capturing() as output:
            self._test__increment_probtot()

    def _test__increment_probtot(self):
        """
        Unittests for the function _increment_probtot
        """

        test_pop = Population()
        test_pop._increment_probtot(0.5)
        self.assertEqual(test_pop.grid_options["_probtot"], 0.5)


class test__increment_count(unittest.TestCase):
    """
    Unittests for _increment_count function 
    """

    def test__increment_count(self):
        with Capturing() as output:
            self._test__increment_count()

    def _test__increment_count(self):
        """
        Unittests for the function _increment_count
        """

        test_pop = Population()
        test_pop._increment_count()
        self.assertEqual(test_pop.grid_options["_count"], 1)


class test__dict_from_line_source_file(unittest.TestCase):
    """
    Unittests for _dict_from_line_source_file function 
    """

    def test__dict_from_line_source_file(self):
        with Capturing() as output:
            self._test__dict_from_line_source_file()

    def _test__dict_from_line_source_file(self):
        """
        Unittests for the function _dict_from_line_source_file
        """

        source_file = os.path.join(TMP_DIR, "example_source_file.txt")

        # write
        with open(source_file, "w") as f:
            f.write("binary_c M_1 10 metallicity 0.02\n")

        test_pop = Population()

        # readout
        with open(source_file, "r") as f:
            for line in f.readlines():
                argdict = test_pop._dict_from_line_source_file(line)

                self.assertTrue(argdict["M_1"] == 10)
                self.assertTrue(argdict["metallicity"] == 0.02)


class test_evolve_single(unittest.TestCase):
    """
    Unittests for evolve_single function 
    """

    def test_evolve_single(self):
        with Capturing() as output:
            self._test_evolve_single()

    def _test_evolve_single(self):
        """
        Unittests for the function evolve_single
        """

        CUSTOM_LOGGING_STRING_MASSES = """
        Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
            //
            stardata->model.time, // 1

            // masses
            stardata->common.zero_age.mass[0], //
            stardata->common.zero_age.mass[1], //

            stardata->star[0].mass,
            stardata->star[1].mass
            );
        """

        test_pop = Population()
        test_pop.set(
            M_1=10,
            M_2=5,
            orbital_period=100000,
            metallicty=0.02,
            max_evolution_time=15000,
            verbosity=TEST_VERBOSITY,
        )

        test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES)

        output = test_pop.evolve_single()

        #
        self.assertTrue(len(output.splitlines()) > 1)
        self.assertIn("TEST_CUSTOM_LOGGING_1", output)

        #
        custom_logging_dict = {"TEST_CUSTOM_LOGGING_2": ["star[0].mass", "model.time"]}
        test_pop_2 = Population()
        test_pop_2.set(
            M_1=10,
            M_2=5,
            orbital_period=100000,
            metallicty=0.02,
            max_evolution_time=15000,
            verbosity=TEST_VERBOSITY,
        )

        test_pop_2.set(C_auto_logging=custom_logging_dict)

        output_2 = test_pop_2.evolve_single()

        #
        self.assertTrue(len(output_2.splitlines()) > 1)
        self.assertIn("TEST_CUSTOM_LOGGING_2", output_2)






########
# Some tests that are not really -unit- tests
class test_resultdict(unittest.TestCase):
    """
    Unittests for bin_data
    """

    def test_adding_results(self):
        """
        Function to test whether the results are properly added and combined
        """

        # Create custom logging statement
        custom_logging_statement = """
        if (stardata->model.time < stardata->model.max_evolution_time)
        {
            Printf("EXAMPLE_OUTPUT %30.16e %g %g %30.12e %d\\n",
                //
                stardata->model.time, // 1
                stardata->star[0].mass, // 2
                stardata->common.zero_age.mass[0], // 3
                stardata->model.probability, // 4
                stardata->star[0].stellar_type // 5
          );
        };
        /* Kill the simulation to save time */
        stardata->model.max_evolution_time = stardata->model.time - stardata->model.dtm;
        """

        example_pop = Population()
        example_pop.set(verbosity=0)
        example_pop.set(
            max_evolution_time=15000,  # bse_options
            # grid_options
            num_cores=3,
            tmp_dir=TMP_DIR,
            # Custom options
            data_dir=os.path.join(TMP_DIR, "test_resultdict"),  # custom_options
            C_logging_code=custom_logging_statement,
            parse_function=parse_function_adding_results,
        )

        # Add grid variables
        resolution = {"M_1": 10}

        # Mass
        example_pop.add_grid_variable(
            name="lnm1",
            longname="Primary mass",
            valuerange=[2, 150],
            samplerfunc="const(math.log(2), math.log(150), {})".format(
                resolution["M_1"]
            ),
            precode="M_1=math.exp(lnm1)",
            probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1",
            dphasevol="dlnm1",
            parameter_name="M_1",
            condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
        )

        ## Executing a population
        ## This uses the values generated by the grid_variables
        analytics = example_pop.evolve()

        #
        grid_prob = analytics["total_probability"]
        result_dict_prob = example_pop.grid_results["example"]["probability"]

        # amt systems
        grid_count = analytics["total_count"]
        result_dict_count = example_pop.grid_results["example"]["count"]

        # Check if the total probability matches
        self.assertAlmostEqual(
            grid_prob,
            result_dict_prob,
            places=12,
            msg="Total probability from grid {} and from result dict {} are not equal".format(
                grid_prob, result_dict_prob
            ),
        )

        # Check if the total count matches
        self.assertEqual(
            grid_count,
            result_dict_count,
            msg="Total count from grid {} and from result dict {} are not equal".format(
                grid_count, result_dict_count
            ),
        )

        # Check if the structure is what we expect. Note: this depends on the probability calculation. if that changes we need to recalibrate this
        test_case_dict = {
            2.25: 0.01895481306515,
            3.75: 0.01081338190204,
            5.75: 0.006168841009268,
            9.25: 0.003519213484031,
            13.75: 0.002007648361756,
            21.25: 0.001145327489437,
            33.25: 0.0006533888518775,
            50.75: 0.0003727466560393,
            78.25: 0.000212645301782,
            120.75: 0.0001213103421247,
        }

        self.assertEqual(
            test_case_dict, dict(example_pop.grid_results["example"]["mass"])
        )




# class test_grid_evolve(unittest.TestCase):
#     """
#     Unittests for function Population.evolve()
#     """

#     def test_grid_evolve_1_thread(self):
#         with Capturing() as output:
#             self._test_grid_evolve_1_thread()

#     def _test_grid_evolve_1_thread(self):
#         """
#         Unittests to see if 1 thread does all the systems
#         """

#         test_pop_evolve_1_thread = Population()
#         test_pop_evolve_1_thread.set(
#             num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
#         )

#         resolution = {"M_1": 10}

#         test_pop_evolve_1_thread.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop_evolve_1_thread.evolve()
#         self.assertLess(
#             np.abs(analytics["total_probability"] - 0.10820655287892997),
#             1e-10,
#             msg=analytics["total_probability"],
#         )
#         self.assertTrue(analytics["total_count"] == 10)

#     def test_grid_evolve_2_threads(self):
#         with Capturing() as output:
#             self._test_grid_evolve_2_threads()

#     def _test_grid_evolve_2_threads(self):
#         """
#         Unittests to see if multiple threads handle the all the systems correctly
#         """

#         test_pop = Population()
#         test_pop.set(
#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
#         )

#         resolution = {"M_1": 10}

#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop.evolve()
#         self.assertLess(
#             np.abs(analytics["total_probability"] - 0.10820655287892997),
#             1e-10,
#             msg=analytics["total_probability"],
#         )  #
#         self.assertTrue(analytics["total_count"] == 10)

#     def test_grid_evolve_2_threads_with_custom_logging(self):
#         with Capturing() as output:
#             self._test_grid_evolve_2_threads_with_custom_logging()

#     def _test_grid_evolve_2_threads_with_custom_logging(self):
#         """
#         Unittests to see if multiple threads do the custom logging correctly
#         """

#         data_dir_value = os.path.join(TMP_DIR, "grid_tests")
#         num_cores_value = 2
#         custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));'

#         test_pop = Population()

#         test_pop.set(
#             num_cores=num_cores_value,
#             verbosity=TEST_VERBOSITY,
#             M_2=1,
#             orbital_period=100000,
#             data_dir=data_dir_value,
#             C_logging_code=custom_logging_string,  # input it like this.
#             parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging,
#         )
#         test_pop.set(ensemble=0)
#         resolution = {"M_1": 2}

#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop.evolve()
#         output_names = [
#             os.path.join(
#                 data_dir_value,
#                 "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format(
#                     analytics["population_name"], thread_id
#                 ),
#             )
#             for thread_id in range(num_cores_value)
#         ]

#         for output_name in output_names:
#             self.assertTrue(os.path.isfile(output_name))

#             with open(output_name, "r") as f:
#                 output_string = f.read()

#             self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string)

#             remove_file(output_name)

#     def test_grid_evolve_with_condition_error(self):
#         with Capturing() as output:
#             self._test_grid_evolve_with_condition_error()

#     def _test_grid_evolve_with_condition_error(self):
#         """
#         Unittests to see if the threads catch the errors correctly.
#         """

#         test_pop = Population()
#         test_pop.set(
#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
#         )

#         # Set the amt of failed systems that each thread will log
#         test_pop.set(failed_systems_threshold=4)

#         CUSTOM_LOGGING_STRING_WITH_EXIT = """
# Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry");
# Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
#     //
#     stardata->model.time, // 1

#     // masses
#     stardata->common.zero_age.mass[0], //
#     stardata->common.zero_age.mass[1], //

#     stardata->star[0].mass,
#     stardata->star[1].mass
# );
#         """

#         test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)

#         resolution = {"M_1": 10}
#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop.evolve()
#         self.assertLess(
#             np.abs(analytics["total_probability"] - 0.10820655287892997),
#             1e-10,
#             msg=analytics["total_probability"],
#         )  #
#         self.assertEqual(analytics["failed_systems_error_codes"], [0])
#         self.assertTrue(analytics["total_count"] == 10)
#         self.assertTrue(analytics["failed_count"] == 10)
#         self.assertTrue(analytics["errors_found"] == True)
#         self.assertTrue(analytics["errors_exceeded"] == True)

#         # test to see if 1 thread does all the systems

#         test_pop = Population()
#         test_pop.set(
#             num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
#         )
#         test_pop.set(failed_systems_threshold=4)
#         test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT)

#         resolution = {"M_1": 10, "q": 2}

#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         test_pop.add_grid_variable(
#             name="q",
#             longname="Mass ratio",
#             valuerange=["0.1/M_1", 1],
#             samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]),
#             probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
#             dphasevol="dq",
#             precode="M_2 = q * M_1",
#             parameter_name="M_2",
#             # condition="M_1 in dir()",  # Impose a condition on this grid variable. Mostly for a check for yourself
#             condition="'random_var' in dir()",  # This will raise an error because random_var is not defined.
#         )

#         # TODO: why should it raise this error? It should probably raise a valueerror when the limit is exceeded right?
#         # DEcided to turn it off for now because there is not raise VAlueError in that chain of functions.
#         # NOTE: Found out why this test was here. It is to do with the condition random_var in dir(), but I changed the behaviour from raising an error to continue. This has to do with the moe&distefano code that will loop over several multiplicities
#         # TODO: make sure the continue behaviour is what we actually want.

#         # self.assertRaises(ValueError, test_pop.evolve)

#     def test_grid_evolve_no_grid_variables(self):
#         with Capturing() as output:
#             self._test_grid_evolve_no_grid_variables()

#     def _test_grid_evolve_no_grid_variables(self):
#         """
#         Unittests to see if errors are raised if there are no grid variables
#         """

#         test_pop = Population()
#         test_pop.set(
#             num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY
#         )

#         resolution = {"M_1": 10}
#         self.assertRaises(ValueError, test_pop.evolve)

#     def test_grid_evolve_2_threads_with_ensemble_direct_output(self):
#         with Capturing() as output:
#             self._test_grid_evolve_2_threads_with_ensemble_direct_output()

#     def _test_grid_evolve_2_threads_with_ensemble_direct_output(self):
#         """
#         Unittests to see if multiple threads output the ensemble information to files correctly
#         """

#         data_dir_value = TMP_DIR
#         num_cores_value = 2

#         test_pop = Population()
#         test_pop.set(
#             num_cores=num_cores_value,
#             verbosity=TEST_VERBOSITY,
#             M_2=1,
#             orbital_period=100000,
#             ensemble=1,
#             ensemble_defer=1,
#             ensemble_filters_off=1,
#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
#             ensemble_dt=1000,
#         )
#         test_pop.set(
#             data_dir=TMP_DIR,
#             ensemble_output_name="ensemble_output.json",
#             combine_ensemble_with_thread_joining=False,
#         )

#         resolution = {"M_1": 10}

#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop.evolve()
#         output_names = [
#             os.path.join(
#                 data_dir_value,
#                 "ensemble_output_{}_{}.json".format(
#                     analytics["population_name"], thread_id
#                 ),
#             )
#             for thread_id in range(num_cores_value)
#         ]

#         for output_name in output_names:
#             self.assertTrue(os.path.isfile(output_name))

#             with open(output_name, "r") as f:
#                 file_content = f.read()

#                 ensemble_json = json.loads(file_content)

#                 self.assertTrue(isinstance(ensemble_json, dict))
#                 self.assertNotEqual(ensemble_json, {})

#                 self.assertIn("number_counts", ensemble_json)
#                 self.assertNotEqual(ensemble_json["number_counts"], {})

#     def test_grid_evolve_2_threads_with_ensemble_combining(self):
#         with Capturing() as output:
#             self._test_grid_evolve_2_threads_with_ensemble_combining()

#     def _test_grid_evolve_2_threads_with_ensemble_combining(self):
#         """
#         Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid
#         """

#         data_dir_value = TMP_DIR
#         num_cores_value = 2

#         test_pop = Population()
#         test_pop.set(
#             num_cores=num_cores_value,
#             verbosity=TEST_VERBOSITY,
#             M_2=1,
#             orbital_period=100000,
#             ensemble=1,
#             ensemble_defer=1,
#             ensemble_filters_off=1,
#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
#             ensemble_dt=1000,
#         )
#         test_pop.set(
#             data_dir=TMP_DIR,
#             combine_ensemble_with_thread_joining=True,
#             ensemble_output_name="ensemble_output.json",
#         )

#         resolution = {"M_1": 10}

#         test_pop.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics = test_pop.evolve()

#         self.assertTrue(isinstance(test_pop.grid_ensemble_results["ensemble"], dict))
#         self.assertNotEqual(test_pop.grid_ensemble_results["ensemble"], {})

#         self.assertIn("number_counts", test_pop.grid_ensemble_results["ensemble"])
#         self.assertNotEqual(
#             test_pop.grid_ensemble_results["ensemble"]["number_counts"], {}
#         )

#     def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
#         with Capturing() as output:
#             self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods()

#     def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
#         """
#         Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same
#         """

#         data_dir_value = TMP_DIR
#         num_cores_value = 2

#         # First
#         test_pop_1 = Population()
#         test_pop_1.set(
#             num_cores=num_cores_value,
#             verbosity=TEST_VERBOSITY,
#             M_2=1,
#             orbital_period=100000,
#             ensemble=1,
#             ensemble_defer=1,
#             ensemble_filters_off=1,
#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
#             ensemble_dt=1000,
#         )
#         test_pop_1.set(
#             data_dir=TMP_DIR,
#             combine_ensemble_with_thread_joining=True,
#             ensemble_output_name="ensemble_output.json",
#         )

#         resolution = {"M_1": 10}

#         test_pop_1.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics_1 = test_pop_1.evolve()
#         ensemble_output_1 = test_pop_1.grid_ensemble_results

#         # second
#         test_pop_2 = Population()
#         test_pop_2.set(
#             num_cores=num_cores_value,
#             verbosity=TEST_VERBOSITY,
#             M_2=1,
#             orbital_period=100000,
#             ensemble=1,
#             ensemble_defer=1,
#             ensemble_filters_off=1,
#             ensemble_filter_STELLAR_TYPE_COUNTS=1,
#             ensemble_dt=1000,
#         )
#         test_pop_2.set(
#             data_dir=TMP_DIR,
#             ensemble_output_name="ensemble_output.json",
#             combine_ensemble_with_thread_joining=False,
#         )

#         resolution = {"M_1": 10}

#         test_pop_2.add_grid_variable(
#             name="lnm1",
#             longname="Primary mass",
#             valuerange=[1, 100],
#             samplerfunc="const(math.log(1), math.log(100), {})".format(
#                 resolution["M_1"]
#             ),
#             precode="M_1=math.exp(lnm1)",
#             probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
#             dphasevol="dlnm1",
#             parameter_name="M_1",
#             condition="",  # Impose a condition on this grid variable. Mostly for a check for yourself
#         )

#         analytics_2 = test_pop_2.evolve()
#         output_names_2 = [
#             os.path.join(
#                 data_dir_value,
#                 "ensemble_output_{}_{}.json".format(
#                     analytics_2["population_name"], thread_id
#                 ),
#             )
#             for thread_id in range(num_cores_value)
#         ]
#         ensemble_output_2 = {}

#         for output_name in output_names_2:
#             self.assertTrue(os.path.isfile(output_name))

#             with open(output_name, "r") as f:
#                 file_content = f.read()

#                 ensemble_json = json.loads(file_content)

#                 ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json)

#         for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]:
#             self.assertIn(key, ensemble_output_2["number_counts"]["stellar_type"]["0"])

#             # compare values
#             self.assertLess(
#                 np.abs(
#                     ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"][
#                         key
#                     ]
#                     - ensemble_output_2["number_counts"]["stellar_type"]["0"][key]
#                 ),
#                 1e-8,
#             )

if __name__ == "__main__":
    unittest.main()