diff --git a/binarycpython/tests/example.ipynb b/binarycpython/tests/example.ipynb deleted file mode 100644 index 57ce0148138a19e0c9cffd6872673c3f44d627d7..0000000000000000000000000000000000000000 --- a/binarycpython/tests/example.ipynb +++ /dev/null @@ -1,95 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "fd5b1a83-7212-4aca-b317-991bf289fba8", - "metadata": {}, - "outputs": [], - "source": [ - "def add(a, b):\n", - " return a + b" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4d842395-3e17-48e8-b613-9856365e9796", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "11" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "add(5, 6)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "f2afc967-a66a-4a47-bfc5-0d6c17826794", - "metadata": {}, - "outputs": [ - { - "ename": "ZeroDivisionError", - "evalue": "division by zero", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mZeroDivisionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m<ipython-input-3-bc757c3fda29>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;36m1\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mZeroDivisionError\u001b[0m: division by zero" - ] - } - ], - "source": [ - "1 / 0" - ] - }, - { - "cell_type": "markdown", - "id": "8491b29d-375d-458f-8a46-fc822422d8f3", - "metadata": {}, - "source": [ - "hello" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "601a89e6-5ca6-4725-8834-5e975ba76726", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py index 6bfb9272b5a9db3e16e0c061e667a2ca0b21e28f..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/binarycpython/tests/test_grid.py +++ b/binarycpython/tests/test_grid.py @@ -1,1184 +0,0 @@ -""" -Test cases for the grid - -Tasks: - TODO: write tests for load_from_sourcefile -""" - -import os -import sys -import json -import unittest -import numpy as np - -from binarycpython.utils.grid import Population - -from binarycpython.utils.functions import ( - temp_dir, - remove_file, - Capturing, - bin_data, -) - -from binarycpython.utils.ensemble import ( - extract_ensemble_json_from_string, -) -from binarycpython.utils.dicts import ( - merge_dicts, -) - -from binarycpython.utils.custom_logging_functions import binary_c_log_code - -TMP_DIR = temp_dir("tests", "test_grid") -TEST_VERBOSITY = 1 - - -def parse_function_test_grid_evolve_2_threads_with_custom_logging(self, output): - """ - Simple parse function that directly appends all the output to a file - """ - - # Get some information from the - data_dir = self.custom_options["data_dir"] - - # make outputfilename - output_filename = os.path.join( - data_dir, - "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format( - self.grid_options["_population_id"], self.process_ID - ), - ) - - # Check directory, make if necessary - os.makedirs(data_dir, exist_ok=True) - - if not os.path.exists(output_filename): - with open(output_filename, "w") as first_f: - first_f.write(output + "\n") - else: - with open(output_filename, "a") as first_f: - first_f.write(output + "\n") - - -# class test_(unittest.TestCase): -# """ -# Unittests for function -# """ - -# def test_1(self): -# pass - -# def test_(self): -# """ -# Unittests for the function -# """ - - -class test_Population(unittest.TestCase): - """ - Unittests for function - """ - - def test_setup(self): - with Capturing() as output: - self._test_setup() - - def _test_setup(self): - """ - Unittests for function _setup - """ - test_pop = Population() - - self.assertTrue("orbital_period" in test_pop.defaults) - self.assertTrue("metallicity" in test_pop.defaults) - self.assertNotIn("help_all", test_pop.cleaned_up_defaults) - self.assertEqual(test_pop.bse_options, {}) - self.assertEqual(test_pop.custom_options, {}) - self.assertEqual(test_pop.argline_dict, {}) - self.assertEqual(test_pop.persistent_data_memory_dict, {}) - self.assertTrue(test_pop.grid_options["parse_function"] == None) - self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) - - def test_set(self): - with Capturing() as output: - self._test_set() - - def _test_set(self): - """ - Unittests for function set - """ - - test_pop = Population() - test_pop.set(num_cores=2, verbosity=TEST_VERBOSITY) - test_pop.set(M_1=10) - test_pop.set(data_dir="/tmp/binary_c_python") - test_pop.set(ensemble_filter_SUPERNOVAE=1, ensemble_dt=1000) - - self.assertIn("data_dir", test_pop.custom_options) - self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") - - # - self.assertTrue(test_pop.bse_options["M_1"] == 10) - self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) - - # - self.assertTrue(test_pop.grid_options["num_cores"] == 2) - - def test_cmdline(self): - with Capturing() as output: - self._test_cmdline() - - def _test_cmdline(self): - """ - Unittests for function parse_cmdline - """ - - # copy old sys.argv values - prev_sysargv = sys.argv.copy() - - # make a dummy cmdline arg input - sys.argv = [ - "script", - "metallicity=0.0002", - "num_cores=2", - "data_dir=/tmp/binary_c_python", - ] - - # Set up population - test_pop = Population() - test_pop.set(data_dir="/tmp", verbosity=TEST_VERBOSITY) - - # parse arguments - test_pop.parse_cmdline() - - # metallicity - self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) - self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002") - - # Amt cores - self.assertTrue(isinstance(test_pop.grid_options["num_cores"], int)) - self.assertTrue(test_pop.grid_options["num_cores"] == 2) - - # datadir - self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) - self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") - - # put back the other args if they exist - sys.argv = prev_sysargv.copy() - - def test__return_argline(self): - with Capturing() as output: - self._test__return_argline() - - def _test__return_argline(self): - """ - Unittests for the function _return_argline - """ - - # Set up population - test_pop = Population() - test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) - test_pop.set(M_1=10) - - argline = test_pop._return_argline() - self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") - - # custom dict - argline2 = test_pop._return_argline( - {"example_parameter1": 10, "example_parameter2": "hello"} - ) - self.assertTrue( - argline2 == "binary_c example_parameter1 10 example_parameter2 hello" - ) - - def test_add_grid_variable(self): - with Capturing() as output: - self._test_add_grid_variable() - - def _test_add_grid_variable(self): - """ - Unittests for the function add_grid_variable - - TODO: Should I test more here? - """ - - test_pop = Population() - - resolution = {"M_1": 10, "q": 10} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - test_pop.add_grid_variable( - name="q", - longname="Mass ratio", - valuerange=["0.1/M_1", 1], - samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), - probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", - dphasevol="dq", - precode="M_2 = q * M_1", - parameter_name="M_2", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - self.assertIn("q", test_pop.grid_options["_grid_variables"]) - self.assertIn("lnm1", test_pop.grid_options["_grid_variables"]) - self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) - - def test_return_population_settings(self): - with Capturing() as output: - self._test_return_population_settings() - - def _test_return_population_settings(self): - """ - Unittests for the function return_population_settings - """ - - test_pop = Population() - test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) - test_pop.set(M_1=10) - test_pop.set(num_cores=2) - test_pop.set(data_dir="/tmp") - - population_settings = test_pop.return_population_settings() - - self.assertIn("bse_options", population_settings) - self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02) - self.assertTrue(population_settings["bse_options"]["M_1"] == 10) - - self.assertIn("grid_options", population_settings) - self.assertTrue(population_settings["grid_options"]["num_cores"] == 2) - - self.assertIn("custom_options", population_settings) - self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") - - def test_return_binary_c_version_info(self): - with Capturing() as output: - self._test_return_binary_c_version_info() - - def _test_return_binary_c_version_info(self): - """ - Unittests for the function return_binary_c_version_info - """ - - test_pop = Population() - binary_c_version_info = test_pop.return_binary_c_version_info(parsed=True) - - self.assertTrue(isinstance(binary_c_version_info, dict)) - self.assertIn("isotopes", binary_c_version_info) - self.assertIn("argpairs", binary_c_version_info) - self.assertIn("ensembles", binary_c_version_info) - self.assertIn("macros", binary_c_version_info) - self.assertIn("dt_limits", binary_c_version_info) - self.assertIn("nucleosynthesis_sources", binary_c_version_info) - self.assertIn("miscellaneous", binary_c_version_info) - - self.assertIsNotNone(binary_c_version_info["argpairs"]) - self.assertIsNotNone(binary_c_version_info["ensembles"]) - self.assertIsNotNone(binary_c_version_info["macros"]) - self.assertIsNotNone(binary_c_version_info["dt_limits"]) - self.assertIsNotNone(binary_c_version_info["miscellaneous"]) - - if binary_c_version_info["macros"]["NUCSYN"] == "on": - self.assertIsNotNone(binary_c_version_info["isotopes"]) - - if binary_c_version_info["macros"]["NUCSYN_ID_SOURCES"] == "on": - self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) - - def test_return_binary_c_defaults(self): - with Capturing() as output: - self._test_return_binary_c_defaults() - - def _test_return_binary_c_defaults(self): - """ - Unittests for the function return_binary_c_defaults - """ - - test_pop = Population() - binary_c_defaults = test_pop.return_binary_c_defaults() - self.assertIn("probability", binary_c_defaults) - self.assertIn("phasevol", binary_c_defaults) - self.assertIn("metallicity", binary_c_defaults) - - def test_return_all_info(self): - with Capturing() as output: - self._test_return_all_info() - - def _test_return_all_info(self): - """ - Unittests for the function return_all_info - Not going to do too much tests here, just check if they are not empty - """ - - test_pop = Population() - all_info = test_pop.return_all_info() - - self.assertIn("population_settings", all_info) - self.assertIn("binary_c_defaults", all_info) - self.assertIn("binary_c_version_info", all_info) - self.assertIn("binary_c_help_all", all_info) - - self.assertNotEqual(all_info["population_settings"], {}) - self.assertNotEqual(all_info["binary_c_defaults"], {}) - self.assertNotEqual(all_info["binary_c_version_info"], {}) - self.assertNotEqual(all_info["binary_c_help_all"], {}) - - def test_export_all_info(self): - with Capturing() as output: - self._test_export_all_info() - - def _test_export_all_info(self): - """ - Unittests for the function export_all_info - """ - - test_pop = Population() - - test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) - test_pop.set(M_1=10) - test_pop.set(num_cores=2) - test_pop.set(data_dir=TMP_DIR) - - # datadir - settings_filename = test_pop.export_all_info(use_datadir=True) - self.assertTrue(os.path.isfile(settings_filename)) - with open(settings_filename, "r") as f: - all_info = json.loads(f.read()) - - # - self.assertIn("population_settings", all_info) - self.assertIn("binary_c_defaults", all_info) - self.assertIn("binary_c_version_info", all_info) - self.assertIn("binary_c_help_all", all_info) - - # - self.assertNotEqual(all_info["population_settings"], {}) - self.assertNotEqual(all_info["binary_c_defaults"], {}) - self.assertNotEqual(all_info["binary_c_version_info"], {}) - self.assertNotEqual(all_info["binary_c_help_all"], {}) - - # custom name - # datadir - settings_filename = test_pop.export_all_info( - use_datadir=False, - outfile=os.path.join(TMP_DIR, "example_settings.json"), - ) - self.assertTrue(os.path.isfile(settings_filename)) - with open(settings_filename, "r") as f: - all_info = json.loads(f.read()) - - # - self.assertIn("population_settings", all_info) - self.assertIn("binary_c_defaults", all_info) - self.assertIn("binary_c_version_info", all_info) - self.assertIn("binary_c_help_all", all_info) - - # - self.assertNotEqual(all_info["population_settings"], {}) - self.assertNotEqual(all_info["binary_c_defaults"], {}) - self.assertNotEqual(all_info["binary_c_version_info"], {}) - self.assertNotEqual(all_info["binary_c_help_all"], {}) - - # wrong filename - self.assertRaises( - ValueError, - test_pop.export_all_info, - use_datadir=False, - outfile=os.path.join(TMP_DIR, "example_settings.txt"), - ) - - def test__cleanup_defaults(self): - with Capturing() as output: - self._test__cleanup_defaults() - - def _test__cleanup_defaults(self): - """ - Unittests for the function _cleanup_defaults - """ - - test_pop = Population() - cleaned_up_defaults = test_pop._cleanup_defaults() - self.assertNotIn("help_all", cleaned_up_defaults) - - def test__increment_probtot(self): - with Capturing() as output: - self._test__increment_probtot() - - def _test__increment_probtot(self): - """ - Unittests for the function _increment_probtot - """ - - test_pop = Population() - test_pop._increment_probtot(0.5) - self.assertEqual(test_pop.grid_options["_probtot"], 0.5) - - def test__increment_count(self): - with Capturing() as output: - self._test__increment_count() - - def _test__increment_count(self): - """ - Unittests for the function _increment_probtot - """ - - test_pop = Population() - test_pop._increment_count() - self.assertEqual(test_pop.grid_options["_count"], 1) - - def test__dict_from_line_source_file(self): - with Capturing() as output: - self._test__dict_from_line_source_file() - - def _test__dict_from_line_source_file(self): - """ - Unittests for the function _dict_from_line_source_file - """ - - source_file = os.path.join(TMP_DIR, "example_source_file.txt") - - # write - with open(source_file, "w") as f: - f.write("binary_c M_1 10 metallicity 0.02\n") - - test_pop = Population() - - # readout - with open(source_file, "r") as f: - for line in f.readlines(): - argdict = test_pop._dict_from_line_source_file(line) - - self.assertTrue(argdict["M_1"] == 10) - self.assertTrue(argdict["metallicity"] == 0.02) - - def test_evolve_single(self): - with Capturing() as output: - self._test_evolve_single() - - def _test_evolve_single(self): - """ - Unittests for the function evolve_single - """ - - CUSTOM_LOGGING_STRING_MASSES = """ - Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", - // - stardata->model.time, // 1 - - // masses - stardata->common.zero_age.mass[0], // - stardata->common.zero_age.mass[1], // - - stardata->star[0].mass, - stardata->star[1].mass - ); - """ - - test_pop = Population() - test_pop.set( - M_1=10, - M_2=5, - orbital_period=100000, - metallicty=0.02, - max_evolution_time=15000, - verbosity=TEST_VERBOSITY, - ) - - test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES) - - output = test_pop.evolve_single() - - # - self.assertTrue(len(output.splitlines()) > 1) - self.assertIn("TEST_CUSTOM_LOGGING_1", output) - - # - custom_logging_dict = {"TEST_CUSTOM_LOGGING_2": ["star[0].mass", "model.time"]} - test_pop_2 = Population() - test_pop_2.set( - M_1=10, - M_2=5, - orbital_period=100000, - metallicty=0.02, - max_evolution_time=15000, - verbosity=TEST_VERBOSITY, - ) - - test_pop_2.set(C_auto_logging=custom_logging_dict) - - output_2 = test_pop_2.evolve_single() - - # - self.assertTrue(len(output_2.splitlines()) > 1) - self.assertIn("TEST_CUSTOM_LOGGING_2", output_2) - - -class test_grid_evolve(unittest.TestCase): - """ - Unittests for function Population.evolve() - """ - - def test_grid_evolve_1_thread(self): - with Capturing() as output: - self._test_grid_evolve_1_thread() - - def _test_grid_evolve_1_thread(self): - """ - Unittests to see if 1 thread does all the systems - """ - - test_pop_evolve_1_thread = Population() - test_pop_evolve_1_thread.set( - num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY - ) - - resolution = {"M_1": 10} - - test_pop_evolve_1_thread.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop_evolve_1_thread.evolve() - self.assertLess( - np.abs(analytics["total_probability"] - 0.10820655287892997), - 1e-10, - msg=analytics["total_probability"], - ) - self.assertTrue(analytics["total_count"] == 10) - - def test_grid_evolve_2_threads(self): - with Capturing() as output: - self._test_grid_evolve_2_threads() - - def _test_grid_evolve_2_threads(self): - """ - Unittests to see if multiple threads handle the all the systems correctly - """ - - test_pop = Population() - test_pop.set( - num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY - ) - - resolution = {"M_1": 10} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop.evolve() - self.assertLess( - np.abs(analytics["total_probability"] - 0.10820655287892997), - 1e-10, - msg=analytics["total_probability"], - ) # - self.assertTrue(analytics["total_count"] == 10) - - def test_grid_evolve_2_threads_with_custom_logging(self): - with Capturing() as output: - self._test_grid_evolve_2_threads_with_custom_logging() - - def _test_grid_evolve_2_threads_with_custom_logging(self): - """ - Unittests to see if multiple threads do the custom logging correctly - """ - - data_dir_value = os.path.join(TMP_DIR, "grid_tests") - num_cores_value = 2 - custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));' - - test_pop = Population() - - test_pop.set( - num_cores=num_cores_value, - verbosity=TEST_VERBOSITY, - M_2=1, - orbital_period=100000, - data_dir=data_dir_value, - C_logging_code=custom_logging_string, # input it like this. - parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging, - ) - test_pop.set(ensemble=0) - resolution = {"M_1": 2} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop.evolve() - output_names = [ - os.path.join( - data_dir_value, - "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format( - analytics["population_name"], thread_id - ), - ) - for thread_id in range(num_cores_value) - ] - - for output_name in output_names: - self.assertTrue(os.path.isfile(output_name)) - - with open(output_name, "r") as f: - output_string = f.read() - - self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string) - - remove_file(output_name) - - def test_grid_evolve_with_condition_error(self): - with Capturing() as output: - self._test_grid_evolve_with_condition_error() - - def _test_grid_evolve_with_condition_error(self): - """ - Unittests to see if the threads catch the errors correctly. - """ - - test_pop = Population() - test_pop.set( - num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY - ) - - # Set the amt of failed systems that each thread will log - test_pop.set(failed_systems_threshold=4) - - CUSTOM_LOGGING_STRING_WITH_EXIT = """ -Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry"); -Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", - // - stardata->model.time, // 1 - - // masses - stardata->common.zero_age.mass[0], // - stardata->common.zero_age.mass[1], // - - stardata->star[0].mass, - stardata->star[1].mass -); - """ - - test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) - - resolution = {"M_1": 10} - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop.evolve() - self.assertLess( - np.abs(analytics["total_probability"] - 0.10820655287892997), - 1e-10, - msg=analytics["total_probability"], - ) # - self.assertEqual(analytics["failed_systems_error_codes"], [0]) - self.assertTrue(analytics["total_count"] == 10) - self.assertTrue(analytics["failed_count"] == 10) - self.assertTrue(analytics["errors_found"] == True) - self.assertTrue(analytics["errors_exceeded"] == True) - - # test to see if 1 thread does all the systems - - test_pop = Population() - test_pop.set( - num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY - ) - test_pop.set(failed_systems_threshold=4) - test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) - - resolution = {"M_1": 10, "q": 2} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - test_pop.add_grid_variable( - name="q", - longname="Mass ratio", - valuerange=["0.1/M_1", 1], - samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), - probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", - dphasevol="dq", - precode="M_2 = q * M_1", - parameter_name="M_2", - # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself - condition="'random_var' in dir()", # This will raise an error because random_var is not defined. - ) - - # TODO: why should it raise this error? It should probably raise a valueerror when the limit is exceeded right? - # DEcided to turn it off for now because there is not raise VAlueError in that chain of functions. - # NOTE: Found out why this test was here. It is to do with the condition random_var in dir(), but I changed the behaviour from raising an error to continue. This has to do with the moe&distefano code that will loop over several multiplicities - # TODO: make sure the continue behaviour is what we actually want. - - # self.assertRaises(ValueError, test_pop.evolve) - - def test_grid_evolve_no_grid_variables(self): - with Capturing() as output: - self._test_grid_evolve_no_grid_variables() - - def _test_grid_evolve_no_grid_variables(self): - """ - Unittests to see if errors are raised if there are no grid variables - """ - - test_pop = Population() - test_pop.set( - num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY - ) - - resolution = {"M_1": 10} - self.assertRaises(ValueError, test_pop.evolve) - - def test_grid_evolve_2_threads_with_ensemble_direct_output(self): - with Capturing() as output: - self._test_grid_evolve_2_threads_with_ensemble_direct_output() - - def _test_grid_evolve_2_threads_with_ensemble_direct_output(self): - """ - Unittests to see if multiple threads output the ensemble information to files correctly - """ - - data_dir_value = TMP_DIR - num_cores_value = 2 - - test_pop = Population() - test_pop.set( - num_cores=num_cores_value, - verbosity=TEST_VERBOSITY, - M_2=1, - orbital_period=100000, - ensemble=1, - ensemble_defer=1, - ensemble_filters_off=1, - ensemble_filter_STELLAR_TYPE_COUNTS=1, - ensemble_dt=1000, - ) - test_pop.set( - data_dir=TMP_DIR, - ensemble_output_name="ensemble_output.json", - combine_ensemble_with_thread_joining=False, - ) - - resolution = {"M_1": 10} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop.evolve() - output_names = [ - os.path.join( - data_dir_value, - "ensemble_output_{}_{}.json".format( - analytics["population_name"], thread_id - ), - ) - for thread_id in range(num_cores_value) - ] - - for output_name in output_names: - self.assertTrue(os.path.isfile(output_name)) - - with open(output_name, "r") as f: - file_content = f.read() - - ensemble_json = json.loads(file_content) - - self.assertTrue(isinstance(ensemble_json, dict)) - self.assertNotEqual(ensemble_json, {}) - - self.assertIn("number_counts", ensemble_json) - self.assertNotEqual(ensemble_json["number_counts"], {}) - - def test_grid_evolve_2_threads_with_ensemble_combining(self): - with Capturing() as output: - self._test_grid_evolve_2_threads_with_ensemble_combining() - - def _test_grid_evolve_2_threads_with_ensemble_combining(self): - """ - Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid - """ - - data_dir_value = TMP_DIR - num_cores_value = 2 - - test_pop = Population() - test_pop.set( - num_cores=num_cores_value, - verbosity=TEST_VERBOSITY, - M_2=1, - orbital_period=100000, - ensemble=1, - ensemble_defer=1, - ensemble_filters_off=1, - ensemble_filter_STELLAR_TYPE_COUNTS=1, - ensemble_dt=1000, - ) - test_pop.set( - data_dir=TMP_DIR, - combine_ensemble_with_thread_joining=True, - ensemble_output_name="ensemble_output.json", - ) - - resolution = {"M_1": 10} - - test_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics = test_pop.evolve() - - self.assertTrue(isinstance(test_pop.grid_ensemble_results["ensemble"], dict)) - self.assertNotEqual(test_pop.grid_ensemble_results["ensemble"], {}) - - self.assertIn("number_counts", test_pop.grid_ensemble_results["ensemble"]) - self.assertNotEqual( - test_pop.grid_ensemble_results["ensemble"]["number_counts"], {} - ) - - def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): - with Capturing() as output: - self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods() - - def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): - """ - Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same - """ - - data_dir_value = TMP_DIR - num_cores_value = 2 - - # First - test_pop_1 = Population() - test_pop_1.set( - num_cores=num_cores_value, - verbosity=TEST_VERBOSITY, - M_2=1, - orbital_period=100000, - ensemble=1, - ensemble_defer=1, - ensemble_filters_off=1, - ensemble_filter_STELLAR_TYPE_COUNTS=1, - ensemble_dt=1000, - ) - test_pop_1.set( - data_dir=TMP_DIR, - combine_ensemble_with_thread_joining=True, - ensemble_output_name="ensemble_output.json", - ) - - resolution = {"M_1": 10} - - test_pop_1.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics_1 = test_pop_1.evolve() - ensemble_output_1 = test_pop_1.grid_ensemble_results - - # second - test_pop_2 = Population() - test_pop_2.set( - num_cores=num_cores_value, - verbosity=TEST_VERBOSITY, - M_2=1, - orbital_period=100000, - ensemble=1, - ensemble_defer=1, - ensemble_filters_off=1, - ensemble_filter_STELLAR_TYPE_COUNTS=1, - ensemble_dt=1000, - ) - test_pop_2.set( - data_dir=TMP_DIR, - ensemble_output_name="ensemble_output.json", - combine_ensemble_with_thread_joining=False, - ) - - resolution = {"M_1": 10} - - test_pop_2.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[1, 100], - samplerfunc="const(math.log(1), math.log(100), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - analytics_2 = test_pop_2.evolve() - output_names_2 = [ - os.path.join( - data_dir_value, - "ensemble_output_{}_{}.json".format( - analytics_2["population_name"], thread_id - ), - ) - for thread_id in range(num_cores_value) - ] - ensemble_output_2 = {} - - for output_name in output_names_2: - self.assertTrue(os.path.isfile(output_name)) - - with open(output_name, "r") as f: - file_content = f.read() - - ensemble_json = json.loads(file_content) - - ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json) - - for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]: - self.assertIn(key, ensemble_output_2["number_counts"]["stellar_type"]["0"]) - - # compare values - self.assertLess( - np.abs( - ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"][ - key - ] - - ensemble_output_2["number_counts"]["stellar_type"]["0"][key] - ), - 1e-8, - ) - - -def parse_function_adding_results(self, output): - """ - Example parse function - """ - - seperator = " " - - parameters = ["time", "mass", "zams_mass", "probability", "stellar_type"] - - self.grid_results["example"]["count"] += 1 - - # Go over the output. - for line in output.splitlines(): - headerline = line.split()[0] - - # CHeck the header and act accordingly - if headerline == "EXAMPLE_OUTPUT": - values = line.split()[1:] - - # Bin the mass probability - self.grid_results["example"]["mass"][ - bin_data(float(values[2]), binwidth=0.5) - ] += float(values[3]) - - # - if not len(parameters) == len(values): - print("Number of column names isnt equal to number of columns") - raise ValueError - - # record the probability of this line (Beware, this is meant to only be run once for each system. its a controls quantity) - self.grid_results["example"]["probability"] += float(values[3]) - - -class test_resultdict(unittest.TestCase): - """ - Unittests for bin_data - """ - - def test_adding_results(self): - """ - Function to test whether the results are properly added and combined - """ - - # Create custom logging statement - custom_logging_statement = """ - if (stardata->model.time < stardata->model.max_evolution_time) - { - Printf("EXAMPLE_OUTPUT %30.16e %g %g %30.12e %d\\n", - // - stardata->model.time, // 1 - stardata->star[0].mass, // 2 - stardata->common.zero_age.mass[0], // 3 - stardata->model.probability, // 4 - stardata->star[0].stellar_type // 5 - ); - }; - /* Kill the simulation to save time */ - stardata->model.max_evolution_time = stardata->model.time - stardata->model.dtm; - """ - - example_pop = Population() - example_pop.set(verbosity=0) - example_pop.set( - max_evolution_time=15000, # bse_options - # grid_options - num_cores=3, - tmp_dir=TMP_DIR, - # Custom options - data_dir=os.path.join(TMP_DIR, "test_resultdict"), # custom_options - C_logging_code=custom_logging_statement, - parse_function=parse_function_adding_results, - ) - - # Add grid variables - resolution = {"M_1": 10} - - # Mass - example_pop.add_grid_variable( - name="lnm1", - longname="Primary mass", - valuerange=[2, 150], - samplerfunc="const(math.log(2), math.log(150), {})".format( - resolution["M_1"] - ), - precode="M_1=math.exp(lnm1)", - probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1", - dphasevol="dlnm1", - parameter_name="M_1", - condition="", # Impose a condition on this grid variable. Mostly for a check for yourself - ) - - ## Executing a population - ## This uses the values generated by the grid_variables - analytics = example_pop.evolve() - - # - grid_prob = analytics["total_probability"] - result_dict_prob = example_pop.grid_results["example"]["probability"] - - # amt systems - grid_count = analytics["total_count"] - result_dict_count = example_pop.grid_results["example"]["count"] - - # Check if the total probability matches - self.assertAlmostEqual( - grid_prob, - result_dict_prob, - places=12, - msg="Total probability from grid {} and from result dict {} are not equal".format( - grid_prob, result_dict_prob - ), - ) - - # Check if the total count matches - self.assertEqual( - grid_count, - result_dict_count, - msg="Total count from grid {} and from result dict {} are not equal".format( - grid_count, result_dict_count - ), - ) - - # Check if the structure is what we expect. Note: this depends on the probability calculation. if that changes we need to recalibrate this - test_case_dict = { - 2.25: 0.01895481306515, - 3.75: 0.01081338190204, - 5.75: 0.006168841009268, - 9.25: 0.003519213484031, - 13.75: 0.002007648361756, - 21.25: 0.001145327489437, - 33.25: 0.0006533888518775, - 50.75: 0.0003727466560393, - 78.25: 0.000212645301782, - 120.75: 0.0001213103421247, - } - - self.assertEqual( - test_case_dict, dict(example_pop.grid_results["example"]["mass"]) - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/binarycpython/tests/test_stellar_types.py b/binarycpython/tests/test_stellar_types.py index 84d3e5ee35edc912fdccadaf4c5c03956a54a1de..2fc5ceb44973cd342c3a7688b68e8ba6f52b5d79 100644 --- a/binarycpython/tests/test_stellar_types.py +++ b/binarycpython/tests/test_stellar_types.py @@ -1,3 +1,5 @@ """ Unittests for stellar_types module """ + +from binarycpython.utils.stellar_types import STELLAR_TYPE_DICT, STELLAR_TYPE_DICT_SHORT \ No newline at end of file diff --git a/binarycpython/tests/tests_population_extensions/test__analytics.py b/binarycpython/tests/tests_population_extensions/test__analytics.py new file mode 100644 index 0000000000000000000000000000000000000000..f5146f02bd8d24c133157db18ac366f522af10e5 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__analytics.py @@ -0,0 +1,8 @@ +""" +Unit classes for the _analytics module population extension + +TODO: make_analytics_dict +TODO: set_time +TODO: time_elapsed +TODO: CPU_time +""" diff --git a/binarycpython/tests/tests_population_extensions/test__cachce.py b/binarycpython/tests/tests_population_extensions/test__cachce.py new file mode 100644 index 0000000000000000000000000000000000000000..00e6f350ca4a37e9e7939206924ed96c8297f962 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__cachce.py @@ -0,0 +1,8 @@ +""" +Unit classes for the _cache module population extension + +TODO: default_cache_dir +TODO: NullCache +TODO: setup_function_cache +TODO: test_caches +""" diff --git a/binarycpython/tests/tests_population_extensions/test__condor.py b/binarycpython/tests/tests_population_extensions/test__condor.py new file mode 100644 index 0000000000000000000000000000000000000000..835c8025cc30980a898e4c9492b0e57b18cb938b --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__condor.py @@ -0,0 +1,15 @@ +""" +Unit classes for the _condor module population extension + +TODO: condorID +TODO: condorpath +TODO: condor_status_file +TODO: condor_check_requirements +TODO: condor_dirs +TODO: set_condor_status +TODO: get_condor_status +TODO: ondor_outfile +TODO: make_condor_dirs +TODO: condor_grid +TODO: condor_queue_stats +""" \ No newline at end of file diff --git a/binarycpython/tests/tests_population_extensions/test__dataIO.py b/binarycpython/tests/tests_population_extensions/test__dataIO.py new file mode 100644 index 0000000000000000000000000000000000000000..98718fdfb10037a4a3d81e1e437812e3ee46b279 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__dataIO.py @@ -0,0 +1,22 @@ +""" +Unit classes for the _condor module population extension + +TODO: dir_ok +TODO: save_population_object +TODO: load_population_object +TODO: merge_populations +TODO: merge_populations_from_file +TODO: snapshot_filename +TODO: load_snapshot +TODO: save_snapshot +TODO: write_ensemble +TODO: write_binary_c_calls_to_file +TODO: set_status +TODO: locked_close +TODO: wait_for_unlock +TODO: locked_open_for_write +TODO: NFS_flush_hack +TODO: compression_type +TODO: open +TODO: NFSpath +""" \ No newline at end of file diff --git a/binarycpython/tests/tests_population_extensions/test__distribution_functions.py b/binarycpython/tests/tests_population_extensions/test__distribution_functions.py index 556bff92a50fdd357209ae6b83d514a7cff10996..8fd20789956def67161fddaa063cd83479064946 100644 --- a/binarycpython/tests/tests_population_extensions/test__distribution_functions.py +++ b/binarycpython/tests/tests_population_extensions/test__distribution_functions.py @@ -1,5 +1,42 @@ """ Module containing the unittests for the distribution functions. + +TODO: powerlaw_constant_nocache +TODO: powerlaw_constant +TODO: powerlaw +TODO: calculate_constants_three_part_powerlaw +TODO: three_part_powerlaw +TODO: gaussian_normalizing_const +TODO: gaussian_func +TODO: gaussian +TODO: Kroupa2001 +TODO: ktg93 +TODO: imf_tinsley1980 +TODO: imf_scalo1986 +TODO: imf_scalo1998 +TODO: imf_chabrier2003 +TODO: Arenou2010_binary_fraction +TODO: raghavan2010_binary_fraction +TODO: duquennoy1991 +TODO: sana12 +TODO: interpolate_in_mass_izzard2012 +TODO: Izzard2012_period_distribution +TODO: flatsections +TODO: cosmic_SFH_madau_dickinson2014 +TODO: poisson +TODO: _poisson +TODO: get_max_multiplicity +TODO: merge_multiplicities +TODO: Moe_di_Stefano_2017_multiplicity_fractions +TODO: build_q_table +TODO: powerlaw_extrapolation_q +TODO: linear_extrapolation_q +TODO: get_integration_constant_q +TODO: fill_data +TODO: calc_e_integral +TODO: calc_P_integral +TODO: calc_total_probdens +TODO: Moe_di_Stefano_2017_pdf """ import unittest @@ -62,16 +99,16 @@ class test_number(unittest.TestCase): self.assertEqual(input_1, output_1) -class test_const(unittest.TestCase): +class test_const_distribution(unittest.TestCase): """ Class for unit test of number """ - def test_const(self): + def test_const_distribution(self): with Capturing() as output: self._test_const() - def _test_const(self): + def _test_const_distribution(self): """ Unittest for function const """ @@ -130,698 +167,698 @@ class test_powerlaw(unittest.TestCase): # extra test for k = -1 self.assertRaises(ValueError, distribution_functions_pop.powerlaw, 1, 100, -1, 10) +#### +class test_three_part_power_law(unittest.TestCase): + """ + Class for unit test of three_part_power_law + """ + + def test_three_part_power_law(self): + with Capturing() as output: + self._test_three_part_power_law() + + def _test_three_part_power_law(self): + """ + unittest for three_part_power_law + """ + + distribution_functions_pop = Population() + + perl_results = [ + 10.0001044752901, + 2.03065220596677, + 0.0501192469795434, + 0.000251191267451594, + 9.88540897458207e-05, + 6.19974072148769e-06, + ] + python_results = [] + input_lists = [] + + for mass in MASS_LIST: + input_lists.append(mass) + python_results.append( + distribution_functions_pop.three_part_powerlaw(mass, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) + ) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass, per: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), TOLERANCE, msg=msg + ) + + # Extra test: + # M < M0 + self.assertTrue( + distribution_functions_pop.three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) == 0, + msg="Probability should be zero as M < M0", + ) + + +class test_Kroupa2001(unittest.TestCase): + """ + Class for unit test of Kroupa2001 + """ + + def test_Kroupa2001(self): + with Capturing() as output: + self._test_Kroupa2001() + + def _test_Kroupa2001(self): + """ + unittest for three_part_power_law + """ + + distribution_functions_pop = Population() + + perl_results = [ + 0, # perl value is actually 5.71196495365248 + 2.31977861075353, + 0.143138195684851, + 0.000717390363216896, + 0.000282322598503135, + 1.77061658757533e-05, + ] + python_results = [] + input_lists = [] + + for mass in MASS_LIST: + input_lists.append(mass) + python_results.append(distribution_functions_pop.Kroupa2001(mass)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), TOLERANCE, msg=msg + ) + + # Extra tests: + self.assertEqual( + distribution_functions_pop.Kroupa2001(10, newopts={"mmax": 300}), + distribution_functions_pop.three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3), + ) + +class TestDistributions(unittest.TestCase): + """ + Unittest class + + # https://stackoverflow.com/questions/17353213/init-for-unittest-testcase + """ + + def __init__(self, *args, **kwargs): + """ + init + """ + super(TestDistributions, self).__init__(*args, **kwargs) + + def test_ktg93(self): + with Capturing() as output: + self._test_ktg93() + + def _test_ktg93(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 0, # perl value is actually 5.79767807698379 but that is not correct + 2.35458895566605, + 0.155713799148675, + 0.000310689875361984, + 0.000103963454405194, + 4.02817276824841e-06, + ] + python_results = [] + input_lists = [] + + for mass in self.mass_list: + input_lists.append(mass) + python_results.append(ktg93(mass)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + # extra test: + self.assertEqual( + ktg93(10, newopts={"mmax": 300}), + three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7), + ) + + def test_imf_tinsley1980(self): + with Capturing() as output: + self._test_imf_tinsley1980() + + def _test_imf_tinsley1980(self): + """ + Unittest for function imf_tinsley1980 + """ + + m = 1.2 + self.assertEqual( + imf_tinsley1980(m), + three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3), + ) + + def test_imf_scalo1986(self): + with Capturing() as output: + self._test_imf_scalo1986() + + def _test_imf_scalo1986(self): + """ + Unittest for function imf_scalo1986 + """ + + m = 1.2 + self.assertEqual( + imf_scalo1986(m), + three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70), + ) + + def test_imf_scalo1998(self): + with Capturing() as output: + self._test_imf_scalo1998() + + def _test_imf_scalo1998(self): + """ + Unittest for function imf_scalo1986 + """ + + m = 1.2 + self.assertEqual( + imf_scalo1998(m), + three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3), + ) + + def test_imf_chabrier2003(self): + with Capturing() as output: + self._test_imf_chabrier2003() + + def _test_imf_chabrier2003(self): + """ + Unittest for function imf_chabrier2003 + """ + + input_1 = 0 + self.assertRaises(ValueError, imf_chabrier2003, input_1) + + masses = [0.1, 0.2, 0.5, 1, 2, 10, 15, 50] + perl_results = [ + 5.64403964849588, + 2.40501495673496, + 0.581457346702825, + 0.159998782068074, + 0.0324898485372181, + 0.000801893469684309, + 0.000315578044662863, + 1.97918170035704e-05, + ] + python_results = [imf_chabrier2003(m) for m in masses] + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass: {}".format( + perl_results[i], python_results[i], str(masses[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) -# class test_three_part_power_law(unittest.TestCase): -# """ -# Class for unit test of three_part_power_law -# """ - -# def test_three_part_power_law(self): -# with Capturing() as output: -# self._test_three_part_power_law() - -# def _test_three_part_power_law(self): -# """ -# unittest for three_part_power_law -# """ - -# distribution_functions_pop = Population() - -# perl_results = [ -# 10.0001044752901, -# 2.03065220596677, -# 0.0501192469795434, -# 0.000251191267451594, -# 9.88540897458207e-05, -# 6.19974072148769e-06, -# ] -# python_results = [] -# input_lists = [] - -# for mass in MASS_LIST: -# input_lists.append(mass) -# python_results.append( -# distribution_functions_pop.three_part_powerlaw(mass, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) -# ) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass, per: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), TOLERANCE, msg=msg -# ) - -# # Extra test: -# # M < M0 -# self.assertTrue( -# distribution_functions_pop.three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) == 0, -# msg="Probability should be zero as M < M0", -# ) - - -# class test_Kroupa2001(unittest.TestCase): -# """ -# Class for unit test of Kroupa2001 -# """ - -# def test_Kroupa2001(self): -# with Capturing() as output: -# self._test_Kroupa2001() - -# def _test_Kroupa2001(self): -# """ -# unittest for three_part_power_law -# """ - -# distribution_functions_pop = Population() - -# perl_results = [ -# 0, # perl value is actually 5.71196495365248 -# 2.31977861075353, -# 0.143138195684851, -# 0.000717390363216896, -# 0.000282322598503135, -# 1.77061658757533e-05, -# ] -# python_results = [] -# input_lists = [] - -# for mass in MASS_LIST: -# input_lists.append(mass) -# python_results.append(distribution_functions_pop.Kroupa2001(mass)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), TOLERANCE, msg=msg -# ) - -# # Extra tests: -# self.assertEqual( -# distribution_functions_pop.Kroupa2001(10, newopts={"mmax": 300}), -# distribution_functions_pop.three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3), -# ) - -# class TestDistributions(unittest.TestCase): -# """ -# Unittest class - -# # https://stackoverflow.com/questions/17353213/init-for-unittest-testcase -# """ - -# def __init__(self, *args, **kwargs): -# """ -# init -# """ -# super(TestDistributions, self).__init__(*args, **kwargs) - -# def test_ktg93(self): -# with Capturing() as output: -# self._test_ktg93() - -# def _test_ktg93(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 0, # perl value is actually 5.79767807698379 but that is not correct -# 2.35458895566605, -# 0.155713799148675, -# 0.000310689875361984, -# 0.000103963454405194, -# 4.02817276824841e-06, -# ] -# python_results = [] -# input_lists = [] - -# for mass in self.mass_list: -# input_lists.append(mass) -# python_results.append(ktg93(mass)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# # extra test: -# self.assertEqual( -# ktg93(10, newopts={"mmax": 300}), -# three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7), -# ) - -# def test_imf_tinsley1980(self): -# with Capturing() as output: -# self._test_imf_tinsley1980() - -# def _test_imf_tinsley1980(self): -# """ -# Unittest for function imf_tinsley1980 -# """ - -# m = 1.2 -# self.assertEqual( -# imf_tinsley1980(m), -# three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3), -# ) - -# def test_imf_scalo1986(self): -# with Capturing() as output: -# self._test_imf_scalo1986() - -# def _test_imf_scalo1986(self): -# """ -# Unittest for function imf_scalo1986 -# """ - -# m = 1.2 -# self.assertEqual( -# imf_scalo1986(m), -# three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70), -# ) - -# def test_imf_scalo1998(self): -# with Capturing() as output: -# self._test_imf_scalo1998() - -# def _test_imf_scalo1998(self): -# """ -# Unittest for function imf_scalo1986 -# """ - -# m = 1.2 -# self.assertEqual( -# imf_scalo1998(m), -# three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3), -# ) - -# def test_imf_chabrier2003(self): -# with Capturing() as output: -# self._test_imf_chabrier2003() - -# def _test_imf_chabrier2003(self): -# """ -# Unittest for function imf_chabrier2003 -# """ - -# input_1 = 0 -# self.assertRaises(ValueError, imf_chabrier2003, input_1) - -# masses = [0.1, 0.2, 0.5, 1, 2, 10, 15, 50] -# perl_results = [ -# 5.64403964849588, -# 2.40501495673496, -# 0.581457346702825, -# 0.159998782068074, -# 0.0324898485372181, -# 0.000801893469684309, -# 0.000315578044662863, -# 1.97918170035704e-05, -# ] -# python_results = [imf_chabrier2003(m) for m in masses] - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass: {}".format( -# perl_results[i], python_results[i], str(masses[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# def test_duquennoy1991(self): -# with Capturing() as output: -# self._test_duquennoy1991() - -# def _test_duquennoy1991(self): -# """ -# Unittest for function duquennoy1991 -# """ - -# self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) - -# def test_gaussian(self): -# with Capturing() as output: -# self._test_gaussian() - -# def _test_gaussian(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 0.00218800520299544, -# 0.0121641269671571, -# 0.0657353455837751, -# 0.104951743573429, -# 0.16899534495487, -# 0.0134332780385336, -# ] -# python_results = [] -# input_lists = [] - -# for logper in self.logper_list: -# input_lists.append(logper) -# python_results.append(gaussian(logper, 4.8, 2.3, -2.0, 12.0)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for logper: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# # Extra test: -# self.assertTrue( -# gaussian(15, 4.8, 2.3, -2.0, 12.0) == 0, -# msg="Probability should be 0 because the input period is out of bounds", -# ) - -# def test_Arenou2010_binary_fraction(self): -# with Capturing() as output: -# self._test_Arenou2010_binary_fraction() - -# def _test_Arenou2010_binary_fraction(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 0.123079723518677, -# 0.178895136157746, -# 0.541178340047153, -# 0.838798485820276, -# 0.838799998443204, -# 0.8388, -# ] -# python_results = [] -# input_lists = [] - -# for mass in self.mass_list: -# input_lists.append(mass) -# python_results.append(Arenou2010_binary_fraction(mass)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# def test_raghavan2010_binary_fraction(self): -# with Capturing() as output: -# self._test_raghavan2010_binary_fraction() - -# def _test_raghavan2010_binary_fraction(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [0.304872297931597, 0.334079955706623, 0.41024, 1, 1, 1] -# python_results = [] -# input_lists = [] - -# for mass in self.mass_list: -# input_lists.append(mass) -# python_results.append(raghavan2010_binary_fraction(mass)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# def test_Izzard2012_period_distribution(self): -# with Capturing() as output: -# self._test_Izzard2012_period_distribution() - -# def _test_Izzard2012_period_distribution(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 0, -# 0.00941322840619318, -# 0.0575068231479569, -# 0.0963349886047932, -# 0.177058537292581, -# 0.0165713385659234, -# 0, -# 0.00941322840619318, -# 0.0575068231479569, -# 0.0963349886047932, -# 0.177058537292581, -# 0.0165713385659234, -# 0, -# 0.00941322840619318, -# 0.0575068231479569, -# 0.0963349886047932, -# 0.177058537292581, -# 0.0165713385659234, -# 0, -# 7.61631504133159e-09, -# 0.168028727846997, -# 0.130936282216512, -# 0.0559170865520968, -# 0.0100358604460285, -# 0, -# 2.08432736869149e-21, -# 0.18713622563288, -# 0.143151383185002, -# 0.0676299576972089, -# 0.0192427864870784, -# 0, -# 1.1130335685003e-24, -# 0.194272603987661, -# 0.14771508552257, -# 0.0713078479280884, -# 0.0221093965810181, -# ] -# python_results = [] -# input_lists = [] - -# for mass in self.mass_list: -# for per in self.per_list: -# input_lists.append([mass, per]) - -# python_results.append(Izzard2012_period_distribution(per, mass)) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass, per: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# def test_flatsections(self): -# with Capturing() as output: -# self._test_flatsections() - -# def _test_flatsections(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 1.01010101010101, -# 1.01010101010101, -# 1.01010101010101, -# 1.01010101010101, -# 1.01010101010101, -# 1.01010101010101, -# ] -# python_results = [] -# input_lists = [] - -# for q in self.q_list: -# input_lists.append(q) -# python_results.append( -# flatsections(q, [{"min": 0.01, "max": 1.0, "height": 1.0}]) -# ) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for q: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) - -# def test_sana12(self): -# with Capturing() as output: -# self._test_sana12() - -# def _test_sana12(self): -# """ -# unittest for three_part_power_law -# """ - -# perl_results = [ -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.121764808010258, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# 0.481676471294883, -# 0.481676471294883, -# 0.131020615300798, -# 0.102503482445846, -# 0.0678037785559114, -# 0.066436408359805, -# ] -# python_results = [] -# input_lists = [] - -# for mass in self.mass_list: -# for q in self.q_list: -# for per in self.per_list: -# mass_2 = mass * q - -# sep = calc_sep_from_period(mass, mass_2, per) -# sep_min = calc_sep_from_period(mass, mass_2, 10 ** 0.15) -# sep_max = calc_sep_from_period(mass, mass_2, 10 ** 5.5) - -# input_lists.append([mass, mass_2, per]) - -# python_results.append( -# sana12( -# mass, mass_2, sep, per, sep_min, sep_max, 0.15, 5.5, -0.55 -# ) -# ) - -# # GO over the results and check whether they are equal (within tolerance) -# for i in range(len(python_results)): -# msg = "Error: Value perl: {} Value python: {} for mass, mass2, per: {}".format( -# perl_results[i], python_results[i], str(input_lists[i]) -# ) -# self.assertLess( -# np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg -# ) + def test_duquennoy1991(self): + with Capturing() as output: + self._test_duquennoy1991() + + def _test_duquennoy1991(self): + """ + Unittest for function duquennoy1991 + """ + + self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) + + def test_gaussian(self): + with Capturing() as output: + self._test_gaussian() + + def _test_gaussian(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 0.00218800520299544, + 0.0121641269671571, + 0.0657353455837751, + 0.104951743573429, + 0.16899534495487, + 0.0134332780385336, + ] + python_results = [] + input_lists = [] + + for logper in self.logper_list: + input_lists.append(logper) + python_results.append(gaussian(logper, 4.8, 2.3, -2.0, 12.0)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for logper: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + # Extra test: + self.assertTrue( + gaussian(15, 4.8, 2.3, -2.0, 12.0) == 0, + msg="Probability should be 0 because the input period is out of bounds", + ) + + def test_Arenou2010_binary_fraction(self): + with Capturing() as output: + self._test_Arenou2010_binary_fraction() + + def _test_Arenou2010_binary_fraction(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 0.123079723518677, + 0.178895136157746, + 0.541178340047153, + 0.838798485820276, + 0.838799998443204, + 0.8388, + ] + python_results = [] + input_lists = [] + + for mass in self.mass_list: + input_lists.append(mass) + python_results.append(Arenou2010_binary_fraction(mass)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + def test_raghavan2010_binary_fraction(self): + with Capturing() as output: + self._test_raghavan2010_binary_fraction() + + def _test_raghavan2010_binary_fraction(self): + """ + unittest for three_part_power_law + """ + + perl_results = [0.304872297931597, 0.334079955706623, 0.41024, 1, 1, 1] + python_results = [] + input_lists = [] + + for mass in self.mass_list: + input_lists.append(mass) + python_results.append(raghavan2010_binary_fraction(mass)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + def test_Izzard2012_period_distribution(self): + with Capturing() as output: + self._test_Izzard2012_period_distribution() + + def _test_Izzard2012_period_distribution(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 0, + 0.00941322840619318, + 0.0575068231479569, + 0.0963349886047932, + 0.177058537292581, + 0.0165713385659234, + 0, + 0.00941322840619318, + 0.0575068231479569, + 0.0963349886047932, + 0.177058537292581, + 0.0165713385659234, + 0, + 0.00941322840619318, + 0.0575068231479569, + 0.0963349886047932, + 0.177058537292581, + 0.0165713385659234, + 0, + 7.61631504133159e-09, + 0.168028727846997, + 0.130936282216512, + 0.0559170865520968, + 0.0100358604460285, + 0, + 2.08432736869149e-21, + 0.18713622563288, + 0.143151383185002, + 0.0676299576972089, + 0.0192427864870784, + 0, + 1.1130335685003e-24, + 0.194272603987661, + 0.14771508552257, + 0.0713078479280884, + 0.0221093965810181, + ] + python_results = [] + input_lists = [] + + for mass in self.mass_list: + for per in self.per_list: + input_lists.append([mass, per]) + + python_results.append(Izzard2012_period_distribution(per, mass)) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass, per: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + def test_flatsections(self): + with Capturing() as output: + self._test_flatsections() + + def _test_flatsections(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 1.01010101010101, + 1.01010101010101, + 1.01010101010101, + 1.01010101010101, + 1.01010101010101, + 1.01010101010101, + ] + python_results = [] + input_lists = [] + + for q in self.q_list: + input_lists.append(q) + python_results.append( + flatsections(q, [{"min": 0.01, "max": 1.0, "height": 1.0}]) + ) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for q: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) + + def test_sana12(self): + with Capturing() as output: + self._test_sana12() + + def _test_sana12(self): + """ + unittest for three_part_power_law + """ + + perl_results = [ + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.121764808010258, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + 0.481676471294883, + 0.481676471294883, + 0.131020615300798, + 0.102503482445846, + 0.0678037785559114, + 0.066436408359805, + ] + python_results = [] + input_lists = [] + + for mass in self.mass_list: + for q in self.q_list: + for per in self.per_list: + mass_2 = mass * q + + sep = calc_sep_from_period(mass, mass_2, per) + sep_min = calc_sep_from_period(mass, mass_2, 10 ** 0.15) + sep_max = calc_sep_from_period(mass, mass_2, 10 ** 5.5) + + input_lists.append([mass, mass_2, per]) + + python_results.append( + sana12( + mass, mass_2, sep, per, sep_min, sep_max, 0.15, 5.5, -0.55 + ) + ) + + # GO over the results and check whether they are equal (within tolerance) + for i in range(len(python_results)): + msg = "Error: Value perl: {} Value python: {} for mass, mass2, per: {}".format( + perl_results[i], python_results[i], str(input_lists[i]) + ) + self.assertLess( + np.abs(python_results[i] - perl_results[i]), self.tolerance, msg=msg + ) if __name__ == "__main__": diff --git a/binarycpython/tests/tests_population_extensions/test__grid_logging.py b/binarycpython/tests/tests_population_extensions/test__grid_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2a2c4dff0ee51097ccababb250745fa92556c3 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__grid_logging.py @@ -0,0 +1,13 @@ +""" +Module containing the unittests for the grid_logging functions + +TODO: _set_custom_logging +TODO: _print_info +TODO: _set_loggers +TODO: vb1print +TODO: vb2print +TODO: verbose_print +TODO: _boxed +TODO: _get_stream_logger +TODO: _clean_up_custom_logging +""" diff --git a/binarycpython/tests/tests_population_extensions/test__gridcode.py b/binarycpython/tests/tests_population_extensions/test__gridcode.py new file mode 100644 index 0000000000000000000000000000000000000000..526b2d090b70ef0b1de12e6c0a722d6b18036c64 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__gridcode.py @@ -0,0 +1,16 @@ +""" +Unittests for gridcode module + +TODO: _gridcode_filename +TODO: _add_code +TODO: _indent_block +TODO: _increment_indent_depth +TODO: _generate_grid_code +TODO: _write_gridcode_system_call +TODO: _load_grid_function +TODO: _last_grid_variable +TODO: update_grid_variable +TODO: delete_grid_variable +TODO: rename_grid_variable +TODO: add_grid_variable +""" \ No newline at end of file diff --git a/binarycpython/tests/tests_population_extensions/test__metadata.py b/binarycpython/tests/tests_population_extensions/test__metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6ef8a601d5fc42ecca11fe385230ac8de9881d --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__metadata.py @@ -0,0 +1,7 @@ +""" +Unittests for metadata module + +TODO: add_system_metadata +TODO: add_ensemble_metadata +TODO: _metadata_keylist +""" diff --git a/binarycpython/tests/tests_population_extensions/test__slurm.py b/binarycpython/tests/tests_population_extensions/test__slurm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1ea37b1e09d76ff15055b111a6065911ddb1f2 --- /dev/null +++ b/binarycpython/tests/tests_population_extensions/test__slurm.py @@ -0,0 +1,15 @@ +""" +Unittests for slurm module + +TODO: slurmID +TODO: slurmpath +TODO: slurm_status_file +TODO: slurm_check_requirements +TODO: slurm_dirs +TODO: set_slurm_status +TODO: get_slurm_status +TODO: slurm_outfile +TODO: make_slurm_dirs +TODO: slurm_grid +TODO: slurm_queue_stats +""" diff --git a/binarycpython/tests/tests_population_extensions/test__spacing_functions.py b/binarycpython/tests/tests_population_extensions/test__spacing_functions.py index 51f6edc64b1309fe28e1c66b21bbf15683e0bbda..663f3b91b307061f87cd84d1233ff32eb37a60cd 100644 --- a/binarycpython/tests/tests_population_extensions/test__spacing_functions.py +++ b/binarycpython/tests/tests_population_extensions/test__spacing_functions.py @@ -1,5 +1,13 @@ """ Unittests for spacing_functions module + + +TODO: const_linear +TODO: const_int +TODO: const_ranges +TODO: peak_normalized_gaussian_func +TODO: gaussian_zoom +TODO: const_dt """ import unittest @@ -29,6 +37,5 @@ class test_spacing_functions(unittest.TestCase): msg="Output didn't contain SINGLE_STAR_LIFETIME", ) - if __name__ == "__main__": unittest.main() diff --git a/binarycpython/tests/tests_population_extensions/test__version_info.py b/binarycpython/tests/tests_population_extensions/test__version_info.py index 70729239aa0d1808a7df6826efa5b974058f33bb..e5331800e61775c6b634fd50031bb3c8df93f79b 100644 --- a/binarycpython/tests/tests_population_extensions/test__version_info.py +++ b/binarycpython/tests/tests_population_extensions/test__version_info.py @@ -1,5 +1,9 @@ """ Unit tests for the _version_info Population extension module + +TODO: return_binary_c_version_info +TODO: parse_binary_c_version_info +TODO: minimum_stellar_mass """ import os diff --git a/binarycpython/tests/tmp_functions.py b/binarycpython/tests/tmp_functions.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..6bfb9272b5a9db3e16e0c061e667a2ca0b21e28f 100644 --- a/binarycpython/tests/tmp_functions.py +++ b/binarycpython/tests/tmp_functions.py @@ -0,0 +1,1184 @@ +""" +Test cases for the grid + +Tasks: + TODO: write tests for load_from_sourcefile +""" + +import os +import sys +import json +import unittest +import numpy as np + +from binarycpython.utils.grid import Population + +from binarycpython.utils.functions import ( + temp_dir, + remove_file, + Capturing, + bin_data, +) + +from binarycpython.utils.ensemble import ( + extract_ensemble_json_from_string, +) +from binarycpython.utils.dicts import ( + merge_dicts, +) + +from binarycpython.utils.custom_logging_functions import binary_c_log_code + +TMP_DIR = temp_dir("tests", "test_grid") +TEST_VERBOSITY = 1 + + +def parse_function_test_grid_evolve_2_threads_with_custom_logging(self, output): + """ + Simple parse function that directly appends all the output to a file + """ + + # Get some information from the + data_dir = self.custom_options["data_dir"] + + # make outputfilename + output_filename = os.path.join( + data_dir, + "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format( + self.grid_options["_population_id"], self.process_ID + ), + ) + + # Check directory, make if necessary + os.makedirs(data_dir, exist_ok=True) + + if not os.path.exists(output_filename): + with open(output_filename, "w") as first_f: + first_f.write(output + "\n") + else: + with open(output_filename, "a") as first_f: + first_f.write(output + "\n") + + +# class test_(unittest.TestCase): +# """ +# Unittests for function +# """ + +# def test_1(self): +# pass + +# def test_(self): +# """ +# Unittests for the function +# """ + + +class test_Population(unittest.TestCase): + """ + Unittests for function + """ + + def test_setup(self): + with Capturing() as output: + self._test_setup() + + def _test_setup(self): + """ + Unittests for function _setup + """ + test_pop = Population() + + self.assertTrue("orbital_period" in test_pop.defaults) + self.assertTrue("metallicity" in test_pop.defaults) + self.assertNotIn("help_all", test_pop.cleaned_up_defaults) + self.assertEqual(test_pop.bse_options, {}) + self.assertEqual(test_pop.custom_options, {}) + self.assertEqual(test_pop.argline_dict, {}) + self.assertEqual(test_pop.persistent_data_memory_dict, {}) + self.assertTrue(test_pop.grid_options["parse_function"] == None) + self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int)) + + def test_set(self): + with Capturing() as output: + self._test_set() + + def _test_set(self): + """ + Unittests for function set + """ + + test_pop = Population() + test_pop.set(num_cores=2, verbosity=TEST_VERBOSITY) + test_pop.set(M_1=10) + test_pop.set(data_dir="/tmp/binary_c_python") + test_pop.set(ensemble_filter_SUPERNOVAE=1, ensemble_dt=1000) + + self.assertIn("data_dir", test_pop.custom_options) + self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python") + + # + self.assertTrue(test_pop.bse_options["M_1"] == 10) + self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1) + + # + self.assertTrue(test_pop.grid_options["num_cores"] == 2) + + def test_cmdline(self): + with Capturing() as output: + self._test_cmdline() + + def _test_cmdline(self): + """ + Unittests for function parse_cmdline + """ + + # copy old sys.argv values + prev_sysargv = sys.argv.copy() + + # make a dummy cmdline arg input + sys.argv = [ + "script", + "metallicity=0.0002", + "num_cores=2", + "data_dir=/tmp/binary_c_python", + ] + + # Set up population + test_pop = Population() + test_pop.set(data_dir="/tmp", verbosity=TEST_VERBOSITY) + + # parse arguments + test_pop.parse_cmdline() + + # metallicity + self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str)) + self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002") + + # Amt cores + self.assertTrue(isinstance(test_pop.grid_options["num_cores"], int)) + self.assertTrue(test_pop.grid_options["num_cores"] == 2) + + # datadir + self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str)) + self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python") + + # put back the other args if they exist + sys.argv = prev_sysargv.copy() + + def test__return_argline(self): + with Capturing() as output: + self._test__return_argline() + + def _test__return_argline(self): + """ + Unittests for the function _return_argline + """ + + # Set up population + test_pop = Population() + test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) + test_pop.set(M_1=10) + + argline = test_pop._return_argline() + self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02") + + # custom dict + argline2 = test_pop._return_argline( + {"example_parameter1": 10, "example_parameter2": "hello"} + ) + self.assertTrue( + argline2 == "binary_c example_parameter1 10 example_parameter2 hello" + ) + + def test_add_grid_variable(self): + with Capturing() as output: + self._test_add_grid_variable() + + def _test_add_grid_variable(self): + """ + Unittests for the function add_grid_variable + + TODO: Should I test more here? + """ + + test_pop = Population() + + resolution = {"M_1": 10, "q": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + test_pop.add_grid_variable( + name="q", + longname="Mass ratio", + valuerange=["0.1/M_1", 1], + samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), + probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", + dphasevol="dq", + precode="M_2 = q * M_1", + parameter_name="M_2", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + self.assertIn("q", test_pop.grid_options["_grid_variables"]) + self.assertIn("lnm1", test_pop.grid_options["_grid_variables"]) + self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2) + + def test_return_population_settings(self): + with Capturing() as output: + self._test_return_population_settings() + + def _test_return_population_settings(self): + """ + Unittests for the function return_population_settings + """ + + test_pop = Population() + test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) + test_pop.set(M_1=10) + test_pop.set(num_cores=2) + test_pop.set(data_dir="/tmp") + + population_settings = test_pop.return_population_settings() + + self.assertIn("bse_options", population_settings) + self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02) + self.assertTrue(population_settings["bse_options"]["M_1"] == 10) + + self.assertIn("grid_options", population_settings) + self.assertTrue(population_settings["grid_options"]["num_cores"] == 2) + + self.assertIn("custom_options", population_settings) + self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp") + + def test_return_binary_c_version_info(self): + with Capturing() as output: + self._test_return_binary_c_version_info() + + def _test_return_binary_c_version_info(self): + """ + Unittests for the function return_binary_c_version_info + """ + + test_pop = Population() + binary_c_version_info = test_pop.return_binary_c_version_info(parsed=True) + + self.assertTrue(isinstance(binary_c_version_info, dict)) + self.assertIn("isotopes", binary_c_version_info) + self.assertIn("argpairs", binary_c_version_info) + self.assertIn("ensembles", binary_c_version_info) + self.assertIn("macros", binary_c_version_info) + self.assertIn("dt_limits", binary_c_version_info) + self.assertIn("nucleosynthesis_sources", binary_c_version_info) + self.assertIn("miscellaneous", binary_c_version_info) + + self.assertIsNotNone(binary_c_version_info["argpairs"]) + self.assertIsNotNone(binary_c_version_info["ensembles"]) + self.assertIsNotNone(binary_c_version_info["macros"]) + self.assertIsNotNone(binary_c_version_info["dt_limits"]) + self.assertIsNotNone(binary_c_version_info["miscellaneous"]) + + if binary_c_version_info["macros"]["NUCSYN"] == "on": + self.assertIsNotNone(binary_c_version_info["isotopes"]) + + if binary_c_version_info["macros"]["NUCSYN_ID_SOURCES"] == "on": + self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"]) + + def test_return_binary_c_defaults(self): + with Capturing() as output: + self._test_return_binary_c_defaults() + + def _test_return_binary_c_defaults(self): + """ + Unittests for the function return_binary_c_defaults + """ + + test_pop = Population() + binary_c_defaults = test_pop.return_binary_c_defaults() + self.assertIn("probability", binary_c_defaults) + self.assertIn("phasevol", binary_c_defaults) + self.assertIn("metallicity", binary_c_defaults) + + def test_return_all_info(self): + with Capturing() as output: + self._test_return_all_info() + + def _test_return_all_info(self): + """ + Unittests for the function return_all_info + Not going to do too much tests here, just check if they are not empty + """ + + test_pop = Population() + all_info = test_pop.return_all_info() + + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + def test_export_all_info(self): + with Capturing() as output: + self._test_export_all_info() + + def _test_export_all_info(self): + """ + Unittests for the function export_all_info + """ + + test_pop = Population() + + test_pop.set(metallicity=0.02, verbosity=TEST_VERBOSITY) + test_pop.set(M_1=10) + test_pop.set(num_cores=2) + test_pop.set(data_dir=TMP_DIR) + + # datadir + settings_filename = test_pop.export_all_info(use_datadir=True) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # custom name + # datadir + settings_filename = test_pop.export_all_info( + use_datadir=False, + outfile=os.path.join(TMP_DIR, "example_settings.json"), + ) + self.assertTrue(os.path.isfile(settings_filename)) + with open(settings_filename, "r") as f: + all_info = json.loads(f.read()) + + # + self.assertIn("population_settings", all_info) + self.assertIn("binary_c_defaults", all_info) + self.assertIn("binary_c_version_info", all_info) + self.assertIn("binary_c_help_all", all_info) + + # + self.assertNotEqual(all_info["population_settings"], {}) + self.assertNotEqual(all_info["binary_c_defaults"], {}) + self.assertNotEqual(all_info["binary_c_version_info"], {}) + self.assertNotEqual(all_info["binary_c_help_all"], {}) + + # wrong filename + self.assertRaises( + ValueError, + test_pop.export_all_info, + use_datadir=False, + outfile=os.path.join(TMP_DIR, "example_settings.txt"), + ) + + def test__cleanup_defaults(self): + with Capturing() as output: + self._test__cleanup_defaults() + + def _test__cleanup_defaults(self): + """ + Unittests for the function _cleanup_defaults + """ + + test_pop = Population() + cleaned_up_defaults = test_pop._cleanup_defaults() + self.assertNotIn("help_all", cleaned_up_defaults) + + def test__increment_probtot(self): + with Capturing() as output: + self._test__increment_probtot() + + def _test__increment_probtot(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_probtot(0.5) + self.assertEqual(test_pop.grid_options["_probtot"], 0.5) + + def test__increment_count(self): + with Capturing() as output: + self._test__increment_count() + + def _test__increment_count(self): + """ + Unittests for the function _increment_probtot + """ + + test_pop = Population() + test_pop._increment_count() + self.assertEqual(test_pop.grid_options["_count"], 1) + + def test__dict_from_line_source_file(self): + with Capturing() as output: + self._test__dict_from_line_source_file() + + def _test__dict_from_line_source_file(self): + """ + Unittests for the function _dict_from_line_source_file + """ + + source_file = os.path.join(TMP_DIR, "example_source_file.txt") + + # write + with open(source_file, "w") as f: + f.write("binary_c M_1 10 metallicity 0.02\n") + + test_pop = Population() + + # readout + with open(source_file, "r") as f: + for line in f.readlines(): + argdict = test_pop._dict_from_line_source_file(line) + + self.assertTrue(argdict["M_1"] == 10) + self.assertTrue(argdict["metallicity"] == 0.02) + + def test_evolve_single(self): + with Capturing() as output: + self._test_evolve_single() + + def _test_evolve_single(self): + """ + Unittests for the function evolve_single + """ + + CUSTOM_LOGGING_STRING_MASSES = """ + Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", + // + stardata->model.time, // 1 + + // masses + stardata->common.zero_age.mass[0], // + stardata->common.zero_age.mass[1], // + + stardata->star[0].mass, + stardata->star[1].mass + ); + """ + + test_pop = Population() + test_pop.set( + M_1=10, + M_2=5, + orbital_period=100000, + metallicty=0.02, + max_evolution_time=15000, + verbosity=TEST_VERBOSITY, + ) + + test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_MASSES) + + output = test_pop.evolve_single() + + # + self.assertTrue(len(output.splitlines()) > 1) + self.assertIn("TEST_CUSTOM_LOGGING_1", output) + + # + custom_logging_dict = {"TEST_CUSTOM_LOGGING_2": ["star[0].mass", "model.time"]} + test_pop_2 = Population() + test_pop_2.set( + M_1=10, + M_2=5, + orbital_period=100000, + metallicty=0.02, + max_evolution_time=15000, + verbosity=TEST_VERBOSITY, + ) + + test_pop_2.set(C_auto_logging=custom_logging_dict) + + output_2 = test_pop_2.evolve_single() + + # + self.assertTrue(len(output_2.splitlines()) > 1) + self.assertIn("TEST_CUSTOM_LOGGING_2", output_2) + + +class test_grid_evolve(unittest.TestCase): + """ + Unittests for function Population.evolve() + """ + + def test_grid_evolve_1_thread(self): + with Capturing() as output: + self._test_grid_evolve_1_thread() + + def _test_grid_evolve_1_thread(self): + """ + Unittests to see if 1 thread does all the systems + """ + + test_pop_evolve_1_thread = Population() + test_pop_evolve_1_thread.set( + num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY + ) + + resolution = {"M_1": 10} + + test_pop_evolve_1_thread.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop_evolve_1_thread.evolve() + self.assertLess( + np.abs(analytics["total_probability"] - 0.10820655287892997), + 1e-10, + msg=analytics["total_probability"], + ) + self.assertTrue(analytics["total_count"] == 10) + + def test_grid_evolve_2_threads(self): + with Capturing() as output: + self._test_grid_evolve_2_threads() + + def _test_grid_evolve_2_threads(self): + """ + Unittests to see if multiple threads handle the all the systems correctly + """ + + test_pop = Population() + test_pop.set( + num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY + ) + + resolution = {"M_1": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + self.assertLess( + np.abs(analytics["total_probability"] - 0.10820655287892997), + 1e-10, + msg=analytics["total_probability"], + ) # + self.assertTrue(analytics["total_count"] == 10) + + def test_grid_evolve_2_threads_with_custom_logging(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_custom_logging() + + def _test_grid_evolve_2_threads_with_custom_logging(self): + """ + Unittests to see if multiple threads do the custom logging correctly + """ + + data_dir_value = os.path.join(TMP_DIR, "grid_tests") + num_cores_value = 2 + custom_logging_string = 'Printf("MY_STELLAR_DATA_TEST_EXAMPLE %g %g %g %g\\n",((double)stardata->model.time),((double)stardata->star[0].mass),((double)stardata->model.probability),((double)stardata->model.dt));' + + test_pop = Population() + + test_pop.set( + num_cores=num_cores_value, + verbosity=TEST_VERBOSITY, + M_2=1, + orbital_period=100000, + data_dir=data_dir_value, + C_logging_code=custom_logging_string, # input it like this. + parse_function=parse_function_test_grid_evolve_2_threads_with_custom_logging, + ) + test_pop.set(ensemble=0) + resolution = {"M_1": 2} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + output_names = [ + os.path.join( + data_dir_value, + "test_grid_evolve_2_threads_with_custom_logging_outputfile_population_{}_thread_{}.dat".format( + analytics["population_name"], thread_id + ), + ) + for thread_id in range(num_cores_value) + ] + + for output_name in output_names: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, "r") as f: + output_string = f.read() + + self.assertIn("MY_STELLAR_DATA_TEST_EXAMPLE", output_string) + + remove_file(output_name) + + def test_grid_evolve_with_condition_error(self): + with Capturing() as output: + self._test_grid_evolve_with_condition_error() + + def _test_grid_evolve_with_condition_error(self): + """ + Unittests to see if the threads catch the errors correctly. + """ + + test_pop = Population() + test_pop.set( + num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY + ) + + # Set the amt of failed systems that each thread will log + test_pop.set(failed_systems_threshold=4) + + CUSTOM_LOGGING_STRING_WITH_EXIT = """ +Exit_binary_c(BINARY_C_NORMAL_EXIT, "testing exits. This is part of the testing, don't worry"); +Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n", + // + stardata->model.time, // 1 + + // masses + stardata->common.zero_age.mass[0], // + stardata->common.zero_age.mass[1], // + + stardata->star[0].mass, + stardata->star[1].mass +); + """ + + test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) + + resolution = {"M_1": 10} + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + self.assertLess( + np.abs(analytics["total_probability"] - 0.10820655287892997), + 1e-10, + msg=analytics["total_probability"], + ) # + self.assertEqual(analytics["failed_systems_error_codes"], [0]) + self.assertTrue(analytics["total_count"] == 10) + self.assertTrue(analytics["failed_count"] == 10) + self.assertTrue(analytics["errors_found"] == True) + self.assertTrue(analytics["errors_exceeded"] == True) + + # test to see if 1 thread does all the systems + + test_pop = Population() + test_pop.set( + num_cores=2, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY + ) + test_pop.set(failed_systems_threshold=4) + test_pop.set(C_logging_code=CUSTOM_LOGGING_STRING_WITH_EXIT) + + resolution = {"M_1": 10, "q": 2} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + test_pop.add_grid_variable( + name="q", + longname="Mass ratio", + valuerange=["0.1/M_1", 1], + samplerfunc="const(0.1/M_1, 1, {})".format(resolution["q"]), + probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])", + dphasevol="dq", + precode="M_2 = q * M_1", + parameter_name="M_2", + # condition="M_1 in dir()", # Impose a condition on this grid variable. Mostly for a check for yourself + condition="'random_var' in dir()", # This will raise an error because random_var is not defined. + ) + + # TODO: why should it raise this error? It should probably raise a valueerror when the limit is exceeded right? + # DEcided to turn it off for now because there is not raise VAlueError in that chain of functions. + # NOTE: Found out why this test was here. It is to do with the condition random_var in dir(), but I changed the behaviour from raising an error to continue. This has to do with the moe&distefano code that will loop over several multiplicities + # TODO: make sure the continue behaviour is what we actually want. + + # self.assertRaises(ValueError, test_pop.evolve) + + def test_grid_evolve_no_grid_variables(self): + with Capturing() as output: + self._test_grid_evolve_no_grid_variables() + + def _test_grid_evolve_no_grid_variables(self): + """ + Unittests to see if errors are raised if there are no grid variables + """ + + test_pop = Population() + test_pop.set( + num_cores=1, M_2=1, orbital_period=100000, verbosity=TEST_VERBOSITY + ) + + resolution = {"M_1": 10} + self.assertRaises(ValueError, test_pop.evolve) + + def test_grid_evolve_2_threads_with_ensemble_direct_output(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_direct_output() + + def _test_grid_evolve_2_threads_with_ensemble_direct_output(self): + """ + Unittests to see if multiple threads output the ensemble information to files correctly + """ + + data_dir_value = TMP_DIR + num_cores_value = 2 + + test_pop = Population() + test_pop.set( + num_cores=num_cores_value, + verbosity=TEST_VERBOSITY, + M_2=1, + orbital_period=100000, + ensemble=1, + ensemble_defer=1, + ensemble_filters_off=1, + ensemble_filter_STELLAR_TYPE_COUNTS=1, + ensemble_dt=1000, + ) + test_pop.set( + data_dir=TMP_DIR, + ensemble_output_name="ensemble_output.json", + combine_ensemble_with_thread_joining=False, + ) + + resolution = {"M_1": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + output_names = [ + os.path.join( + data_dir_value, + "ensemble_output_{}_{}.json".format( + analytics["population_name"], thread_id + ), + ) + for thread_id in range(num_cores_value) + ] + + for output_name in output_names: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, "r") as f: + file_content = f.read() + + ensemble_json = json.loads(file_content) + + self.assertTrue(isinstance(ensemble_json, dict)) + self.assertNotEqual(ensemble_json, {}) + + self.assertIn("number_counts", ensemble_json) + self.assertNotEqual(ensemble_json["number_counts"], {}) + + def test_grid_evolve_2_threads_with_ensemble_combining(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_combining() + + def _test_grid_evolve_2_threads_with_ensemble_combining(self): + """ + Unittests to see if multiple threads correclty combine the ensemble data and store them in the grid + """ + + data_dir_value = TMP_DIR + num_cores_value = 2 + + test_pop = Population() + test_pop.set( + num_cores=num_cores_value, + verbosity=TEST_VERBOSITY, + M_2=1, + orbital_period=100000, + ensemble=1, + ensemble_defer=1, + ensemble_filters_off=1, + ensemble_filter_STELLAR_TYPE_COUNTS=1, + ensemble_dt=1000, + ) + test_pop.set( + data_dir=TMP_DIR, + combine_ensemble_with_thread_joining=True, + ensemble_output_name="ensemble_output.json", + ) + + resolution = {"M_1": 10} + + test_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics = test_pop.evolve() + + self.assertTrue(isinstance(test_pop.grid_ensemble_results["ensemble"], dict)) + self.assertNotEqual(test_pop.grid_ensemble_results["ensemble"], {}) + + self.assertIn("number_counts", test_pop.grid_ensemble_results["ensemble"]) + self.assertNotEqual( + test_pop.grid_ensemble_results["ensemble"]["number_counts"], {} + ) + + def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): + with Capturing() as output: + self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods() + + def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self): + """ + Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same + """ + + data_dir_value = TMP_DIR + num_cores_value = 2 + + # First + test_pop_1 = Population() + test_pop_1.set( + num_cores=num_cores_value, + verbosity=TEST_VERBOSITY, + M_2=1, + orbital_period=100000, + ensemble=1, + ensemble_defer=1, + ensemble_filters_off=1, + ensemble_filter_STELLAR_TYPE_COUNTS=1, + ensemble_dt=1000, + ) + test_pop_1.set( + data_dir=TMP_DIR, + combine_ensemble_with_thread_joining=True, + ensemble_output_name="ensemble_output.json", + ) + + resolution = {"M_1": 10} + + test_pop_1.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics_1 = test_pop_1.evolve() + ensemble_output_1 = test_pop_1.grid_ensemble_results + + # second + test_pop_2 = Population() + test_pop_2.set( + num_cores=num_cores_value, + verbosity=TEST_VERBOSITY, + M_2=1, + orbital_period=100000, + ensemble=1, + ensemble_defer=1, + ensemble_filters_off=1, + ensemble_filter_STELLAR_TYPE_COUNTS=1, + ensemble_dt=1000, + ) + test_pop_2.set( + data_dir=TMP_DIR, + ensemble_output_name="ensemble_output.json", + combine_ensemble_with_thread_joining=False, + ) + + resolution = {"M_1": 10} + + test_pop_2.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[1, 100], + samplerfunc="const(math.log(1), math.log(100), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + analytics_2 = test_pop_2.evolve() + output_names_2 = [ + os.path.join( + data_dir_value, + "ensemble_output_{}_{}.json".format( + analytics_2["population_name"], thread_id + ), + ) + for thread_id in range(num_cores_value) + ] + ensemble_output_2 = {} + + for output_name in output_names_2: + self.assertTrue(os.path.isfile(output_name)) + + with open(output_name, "r") as f: + file_content = f.read() + + ensemble_json = json.loads(file_content) + + ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json) + + for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]: + self.assertIn(key, ensemble_output_2["number_counts"]["stellar_type"]["0"]) + + # compare values + self.assertLess( + np.abs( + ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"][ + key + ] + - ensemble_output_2["number_counts"]["stellar_type"]["0"][key] + ), + 1e-8, + ) + + +def parse_function_adding_results(self, output): + """ + Example parse function + """ + + seperator = " " + + parameters = ["time", "mass", "zams_mass", "probability", "stellar_type"] + + self.grid_results["example"]["count"] += 1 + + # Go over the output. + for line in output.splitlines(): + headerline = line.split()[0] + + # CHeck the header and act accordingly + if headerline == "EXAMPLE_OUTPUT": + values = line.split()[1:] + + # Bin the mass probability + self.grid_results["example"]["mass"][ + bin_data(float(values[2]), binwidth=0.5) + ] += float(values[3]) + + # + if not len(parameters) == len(values): + print("Number of column names isnt equal to number of columns") + raise ValueError + + # record the probability of this line (Beware, this is meant to only be run once for each system. its a controls quantity) + self.grid_results["example"]["probability"] += float(values[3]) + + +class test_resultdict(unittest.TestCase): + """ + Unittests for bin_data + """ + + def test_adding_results(self): + """ + Function to test whether the results are properly added and combined + """ + + # Create custom logging statement + custom_logging_statement = """ + if (stardata->model.time < stardata->model.max_evolution_time) + { + Printf("EXAMPLE_OUTPUT %30.16e %g %g %30.12e %d\\n", + // + stardata->model.time, // 1 + stardata->star[0].mass, // 2 + stardata->common.zero_age.mass[0], // 3 + stardata->model.probability, // 4 + stardata->star[0].stellar_type // 5 + ); + }; + /* Kill the simulation to save time */ + stardata->model.max_evolution_time = stardata->model.time - stardata->model.dtm; + """ + + example_pop = Population() + example_pop.set(verbosity=0) + example_pop.set( + max_evolution_time=15000, # bse_options + # grid_options + num_cores=3, + tmp_dir=TMP_DIR, + # Custom options + data_dir=os.path.join(TMP_DIR, "test_resultdict"), # custom_options + C_logging_code=custom_logging_statement, + parse_function=parse_function_adding_results, + ) + + # Add grid variables + resolution = {"M_1": 10} + + # Mass + example_pop.add_grid_variable( + name="lnm1", + longname="Primary mass", + valuerange=[2, 150], + samplerfunc="const(math.log(2), math.log(150), {})".format( + resolution["M_1"] + ), + precode="M_1=math.exp(lnm1)", + probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 150, -1.3, -2.3, -2.3)*M_1", + dphasevol="dlnm1", + parameter_name="M_1", + condition="", # Impose a condition on this grid variable. Mostly for a check for yourself + ) + + ## Executing a population + ## This uses the values generated by the grid_variables + analytics = example_pop.evolve() + + # + grid_prob = analytics["total_probability"] + result_dict_prob = example_pop.grid_results["example"]["probability"] + + # amt systems + grid_count = analytics["total_count"] + result_dict_count = example_pop.grid_results["example"]["count"] + + # Check if the total probability matches + self.assertAlmostEqual( + grid_prob, + result_dict_prob, + places=12, + msg="Total probability from grid {} and from result dict {} are not equal".format( + grid_prob, result_dict_prob + ), + ) + + # Check if the total count matches + self.assertEqual( + grid_count, + result_dict_count, + msg="Total count from grid {} and from result dict {} are not equal".format( + grid_count, result_dict_count + ), + ) + + # Check if the structure is what we expect. Note: this depends on the probability calculation. if that changes we need to recalibrate this + test_case_dict = { + 2.25: 0.01895481306515, + 3.75: 0.01081338190204, + 5.75: 0.006168841009268, + 9.25: 0.003519213484031, + 13.75: 0.002007648361756, + 21.25: 0.001145327489437, + 33.25: 0.0006533888518775, + 50.75: 0.0003727466560393, + 78.25: 0.000212645301782, + 120.75: 0.0001213103421247, + } + + self.assertEqual( + test_case_dict, dict(example_pop.grid_results["example"]["mass"]) + ) + + +if __name__ == "__main__": + unittest.main()