From afa13bc842489ebc274f96b045b1ce5217112c9e Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Thu, 21 Oct 2021 10:43:28 +0100
Subject: [PATCH] add msgpack load/save ensemble support

---
 binarycpython/utils/functions.py             | 149 +++++++++++++++----
 binarycpython/utils/grid.py                  | 110 +++++++-------
 binarycpython/utils/grid_options_defaults.py |   1 -
 requirements.txt                             |   1 +
 setup.py                                     |   3 +-
 5 files changed, 185 insertions(+), 79 deletions(-)

diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py
index 619354e78..041da2ea0 100644
--- a/binarycpython/utils/functions.py
+++ b/binarycpython/utils/functions.py
@@ -24,6 +24,7 @@ import humanize
 import inspect
 from io import StringIO
 import json
+import msgpack
 import numpy as np
 import os
 import psutil
@@ -34,10 +35,12 @@ import sys
 import subprocess
 import tempfile
 import time
-from tqdm import tqdm
 import types
 from typing import Union, Any
 
+import simplejson
+#import orjson
+
 ########################################################
 # Unsorted
 ########################################################
@@ -2255,7 +2258,50 @@ class BinaryCEncoder(json.JSONEncoder):
         # Let the base class default method raise the TypeError
         return json.JSONEncoder.default(self, o)
 
-def load_ensemble(filename,convert_float_keys=True,select_keys=None):
+def open_ensemble(filename):
+    """
+    Function to open an ensemble at filename for reading and decompression if required.
+    """
+    compression = ensemble_compression(filename)
+    if ensemble_file_type(filename) is 'msgpack':
+        flags = 'rb'
+    else:
+        flags = 'rt'
+    if compression is 'bzip2':
+        file_object = bz2.open(filename,flags)
+    elif compression is 'gzip':
+        file_object = gzip.open(filename,flags)
+    else:
+        file_object = open(filename,flags)
+    return file_object
+
+def ensemble_compression(filename):
+    """
+    Return the compression type of the ensemble file, based on its filename extension.
+    """
+    if filename.endswith('.bz2'):
+        return 'bzip2'
+    elif filename.endswith('.gz'):
+        return 'gzip'
+    else:
+        return None
+
+def ensemble_file_type(filename):
+    """
+    Returns the file type of an ensemble file.
+    """
+    if '.json' in filename:
+        filetype = 'JSON'
+    elif '.msgpack' in filename:
+        filetype = 'msgpack'
+    else:
+        filetype = None
+    return filetype
+
+def load_ensemble(filename,
+                  convert_float_keys=True,
+                  select_keys=None,
+                  timing=False):
     """
     Function to load an ensemeble file, even if it is compressed,
     and return its contents to as a Python dictionary.
@@ -2264,27 +2310,52 @@ def load_ensemble(filename,convert_float_keys=True,select_keys=None):
         convert_float_keys : if True, converts strings to floats.
         select_keys : a list of keys to be selected from the ensemble.
     """
-    if(filename.endswith('.bz2')):
-        jfile = bz2.open(filename,'rt')
-    elif(filename.endswith('.gz')):
-        jfile = gzip.open(filename,'rt')
-    else:
-        jfile = open(filename,'rt')
-
 
+    # open the file
 
     # load with some info to the terminal
-    print("Loading JSON...")
-    _loaded = False
-    def _hook(obj):
-        nonlocal _loaded
-        if _loaded == False:
-            _loaded = True
-            print("\nLoaded JSON data, now putting in a dictionary")
-        return obj
+    print("Loading JSON...",timing)
+
+    # open the ensemble and get the file type
+    file_object = open_ensemble(filename)
+    filetype = ensemble_file_type(filename)
+
+    if not filetype or not file_object:
+        print("Unknown filetype : your ensemble should be saved either as JSON or msgpack data.")
+        exit()
+
     with Halo(text='Loading', interval=250, spinner='moon',color='yellow'):
-        data = json.load(jfile,
-                         object_hook=_hook)
+        tstart = time.time()
+        _loaded = False
+        def _hook(obj):
+            nonlocal _loaded
+            if _loaded == False:
+                _loaded = True
+                print("\nLoaded {} data, now putting in a dictionary".format(filetype),
+                      flush=True)
+            return obj
+
+        if filetype is 'JSON':
+            # orjson promises to be fast, but it doesn't seem to be
+            # and fails on "Infinity"... oops
+            #data = orjson.loads(file_object.read())
+
+            # simplejson is faster than standard json and "just works"
+            # on the big Moe set in 37s
+            data = simplejson.load(file_object,
+                                   object_hook=_hook)
+
+            # standard json module
+            # on the big Moe set takes 42s
+            #data = json.load(file_object,
+            #                 object_hook=_hook)
+        elif filetype is 'msgpack':
+            data = msgpack.load(file_object,
+                                object_hook=_hook)
+
+        if timing:
+            print("\n\nTook {} s to load the data\n\n".format(time.time() - tstart),
+                  flush=True)
 
     # strip non-selected keys, if a list is given in select_keys
     if select_keys:
