From 95f00b0dd2475b4b4f14e511e1cb1fc758350465 Mon Sep 17 00:00:00 2001
From: David Hendriks <davidhendriks93@gmail.com>
Date: Mon, 17 May 2021 15:36:32 +0100
Subject: [PATCH] Added subtract_dict function to the repo

---
 binarycpython/utils/functions.py | 134 +++++++++++++++++++++++++++++++
 1 file changed, 134 insertions(+)

diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py
index d6ad6922a..7bf7e1b81 100644
--- a/binarycpython/utils/functions.py
+++ b/binarycpython/utils/functions.py
@@ -31,6 +31,140 @@ import binarycpython.utils.moe_distefano_data as moe_distefano_data
 ########################################################
 
 
+def subtract_dicts(dict_1: dict, dict_2: dict) -> dict:
+    """
+    Function to subtract two dictionaries. 
+
+    Only allows values to be either a dict or a numerical type
+
+    For the overlapping keys (key name present in both dicts):
+        When the keys are of the same type:
+            - If the types are of numerical type: subtract the value at dict 2 from dict 1.
+            - If the types are both dictionaries: call this function with the subdicts
+
+        WHen the keys are not of the same type:
+            - if the keys are all of numerical types
+
+    For the unique keys:
+        - if the key is from dict 1: adds the value to the new dict (be it numerical value or dict)
+        - If the key is from dict 2: Adds the negative of its value in case of numerical type.
+            if the type is a dict, the result of subtract_dicts({}, dict_2[key]) will be set
+
+    If the result is 0, the key will be removed from the resulting dict.
+    If that results in an empty dict, the dict will be removed too.
+
+    Args:
+        dict_1: first dictionary
+        dict_2: second dictionary
+
+    Returns:
+        Subtracted dictionary
+    """
+
+    # Set up new dict
+    new_dict = {}
+
+    # Define allowed numerical types
+    ALLOWED_NUMERICAL_TYPES = (float, int, np.float64)
+
+    #
+    keys_1 = dict_1.keys()
+    keys_2 = dict_2.keys()
+
+    # Find overlapping keys of both dicts
+    overlapping_keys = set(keys_1).intersection(set(keys_2))
+
+    # Find the keys that are unique
+    unique_to_dict_1 = set(keys_1).difference(set(keys_2))
+    unique_to_dict_2 = set(keys_2).difference(set(keys_1))
+
+    # Add the unique keys to the new dict
+    for key in unique_to_dict_1:
+        # If these items are numerical types
+        if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES):
+            new_dict[key] = dict_1[key]
+            if new_dict[key] == 0:
+                del new_dict[key]
+
+        # Else, to be safe we should deepcopy them
+        elif isinstance(dict_1[key], dict):
+            copy_dict = copy.deepcopy(dict_1[key])
+            new_dict[key] = copy_dict
+        else:
+            msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_1[key]))
+            print(msg)
+            raise ValueError(msg)
+
+    # Add the unique keys to the new dict
+    for key in unique_to_dict_2:
+        # If these items are numerical type, we should add the negative of the value
+        if isinstance(dict_2[key], ALLOWED_NUMERICAL_TYPES):
+            new_dict[key] = -dict_2[key]
+            if new_dict[key] == 0:
+                del new_dict[key]
+
+        # Else we should place the negative of that dictionary in the new place
+        elif isinstance(dict_2[key], dict):
+            new_dict[key] = subtract_dicts({}, dict_2[key])
+        else:
+            msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_2[key]))
+            print(msg)
+            raise ValueError(msg)
+
+    # Go over the common keys:
+    for key in overlapping_keys:
+
+        # See whether the types are actually the same
+        if not type(dict_1[key]) is type(dict_2[key]):
+            # Exceptions:
+            if (type(dict_1[key]) in ALLOWED_NUMERICAL_TYPES) and (
+                type(dict_2[key]) in ALLOWED_NUMERICAL_TYPES
+            ):
+                # We can safely subtract the values since they are all numeric
+                new_dict[key] = dict_1[key] - dict_2[key]
+                if new_dict[key] == 0:
+                    del new_dict[key]
+
+            else:
+                print(
+                    "Error key: {} value: {} type: {} and key: {} value: {} type: {} are not of the same type and cannot be merged".format(
+                        key,
+                        dict_1[key],
+                        type(dict_1[key]),
+                        key,
+                        dict_2[key],
+                        type(dict_2[key]),
+                    )
+                )
+                raise ValueError
+
+        # This is where the keys are the same
+        else:
+            # If these items are numeric types
+            if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES):
+                new_dict[key] = dict_1[key] - dict_2[key]
+
+                # Remove entry if the value is 0
+                if new_dict[key] == 0:
+                    del new_dict[key]
+
+            # Else, to be safe we should deepcopy them
+            elif isinstance(dict_1[key], dict):
+                new_dict[key] = subtract_dicts(dict_1[key], dict_2[key])
+
+                # Remove entry if it results in an empty dict
+                # TODO: write test to prevent empty dicts from showing up
+                if not new_dict[key]:
+                    del new_dict[key]
+            else:
+                msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_2[key]))
+                print(msg)
+                raise ValueError(msg)
+
+    #
+    return new_dict
+
+
 def get_moe_distefano_dataset(options):
     """
     Function to get the default moe and Distefano dataset or accept a userinput.
-- 
GitLab