Skip to content
Snippets Groups Projects
Commit 3fbea66f authored by David Hendriks's avatar David Hendriks
Browse files

black formatting

parent 08681f34
No related branches found
No related tags found
No related merge requests found
...@@ -30,18 +30,16 @@ class TestDistributions(unittest.TestCase): ...@@ -30,18 +30,16 @@ class TestDistributions(unittest.TestCase):
Unittest for function set_opts Unittest for function set_opts
""" """
default_dict = {'m1': 2, 'm2': 3} default_dict = {"m1": 2, "m2": 3}
output_dict_1 = set_opts(default_dict, {}) output_dict_1 = set_opts(default_dict, {})
self.assertTrue(output_dict_1==default_dict) self.assertTrue(output_dict_1 == default_dict)
new_opts = {"m1": 10}
new_opts = {'m1': 10}
output_dict_2 = set_opts(default_dict, new_opts) output_dict_2 = set_opts(default_dict, new_opts)
updated_dict = default_dict.copy() updated_dict = default_dict.copy()
updated_dict['m1'] = 10 updated_dict["m1"] = 10
self.assertTrue(output_dict_2==updated_dict)
self.assertTrue(output_dict_2 == updated_dict)
def test_flat(self): def test_flat(self):
""" """
...@@ -69,12 +67,14 @@ class TestDistributions(unittest.TestCase): ...@@ -69,12 +67,14 @@ class TestDistributions(unittest.TestCase):
""" """
output_1 = const(min_bound=0, max_bound=2) output_1 = const(min_bound=0, max_bound=2)
self.assertEqual(output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1)) self.assertEqual(
output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1)
)
output_2 = const(min_bound=0, max_bound=2, val=3) output_2 = const(min_bound=0, max_bound=2, val=3)
self.assertEqual(output_2, 0, msg="Value should be 0, but is {}".format(output_2)) self.assertEqual(
output_2, 0, msg="Value should be 0, but is {}".format(output_2)
)
def test_powerlaw(self): def test_powerlaw(self):
""" """
...@@ -127,8 +127,10 @@ class TestDistributions(unittest.TestCase): ...@@ -127,8 +127,10 @@ class TestDistributions(unittest.TestCase):
# Extra test: # Extra test:
# M < M0 # M < M0
self.assertTrue(three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3)==0, msg="Probability should be zero as M < M0") self.assertTrue(
three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3) == 0,
msg="Probability should be zero as M < M0",
)
def test_Kroupa2001(self): def test_Kroupa2001(self):
""" """
...@@ -152,8 +154,10 @@ class TestDistributions(unittest.TestCase): ...@@ -152,8 +154,10 @@ class TestDistributions(unittest.TestCase):
for i in range(len(python_results)): for i in range(len(python_results)):
self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
# Extra tests: # Extra tests:
self.assertEqual(Kroupa2001(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3)) self.assertEqual(
Kroupa2001(10, newopts={"mmax": 300}),
three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3),
)
def test_ktg93(self): def test_ktg93(self):
""" """
...@@ -177,8 +181,10 @@ class TestDistributions(unittest.TestCase): ...@@ -177,8 +181,10 @@ class TestDistributions(unittest.TestCase):
for i in range(len(python_results)): for i in range(len(python_results)):
self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
# extra test: # extra test:
self.assertEqual(ktg93(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7)) self.assertEqual(
ktg93(10, newopts={"mmax": 300}),
three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7),
)
def test_imf_tinsley1980(self): def test_imf_tinsley1980(self):
""" """
...@@ -186,7 +192,10 @@ class TestDistributions(unittest.TestCase): ...@@ -186,7 +192,10 @@ class TestDistributions(unittest.TestCase):
""" """
m = 1.2 m = 1.2
self.assertEqual(imf_tinsley1980(m), three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3)) self.assertEqual(
imf_tinsley1980(m),
three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3),
)
def test_imf_scalo1986(self): def test_imf_scalo1986(self):
""" """
...@@ -194,8 +203,10 @@ class TestDistributions(unittest.TestCase): ...@@ -194,8 +203,10 @@ class TestDistributions(unittest.TestCase):
""" """
m = 1.2 m = 1.2
self.assertEqual(imf_scalo1986(m), three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70)) self.assertEqual(
imf_scalo1986(m),
three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70),
)
def test_imf_scalo1998(self): def test_imf_scalo1998(self):
""" """
...@@ -203,8 +214,10 @@ class TestDistributions(unittest.TestCase): ...@@ -203,8 +214,10 @@ class TestDistributions(unittest.TestCase):
""" """
m = 1.2 m = 1.2
self.assertEqual(imf_scalo1998(m), three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3)) self.assertEqual(
imf_scalo1998(m),
three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3),
)
def test_imf_chabrier2003(self): def test_imf_chabrier2003(self):
""" """
...@@ -216,11 +229,19 @@ class TestDistributions(unittest.TestCase): ...@@ -216,11 +229,19 @@ class TestDistributions(unittest.TestCase):
# for m=0.5 # for m=0.5
m = 0.5 m = 0.5
self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") self.assertLess(
np.abs(imf_chabrier2003(m) - 0.581457346702825),
self.tolerance,
msg="Difference is bigger than the tolerance",
)
# For m = 2 # For m = 2
m = 2 m = 2
self.assertLess(np.abs(imf_chabrier2003(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance") self.assertLess(
np.abs(imf_chabrier2003(m) - 0.581457346702825),
self.tolerance,
msg="Difference is bigger than the tolerance",
)
def test_duquennoy1991(self): def test_duquennoy1991(self):
""" """
...@@ -229,7 +250,6 @@ class TestDistributions(unittest.TestCase): ...@@ -229,7 +250,6 @@ class TestDistributions(unittest.TestCase):
self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12)) self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12))
def test_gaussian(self): def test_gaussian(self):
""" """
unittest for three_part_power_law unittest for three_part_power_law
...@@ -253,7 +273,10 @@ class TestDistributions(unittest.TestCase): ...@@ -253,7 +273,10 @@ class TestDistributions(unittest.TestCase):
self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance) self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
# Extra test: # Extra test:
self.assertTrue(gaussian(15, 4.8, 2.3, -2.0, 12.0)==0, msg="Probability should be 0 because the input period is out of bounds") self.assertTrue(
gaussian(15, 4.8, 2.3, -2.0, 12.0) == 0,
msg="Probability should be 0 because the input period is out of bounds",
)
def test_Arenou2010_binary_fraction(self): def test_Arenou2010_binary_fraction(self):
""" """
......
This diff is collapsed.
"""
Test cases for the grid
Tasks:
TODO: write tests for load_from_sourcefile
"""
import os
import sys import sys
import json
import unittest import unittest
import tempfile import tempfile
import datetime
from binarycpython.utils.grid import Population from binarycpython.utils.grid import Population
from binarycpython.utils.functions import temp_dir
binary_c_temp_dir = temp_dir()
# class test_(unittest.TestCase):
# """
# Unittests for function
# """
# def test_1(self):
# pass
# def test_(self):
# """
# Unittests for the function
# """
class test_Population(unittest.TestCase): class test_Population(unittest.TestCase):
""" """
...@@ -12,44 +40,315 @@ class test_Population(unittest.TestCase): ...@@ -12,44 +40,315 @@ class test_Population(unittest.TestCase):
def test_setup(self): def test_setup(self):
test_pop = Population() test_pop = Population()
self.assertTrue('orbital_period' in test_pop.defaults) self.assertTrue("orbital_period" in test_pop.defaults)
self.assertTrue('metallicity' in test_pop.defaults) self.assertTrue("metallicity" in test_pop.defaults)
self.assertNotIn('help_all', test_pop.cleaned_up_defaults) self.assertNotIn("help_all", test_pop.cleaned_up_defaults)
self.assertEqual(test_pop.bse_options, {}) self.assertEqual(test_pop.bse_options, {})
self.assertEqual(test_pop.custom_options, {}) self.assertEqual(test_pop.custom_options, {})
self.assertEqual(test_pop.argline_dict, {}) self.assertEqual(test_pop.argline_dict, {})
self.assertEqual(test_pop.persistent_data_memory_dict, {}) self.assertEqual(test_pop.persistent_data_memory_dict, {})
self.assertTrue(test_pop.grid_options['parse_function']==None) self.assertTrue(test_pop.grid_options["parse_function"] == None)
self.assertTrue(isinstance(test_pop.grid_options['_main_pid'], int)) self.assertTrue(isinstance(test_pop.grid_options["_main_pid"], int))
def test_set(self): def test_set(self):
test_pop = Population() test_pop = Population()
test_pop.set(amt_cores=2) test_pop.set(amt_cores=2)
test_pop.set(M_1=10) test_pop.set(M_1=10)
test_pop.set(data_dir='/tmp/binary_c_python') test_pop.set(data_dir="/tmp/binary_c_python")
test_pop.set(ensemble_filter_SUPERNOVAE=1) test_pop.set(ensemble_filter_SUPERNOVAE=1)
self.assertIn('data_dir', test_pop.custom_options) self.assertIn("data_dir", test_pop.custom_options)
self.assertEqual(test_pop.custom_options['data_dir'], '/tmp/binary_c_python') self.assertEqual(test_pop.custom_options["data_dir"], "/tmp/binary_c_python")
#
self.assertTrue(test_pop.bse_options['M_1']==10)
self.assertTrue(test_pop.bse_options['ensemble_filter_SUPERNOVAE']==1)
# #
self.assertTrue(test_pop.grid_options['amt_cores']==2) self.assertTrue(test_pop.bse_options["M_1"] == 10)
self.assertTrue(test_pop.bse_options["ensemble_filter_SUPERNOVAE"] == 1)
#
self.assertTrue(test_pop.grid_options["amt_cores"] == 2)
def test_cmdline(self): def test_cmdline(self):
cmdline_arg = '--cmdline \"metallicity=0.0002\"' # copy old sys.argv values
sys.argv = ['script', '--cmdline', "metallicity=0.0002"] prev_sysargv = sys.argv.copy()
# make a dummy cmdline arg input
sys.argv = [
"script",
"--cmdline",
"metallicity=0.0002 amt_cores=2 data_dir=/tmp/binary_c_python",
]
# Set up population
test_pop = Population() test_pop = Population()
test_pop.set(data_dir="/tmp")
# parse arguments
test_pop.parse_cmdline() test_pop.parse_cmdline()
print(test_pop.bse_options)
self.assertTrue(test_pop.bse_options['metallicity']==0.0002) # metallicity
self.assertTrue(isinstance(test_pop.bse_options["metallicity"], str))
self.assertTrue(test_pop.bse_options["metallicity"] == "0.0002")
# Amt cores
self.assertTrue(isinstance(test_pop.grid_options["amt_cores"], int))
self.assertTrue(test_pop.grid_options["amt_cores"] == 2)
# datadir
self.assertTrue(isinstance(test_pop.custom_options["data_dir"], str))
self.assertTrue(test_pop.custom_options["data_dir"] == "/tmp/binary_c_python")
# put back the other args if they exist
sys.argv = prev_sysargv.copy()
def test__return_argline(self):
"""
Unittests for the function _return_argline
"""
# Set up population
test_pop = Population()
test_pop.set(metallicity=0.02)
test_pop.set(M_1=10)
argline = test_pop._return_argline()
self.assertTrue(argline == "binary_c M_1 10 metallicity 0.02")
# custom dict
argline2 = test_pop._return_argline(
{"example_parameter1": 10, "example_parameter2": "hello"}
)
self.assertTrue(
argline2 == "binary_c example_parameter1 10 example_parameter2 hello"
)
def test_add_grid_variable(self):
"""
Unittests for the function add_grid_variable
TODO: Should I test more here?
"""
test_pop = Population()
resolution = {"M_1": 10, "q": 10}
test_pop.add_grid_variable(
name="lnm1",
longname="Primary mass",
valuerange=[1, 100],
resolution="{}".format(resolution["M_1"]),
spacingfunc="const(math.log(1), math.log(100), {})".format(
resolution["M_1"]
),
precode="M_1=math.exp(lnm1)",
probdist="three_part_powerlaw(M_1, 0.1, 0.5, 1.0, 100, -1.3, -2.3, -2.3)*M_1",
dphasevol="dlnm1",
parameter_name="M_1",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
test_pop.add_grid_variable(
name="q",
longname="Mass ratio",
valuerange=["0.1/M_1", 1],
resolution="{}".format(resolution["q"]),
spacingfunc="const(0.1/M_1, 1, {})".format(resolution["q"]),
probdist="flatsections(q, [{'min': 0.1/M_1, 'max': 1.0, 'height': 1}])",
dphasevol="dq",
precode="M_2 = q * M_1",
parameter_name="M_2",
condition="", # Impose a condition on this grid variable. Mostly for a check for yourself
)
self.assertIn("q", test_pop.grid_options["_grid_variables"])
self.assertIn("lnm1", test_pop.grid_options["_grid_variables"])
self.assertEqual(len(test_pop.grid_options["_grid_variables"]), 2)
def test_return_population_settings(self):
"""
Unittests for the function return_population_settings
"""
test_pop = Population()
test_pop.set(metallicity=0.02)
test_pop.set(M_1=10)
test_pop.set(amt_cores=2)
test_pop.set(data_dir="/tmp")
population_settings = test_pop.return_population_settings()
self.assertIn("bse_options", population_settings)
self.assertTrue(population_settings["bse_options"]["metallicity"] == 0.02)
self.assertTrue(population_settings["bse_options"]["M_1"] == 10)
self.assertIn("grid_options", population_settings)
self.assertTrue(population_settings["grid_options"]["amt_cores"] == 2)
self.assertIn("custom_options", population_settings)
self.assertTrue(population_settings["custom_options"]["data_dir"] == "/tmp")
def test__return_binary_c_version_info(self):
"""
Unittests for the function _return_binary_c_version_info
"""
test_pop = Population()
binary_c_version_info = test_pop._return_binary_c_version_info(parsed=True)
self.assertTrue(isinstance(binary_c_version_info, dict))
self.assertIn("isotopes", binary_c_version_info)
self.assertIn("argpairs", binary_c_version_info)
self.assertIn("ensembles", binary_c_version_info)
self.assertIn("macros", binary_c_version_info)
self.assertIn("dt_limits", binary_c_version_info)
self.assertIn("nucleosynthesis_sources", binary_c_version_info)
self.assertIn("miscellaneous", binary_c_version_info)
self.assertIsNotNone(binary_c_version_info["isotopes"])
self.assertIsNotNone(binary_c_version_info["argpairs"])
self.assertIsNotNone(binary_c_version_info["ensembles"])
self.assertIsNotNone(binary_c_version_info["macros"])
self.assertIsNotNone(binary_c_version_info["dt_limits"])
self.assertIsNotNone(binary_c_version_info["nucleosynthesis_sources"])
self.assertIsNotNone(binary_c_version_info["miscellaneous"])
def test__return_binary_c_defaults(self):
"""
Unittests for the function _return_binary_c_defaults
"""
test_pop = Population()
binary_c_defaults = test_pop._return_binary_c_defaults()
self.assertIn("probability", binary_c_defaults)
self.assertIn("phasevol", binary_c_defaults)
self.assertIn("metallicity", binary_c_defaults)
def test_return_all_info(self):
"""
Unittests for the function return_all_info
Not going to do too much tests here, just check if they are not empty
"""
test_pop = Population()
all_info = test_pop.return_all_info()
self.assertIn("population_settings", all_info)
self.assertIn("binary_c_defaults", all_info)
self.assertIn("binary_c_version_info", all_info)
self.assertIn("binary_c_help_all", all_info)
self.assertNotEqual(all_info["population_settings"], {})
self.assertNotEqual(all_info["binary_c_defaults"], {})
self.assertNotEqual(all_info["binary_c_version_info"], {})
self.assertNotEqual(all_info["binary_c_help_all"], {})
def test_export_all_info(self):
"""
Unittests for the function export_all_info
"""
test_pop = Population()
test_pop.set(metallicity=0.02)
test_pop.set(M_1=10)
test_pop.set(amt_cores=2)
test_pop.set(data_dir=binary_c_temp_dir)
# datadir
settings_filename = test_pop.export_all_info(use_datadir=True)
self.assertTrue(os.path.isfile(settings_filename))
with open(settings_filename, "r") as f:
all_info = json.loads(f.read())
#
self.assertIn("population_settings", all_info)
self.assertIn("binary_c_defaults", all_info)
self.assertIn("binary_c_version_info", all_info)
self.assertIn("binary_c_help_all", all_info)
#
self.assertNotEqual(all_info["population_settings"], {})
self.assertNotEqual(all_info["binary_c_defaults"], {})
self.assertNotEqual(all_info["binary_c_version_info"], {})
self.assertNotEqual(all_info["binary_c_help_all"], {})
# custom name
# datadir
settings_filename = test_pop.export_all_info(
use_datadir=False,
outfile=os.path.join(binary_c_temp_dir, "example_settings.json"),
)
self.assertTrue(os.path.isfile(settings_filename))
with open(settings_filename, "r") as f:
all_info = json.loads(f.read())
#
self.assertIn("population_settings", all_info)
self.assertIn("binary_c_defaults", all_info)
self.assertIn("binary_c_version_info", all_info)
self.assertIn("binary_c_help_all", all_info)
#
self.assertNotEqual(all_info["population_settings"], {})
self.assertNotEqual(all_info["binary_c_defaults"], {})
self.assertNotEqual(all_info["binary_c_version_info"], {})
self.assertNotEqual(all_info["binary_c_help_all"], {})
# wrong filename
self.assertRaises(
ValueError,
test_pop.export_all_info,
use_datadir=False,
outfile=os.path.join(binary_c_temp_dir, "example_settings.txt"),
)
def test__cleanup_defaults(self):
"""
Unittests for the function _cleanup_defaults
"""
test_pop = Population()
cleaned_up_defaults = test_pop._cleanup_defaults()
self.assertNotIn("help_all", cleaned_up_defaults)
def test__increment_probtot(self):
"""
Unittests for the function _increment_probtot
"""
test_pop = Population()
test_pop._increment_probtot(0.5)
self.assertEqual(test_pop.grid_options["_probtot"], 0.5)
def test__increment_count(self):
"""
Unittests for the function _increment_probtot
"""
test_pop = Population()
test_pop._increment_count()
self.assertEqual(test_pop.grid_options["_count"], 1)
def test__dict_from_line_source_file(self):
"""
Unittests for the function _dict_from_line_source_file
"""
source_file = os.path.join(binary_c_temp_dir, "example_source_file.txt")
# write
with open(source_file, "w") as f:
f.write("binary_c M_1 10 metallicity 0.02\n")
test_pop = Population()
# readout
with open(source_file, "r") as f:
for line in f.readlines():
argdict = test_pop._dict_from_line_source_file(line)
self.assertTrue(argdict["M_1"] == 10)
self.assertTrue(argdict["metallicity"] == 0.02)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -47,27 +47,27 @@ class test_grid_options_defaults(unittest.TestCase): ...@@ -47,27 +47,27 @@ class test_grid_options_defaults(unittest.TestCase):
output_1 = grid_options_description_checker(print_info=True) output_1 = grid_options_description_checker(print_info=True)
self.assertTrue(isinstance(output_1, int)) self.assertTrue(isinstance(output_1, int))
self.assertTrue(output_1>0) self.assertTrue(output_1 > 0)
def test_write_grid_options_to_rst_file(self): def test_write_grid_options_to_rst_file(self):
""" """
Unit tests for the grid_options_description_checker function Unit tests for the grid_options_description_checker function
""" """
input_1 = os.path.join(binary_c_temp_dir, "test_write_grid_options_to_rst_file_1.txt") input_1 = os.path.join(
binary_c_temp_dir, "test_write_grid_options_to_rst_file_1.txt"
)
output_1 = write_grid_options_to_rst_file(input_1) output_1 = write_grid_options_to_rst_file(input_1)
self.assertIsNone(output_1) self.assertIsNone(output_1)
input_2 = os.path.join(
input_2 = os.path.join(binary_c_temp_dir, "test_write_grid_options_to_rst_file_2.rst") binary_c_temp_dir, "test_write_grid_options_to_rst_file_2.rst"
)
output_2 = write_grid_options_to_rst_file(input_2) output_2 = write_grid_options_to_rst_file(input_2)
self.assertTrue(os.path.isfile(input_2)) self.assertTrue(os.path.isfile(input_2))
write_grid_options_to_rst_file write_grid_options_to_rst_file
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -11,16 +11,17 @@ import matplotlib.pyplot as plt ...@@ -11,16 +11,17 @@ import matplotlib.pyplot as plt
# def test_1(self): # def test_1(self):
# pass # pass
class test_color_by_index(unittest.TestCase): class test_color_by_index(unittest.TestCase):
""" """
Unittests for function color_by_index Unittests for function color_by_index
""" """
def test_1(self): def test_1(self):
colors = ['red', 'white', 'blue'] colors = ["red", "white", "blue"]
color = color_by_index([1,2,3], 1, colors) color = color_by_index([1, 2, 3], 1, colors)
self.assertTrue(color=='blue') self.assertTrue(color == "blue")
class test_plot_system(unittest.TestCase): class test_plot_system(unittest.TestCase):
...@@ -29,12 +30,20 @@ class test_plot_system(unittest.TestCase): ...@@ -29,12 +30,20 @@ class test_plot_system(unittest.TestCase):
""" """
def test_mass_evolution_plot(self): def test_mass_evolution_plot(self):
plot_type = 'mass_evolution' plot_type = "mass_evolution"
show_plot = False show_plot = False
output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) output_fig_1 = plot_system(
plot_type,
show_plot=show_plot,
M_1=1,
metallicity=0.002,
M_2=0.1,
separation=0,
orbital_period=100000000000,
)
fig, ax = plt.subplots(nrows=1) fig, ax = plt.subplots(nrows=1)
self.assertTrue(type(output_fig_1)==fig.__class__) self.assertTrue(type(output_fig_1) == fig.__class__)
# with stellar types # with stellar types
# plot_type = 'mass_evolution' # plot_type = 'mass_evolution'
...@@ -50,12 +59,20 @@ class test_plot_system(unittest.TestCase): ...@@ -50,12 +59,20 @@ class test_plot_system(unittest.TestCase):
# output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)
def test_orbit_evolution_plot(self): def test_orbit_evolution_plot(self):
plot_type = 'orbit_evolution' plot_type = "orbit_evolution"
show_plot = False show_plot = False
output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) output_fig_1 = plot_system(
plot_type,
show_plot=show_plot,
M_1=1,
metallicity=0.002,
M_2=0.1,
separation=0,
orbital_period=100000000000,
)
fig, ax = plt.subplots(nrows=1) fig, ax = plt.subplots(nrows=1)
self.assertTrue(type(output_fig_1)==fig.__class__) self.assertTrue(type(output_fig_1) == fig.__class__)
# with stellar types # with stellar types
# plot_type = 'orbit_evolution' # plot_type = 'orbit_evolution'
...@@ -71,12 +88,20 @@ class test_plot_system(unittest.TestCase): ...@@ -71,12 +88,20 @@ class test_plot_system(unittest.TestCase):
# output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)
def test_hr_diagram_plot(self): def test_hr_diagram_plot(self):
plot_type = 'hr_diagram' plot_type = "hr_diagram"
show_plot = False show_plot = False
output_fig_1 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) output_fig_1 = plot_system(
plot_type,
show_plot=show_plot,
M_1=1,
metallicity=0.002,
M_2=0.1,
separation=0,
orbital_period=100000000000,
)
fig, ax = plt.subplots(nrows=1) fig, ax = plt.subplots(nrows=1)
self.assertTrue(type(output_fig_1)==fig.__class__) self.assertTrue(type(output_fig_1) == fig.__class__)
# with stellar types # with stellar types
# plot_type = 'hr_diagram' # plot_type = 'hr_diagram'
...@@ -92,8 +117,9 @@ class test_plot_system(unittest.TestCase): ...@@ -92,8 +117,9 @@ class test_plot_system(unittest.TestCase):
# output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000) # output_fig_2 = plot_system(plot_type, show_plot=show_plot, M_1=1, metallicity=0.002, M_2=0.1, separation=0, orbital_period=100000000000)
def test_unknown_plottype(self): def test_unknown_plottype(self):
plot_type = 'random' plot_type = "random"
self.assertRaises(ValueError, plot_system, plot_type) self.assertRaises(ValueError, plot_system, plot_type)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import unittest import unittest
from binarycpython.utils.stellar_types import * from binarycpython.utils.stellar_types import *
\ No newline at end of file
...@@ -10,6 +10,7 @@ from binarycpython.utils.useful_funcs import * ...@@ -10,6 +10,7 @@ from binarycpython.utils.useful_funcs import *
# def test_1(self): # def test_1(self):
# pass # pass
class test_calc_period_from_sep(unittest.TestCase): class test_calc_period_from_sep(unittest.TestCase):
""" """
Unittests for function calc_period_from_sep Unittests for function calc_period_from_sep
...@@ -21,6 +22,7 @@ class test_calc_period_from_sep(unittest.TestCase): ...@@ -21,6 +22,7 @@ class test_calc_period_from_sep(unittest.TestCase):
output_1 = calc_period_from_sep(1, 1, 1) output_1 = calc_period_from_sep(1, 1, 1)
self.assertEqual(output_1, 0.08188845248066838) self.assertEqual(output_1, 0.08188845248066838)
class test_calc_sep_from_period(unittest.TestCase): class test_calc_sep_from_period(unittest.TestCase):
""" """
Unittests for function calc_sep_from_period Unittests for function calc_sep_from_period
...@@ -33,6 +35,7 @@ class test_calc_sep_from_period(unittest.TestCase): ...@@ -33,6 +35,7 @@ class test_calc_sep_from_period(unittest.TestCase):
output_1 = calc_sep_from_period(1, 1, 1) output_1 = calc_sep_from_period(1, 1, 1)
self.assertEqual(output_1, 5.302958446503317) self.assertEqual(output_1, 5.302958446503317)
class test_roche_lobe(unittest.TestCase): class test_roche_lobe(unittest.TestCase):
""" """
Unittests for function roche_lobe Unittests for function roche_lobe
...@@ -42,10 +45,11 @@ class test_roche_lobe(unittest.TestCase): ...@@ -42,10 +45,11 @@ class test_roche_lobe(unittest.TestCase):
mass_donor = 2 mass_donor = 2
mass_accretor = 1 mass_accretor = 1
output_1 = roche_lobe(mass_accretor/mass_donor) output_1 = roche_lobe(mass_accretor / mass_donor)
print(output_1) print(output_1)
self.assertLess(np.abs(output_1-0.3207881203346875), 1e-10) self.assertLess(np.abs(output_1 - 0.3207881203346875), 1e-10)
class test_ragb(unittest.TestCase): class test_ragb(unittest.TestCase):
""" """
...@@ -58,6 +62,7 @@ class test_ragb(unittest.TestCase): ...@@ -58,6 +62,7 @@ class test_ragb(unittest.TestCase):
self.assertEqual(output, 820) self.assertEqual(output, 820)
class test_rzams(unittest.TestCase): class test_rzams(unittest.TestCase):
""" """
Unittests for function rzams Unittests for function rzams
...@@ -68,19 +73,20 @@ class test_rzams(unittest.TestCase): ...@@ -68,19 +73,20 @@ class test_rzams(unittest.TestCase):
metallicity = 0.02 metallicity = 0.02
output_1 = rzams(mass, metallicity) output_1 = rzams(mass, metallicity)
self.assertLess(np.abs(output_1-0.458757762074762), 1e-7) self.assertLess(np.abs(output_1 - 0.458757762074762), 1e-7)
mass = 12.5 mass = 12.5
metallicity = 0.01241 metallicity = 0.01241
output_2 = rzams(mass, metallicity) output_2 = rzams(mass, metallicity)
self.assertLess(np.abs(output_2-4.20884329861741), 1e-7) self.assertLess(np.abs(output_2 - 4.20884329861741), 1e-7)
mass = 149 mass = 149
metallicity = 0.001241 metallicity = 0.001241
output_3 = rzams(mass, metallicity) output_3 = rzams(mass, metallicity)
self.assertLess(np.abs(output_3-12.8209978916491), 1e-7) self.assertLess(np.abs(output_3 - 12.8209978916491), 1e-7)
class test_zams_collission(unittest.TestCase): class test_zams_collission(unittest.TestCase):
""" """
...@@ -94,12 +100,17 @@ class test_zams_collission(unittest.TestCase): ...@@ -94,12 +100,17 @@ class test_zams_collission(unittest.TestCase):
eccentricity = 0 eccentricity = 0
metallicity = 0.02 metallicity = 0.02
output_collision_1 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) output_collision_1 = zams_collision(
self.assertTrue(output_collision_1==0) mass1, mass2, sep, eccentricity, metallicity
)
self.assertTrue(output_collision_1 == 0)
sep = 1 sep = 1
output_collision_2 = zams_collision(mass1, mass2, sep, eccentricity, metallicity) output_collision_2 = zams_collision(
self.assertTrue(output_collision_2==1) mass1, mass2, sep, eccentricity, metallicity
)
self.assertTrue(output_collision_2 == 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -73,6 +73,7 @@ def remove_file(file: str, verbosity: int = 0) -> None: ...@@ -73,6 +73,7 @@ def remove_file(file: str, verbosity: int = 0) -> None:
else: else:
verbose_print("File/directory {} doesn't exist. Can't remove it.", verbosity, 1) verbose_print("File/directory {} doesn't exist. Can't remove it.", verbosity, 1)
def temp_dir() -> str: def temp_dir() -> str:
""" """
Function to return the path the custom logging library shared object Function to return the path the custom logging library shared object
...@@ -190,6 +191,8 @@ def parse_binary_c_version_info(version_info_string: str) -> dict: ...@@ -190,6 +191,8 @@ def parse_binary_c_version_info(version_info_string: str) -> dict:
""" """
Function that parses the binary_c version info. Long function with a lot of branches Function that parses the binary_c version info. Long function with a lot of branches
TODO: fix this function. stuff is missing: isotopes, macros, nucleosynthesis_sources
Args: Args:
version_info_string: raw output of version_info call to binary_c version_info_string: raw output of version_info call to binary_c
......
...@@ -17,6 +17,7 @@ Tasks: ...@@ -17,6 +17,7 @@ Tasks:
- TODO: add functionality to return the ensemble_list - TODO: add functionality to return the ensemble_list
- TODO: consider spreading the functions over more files. - TODO: consider spreading the functions over more files.
- TODO: type the private functions - TODO: type the private functions
- TODO: fix the correct object types for the default values of the bse_options
""" """
import os import os
...@@ -189,9 +190,16 @@ class Population: ...@@ -189,9 +190,16 @@ class Population:
Function to handle settings values via the command line. Function to handle settings values via the command line.
Best to be called after all the .set(..) lines, and just before the .evolve() is called Best to be called after all the .set(..) lines, and just before the .evolve() is called
If you input any known parameter (i.e. contained in grid_options, defaults/bse_options or custom_options),
this function will attempt to convert the input from string (because everything is string) to the type of
the value that option had before.
The values of the bse_options are initially all strings, but after user input they can change to ints.
The value of any new parameter (which will go to custom_options) will be a string.
Tasks: Tasks:
- TODO: remove the need for --cmdline - TODO: remove the need for --cmdline
- TODO: fix that the input is converted to the correct type (i.e. type of the default value)
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -243,12 +251,24 @@ class Population: ...@@ -243,12 +251,24 @@ class Population:
# (attempt to) convert # (attempt to) convert
if old_value_found: if old_value_found:
try: try:
verbose_print("Converting type of {} from {} to {}".format(parameter, type(value), type(old_value)), self.grid_options['verbosity'], 1) verbose_print(
"Converting type of {} from {} to {}".format(
parameter, type(value), type(old_value)
),
self.grid_options["verbosity"],
1,
)
value = type(old_value)(value) value = type(old_value)(value)
verbose_print("Success!", self.grid_options['verbosity'], 1) verbose_print("Success!", self.grid_options["verbosity"], 1)
except ValueError as e: except ValueError as e:
verbose_print("Tried to convert the given parameter {}/value {} to its correct type {} (from old value {}). But that wasn't possible.".format(parameter, value, type(old_value), old_value), self.grid_options['verbosity'], 0) verbose_print(
"Tried to convert the given parameter {}/value {} to its correct type {} (from old value {}). But that wasn't possible.".format(
parameter, value, type(old_value), old_value
),
self.grid_options["verbosity"],
0,
)
# Add to dict # Add to dict
cmdline_dict[parameter] = value cmdline_dict[parameter] = value
...@@ -440,7 +460,7 @@ class Population: ...@@ -440,7 +460,7 @@ class Population:
include_binary_c_defaults: bool = True, include_binary_c_defaults: bool = True,
include_binary_c_version_info: bool = True, include_binary_c_version_info: bool = True,
include_binary_c_help_all: bool = True, include_binary_c_help_all: bool = True,
) -> None: ) -> Union[str, None]:
""" """
Function that exports the all_info to a json file Function that exports the all_info to a json file
...@@ -502,6 +522,7 @@ class Population: ...@@ -502,6 +522,7 @@ class Population:
default=binaryc_json_serializer, default=binaryc_json_serializer,
) )
) )
return settings_fullname
else: else:
verbose_print( verbose_print(
"Writing settings to {}".format(outfile), "Writing settings to {}".format(outfile),
...@@ -515,12 +536,14 @@ class Population: ...@@ -515,12 +536,14 @@ class Population:
0, 0,
) )
raise ValueError raise ValueError
with open(outfile, "w") as file: with open(outfile, "w") as file:
file.write( file.write(
json.dumps( json.dumps(
all_info_cleaned, indent=4, default=binaryc_json_serializer all_info_cleaned, indent=4, default=binaryc_json_serializer
) )
) )
return outfile
def _set_custom_logging(self): def _set_custom_logging(self):
""" """
...@@ -713,10 +736,12 @@ class Population: ...@@ -713,10 +736,12 @@ class Population:
# Log and print some information # Log and print some information
verbose_print( verbose_print(
"Population-{} finished! It took a total of {}s to run {} systems on {} cores".format( "Population-{} finished! It took a total of {}s to run {} systems on {} cores".format(
self.grid_options["_population_id"], self.grid_options["_population_id"],
self.grid_options['_end_time_evolution']-self.grid_options['_start_time_evolution'], self.grid_options["_end_time_evolution"]
self.grid_options['_total_starcount'], - self.grid_options["_start_time_evolution"],
self.grid_options['amt_cores']), self.grid_options["_total_starcount"],
self.grid_options["amt_cores"],
),
self.grid_options["verbosity"], self.grid_options["verbosity"],
0, 0,
) )
...@@ -744,7 +769,11 @@ class Population: ...@@ -744,7 +769,11 @@ class Population:
0, 0,
) )
else: else:
verbose_print("There were no errors found in this run.", self.grid_options["verbosity"], 0) verbose_print(
"There were no errors found in this run.",
self.grid_options["verbosity"],
0,
)
## ##
# Clean up code: remove files, unset values. # Clean up code: remove files, unset values.
...@@ -752,7 +781,8 @@ class Population: ...@@ -752,7 +781,8 @@ class Population:
def _process_run_population(self, ID): def _process_run_population(self, ID):
""" """
Function that loops over the whole generator, but only runs systems that fit to: if (localcounter+ID) % self.grid_options["amt_cores"] == 0 Function that loops over the whole generator, but only runs
systems that fit to: if (localcounter+ID) % self.grid_options["amt_cores"] == 0
That way with 4 processes, process 1 runs sytem 0, 4, 8... process 2 runs system 1, 5, 9..., etc That way with 4 processes, process 1 runs sytem 0, 4, 8... process 2 runs system 1, 5, 9..., etc
...@@ -771,9 +801,15 @@ class Population: ...@@ -771,9 +801,15 @@ class Population:
# Set up local variables # Set up local variables
running = True running = True
localcounter = 0 # global counter for the whole loop. (need to be ticked every loop) localcounter = (
probability_of_systems_run = 0 # counter for the probability of the actual systems this tread ran 0 # global counter for the whole loop. (need to be ticked every loop)
number_of_systems_run = 0 # counter for the actual amt of systems this thread ran )
probability_of_systems_run = (
0 # counter for the probability of the actual systems this tread ran
)
number_of_systems_run = (
0 # counter for the actual amt of systems this thread ran
)
verbose_print( verbose_print(
"Process {} started".format(ID), self.grid_options["verbosity"], 0 "Process {} started".format(ID), self.grid_options["verbosity"], 0
...@@ -803,9 +839,9 @@ class Population: ...@@ -803,9 +839,9 @@ class Population:
self._evolve_system_mp(full_system_dict) self._evolve_system_mp(full_system_dict)
# TODO: fix the 'repeat' and 'weight' tracking here # TODO: fix the 'repeat' and 'weight' tracking here
# Keep track of systems: # Keep track of systems:
probability_of_systems_run += full_system_dict['probability'] probability_of_systems_run += full_system_dict["probability"]
number_of_systems_run += 1 number_of_systems_run += 1
except StopIteration: except StopIteration:
running = False running = False
...@@ -813,7 +849,6 @@ class Population: ...@@ -813,7 +849,6 @@ class Population:
# Has to be here because this one is used for the (localcounter+ID) % (self..) # Has to be here because this one is used for the (localcounter+ID) % (self..)
localcounter += 1 localcounter += 1
# Return a set of results and errors # Return a set of results and errors
output_dict = { output_dict = {
"results": self.grid_options["results"], "results": self.grid_options["results"],
...@@ -824,10 +859,18 @@ class Population: ...@@ -824,10 +859,18 @@ class Population:
], ],
"_errors_exceeded": self.grid_options["_errors_exceeded"], "_errors_exceeded": self.grid_options["_errors_exceeded"],
"_errors_found": self.grid_options["_errors_found"], "_errors_found": self.grid_options["_errors_found"],
"_probtot": self.grid_options["_probtot"],
"_count": self.grid_options["_count"],
} }
verbose_print( verbose_print(
"Process {}: generator done. Ran {} systems with a total probability of {}. This thread had {} failing systems with a total probability of {}".format(ID, number_of_systems_run, probability_of_systems_run, self.grid_options["_failed_count"], self.grid_options["_failed_prob"]), "Process {}: generator done. Ran {} systems with a total probability of {}. This thread had {} failing systems with a total probability of {}".format(
ID,
number_of_systems_run,
probability_of_systems_run,
self.grid_options["_failed_count"],
self.grid_options["_failed_prob"],
),
self.grid_options["verbosity"], self.grid_options["verbosity"],
0, 0,
) )
...@@ -885,9 +928,9 @@ class Population: ...@@ -885,9 +928,9 @@ class Population:
for output_dict in result: for output_dict in result:
combined_output_dict = merge_dicts(combined_output_dict, output_dict) combined_output_dict = merge_dicts(combined_output_dict, output_dict)
# Put the values back as object properties
print(combined_output_dict) print(combined_output_dict)
# Put the values back as object properties
self.grid_options["results"] = combined_output_dict["results"] self.grid_options["results"] = combined_output_dict["results"]
self.grid_options["_failed_count"] = combined_output_dict["_failed_count"] self.grid_options["_failed_count"] = combined_output_dict["_failed_count"]
self.grid_options["_failed_prob"] = combined_output_dict["_failed_prob"] self.grid_options["_failed_prob"] = combined_output_dict["_failed_prob"]
...@@ -896,6 +939,8 @@ class Population: ...@@ -896,6 +939,8 @@ class Population:
) )
self.grid_options["_errors_exceeded"] = combined_output_dict["_errors_exceeded"] self.grid_options["_errors_exceeded"] = combined_output_dict["_errors_exceeded"]
self.grid_options["_errors_found"] = combined_output_dict["_errors_found"] self.grid_options["_errors_found"] = combined_output_dict["_errors_found"]
self.grid_options["_probtot"] = combined_output_dict["_probtot"]
self.grid_options["_count"] = combined_output_dict["_count"]
def _evolve_population_lin(self): def _evolve_population_lin(self):
""" """
...@@ -1001,14 +1046,12 @@ class Population: ...@@ -1001,14 +1046,12 @@ class Population:
Tasks: Tasks:
TODO: Make other kinds of populations possible. i.e, read out type of grid, TODO: Make other kinds of populations possible. i.e, read out type of grid,
and set up accordingly and set up accordingly
TODO: make this function more general. Have it explicitly set the system_generator TODO: make this function more general. Have it explicitly set the system_generator
function function
""" """
if not self.grid_options["parse_function"]: if not self.grid_options["parse_function"]:
print("Error: No parse function set. Aborting run") print("Warning: No parse function set. Make sure you intended to do this.")
raise ValueError
####################### #######################
### Custom logging code: ### Custom logging code:
...@@ -1132,6 +1175,7 @@ class Population: ...@@ -1132,6 +1175,7 @@ class Population:
self.grid_options["_failed_prob"] = 0 self.grid_options["_failed_prob"] = 0
self.grid_options["_errors_found"] = False self.grid_options["_errors_found"] = False
self.grid_options["_errors_exceeded"] = False self.grid_options["_errors_exceeded"] = False
self.grid_options["_failed_systems_error_codes"] = []
# Remove files # Remove files
# TODO: remove files # TODO: remove files
...@@ -1165,7 +1209,7 @@ class Population: ...@@ -1165,7 +1209,7 @@ class Population:
# TODO: add sensible description to this function. # TODO: add sensible description to this function.
# TODO: Check whether all the probability and phasevol values are correct. # TODO: Check whether all the probability and phasevol values are correct.
# TODO: import only the necessary packages/functions # TODO: import only the necessary packages/functions
Results in a generated file that contains a system_generator function. Results in a generated file that contains a system_generator function.
""" """
...@@ -1531,10 +1575,7 @@ class Population: ...@@ -1531,10 +1575,7 @@ class Population:
# #
# code_string += indent * (depth + 1) + "\n" # code_string += indent * (depth + 1) + "\n"
code_string += indent * (depth + 1) + "#" * 40 + "\n" code_string += indent * (depth + 1) + "#" * 40 + "\n"
code_string += ( code_string += indent * (depth + 1) + "if print_results:\n"
indent * (depth + 1)
+ "if print_results:\n"
)
code_string += ( code_string += (
indent * (depth + 2) indent * (depth + 2)
+ "print('Grid has handled {} stars'.format(_total_starcount))\n" + "print('Grid has handled {} stars'.format(_total_starcount))\n"
...@@ -1719,7 +1760,7 @@ class Population: ...@@ -1719,7 +1760,7 @@ class Population:
verbose_print("Source file loaded", self.grid_options["verbosity"], 1) verbose_print("Source file loaded", self.grid_options["verbosity"], 1)
def _dict_from_line_source_file(self): def _dict_from_line_source_file(self, line):
""" """
Function that creates a dict from a binary_c argline Function that creates a dict from a binary_c argline
""" """
......
...@@ -66,7 +66,7 @@ def roche_lobe(q: Union[int, float]) -> Union[int, float]: ...@@ -66,7 +66,7 @@ def roche_lobe(q: Union[int, float]) -> Union[int, float]:
# TODO: check whether the logs are correct # TODO: check whether the logs are correct
Args: Args:
q: mass ratio of the binary (secondary/primary). If you input: q = mass_accretor/mass_donor, you will get the rochelobe radius of the accretor. And vice versa for the donor. q: mass ratio of the binary (secondary/primary). If you input: q = mass_accretor/mass_donor, you will get the rochelobe radius of the accretor. And vice versa for the donor.
Returns: Returns:
Roche lobe radius in units of the separation Roche lobe radius in units of the separation
......
...@@ -50,13 +50,15 @@ extensions = [ ...@@ -50,13 +50,15 @@ extensions = [
"hawkmoth", "hawkmoth",
"m2r2", "m2r2",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinx_autodoc_typehints", # https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html "sphinx_autodoc_typehints", # https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html
] ]
# Napoleon settings # Napoleon settings
napoleon_google_docstring = True # https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html napoleon_google_docstring = (
napoleon_numpy_docstring = False True # https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
)
napoleon_numpy_docstring = False
napoleon_include_init_with_doc = False napoleon_include_init_with_doc = False
napoleon_include_private_with_doc = False napoleon_include_private_with_doc = False
napoleon_include_special_with_doc = True napoleon_include_special_with_doc = True
...@@ -97,6 +99,7 @@ import m2r2 ...@@ -97,6 +99,7 @@ import m2r2
current_m2r2_setup = m2r2.setup current_m2r2_setup = m2r2.setup
def patched_m2r2_setup(app): def patched_m2r2_setup(app):
try: try:
return current_m2r2_setup(app) return current_m2r2_setup(app)
...@@ -104,13 +107,18 @@ def patched_m2r2_setup(app): ...@@ -104,13 +107,18 @@ def patched_m2r2_setup(app):
app.add_source_suffix(".md", "markdown") app.add_source_suffix(".md", "markdown")
app.add_source_parser(m2r2.M2RParser) app.add_source_parser(m2r2.M2RParser)
return dict( return dict(
version=m2r2.__version__, parallel_read_safe=True, parallel_write_safe=True, version=m2r2.__version__,
parallel_read_safe=True,
parallel_write_safe=True,
) )
m2r2.setup = patched_m2r2_setup m2r2.setup = patched_m2r2_setup
# Generate some custom documentations for this version of binarycpython and binary_c # Generate some custom documentations for this version of binarycpython and binary_c
from binarycpython.utils.functions import write_binary_c_parameter_descriptions_to_rst_file from binarycpython.utils.functions import (
write_binary_c_parameter_descriptions_to_rst_file,
)
from binarycpython.utils.grid_options_defaults import write_grid_options_to_rst_file from binarycpython.utils.grid_options_defaults import write_grid_options_to_rst_file
print("Generating binary_c_parameters.rst") print("Generating binary_c_parameters.rst")
......
...@@ -74,8 +74,12 @@ def run_example_binary_with_run_system(): ...@@ -74,8 +74,12 @@ def run_example_binary_with_run_system():
# print(output) # print(output)
# Catch results that start with a given header. (Mind that binary_c has to be configured to print them if your not using a custom logging function) # Catch results that start with a given header. (Mind that binary_c has to be configured to print them if your not using a custom logging function)
result_example_header_1 = example_parse_output(output, selected_header="example_header_1") result_example_header_1 = example_parse_output(
result_example_header_2 = example_parse_output(output, selected_header="example_header_2") output, selected_header="example_header_1"
)
result_example_header_2 = example_parse_output(
output, selected_header="example_header_2"
)
# print(result_example_header_1) # print(result_example_header_1)
......
...@@ -37,6 +37,7 @@ def check_version(installed_binary_c_version, required_binary_c_versions): ...@@ -37,6 +37,7 @@ def check_version(installed_binary_c_version, required_binary_c_versions):
) )
assert installed_binary_c_version in required_binary_c_versions, message assert installed_binary_c_version in required_binary_c_versions, message
def execute_make(): def execute_make():
""" """
Function to execute the makefile. Function to execute the makefile.
...@@ -147,14 +148,11 @@ API_h = os.path.join(BINARY_C_DIR, "src", "API", "binary_c_API.h") ...@@ -147,14 +148,11 @@ API_h = os.path.join(BINARY_C_DIR, "src", "API", "binary_c_API.h")
############################################################ ############################################################
# Setting all directories and LIBRARIES to their final values # Setting all directories and LIBRARIES to their final values
############################################################ ############################################################
INCLUDE_DIRS = ( INCLUDE_DIRS = [
[ os.path.join(BINARY_C_DIR, "src"),
os.path.join(BINARY_C_DIR, "src"), os.path.join(BINARY_C_DIR, "src", "API"),
os.path.join(BINARY_C_DIR, "src", "API"), "include",
"include", ] + BINARY_C_INCDIRS
]
+ BINARY_C_INCDIRS
)
if GSL_DIR: if GSL_DIR:
INCLUDE_DIRS += [os.path.join(GSL_DIR, "include")] INCLUDE_DIRS += [os.path.join(GSL_DIR, "include")]
...@@ -226,6 +224,7 @@ class CustomBuildCommand(distutils.command.build.build): ...@@ -226,6 +224,7 @@ class CustomBuildCommand(distutils.command.build.build):
# Run the original build command # Run the original build command
distutils.command.build.build.run(self) distutils.command.build.build.run(self)
setup( setup(
name="binarycpython", name="binarycpython",
version="0.2.8", version="0.2.8",
...@@ -237,7 +236,7 @@ setup( ...@@ -237,7 +236,7 @@ setup(
author_email="davidhendriks93@gmail.com", author_email="davidhendriks93@gmail.com",
long_description=readme(), long_description=readme(),
# long_description="hello", # long_description="hello",
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
url="https://gitlab.eps.surrey.ac.uk/ri0005/binary_c-python", url="https://gitlab.eps.surrey.ac.uk/ri0005/binary_c-python",
license="gpl", license="gpl",
keywords=[ keywords=[
...@@ -252,7 +251,15 @@ setup( ...@@ -252,7 +251,15 @@ setup(
"binarycpython.core", "binarycpython.core",
"binarycpython.tests", "binarycpython.tests",
], ],
install_requires=["numpy", "pytest", "h5py", "pathos", "pandas", "astropy", "matplotlib"], install_requires=[
"numpy",
"pytest",
"h5py",
"pathos",
"pandas",
"astropy",
"matplotlib",
],
include_package_data=True, include_package_data=True,
ext_modules=[BINARY_C_PYTHON_API_MODULE], # binary_c must be loaded ext_modules=[BINARY_C_PYTHON_API_MODULE], # binary_c must be loaded
classifiers=[ classifiers=[
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment