From 609d5580749e32df535711a0169cd61f2b8e41d0 Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Sun, 17 Oct 2021 10:23:14 +0100
Subject: [PATCH] add option to convert dict keys to floats when loading
 ensemble data : this is required for matplotlib/seaborn otherwise the plots
 are completely messed up

---
 binarycpython/utils/functions.py | 62 ++++++++++++++++++++------------
 1 file changed, 40 insertions(+), 22 deletions(-)

diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py
index f70db74f9..8d7e27161 100644
--- a/binarycpython/utils/functions.py
+++ b/binarycpython/utils/functions.py
@@ -12,10 +12,7 @@ import astropy.units as u
 import binarycpython.utils.moe_di_stefano_2017_data as moe_di_stefano_2017_data
 from binarycpython import _binary_c_bindings
 import bz2
-from collections import (
-    defaultdict,
-    OrderedDict,
-)
+import collections
 from colorama import Fore, Back, Style
 import copy
 import datetime as dt
@@ -1149,7 +1146,7 @@ def example_parse_output(output: str, selected_header: str) -> dict:
     keys = value_dicts[0].keys()
 
     # Construct final dict.
-    final_values_dict = defaultdict(list)
+    final_values_dict = collections.defaultdict(list)
     for value_dict in value_dicts:
         for key in keys:
             final_values_dict[key].append(value_dict[key])
@@ -1724,7 +1721,7 @@ def inspect_dict(
             type(input_dict[key]) (except if the value is a dict)
     """
 
-    structure_dict = OrderedDict()  # TODO: check if this still works
+    structure_dict = collections.OrderedDict()  # TODO: check if this still works
 
     #
     for key, value in input_dict.items():
@@ -1749,7 +1746,7 @@ def count_keys_recursive(input_dict):
     local_count = 0
     for key in input_dict.keys():
         local_count += 1
-        if isinstance(input_dict[key], (dict, OrderedDict)):
+        if isinstance(input_dict[key], (dict, collections.OrderedDict)):
             local_count += count_keys_recursive(input_dict[key])
     return local_count
 
@@ -1779,7 +1776,7 @@ def merge_dicts(dict_1: dict, dict_2: dict) -> dict:
     """
 
     # Set up new dict
-    new_dict = OrderedDict()  # TODO: check if this still necessary
+    new_dict = collectsions.OrderedDict()  # TODO: check if this still necessary
 
     #
     keys_1 = dict_1.keys()
@@ -1824,9 +1821,9 @@ def merge_dicts(dict_1: dict, dict_2: dict) -> dict:
 
             # Exceptions: versions of dicts can be merged
             elif isinstance(
-                dict_1[key], (dict, OrderedDict, type(AutoVivificationDict))
+                dict_1[key], (dict, collections.OrderedDict, type(AutoVivificationDict))
             ) and isinstance(
-                dict_2[key], (dict, OrderedDict, type(AutoVivificationDict))
+                dict_2[key], (dict, collections.OrderedDict, type(AutoVivificationDict))
             ):
                 new_dict[key] = merge_dicts(dict_1[key], dict_2[key])
 
@@ -1996,7 +1993,7 @@ def multiply_values_dict(input_dict, factor):
 
     for key in input_dict:
         if not key == "general_info":
-            if isinstance(input_dict[key], (dict, OrderedDict)):
+            if isinstance(input_dict[key], (dict, collections.OrderedDict)):
                 input_dict[key] = multiply_values_dict(input_dict[key], factor)
             else:
                 if isinstance(input_dict[key], (int, float)):
@@ -2018,8 +2015,8 @@ def custom_sort_dict(input_dict):
     """
 
     # If the new input is a dictionary, then try to sort it
-    if isinstance(input_dict, (dict, OrderedDict)):
-        new_dict = OrderedDict()
+    if isinstance(input_dict, (dict, collections.OrderedDict)):
+        new_dict = collections.OrderedDict()
 
         keys = input_dict.keys()
 
@@ -2064,10 +2061,10 @@ def recursive_change_key_to_float(input_dict):
     Does not work with lists as values
     """
 
-    new_dict = OrderedDict()  # TODO: check if this still works
+    new_dict = collections.OrderedDict()  # TODO: check if this still works
 
     for key in input_dict:
-        if isinstance(input_dict[key], (dict, OrderedDict)):
+        if isinstance(input_dict[key], (dict, collections.OrderedDict)):
             try:
                 num_key = float(key)
                 new_dict[num_key] = recursive_change_key_to_float(
@@ -2092,10 +2089,10 @@ def recursive_change_key_to_string(input_dict):
     Function to recursively change the key back to a string but this time in a format that we decide
     """
 
-    new_dict = OrderedDict()  # TODO: check if this still works
+    new_dict = collections.OrderedDict()  # TODO: check if this still works
 
     for key in input_dict:
-        if isinstance(input_dict[key], (dict, OrderedDict)):
+        if isinstance(input_dict[key], (dict, collections.OrderedDict)):
             if isinstance(key, (int, float)):
                 string_key = "{:g}".format(key)
                 new_dict[string_key] = recursive_change_key_to_string(
@@ -2256,18 +2253,39 @@ class BinaryCEncoder(json.JSONEncoder):
         # Let the base class default method raise the TypeError
         return json.JSONEncoder.default(self, o)
 
-def load_ensemble(filename):
+def load_ensemble(filename,convert_float_keys=True):
     """
     Function to load an ensemeble file, even if it is compressed,
     and return its contents to as a Python dictionary.
     """
     if(filename.endswith('.bz2')):
-        jfile = bz2.open(filename)
+        jfile = bz2.open(filename,'rt')
     elif(filename.endswith('.gz')):
-        jfile = gzip.open(filename)
+        jfile = gzip.open(filename,'rt')
     else:
-        jfile = open(filename)
-    return json.load(jfile)
+        jfile = open(filename,'rt')
+    data = json.load(jfile)
+    if convert_float_keys == False:
+        return data
+    else:
+        # we need to convert keys to floats
+        def _to_float(json_data):
+            new_data = {}
+            for k,v in json_data.items():
+                if isinstance(v, list):
+                    v = [ _to_float(item) if isinstance(item, dict) else item for item in v ]
+                elif isinstance(v, collections.abc.Mapping):
+                    # dict, ordereddict, etc.
+                    v = _to_float(v)
+                try:
+                    f = float(k)
+                    new_data[f] = json_data[k]
+                except:
+                    new_data[k] = v
+            return new_data
+
+        data = _to_float(data)
+        return data
 
 def ensemble_setting(ensemble,parameter_name):
     """
-- 
GitLab