Skip to content
Snippets Groups Projects
Commit a6bc167a authored by David Hendriks's avatar David Hendriks
Browse files

added tests for functions and distributions

parent 2d50457d
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,24 @@ class TestDistributions(unittest.TestCase):
self.tolerance = 1e-5
def test_setopts(self):
"""
Unittest for function set_opts
"""
default_dict = {'m1': 2, 'm2': 3}
output_dict_1 = set_opts(default_dict, {})
self.assertTrue(output_dict_1==default_dict)
new_opts = {'m1': 10}
output_dict_2 = set_opts(default_dict, new_opts)
updated_dict = default_dict.copy()
updated_dict['m1'] = 10
self.assertTrue(output_dict_2==updated_dict)
def test_flat(self):
"""
Unittest for the function flat
......@@ -205,7 +223,6 @@ class TestDistributions(unittest.TestCase):
self.assertLess(np.abs(imf_chabrier(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance")
def test_duquennoy1991(self):
"""
Unittest for function duquennoy1991
......
import unittest
import tempfile
from binarycpython.utils.functions import *
from binarycpython.utils.custom_logging_functions import binary_c_log_code
from binarycpython.utils.run_system_wrapper import run_system
binary_c_temp_dir = temp_dir()
#############################
# Script that contains unit tests for functions from the binarycpython.utils.functions file
# class test_(unittest.TestCase):
# """
# Unittests for function
# """
class test_get_help_super(unittest.TestCase):
# def test_1(self):
# pass
class dummy():
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
class test_verbose_print(unittest.TestCase):
"""
Unit test for get_help_super
Unittests for verbose_print
"""
def test_all_output(self):
def test_print(self):
verbose_print('test1', 1, 0)
def test_not_print(self):
verbose_print('test1', 0, 1)
class test_remove_file(unittest.TestCase):
"""
Unittests for remove_file
"""
def test_remove_file(self):
with open(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt'), 'w') as f:
f.write('test')
remove_file(os.path.join(binary_c_temp_dir, 'test_remove_file_file.txt'))
def test_remove_nonexisting_file(self):
file = os.path.join(binary_c_temp_dir, 'test_remove_nonexistingfile_file.txt')
remove_file(file)
class test_temp_dir(unittest.TestCase):
"""
Unittests for temp_dir
"""
def test_create_temp_dir(self):
binary_c_temp_dir = temp_dir()
general_temp_dir = tempfile.gettempdir()
self.assertTrue(os.path.isdir(os.path.join(general_temp_dir, 'binary_c_python')))
self.assertTrue(os.path.join(general_temp_dir, 'binary_c_python'))==binary_c_temp_dir
class test_create_hdf5(unittest.TestCase):
"""
Unittests for create_hdf5
"""
def test_1(self):
testdir = os.path.join(binary_c_temp_dir, 'test_create_hdf5')
os.makedirs(testdir, exist_ok=True)
# Create dummy settings file:
settings_dict = {'settings_1': 1, 'settings_2': [1,2]}
with open(os.path.join(testdir, 'example_settings.json'), 'w') as f:
f.write(json.dumps(settings_dict))
with open(os.path.join(testdir, 'data1.dat'), 'w') as f:
f.write("time mass\n")
f.write("1 10")
create_hdf5(testdir, 'testhdf5.hdf5')
file = h5py.File(os.path.join(testdir, 'testhdf5.hdf5'), 'r')
self.assertIn(b'time', file.get('data/data1_header')[()])
self.assertIn(b'mass', file.get('data/data1_header')[()])
self.assertIn('settings_1', json.loads(file.get('settings/used_settings')[()]))
self.assertIn('settings_2', json.loads(file.get('settings/used_settings')[()]))
class test_return_binary_c_version_info(unittest.TestCase):
"""
Unittests for return_binary_c_version_info
"""
def test_not_parsed(self):
version_info = return_binary_c_version_info()
self.assertTrue(isinstance(version_info, str))
self.assertIn('Build', version_info)
self.assertIn('REIMERS_ETA_DEFAULT', version_info)
self.assertIn('SIGMA_THOMPSON', version_info)
def test_parsed(self):
# also tests the parse_version_info indirectly
version_info_parsed = return_binary_c_version_info(parsed=True)
self.assertTrue(isinstance(version_info_parsed, dict))
self.assertIn('isotopes', version_info_parsed.keys())
self.assertIn('argpairs', version_info_parsed.keys())
self.assertIn('ensembles', version_info_parsed.keys())
self.assertIn('macros', version_info_parsed.keys())
self.assertIn('elements', version_info_parsed.keys())
self.assertIn('dt_limits', version_info_parsed.keys())
self.assertIn('nucleosynthesis_sources', version_info_parsed.keys())
self.assertIn('miscellaneous', version_info_parsed.keys())
class test_parse_binary_c_version_info(unittest.TestCase):
"""
Unittests for function parse_binary_c_version_info
"""
def test_1(self):
info = return_binary_c_version_info()
parsed_info = parse_binary_c_version_info(info)
self.assertIn('isotopes', parsed_info.keys())
self.assertNotEqual(parsed_info['isotopes'], {})
self.assertIn('argpairs', parsed_info.keys())
self.assertNotEqual(parsed_info['argpairs'], {})
self.assertIn('ensembles', parsed_info.keys())
self.assertNotEqual(parsed_info['ensembles'], {})
self.assertIn('macros', parsed_info.keys())
self.assertNotEqual(parsed_info['macros'], {})
self.assertIn('elements', parsed_info.keys())
self.assertNotEqual(parsed_info['elements'], {})
self.assertIn('dt_limits', parsed_info.keys())
self.assertNotEqual(parsed_info['dt_limits'], {})
self.assertIn('nucleosynthesis_sources', parsed_info.keys())
self.assertNotEqual(parsed_info['nucleosynthesis_sources'], {})
self.assertIn('miscellaneous', parsed_info.keys())
self.assertNotEqual(parsed_info['miscellaneous'], {})
class test_output_lines(unittest.TestCase):
"""
Unittests for function output_lines
"""
def test_1(self):
example_text = "hallo\ntest\n123"
output_1 = output_lines(example_text)
self.assertTrue(isinstance(output_1, list))
self.assertIn('hallo', output_1)
self.assertIn('test', output_1)
self.assertIn('123', output_1)
class test_example_parse_output(unittest.TestCase):
"""
Unittests for function example_parse_output
"""
def test_normal_output(self):
# generate logging lines. Here you can choose whatever you want to have logged, and with what header
# You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals.
logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)'
# Generate entire shared lib code around logging lines
custom_logging_code = binary_c_log_code(logging_line)
# Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically.
output = run_system(
M_1=1,
metallicity=0.002,
M_2=0.1,
separation=0,
orbital_period=100000000000,
custom_logging_code=custom_logging_code,
)
parsed_output = example_parse_output(output, "MY_STELLAR_DATA")
self.assertIn('time', parsed_output)
self.assertIn('mass', parsed_output)
self.assertTrue(isinstance(parsed_output['time'], list))
self.assertTrue(len(parsed_output['time'])>0)
def test_mismatch_output(self):
# generate logging lines. Here you can choose whatever you want to have logged, and with what header
# You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals.
logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)'
# Generate entire shared lib code around logging lines
custom_logging_code = binary_c_log_code(logging_line)
# Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically.
output = run_system(
M_1=1,
metallicity=0.002,
M_2=0.1,
separation=0,
orbital_period=100000000000,
custom_logging_code=custom_logging_code,
)
parsed_output = example_parse_output(output, "MY_STELLAR_DATA_MISMATCH")
self.assertIsNone(parsed_output)
class test_get_defaults(unittest.TestCase):
"""
Unittests for function get_defaults
"""
def test_no_filter(self):
output_1 = get_defaults()
self.assertTrue(isinstance(output_1, dict))
self.assertIn('colour_log', output_1.keys())
self.assertIn('M_1', output_1.keys())
self.assertIn('list_args', output_1.keys())
self.assertIn('use_fixed_timestep_%d', output_1.keys())
def test_filter(self):
# Also tests the filter_arg_dict indirectly
output_1 = get_defaults(filter_values=True)
self.assertTrue(isinstance(output_1, dict))
self.assertIn('colour_log', output_1.keys())
self.assertIn('M_1', output_1.keys())
self.assertNotIn('list_args', output_1.keys())
self.assertNotIn('use_fixed_timestep_%d', output_1.keys())
class test_get_arg_keys(unittest.TestCase):
"""
Unittests for function
"""
def test_1(self):
output_1 = get_arg_keys()
self.assertTrue(isinstance(output_1, list))
self.assertIn('colour_log', output_1)
self.assertIn('M_1', output_1)
self.assertIn('list_args', output_1)
self.assertIn('use_fixed_timestep_%d', output_1)
class test_create_arg_string(unittest.TestCase):
"""
Unittests for function create_arg_string
"""
def test_default(self):
input_dict = {'separation': 40000, 'M_1': 10}
argstring = create_arg_string(input_dict)
self.assertEqual(argstring, 'separation 40000 M_1 10')
def test_sort(self):
input_dict = {'M_1': 10, 'separation': 40000}
argstring = create_arg_string(input_dict, sort=True)
self.assertEqual(argstring, 'M_1 10 separation 40000')
def test_sort(self):
input_dict = {'M_1': 10, 'separation': 40000, 'list_args': "NULL"}
argstring = create_arg_string(input_dict, filter_values=True)
self.assertEqual(argstring, 'M_1 10 separation 40000')
class test_get_help(unittest.TestCase):
"""
Unit tests for function get_help
"""
def test_input_normal(self):
"""
Function to test the get_help_super function
Function to test the get_help function
"""
get_help_super_output = get_help_super()
get_help_super_keys = get_help_super_output.keys()
self.assertEqual(
get_help("M_1", print_help=False)["parameter_name"],
"M_1",
msg="get_help('M_1') should return the correct parameter name",
)
self.assertIn("stars", get_help_super_keys, "missing section")
self.assertIn("binary", get_help_super_keys, "missing section")
self.assertIn("nucsyn", get_help_super_keys, "missing section")
self.assertIn("output", get_help_super_keys, "missing section")
self.assertIn("i/o", get_help_super_keys, "missing section")
self.assertIn("algorithms", get_help_super_keys, "missing section")
self.assertIn("misc", get_help_super_keys, "missing section")
def test_no_input(self):
output = get_help()
self.assertIsNone(output)
def test_wrong_input(self):
output = get_help("kaasblokjes")
self.assertIsNone(output)
# def test_print(self):
# output = get_help("M_1", print_help=True)
class test_get_help_all(unittest.TestCase):
"""
......@@ -49,24 +315,178 @@ class test_get_help_all(unittest.TestCase):
self.assertIn("misc", get_help_all_keys, "missing section")
class test_get_help(unittest.TestCase):
def test_input(self):
# def test_print(self):
# # test if stuff is printed
# get_help_all(print_help=True)
class test_get_help_super(unittest.TestCase):
"""
Unit test for get_help_super
"""
def test_all_output(self):
"""
Function to test the get_help function
Function to test the get_help_super function
"""
self.assertEqual(
get_help("M_1", print_help=False)["parameter_name"],
"M_1",
msg="get_help('M_1') should return the correct parameter name",
)
get_help_super_output = get_help_super()
get_help_super_keys = get_help_super_output.keys()
self.assertIn("stars", get_help_super_keys, "missing section")
self.assertIn("binary", get_help_super_keys, "missing section")
self.assertIn("nucsyn", get_help_super_keys, "missing section")
self.assertIn("output", get_help_super_keys, "missing section")
self.assertIn("i/o", get_help_super_keys, "missing section")
self.assertIn("algorithms", get_help_super_keys, "missing section")
self.assertIn("misc", get_help_super_keys, "missing section")
# def test_print(self):
# # test to see if stuff is printed.
# get_help_super(print_help=True)
class test_make_build_text(unittest.TestCase):
"""
Unittests for function
"""
def test_output(self):
build_text = make_build_text()
# Remove the things
build_text = build_text.replace("**binary_c git branch**:", ";")
build_text = build_text.replace("**binary_c git revision**:", ";")
build_text = build_text.replace("**Built on**:", ";")
# Split up
split_text = build_text.split(";")
# Check whether the contents are actually there
self.assertNotEqual(split_text[1].strip(), "second")
self.assertNotEqual(split_text[2].strip(), "second")
self.assertNotEqual(split_text[3].strip(), "second")
class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase):
"""
Unittests for function write_binary_c_parameter_descriptions_to_rst_file
"""
def test_bad_outputname(self):
output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt')
output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name)
self.assertIsNone(output_1)
def test_checkfile(self):
output_name = os.path.join(binary_c_temp_dir, 'test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst')
output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name)
self.assertTrue(os.path.isfile(output_name))
class test_inspect_dict(unittest.TestCase):
"""
Unittests for function
"""
def test_compare_dict(self):
input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}}
output_dict = inspect_dict(input_dict)
compare_dict = {'int': int, 'float': float, 'list': list, 'function': os.path.isfile.__class__, 'dict': {'int': int, 'float': float}}
self.assertTrue(compare_dict==output_dict)
def test_compare_dict(self):
input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}}
output_dict = inspect_dict(input_dict, print_structure=True)
class test_merge_dicts(unittest.TestCase):
"""
Unittests for function
"""
def test_empty(self):
input_dict = {'int': 1, 'float': 1.2, 'list': [1,2,3], 'function': os.path.isfile, 'dict': {'int': 1, 'float': 1.2}}
dict_2 = {}
output_dict = merge_dicts(input_dict, dict_2)
self.assertTrue(output_dict==input_dict)
def test_unequal_types(self):
dict_1 = {"input": 10}
dict_2 = {"input": 'hello'}
self.assertRaises(ValueError, merge_dicts, dict_1, dict_2)
def test_bools(self):
dict_1 = {'bool': True}
dict_2 = {'bool': False}
output_dict = merge_dicts(dict_1, dict_2)
self.assertTrue(isinstance(output_dict['bool'], bool))
self.assertTrue(output_dict['bool'])
def test_ints(self):
dict_1 = {'int': 2}
dict_2 = {'int': 1}
output_dict = merge_dicts(dict_1, dict_2)
self.assertTrue(isinstance(output_dict['int'], int))
self.assertEqual(output_dict['int'], 3)
def test_floats(self):
dict_1 = {'float': 4.5}
dict_2 = {'float': 4.6}
output_dict = merge_dicts(dict_1, dict_2)
self.assertTrue(isinstance(output_dict['float'], float))
self.assertEqual(output_dict['float'], 9.1)
def test_lists(self):
dict_1 = {'list': [1,2]}
dict_2 = {'list': [3,4]}
output_dict = merge_dicts(dict_1, dict_2)
self.assertTrue(isinstance(output_dict['list'], list))
self.assertEqual(output_dict['list'], [1,2,3,4])
def test_dicts(self):
dict_1 = {'dict': {'same': 1, 'other_1': 2.0}}
dict_2 = {'dict': {'same': 2, 'other_2': [4.0]}}
output_dict = merge_dicts(dict_1, dict_2)
self.assertTrue(isinstance(output_dict['dict'], dict))
self.assertEqual(output_dict['dict'], {'same': 3, 'other_1': 2.0, 'other_2': [4.0]})
def test_unsupported(self):
dict_1 = {'new': dummy('david')}
dict_2 = {'new': dummy('gio')}
# output_dict = merge_dicts(dict_1, dict_2)
self.assertRaises(ValueError, merge_dicts, dict_1, dict_2)
class test_binaryc_json_serializer(unittest.TestCase):
"""
Unittests for function binaryc_json_serializer
"""
def test_not_function(self):
stringo = "hello"
output = binaryc_json_serializer(stringo)
self.assertTrue(stringo==output)
def test_function(self):
string_of_function = str(os.path.isfile)
output = binaryc_json_serializer(os.path.isfile)
self.assertTrue(string_of_function==output)
class test_handle_ensemble_string_to_json(unittest.TestCase):
"""
Unittests for function handle_ensemble_string_to_json
"""
def all():
test_get_help()
test_get_help_all()
test_get_help_super()
def test_1(self):
string_of_function = str(os.path.isfile)
input_string = '{"ding": 10, "list_example": [1,2,3]}'
output_dict = handle_ensemble_string_to_json(input_string)
self.assertTrue(isinstance(output_dict, dict))
self.assertTrue(output_dict['ding']==10)
self.assertTrue(output_dict['list_example']==[1,2,3])
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment