From f730ab48f8b0db902d9a387d72d859cd3bbf7834 Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Tue, 19 Oct 2021 22:19:45 +0100
Subject: [PATCH] add progress bar to ensemble loader

---
 binarycpython/utils/functions.py | 35 +++++++++++++++++++++++++++++---
 requirements.txt                 |  1 +
 setup.py                         |  3 ++-
 3 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py
index b46b12ec1..619354e78 100644
--- a/binarycpython/utils/functions.py
+++ b/binarycpython/utils/functions.py
@@ -18,9 +18,10 @@ import copy
 import datetime as dt
 import gc
 import gzip
-import inspect
+from halo import Halo
 import h5py
 import humanize
+import inspect
 from io import StringIO
 import json
 import numpy as np
@@ -33,6 +34,7 @@ import sys
 import subprocess
 import tempfile
 import time
+from tqdm import tqdm
 import types
 from typing import Union, Any
 
@@ -2253,10 +2255,14 @@ class BinaryCEncoder(json.JSONEncoder):
         # Let the base class default method raise the TypeError
         return json.JSONEncoder.default(self, o)
 
-def load_ensemble(filename,convert_float_keys=True):
+def load_ensemble(filename,convert_float_keys=True,select_keys=None):
     """
     Function to load an ensemeble file, even if it is compressed,
     and return its contents to as a Python dictionary.
+
+    Args:
+        convert_float_keys : if True, converts strings to floats.
+        select_keys : a list of keys to be selected from the ensemble.
     """
     if(filename.endswith('.bz2')):
         jfile = bz2.open(filename,'rt')
@@ -2264,7 +2270,30 @@ def load_ensemble(filename,convert_float_keys=True):
         jfile = gzip.open(filename,'rt')
     else:
         jfile = open(filename,'rt')
-    data = json.load(jfile)
+
+
+
+    # load with some info to the terminal
+    print("Loading JSON...")
+    _loaded = False
+    def _hook(obj):
+        nonlocal _loaded
+        if _loaded == False:
+            _loaded = True
+            print("\nLoaded JSON data, now putting in a dictionary")
+        return obj
+    with Halo(text='Loading', interval=250, spinner='moon',color='yellow'):
+        data = json.load(jfile,
+                         object_hook=_hook)
+
+    # strip non-selected keys, if a list is given in select_keys
+    if select_keys:
+        keys = list(data['ensemble'].keys())
+        for key in keys:
+            if not key in select_keys:
+                del data['ensemble'][key]
+
+    # perhaps convert floats?
     if convert_float_keys == False:
         return data
     else:
diff --git a/requirements.txt b/requirements.txt
index d2cfebe3b..bbae3d563 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,6 +10,7 @@ cycler==0.10.0
 decorator==4.4.1
 dill==0.3.1.1
 docutils==0.15.2
+halo==0.0.31
 h5py==2.10.0
 hawkmoth==0.4
 humanize==3.12.0
diff --git a/setup.py b/setup.py
index f33a42870..5257ba540 100644
--- a/setup.py
+++ b/setup.py
@@ -262,7 +262,8 @@ setup(
         "psutil",
         "colorama",
         "strip-ansi",
-        "humanize"
+        "humanize",
+        "halo"
     ],
     include_package_data=True,
     ext_modules=[BINARY_C_PYTHON_API_MODULE],  # binary_c must be loaded
-- 
GitLab