From b6ce75e257e43d1a46966f690a3a338061437f5e Mon Sep 17 00:00:00 2001
From: dh00601 <dh00601@surrey.ac.uk>
Date: Mon, 1 Nov 2021 16:42:57 +0000
Subject: [PATCH] Fixing tests that failed. Added a function to filter dicts
 based on value instead of key. Added M&S data to grid options rather than
 custom options

---
 binarycpython/tests/test_functions.py        |  6 ++---
 binarycpython/tests/test_grid.py             | 14 ++++++-----
 binarycpython/utils/dicts.py                 | 25 +++++++++++++++++---
 binarycpython/utils/functions.py             | 18 ++++++++++----
 binarycpython/utils/grid.py                  | 10 ++++----
 binarycpython/utils/grid_options_defaults.py |  2 ++
 6 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/binarycpython/tests/test_functions.py b/binarycpython/tests/test_functions.py
index 2b37f8925..9edecad5b 100644
--- a/binarycpython/tests/test_functions.py
+++ b/binarycpython/tests/test_functions.py
@@ -15,7 +15,7 @@ from binarycpython.utils.custom_logging_functions import (
 from binarycpython.utils.run_system_wrapper import (
     run_system
 )
-from binarycython.utils.functions import (
+from binarycpython.utils.functions import (
     temp_dir,
     Capturing,
     verbose_print,
@@ -41,7 +41,7 @@ from binarycpython.utils.dicts import (
     inspect_dict,
     merge_dicts
 )
-from binarcpython.utils.ensemble import (
+from binarycpython.utils.ensemble import (
     binaryc_json_serializer,
     handle_ensemble_string_to_json
 
@@ -394,7 +394,7 @@ class test_get_defaults(unittest.TestCase):
         with Capturing() as output:
             self._test_filter()
 
-    def _test_filter(self):
+    def test_filter(self):
         """
         Test checking filtering works correctly
         """
diff --git a/binarycpython/tests/test_grid.py b/binarycpython/tests/test_grid.py
index 21acc3688..78fd31415 100644
--- a/binarycpython/tests/test_grid.py
+++ b/binarycpython/tests/test_grid.py
@@ -853,7 +853,7 @@ Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
             with open(output_name, "r") as f:
                 file_content = f.read()
 
-                self.assertTrue(file_content.startswith("ENSEMBLE_JSON"))
+                self.assertTrue(file_content.startswith('\"ENSEMBLE_JSON'))
 
                 ensemble_json = extract_ensemble_json_from_string(file_content)
 
@@ -919,11 +919,11 @@ Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
             test_pop.grid_ensemble_results["ensemble"]["number_counts"], {}
         )
 
-    def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
-        with Capturing() as output:
-            self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods()
+    # def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
+    #     with Capturing() as output:
+    #         self._test_grid_evolve_2_threads_with_ensemble_comparing_two_methods()
 
-    def _test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
+    def test_grid_evolve_2_threads_with_ensemble_comparing_two_methods(self):
         """
         Unittests to compare the method of storing the combined ensemble data in the object and writing them to files and combining them later. they have to be the same
         """
@@ -1022,10 +1022,12 @@ Printf("TEST_CUSTOM_LOGGING_1 %30.12e %g %g %g %g\\n",
             with open(output_name, "r") as f:
                 file_content = f.read()
 
-                self.assertTrue(file_content.startswith("ENSEMBLE_JSON"))
+                self.assertTrue(file_content.startswith('\"ENSEMBLE_JSON'))
 
                 ensemble_json = extract_ensemble_json_from_string(file_content)
 
+                print(ensemble_json)
+
                 ensemble_output_2 = merge_dicts(ensemble_output_2, ensemble_json)
 
         for key in ensemble_output_1["ensemble"]["number_counts"]["stellar_type"]["0"]:
diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py
index 1e98da979..fd5a27ad8 100644
--- a/binarycpython/utils/dicts.py
+++ b/binarycpython/utils/dicts.py
@@ -692,14 +692,13 @@ def custom_sort_dict(input_dict):
 
 def filter_dict(arg_dict: dict, filter_list: list) -> dict:
     """
-    Function to filter out keys that contain values included in
-    filter_list
+    Function to filter out keys that are contains in filter_list
 
     Args:
         arg_dict: dictionary containing the argument + default key pairs of binary_c
         filter_list: lists of keys to be filtered out
     Returns:
-        filtered dictionary (pairs with NULL and Function values are removed)
+        filtered dictionary
     """
 
     new_dict = arg_dict.copy()
@@ -709,3 +708,23 @@ def filter_dict(arg_dict: dict, filter_list: list) -> dict:
             del new_dict[key]
 
     return new_dict
+
+
+def filter_dict_through_values(arg_dict: dict, filter_list: list) -> dict:
+    """
+    Function to filter out keys that contain values included in filter_list
+
+    Args:
+        arg_dict: dictionary containing the argument + default key pairs of binary_c
+        filter_list: lists of values to be filtered out
+    Returns:
+        filtered dictionary
+    """
+
+    new_dict = {}
+
+    for key in arg_dict:
+        if not arg_dict[key] in filter_list:
+            new_dict[key] = arg_dict[key]
+
+    return new_dict
diff --git a/binarycpython/utils/functions.py b/binarycpython/utils/functions.py
index 746e62fc5..14bcb21a6 100644
--- a/binarycpython/utils/functions.py
+++ b/binarycpython/utils/functions.py
@@ -8,10 +8,7 @@ Tasks:
     - TODO: change all prints to verbose_prints
 """
 
-import astropy.units as u
-import binarycpython.utils.moe_di_stefano_2017_data as moe_di_stefano_2017_data
-from binarycpython import _binary_c_bindings
-from binarycpython.utils.dicts import filter_dict
+
 
 import bz2
 import collections
@@ -44,6 +41,17 @@ import simplejson
 
 # import orjson
 
+import astropy.units as u
+import binarycpython.utils.moe_di_stefano_2017_data as moe_di_stefano_2017_data
+
+from binarycpython import _binary_c_bindings
+from binarycpython.utils.dicts import (
+    filter_dict,
+    filter_dict_through_values
+)
+
+
+
 ########################################################
 # Unsorted
 ########################################################
@@ -1025,7 +1033,7 @@ def filter_arg_dict(arg_dict: dict) -> dict:
         filtered dictionary (pairs with NULL and Function values are removed)
     """
 
-    return filter_dict(arg_dict.copy(), ["NULL", "Function"])
+    return filter_dict_through_values(arg_dict.copy(), ["NULL", "Function", ""])
 
 
 def create_arg_string(
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index 68b4cb801..807de7c76 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -163,7 +163,6 @@ class Population:
         self.grid_options["Moe2017_options"] = copy.deepcopy(
             moe_di_stefano_default_options
         )
-        self.custom_options["Moe2017_JSON_data"] = None
 
         # Write MOE2017 options to a file. NOTE: not sure why i put this here anymore
         os.makedirs(
@@ -1777,6 +1776,7 @@ class Population:
                     self.grid_options["verbosity"],
                     1,
                 )
+
                 self.write_ensemble(output_file, ensemble_raw_output)
 
             # combine ensemble chunks
@@ -1799,7 +1799,7 @@ class Population:
             1,
         )
         # free store memory:
-        _binary_c_bindings.free_store_memaddr(self.grid_options["_store_memaddr"])
+        binary_c_bindings.free_store_memaddr(self.grid_options["_store_memaddr"])
 
         # Return a set of results and errors
         output_dict = {
@@ -3915,9 +3915,9 @@ class Population:
         # Only if the grid is loaded and Moecache contains information
         if not self.grid_options["_loaded_Moe2017_data"]:  # and not Moecache:
 
-            if self.custom_options["Moe2017_JSON_data"]:
+            if self.grid_options["_Moe2017_JSON_data"]:
                 # Use the existing (perhaps modified) JSON data
-                json_data = self.custom_options["Moe2017_JSON_data"]
+                json_data = self.grid_options["_Moe2017_JSON_data"]
 
             else:
                 # Load the JSON data from a file
@@ -3932,7 +3932,7 @@ class Population:
                 json_data["log10M1"] = json_data["log10M1"][0]
 
             # save this data in case we want to modify it later
-            self.custom_options["Moe2017_JSON_data"] = json_data
+            self.grid_options["_Moe2017_JSON_data"] = json_data
 
             # Get all the masses
             logmasses = sorted(json_data["log10M1"].keys())
diff --git a/binarycpython/utils/grid_options_defaults.py b/binarycpython/utils/grid_options_defaults.py
index 5567bb76a..8329518aa 100644
--- a/binarycpython/utils/grid_options_defaults.py
+++ b/binarycpython/utils/grid_options_defaults.py
@@ -73,6 +73,7 @@ grid_options_defaults_dict = {
     "_loaded_Moe2017_data": False,  # Holds flag whether the Moe and di Stefano (2017) data is loaded into memory
     "_set_Moe2017_grid": False,  # Whether the Moe and di Stefano (2017) grid has been loaded
     "Moe2017_options": None,  # Holds the Moe and di Stefano (2017) options.
+    "_Moe2017_JSON_data": None, # Stores the data
     ##########################
     # Custom logging
     ##########################
@@ -501,6 +502,7 @@ grid_options_descriptions = {
     "m&s_options": "Internal variable that holds the Moe and di Stefano (2017) options. Don't write to this your self",
     "_loaded_Moe2017_data": "Internal variable storing whether the Moe and di Stefano (2017) data has been loaded into memory",
     "do_dry_run": "Whether to do a dry run to calculate the total probability for this run",
+    "_Moe2017_JSON_data": "Location to store the loaded Moe&diStefano2017 dataset", # Stores the data
 }
 
 ###
-- 
GitLab