diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py index d6ad6922aaa3d6ea3304410c6ddf1bb7d215c89c..7bf7e1b813f1eb182791719cf4d55d875f56b85f 100644 --- a/binarycpython/utils/functions.py +++ b/binarycpython/utils/functions.py @@ -31,6 +31,140 @@ import binarycpython.utils.moe_distefano_data as moe_distefano_data ######################################################## +def subtract_dicts(dict_1: dict, dict_2: dict) -> dict: + """ + Function to subtract two dictionaries. + + Only allows values to be either a dict or a numerical type + + For the overlapping keys (key name present in both dicts): + When the keys are of the same type: + - If the types are of numerical type: subtract the value at dict 2 from dict 1. + - If the types are both dictionaries: call this function with the subdicts + + WHen the keys are not of the same type: + - if the keys are all of numerical types + + For the unique keys: + - if the key is from dict 1: adds the value to the new dict (be it numerical value or dict) + - If the key is from dict 2: Adds the negative of its value in case of numerical type. + if the type is a dict, the result of subtract_dicts({}, dict_2[key]) will be set + + If the result is 0, the key will be removed from the resulting dict. + If that results in an empty dict, the dict will be removed too. + + Args: + dict_1: first dictionary + dict_2: second dictionary + + Returns: + Subtracted dictionary + """ + + # Set up new dict + new_dict = {} + + # Define allowed numerical types + ALLOWED_NUMERICAL_TYPES = (float, int, np.float64) + + # + keys_1 = dict_1.keys() + keys_2 = dict_2.keys() + + # Find overlapping keys of both dicts + overlapping_keys = set(keys_1).intersection(set(keys_2)) + + # Find the keys that are unique + unique_to_dict_1 = set(keys_1).difference(set(keys_2)) + unique_to_dict_2 = set(keys_2).difference(set(keys_1)) + + # Add the unique keys to the new dict + for key in unique_to_dict_1: + # If these items are numerical types + if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES): + new_dict[key] = dict_1[key] + if new_dict[key] == 0: + del new_dict[key] + + # Else, to be safe we should deepcopy them + elif isinstance(dict_1[key], dict): + copy_dict = copy.deepcopy(dict_1[key]) + new_dict[key] = copy_dict + else: + msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_1[key])) + print(msg) + raise ValueError(msg) + + # Add the unique keys to the new dict + for key in unique_to_dict_2: + # If these items are numerical type, we should add the negative of the value + if isinstance(dict_2[key], ALLOWED_NUMERICAL_TYPES): + new_dict[key] = -dict_2[key] + if new_dict[key] == 0: + del new_dict[key] + + # Else we should place the negative of that dictionary in the new place + elif isinstance(dict_2[key], dict): + new_dict[key] = subtract_dicts({}, dict_2[key]) + else: + msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_2[key])) + print(msg) + raise ValueError(msg) + + # Go over the common keys: + for key in overlapping_keys: + + # See whether the types are actually the same + if not type(dict_1[key]) is type(dict_2[key]): + # Exceptions: + if (type(dict_1[key]) in ALLOWED_NUMERICAL_TYPES) and ( + type(dict_2[key]) in ALLOWED_NUMERICAL_TYPES + ): + # We can safely subtract the values since they are all numeric + new_dict[key] = dict_1[key] - dict_2[key] + if new_dict[key] == 0: + del new_dict[key] + + else: + print( + "Error key: {} value: {} type: {} and key: {} value: {} type: {} are not of the same type and cannot be merged".format( + key, + dict_1[key], + type(dict_1[key]), + key, + dict_2[key], + type(dict_2[key]), + ) + ) + raise ValueError + + # This is where the keys are the same + else: + # If these items are numeric types + if isinstance(dict_1[key], ALLOWED_NUMERICAL_TYPES): + new_dict[key] = dict_1[key] - dict_2[key] + + # Remove entry if the value is 0 + if new_dict[key] == 0: + del new_dict[key] + + # Else, to be safe we should deepcopy them + elif isinstance(dict_1[key], dict): + new_dict[key] = subtract_dicts(dict_1[key], dict_2[key]) + + # Remove entry if it results in an empty dict + # TODO: write test to prevent empty dicts from showing up + if not new_dict[key]: + del new_dict[key] + else: + msg = "Error: using unsupported type for key {}: {}".format(key, type(dict_2[key])) + print(msg) + raise ValueError(msg) + + # + return new_dict + + def get_moe_distefano_dataset(options): """ Function to get the default moe and Distefano dataset or accept a userinput.