diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py index 30ae3800186b6fb024a0360e172b6780a70d51c4..045e9c6c91978ebb8d56c552bb7d916426963fb0 100644 --- a/binarycpython/utils/functions.py +++ b/binarycpython/utils/functions.py @@ -17,12 +17,14 @@ import sys from io import StringIO from typing import Union, Any -from collections import defaultdict +from collections import ( + defaultdict, + OrderedDict, +) import h5py import numpy as np - from binarycpython import _binary_c_bindings import binarycpython.utils.moe_distefano_data as moe_distefano_data @@ -1411,7 +1413,7 @@ def inspect_dict( type(input_dict[key]) (except if the value is a dict) """ - structure_dict = {} + new_dict = OrderedDict() # TODO: check if this still works # for key, value in input_dict.items(): @@ -1453,7 +1455,7 @@ def merge_dicts(dict_1: dict, dict_2: dict) -> dict: """ # Set up new dict - new_dict = {} + new_dict = OrderedDict() # TODO: check if this still works # keys_1 = dict_1.keys() @@ -1548,7 +1550,6 @@ def merge_dicts(dict_1: dict, dict_2: dict) -> dict: # return new_dict - def update_dicts(dict_1: dict, dict_2: dict) -> dict: """ Function to update dict_1 with values of dict_2 in a recursive way. @@ -1571,7 +1572,7 @@ def update_dicts(dict_1: dict, dict_2: dict) -> dict: """ # Set up new dict - new_dict = {} + new_dict = OrderedDict() # TODO: check if this still works # keys_1 = dict_1.keys() @@ -1641,6 +1642,141 @@ def update_dicts(dict_1: dict, dict_2: dict) -> dict: return new_dict +def count_keys_recursive(input_dict): + """ + Function to count the total amount of keys in a dictionary + """ + + local_count = 0 + for key in input_dict.keys(): + local_count += 1 + if isinstance(input_dict[key], (dict, OrderedDict)): + local_count += count_keys_recursive(input_dict[key]) + return local_count + + +def recursive_change_key_to_float(input_dict): + """ + Function to recursively change the key to float + + This only works if the dict contains just sub-dicts or numbers/strings. + Does not work with lists as values + """ + + new_dict = OrderedDict() # TODO: check if this still works + + for key in input_dict: + if isinstance(input_dict[key], (dict, OrderedDict)): + try: + num_key = float(key) + new_dict[num_key] = recursive_change_key_to_float(copy.deepcopy(input_dict[key])) + except ValueError: + new_dict[key] = recursive_change_key_to_float(copy.deepcopy(input_dict[key])) + else: + try: + num_key = float(key) + new_dict[num_key] = input_dict[key] + except ValueError: + new_dict[key] = input_dict[key] + + return new_dict + +def recursive_change_key_to_string(input_dict): + """ + Function to recursively change the key back to a string but this time in a format that we decide + """ + + new_dict = OrderedDict() # TODO: check if this still works + + for key in input_dict: + if isinstance(input_dict[key], (dict, OrderedDict)): + if isinstance(key, (int, float)): + string_key = "{:g}".format(key) + new_dict[string_key] = recursive_change_key_to_string(copy.deepcopy(input_dict[key])) + else: + new_dict[key] = recursive_change_key_to_string(copy.deepcopy(input_dict[key])) + else: + if isinstance(key, (int, float)): + string_key = "{:g}".format(key) + new_dict[string_key] = input_dict[key] + else: + new_dict[key] = input_dict[key] + + return new_dict + + +# New method to create a ordered dictionary +def custom_sort_dict(input_dict): + """ + Returns a dictionary that is ordered, but can handle numbers better than normal OrderedDict + + When the keys of the current dictionary are of mixed type, we first find all the unique types. + Sort that list of type names. Then find the values that fit that type. + Sort those and append them to the sorted keys list. + This is done until all the keys are sorted. + + All objects other than dictionary types are directly return as they are + """ + + # If the new input is a dictionary, then try to sort it + if isinstance(input_dict, (dict, OrderedDict)): + new_dict = OrderedDict() + + keys = input_dict.keys() + + # Check if types are the same + all_types_keys = [] + for key in keys: + if not type(key) in all_types_keys: + all_types_keys.append(type(key)) + + # If there are multiple types, then we loop over them and do a piecewise sort + if len(all_types_keys) > 1: + msg = "Different types in the same dictionary key set" + + # Create a string repr of the type name to sort them afterwards + str_types = {repr(el):el for el in all_types_keys} + + # Set up sorted keys list + sorted_keys = [] + + for key_str_type in sorted(str_types.keys()): + cur_type = str_types[key_str_type] + + cur_list = [key for key in keys if isinstance(key, cur_type)] + cur_sorted_list = sorted(cur_list) + + sorted_keys = sorted_keys + cur_sorted_list + else: + sorted_keys = sorted(keys) + + for key in sorted_keys: + new_dict[key] = custom_sort_dict(copy.deepcopy(input_dict[key])) + + return new_dict + return input_dict + + +def multiply_values_dict(input_dict, factor): + """ + Function that goes over dictionary recursively and multiplies the value if possible by a factor + + If the key equals "general_info", the multiplication gets skipped + """ + + for key in input_dict: + if not key == 'general_info': + if isinstance(input_dict[key], (dict, OrderedDict)): + input_dict[key] = multiply_values_dict(input_dict[key], factor) + else: + if isinstance(input_dict[key], (int, float)): + input_dict[key] = input_dict[key] * factor + + return input_dict + + +##### + def extract_ensemble_json_from_string(binary_c_output: str) -> dict: """ Function to extract the ensemble_json information from a raw binary_c output string