From 9709eb68ccc58456d1d596ccc6108c80464e70d8 Mon Sep 17 00:00:00 2001
From: dh00601 <dh00601@surrey.ac.uk>
Date: Sat, 8 Jan 2022 19:15:26 +0000
Subject: [PATCH] working on dict tests

---
 binarycpython/tests/test_dicts.py | 138 ++++++++++++++++++++++++++++--
 binarycpython/utils/dicts.py      |  19 +++-
 2 files changed, 146 insertions(+), 11 deletions(-)

diff --git a/binarycpython/tests/test_dicts.py b/binarycpython/tests/test_dicts.py
index ce9ab0b0a..204464ce4 100644
--- a/binarycpython/tests/test_dicts.py
+++ b/binarycpython/tests/test_dicts.py
@@ -12,15 +12,11 @@ TODO: subtract_dicts
 TODO: count_keys_recursive
 TODO: update_dicts
 TODO: multiply_values_dict
-TODO: custom_sort_dict
-TODO: filter_dict
-TODO: filter_dict_through_values
-TODO: prepare_dict
-TODO: normalize_dict
 """
 
 import os
 import unittest
+from collections import OrderedDict
 
 from binarycpython.utils.functions import (
     temp_dir,
@@ -30,7 +26,12 @@ from binarycpython.utils.dicts import (
     merge_dicts,
     set_opts,
     AutoVivificationDict,
-    inspect_dict
+    inspect_dict,
+    normalize_dict,
+    filter_dict,
+    filter_dict_through_values,
+    prepare_dict,
+    custom_sort_dict
 )
 
 TMP_DIR = temp_dir("tests", "test_dicts")
@@ -283,7 +284,130 @@ class test_inspect_dict(unittest.TestCase):
             "function": os.path.isfile,
             "dict": {"int": 1, "float": 1.2},
         }
-        output_dict = inspect_dict(input_dict, print_structure=True)
+        _ = inspect_dict(input_dict, print_structure=True)
+
+
+class test_custom_sort_dict(unittest.TestCase):
+    """
+    Unittests for function custom_sort_dict
+    """
+
+    def test_custom_sort_dict(self):
+        with Capturing() as output:
+            self._test_custom_sort_dict()
+
+    def _test_custom_sort_dict(self):
+        """
+        Test custom_sort_dict
+        """
+
+        input_dict = {'2': 1, '1': {2: 1, 1: 10}, -1: 20, 4: -1}
+
+        #
+        output_1 = custom_sort_dict(input_dict)
+
+        desired_output_1 = OrderedDict([(-1, 20),
+             (4, -1),
+             ('1', OrderedDict([(1, 10), (2, 1)])),
+             ('2', 1)]
+        )
+
+        #
+        self.assertEqual(output_1, desired_output_1)
+
+class test_filter_dict(unittest.TestCase):
+    """
+    Unittests for function filter_dict
+    """
+
+    def test_filter_dict(self):
+        with Capturing() as output:
+            self._test_filter_dict()
+
+    def _test_filter_dict(self):
+        """
+        Test filter_dict
+        """
+
+        dict_1 = {'a': 10}
+        input_1 = ['a']
+
+        res_1 = filter_dict(dict_1, input_1)
+
+        self.assertIsInstance(res_1, dict)
+        self.assertFalse(res_1)
+
+
+class test_filter_dict_through_values(unittest.TestCase):
+    """
+    Unittests for function filter_dict_through_values
+    """
+
+    def test_filter_dict_through_values(self):
+        with Capturing() as output:
+            self._test_filter_dict_through_values()
+
+    def _test_filter_dict_through_values(self):
+        """
+        Test filter_dict_through_values
+        """
+
+        dict_1 = {'a': 10}
+        input_1 = [10]
+
+        res_1 = filter_dict_through_values(dict_1, input_1)
+
+        self.assertIsInstance(res_1, dict)
+        self.assertFalse(res_1)
+
+class test_prepare_dict(unittest.TestCase):
+    """
+    Unittests for function prepare_dict
+    """
+
+    def test_prepare_dict(self):
+        with Capturing() as output:
+            self._test_prepare_dict()
+
+    def _test_prepare_dict(self):
+        """
+        Test prepare_dict
+        """
+
+        global_dict = {}
+
+        # Call function to make sure the nested key contains an empty dict to store stuff in
+        input_1 = ['a', 'b']
+        prepare_dict(global_dict, input_1)
+
+        #
+        self.assertIsNotNone(global_dict.get('a', None))
+        self.assertIsNotNone(global_dict['a'].get('b', None))
+        self.assertIsInstance(global_dict['a']['b'], dict)
+        self.assertFalse(global_dict['a']['b'])
+
+class test_normalize_dict(unittest.TestCase):
+    """
+    Unittests for function normalize_dict
+    """
+
+    def test_normalize_dict(self):
+        with Capturing() as output:
+            self._test_normalize_dict()
+
+    def _test_normalize_dict(self):
+        """
+        Test normalize_dict
+        """
+
+        input_1 = {'a': 10, 'b': 20, 'c': 4}
+
+        res_1 = normalize_dict(input_1)
+
+        self.assertEqual(sum(list(res_1.values())), 1.0)
+
+
+
 
 if __name__ == "__main__":
     unittest.main()
\ No newline at end of file
diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py
index f0687f620..8452faf43 100644
--- a/binarycpython/utils/dicts.py
+++ b/binarycpython/utils/dicts.py
@@ -671,6 +671,9 @@ def custom_sort_dict(input_dict):
     This is done until all the keys are sorted.
 
     All objects other than dictionary types are directly return as they are
+
+    Args:
+        input_dict: object which will be sorted (and returned as a new object) if its a dictionary, otherwise it will be returned without change.
     """
 
     # If the new input is a dictionary, then try to sort it
@@ -800,13 +803,21 @@ def set_opts(opts: dict, newopts: dict) -> dict:
     return opts
 
 
-def normalize_dict(result_dict):
+def normalize_dict(result_dict: dict) -> dict:
     """
-    Function to normalise a dictionary
+    Function to normalise a dictionary by summing all the values and dividing each term by the total. Designed for dictionary containing only positive values.
+
+    Args:
+        result_dict: dictionary where values should be positive number objects
+
+    Returns:
+        normalized_dict: dictionary where the values are normalised to sum to 1
     """
 
+    normalized_dict = {}
+
     sum_result = sum(list(result_dict.values()))
     for key in result_dict.keys():
-        result_dict[key] = result_dict[key] / sum_result
+        normalized_dict[key] = result_dict[key] / sum_result
 
-    return result_dict
+    return normalized_dict
-- 
GitLab