diff --git a/binarycpython/tests/main.py b/binarycpython/tests/main.py index 01d8358dac2bc2fdab84fe895a66dd653e79794c..8954d32e08b78707d6b3cfa765f3d886d5c3fdf2 100755 --- a/binarycpython/tests/main.py +++ b/binarycpython/tests/main.py @@ -34,7 +34,8 @@ from binarycpython.tests.test_dicts import ( test_count_keys_recursive, test_keys_to_floats, test_recursive_change_key_to_float, - test_recursive_change_key_to_string + test_recursive_change_key_to_string, + test_multiply_float_values ) from binarycpython.tests.test_ensemble import ( test_binaryc_json_serializer, diff --git a/binarycpython/tests/test_dicts.py b/binarycpython/tests/test_dicts.py index 3146c6b2e1356ace8a75952605a3b1179056a77a..1f4854d916e823e3c0bee730910a9d897cc552a7 100644 --- a/binarycpython/tests/test_dicts.py +++ b/binarycpython/tests/test_dicts.py @@ -3,7 +3,6 @@ Unittests for dicts module TODO: _nested_set TODO: _nested_get -TODO: subtract_dicts TODO: update_dicts """ @@ -30,7 +29,8 @@ from binarycpython.utils.dicts import ( count_keys_recursive, recursive_change_key_to_float, recursive_change_key_to_string, - multiply_float_values + multiply_float_values, + subtract_dicts ) TMP_DIR = temp_dir("tests", "test_dicts") @@ -543,6 +543,122 @@ class test_multiply_float_values(unittest.TestCase): input_2 = {1: 2.2, '2': {'a': 2, 'b': 10, 'c': 0.5, 'd': dummy('david')}} _ = multiply_float_values(input_2, 2) +class test_subtract_dicts(unittest.TestCase): + """ + Unittests for function subtract_dicts + """ + + def test_empty(self): + with Capturing() as output: + self._test_empty() + + def _test_empty(self): + """ + Test subtract_dicts with an empty dict + """ + + input_dict = { + "int": 1, + "float": 1.2, + "dict": {"int": 1, "float": 1.2}, + } + dict_2 = {} + output_dict = subtract_dicts(input_dict, dict_2) + self.assertTrue(output_dict == input_dict) + + def test_unequal_types(self): + with Capturing() as output: + self._test_unequal_types() + + def _test_unequal_types(self): + """ + Test subtract_dicts with unequal types: should raise valueError + """ + + dict_1 = {"input": 10} + dict_2 = {"input": "hello"} + + self.assertRaises(ValueError, subtract_dicts, dict_1, dict_2) + + def test_ints(self): + with Capturing() as _: + self._test_ints() + + def _test_ints(self): + """ + Test subtract_dicts with ints + """ + + dict_1 = {"int": 2} + dict_2 = {"int": 1} + output_dict = subtract_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict["int"], int)) + self.assertEqual(output_dict["int"], 1) + + def test_floats(self): + with Capturing() as output: + self._test_floats() + + def _test_floats(self): + """ + Test subtract_dicts with floats + """ + + dict_1 = {"float": 4.5} + dict_2 = {"float": 4.6} + output_dict = subtract_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict["float"], float)) + self.assertAlmostEqual(output_dict["float"], -0.1, 2) + + def test_zero_result(self): + with Capturing() as output: + self._test_zero_result() + + def _test_zero_result(self): + """ + Test subtract_dicts resulting in a 0 value. which should be removed + """ + + dict_1 = {"a": 4, 'b': 0} + dict_2 = {"a": 4, 'c': 0} + output_dict = subtract_dicts(dict_1, dict_2) + + self.assertIsInstance(output_dict, dict) + self.assertFalse(output_dict) + + def test_lists(self): + with Capturing() as output: + self._test_lists() + + def _test_lists(self): + """ + Test merging dict with lists + """ + + dict_1 = {"list": [1, 2]} + dict_2 = {"list": [3, 4]} + + self.assertRaises(ValueError, subtract_dicts, dict_1, dict_2) + + def test_dicts(self): + with Capturing() as _: + self._test_dicts() + + def _test_dicts(self): + """ + Test merging dict with dicts + """ + + dict_1 = {"dict": {"a": 1, "b": 1}} + dict_2 = {"dict": {"a": 2, "c": 2}} + output_dict = subtract_dicts(dict_1, dict_2) + + self.assertTrue(isinstance(output_dict["dict"], dict)) + self.assertEqual( + output_dict["dict"], {"a": -1, "b": 1, "c": -2} + ) if __name__ == "__main__": unittest.main() diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py index 1671e15c28426e90d4d3579e09dd38f7b6368b3c..eee25c39e9af3fb97c3de331381a584d8839b175 100644 --- a/binarycpython/utils/dicts.py +++ b/binarycpython/utils/dicts.py @@ -3,14 +3,14 @@ Module containing functions that binary_c-python uses to modify dictionaries. """ import collections +from typing import Union import astropy.units as u import numpy as np -from typing import Union - +# Define all numerical types +ALLOWED_NUMERICAL_TYPES = Union[int, float, complex, np.number] -NUMERIC = Union[int, float, complex, np.number] def keys_to_floats(input_dict: dict) -> dict: """ @@ -234,9 +234,6 @@ def subtract_dicts(dict_1: dict, dict_2: dict) -> dict: # Set up new dict new_dict = {} - # Define allowed numerical types - ALLOWED_NUMERICAL_TYPES = (float, int, np.float64) - # keys_1 = dict_1.keys() keys_2 = dict_2.keys() @@ -670,7 +667,7 @@ def update_dicts(dict_1: dict, dict_2: dict) -> dict: return new_dict -def multiply_values_dict(input_dict: dict, factor: NUMERIC): +def multiply_values_dict(input_dict: dict, factor: ALLOWED_NUMERICAL_TYPES): """ Function that goes over dictionary recursively and multiplies the value if possible by a factor