@@ -2294,15 +2365,29 @@ def load_ensemble(filename,convert_float_keys=True,select_keys=None):
                 del data['ensemble'][key]
 
     # perhaps convert floats?
-    if convert_float_keys == False:
-        return data
-    else:
-        # we need to convert keys to floats
+    tstart = time.time()
+    if convert_float_keys:
+        # we need to convert keys to floats:
+        # this is ~ a factor 10 faster than David's
+        # recursive_change_key_to_float routine,
+        # probably because this version only does
+        # the float conversion, nothing else.
         def _to_float(json_data):
-            new_data = {}
+            # assumes nested dicts ...
+            #new_data = {}
+
+            # but this copies the variable type, but has some
+            # pointless copying
+            #new_data = copy.copy(json_data)
+            #new_data.clear()
+
+            # this adopts the type correctly *and* is fast
+            new_data = type(json_data)()
+
             for k,v in json_data.items():
                 if isinstance(v, list):
-                    v = [ _to_float(item) if isinstance(item, dict) else item for item in v ]
+                    v = [ _to_float(item) if isinstance(item, collections.abc.Mapping) \
+                          else item for item in v ]
                 elif isinstance(v, collections.abc.Mapping):
                     # dict, ordereddict, etc.
                     v = _to_float(v)
@@ -2313,8 +2398,18 @@ def load_ensemble(filename,convert_float_keys=True,select_keys=None):
                     new_data[k] = v
             return new_data
 
-        data = _to_float(data)
-        return data
+
+        # timings are for 100 iterations on the big Moe data set
+        #data = format_ensemble_results(data) # 213s
+        #data = recursive_change_key_to_float(data) # 61s
+        data = _to_float(data) # 6.94s
+
+    if timing:
+        print("\n\nTook {} s to convert floats\n\n".format(time.time() - tstart),
+              flush=True)
+
+    # return data
+    return data
 
 def ensemble_setting(ensemble,parameter_name):
     """
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index dee6f9693..7df7d8785 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -29,6 +29,7 @@ import gc
 import gzip
 import importlib.util
 import logging
+import msgpack
 import multiprocessing
 import os
 import py_rinterpolate
@@ -61,6 +62,7 @@ from binarycpython.utils.custom_logging_functions import (
     autogen_C_logging_code,
     binary_c_log_code,
     create_and_load_logging_function,
+
 )
 
 from binarycpython.utils.functions import (
@@ -87,7 +89,9 @@ from binarycpython.utils.functions import (
     ANSI_colours,
     check_if_in_shell,
     format_number,
-    timedelta
+    timedelta,
+    ensemble_file_type,
+    ensemble_compression
 )
 
 # from binarycpython.utils.hpc_functions import (
@@ -2051,7 +2055,7 @@ class Population:
                     0)
 
                 if self.grid_options["exit_after_dry_run"]:
-                    exit()
+                    sys.exit()
 
             #######################
             # Reset values and prepare the grid function
@@ -3198,7 +3202,7 @@ class Population:
     #             )
 
     #         verbose_print("all done!", self.grid_options["verbosity"], 0)
-    #         exit()
+    #         sys.exit()
 
     #     elif self.grid_options["slurm_command"] == "evolve":
     #         # Part to evolve the population.
@@ -3444,7 +3448,7 @@ class Population:
     #                 )
 
     #             verbose_print("all done!", self.grid_options["verbosity"], 0)
-    #             exit()
+    #             sys.exit()
 
     #         elif self.grid_options["condor_command"] == "evolve":
     #             # TODO: write this function
@@ -3469,7 +3473,7 @@ class Population:
     # Functions that aren't ordered yet
     ###################################################
 
-    def write_ensemble(self,output_file,json_data):
+    def write_ensemble(self,output_file,data=None,sort_keys=True,indent=4):
         """
         write_ensemble : Write ensemble results to a file.
 
@@ -3479,64 +3483,70 @@ class Population:
                       If the filename has an extension that we recognise,
                       e.g. .gz or .bz2, we compress the output appropriately.
 
-                      Note that if grid_options['compress_ensemble'] is set, the
-                      appropriate file extension is added if required and compression
-                      is performed.
+                      The filename should contain .json or .msgpack, the two
+                      currently-supported formats.
+
+                      Usually you'll want to output to JSON, but we can
+                      also output to msgpack.
+
+        data :   the data dictionary to be converted and written to the file.
+                 If not set, this defaults to self.grid_ensemble_results.
 
-        json_data :   the JSON data to be written out. This can be raw
-                      unformatted JSON, or the output of JSON_dumps().
+        sort_keys : if True, and output is to JSON, the keys will be sorted.
+                    (default: True, passed to json.dumps)
+
+        indent : number of space characters used in the JSON indent. (Default: 4,
+                 passed to json.dumps)
         """
         # TODO: consider writing this in a formatted structure
 
-        # default to the compression algorithm specified
-        compression = self.grid_options['compress_ensemble']
-
-        # choose algorithm based on file extension if we're not given
-        # an algorithm
-        if not compression:
-            if output_file.endswith('.gz'):
-                compression = 'gzip'
-            elif output_file.endswith('.bz2'):
-                compression = 'bzip2'
-
-        # Write JSON ensemble data to file
-        if compression:
-            # write to a compressed file, adding the appropriate extension if
-            # required
+        # get the file type
+        file_type = ensemble_file_type(output_file)
+
+        # choose compression algorithm based on file extension
+        compression = ensemble_compression(output_file)
+
+        # default to using grid_ensemble_results if no data is given
+        if data is None:
+            data = self.grid_ensemble_results
+
+        if not file_type:
+            print("Unable to determine file type from ensemble filename {} : it should be .json or .msgpack.").format(output_file)
+            sys.exit()
+        elif file_type is 'JSON':
+            # JSON output
             if compression == 'gzip':
-                if output_file.endswith('.gz'):
-                    zipfile = output_file
-                else:
-                    zipfile = output_file + '.gz'
-                with gzip.open(zipfile, "wt") as f:
-                    f.write(json_data)
-                    f.close()
+                # gzip
+                f = gzip.open(output_file, "wt")
             elif compression == 'bzip2':
-                if output_file.endswith('.bz2'):
-                    zipfile = output_file
-                else:
-                    zipfile = output_file + '.bz2'
-                with bz2.open(zipfile, "wt") as f:
-                    f.write(json_data)
-                    f.close()
+                # bzip2
+                f = bz2.open(output_file, "wt")
             else:
-                print("You have asked me to compress the ensemble output using algorithm {algorithm}, but not given me a valid compression algorithm (use gzip or bzip2 in grid_options['compress_ensemble']) : I will write the data uncompressed.").format(algorithm=compression)
                 # raw output (not compressed)
-                with open(output_file, "wt") as f:
-                    f.write(json_data)
-                    f.close()
-        else:
-            # raw output (not compressed)
-            with open(output_file, "wt") as f:
-                f.write(json_data)
-                f.close()
+                f = open(output_file, "wt")
+            f.write(json.dumps(data,
+                               sort_keys=sort_keys,
+                               indent=indent))
+
+        elif file_type is 'msgpack':
+            # msgpack output
+            if compression == 'gzip':
+                f = gzip.open(output_file, "wb")
+            elif compression == 'bzip2':
+                f = bz2.open(output_file, "wb")
+            else:
+                f = open(output_file, "wb")
+            msgpack.dump(data,f)
+        f.close()
 
         print(
-            "Thread {thread}: Wrote ensemble results to file: {colour}{file}{reset}".format(
+            "Thread {thread}: Wrote ensemble results to file: {colour}{file}{reset} (file type {file_type}, compression {compression})".format(
                 thread=self.process_ID,
                 file=output_file,
                 colour=self.ANSI_colours['green'],
-                reset=self.ANSI_colours['reset']
+                reset=self.ANSI_colours['reset'],
+                file_type=file_type,
+                compression=compression,
             )
         )
 
diff --git a/binarycpython/utils/grid_options_defaults.py b/binarycpython/utils/grid_options_defaults.py
index 57f06c4e2..397e57923 100644
--- a/binarycpython/utils/grid_options_defaults.py
+++ b/binarycpython/utils/grid_options_defaults.py
@@ -36,7 +36,6 @@ grid_options_defaults_dict = {
     "_main_pid": -1,  # Placeholder for the main process id of the run.
     "save_ensemble_chunks" : True, # Force the ensemble chunk to be saved even if we are joining a thread (just in case the joining fails)
     "combine_ensemble_with_thread_joining": True,  # Flag on whether to combine everything and return it to the user or if false: write it to data_dir/ensemble_output_{population_id}_{thread_id}.json
-    "compress_ensemble":False, # compress the ensemble output?
     "_commandline_input": "",
     "log_runtime_systems": 0,  # whether to log the runtime of the systems (1 file per thread. stored in the tmp_dir)
     "_actually_evolve_system": True,  # Whether to actually evolve the systems of just act as if. for testing. used in _process_run_population_grid
diff --git a/requirements.txt b/requirements.txt
index bbae3d563..5ce17b531 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,6 +22,7 @@ jedi==0.15.1
 Jinja2==2.10.3
 kiwisolver==1.1.0
 lxml==4.5.0
+msgpack==1.0.2
 m2r==0.2.1
 MarkupSafe==1.1.1
 matplotlib==3.1.2
diff --git a/setup.py b/setup.py
index 5257ba540..820ff8a4c 100644
--- a/setup.py
+++ b/setup.py
@@ -263,7 +263,8 @@ setup(
         "colorama",
         "strip-ansi",
         "humanize",
-        "halo"
+        "halo",
+        "msgpack"
     ],
     include_package_data=True,
     ext_modules=[BINARY_C_PYTHON_API_MODULE],  # binary_c must be loaded
-- 
GitLab