From 5e1a21aee795927489ce3c30e6e1a65634804906 Mon Sep 17 00:00:00 2001
From: Izzard <ri0005@orca.eps.surrey.ac.uk>
Date: Sun, 28 Nov 2021 09:40:10 +0000
Subject: [PATCH] clean up open() functions to use self.auto() which does
 compression for us based on the file extension

---
 binarycpython/utils/HPC.py                   |   6 +-
 binarycpython/utils/condor.py                |   8 +-
 binarycpython/utils/dataIO.py                | 102 ++++++++++++-------
 binarycpython/utils/ensemble.py              |   4 +-
 binarycpython/utils/grid.py                  |  23 +++--
 binarycpython/utils/grid_options_defaults.py |  15 +--
 binarycpython/utils/gridcode.py              |   2 +-
 binarycpython/utils/slurm.py                 |   6 +-
 binarycpython/utils/spacing_functions.py     |   4 +-
 9 files changed, 101 insertions(+), 69 deletions(-)

diff --git a/binarycpython/utils/HPC.py b/binarycpython/utils/HPC.py
index 81e87e933..96d5afde4 100644
--- a/binarycpython/utils/HPC.py
+++ b/binarycpython/utils/HPC.py
@@ -150,7 +150,7 @@ class HPC(condor,slurm):
         joinlist = self.HPC_joinlist(joinlist=joinlist)
         try:
             self.wait_for_unlock(joinlist)
-            f = open(joinlist,'r',encoding='utf-8')
+            f = self.open(joinlist,'r',encoding='utf-8')
             list = f.read().splitlines()
             f.close()
 
@@ -410,7 +410,7 @@ class HPC(condor,slurm):
         if not filename:
             return None
         file = os.path.join(dir,filename)
-        f = open(file,"r",encoding='utf-8')
+        f = self.open(file,"r",encoding='utf-8')
         if not f:
             print("Error: could not open {file} to read the HPC jobid of the directory {dir}".format(file=file,
                                                                                                      dir=dir))
@@ -582,7 +582,7 @@ class HPC(condor,slurm):
         and (if given) the string passed in.
         """
         try:
-            f = open(filename,'w',encoding='utf-8')
+            f = self.open(filename,'w',encoding='utf-8')
             if f:
                 job = self.HPC_jobID()
                 jobtype = self.HPC_job_type()
diff --git a/binarycpython/utils/condor.py b/binarycpython/utils/condor.py
index de7d4983b..4365fa941 100644
--- a/binarycpython/utils/condor.py
+++ b/binarycpython/utils/condor.py
@@ -84,7 +84,7 @@ class condor():
         idfile = os.path.join(dir,
                               "ClusterID")
         if not os.path.exists(idfile):
-            with open(idfile,"w",encoding='utf-8') as fClusterID:
+            with self.open(idfile,"w",encoding='utf-8') as fClusterID:
                 fClusterID.write("{ClusterID}\n".format(ClusterID=self.grid_options['condor_ClusterID']))
                 fClusterID.close()
                 self.NFS_flush_hack(idfile)
@@ -92,7 +92,7 @@ class condor():
         # save condor status
         file = self.condor_status_file(dir=dir)
         if file:
-            with open(file,'w',encoding='utf-8') as f:
+            with self.open(file,'w',encoding='utf-8') as f:
                 f.write(string)
                 f.close()
                 self.NFS_flush_hack(file)
@@ -263,13 +263,13 @@ class condor():
 
             # open the files
             try:
-                submit_script = open(submit_script_path,'w',encoding='utf-8')
+                submit_script = self.open(submit_script_path,'w',encoding='utf-8')
             except IOError:
                 print("Could not open Condor script at {path} for writing: please check you have set {condor_dir} correctly (it is currently {condor_dir} and can write to this directory.".format(
                     path=submit_script_path,
                     condor_dir = self.grid_options['condor_dir']))
             try:
-                job_script = open(job_script_path,'w',encoding='utf-8')
+                job_script = self.open(job_script_path,'w',encoding='utf-8')
             except IOError:
                 print("Could not open Condor script at {path} for writing: please check you have set {condor_dir} correctly (it is currently {condor_dir} and can write to this directory.".format(
                     path=job_script_path,
diff --git a/binarycpython/utils/dataIO.py b/binarycpython/utils/dataIO.py
index 6d625b99b..dd1d98a19 100644
--- a/binarycpython/utils/dataIO.py
+++ b/binarycpython/utils/dataIO.py
@@ -334,9 +334,6 @@ class dataIO():
         # get the file type
         file_type = ensemble_file_type(output_file)
 
-        # choose compression algorithm based on file extension
-        compression = ensemble_compression(output_file)
-
         # default to using grid_ensemble_results if no data is given
         if data is None:
             data = self.grid_ensemble_results
@@ -346,43 +343,29 @@ class dataIO():
                 "Unable to determine file type from ensemble filename {} : it should be .json or .msgpack."
             ).format(output_file)
             self.exit(code=1)
-        elif file_type == "JSON":
-            # JSON output
-            if compression == "gzip":
-                # gzip
-                f = gzip.open(output_file, "wt", encoding=encoding)
-            elif compression == "bzip2":
-                # bzip2
-                f = bz2.open(output_file, "wt", encoding=encoding)
-            else:
-                # raw output (not compressed)
-                f = open(output_file, "wt", encoding=encoding)
-            f.write(json.dumps(data,
-                               sort_keys=sort_keys,
-                               indent=indent,
-                               ensure_ascii=ensure_ascii))
-
-        elif file_type == "msgpack":
-            # msgpack output
-            if compression == "gzip":
-                f = gzip.open(output_file, "wb", encoding=encoding)
-            elif compression == "bzip2":
-                f = bz2.open(output_file, "wb", encoding=encoding)
-            else:
-                f = open(output_file, "wb", encoding=encoding)
-            msgpack.dump(data, f)
-        f.close()
+        else:
+            f = self.open(output_file, "wt", encoding=encoding)
+            if file_type == "JSON":
+                # JSON output
+                f.write(json.dumps(data,
+                                   sort_keys=sort_keys,
+                                   indent=indent,
+                                   ensure_ascii=ensure_ascii))
+            elif file_type == "msgpack":
+                # msgpack output
+                msgpack.dump(data, f)
+            f.close()
 
         print(
-            "Thread {thread}: Wrote ensemble results to file: {colour}{file}{reset} (file type {file_type}, compression {compression})".format(
+            "Thread {thread}: Wrote ensemble results to file: {colour}{file}{reset} (file type {file_type})".format(
                 thread=self.process_ID,
                 file=output_file,
                 colour=self.ANSI_colours["green"],
                 reset=self.ANSI_colours["reset"],
                 file_type=file_type,
-                compression=compression,
             )
         )
+        
     def write_binary_c_calls_to_file(
             self,
             output_dir: Union[str, None] = None,
@@ -466,7 +449,7 @@ class dataIO():
             print("Writing binary_c calls to {}".format(binary_c_calls_full_filename))
 
             # Write to file
-            with open(binary_c_calls_full_filename, "w", encoding=encoding) as file:
+            with self.open(binary_c_calls_full_filename, "w", encoding=encoding) as file:
                 # Get defaults and clean them, then overwrite them with the set values.
                 if include_defaults:
                     # TODO: make sure that the defaults here are cleaned up properly
@@ -504,7 +487,7 @@ class dataIO():
                 self.grid_options["status_dir"],
                 format_statment.format(ID),
             )
-            with open(
+            with self.open(
                     path,
                     "w",
                     encoding='utf-8'
@@ -641,10 +624,10 @@ class dataIO():
                     try:
                         if vb:
                             print("Try to open file at {}".format(filename))
-                        f = open(filename,
-                                 mode="w",
-                                 encoding=encoding,
-                                 **kwargs)
+                        f = self.open(filename,
+                                      mode="w",
+                                      encoding=encoding,
+                                      **kwargs)
                         if vb:
                             print("Return locked file {}, {}".format(f,lock))
                         return (f,lock)
@@ -678,3 +661,48 @@ class dataIO():
         dir = os.path.dirname(filename)
         os.scandir(dir)
         
+
+    def compression_type(self,filename):
+        """
+        Return the compression type of the ensemble file, based on its filename extension.
+            """
+        if filename.endswith(".bz2"):
+            return "bzip2"
+        elif filename.endswith(".gz"):
+            return "gzip"
+        else:
+            return None
+
+    def open(self,file, mode='r', buffering=- 1, encoding=None, errors=None, newline=None, closefd=True, opener=None, compression=None, compresslevel=None)
+        """
+        Wrapper for open() with automatic compression based on the file extension.
+        """
+        if compression is None:
+            compression = compression_type(file)
+        if compression:
+            if compresslevel is None:
+                compresslevel = 9
+            if compression is "bzip2":
+                file_object = bz2.open(file,
+                                       mode=mode,
+                                       compresslevel=compresslevel,
+                                       encoding=encoding,
+                                       error=errors,
+                                       newline=newline)
+            elif compression is "gzip":
+                file_object = gzip.open(file,
+                                        mode=mode,
+                                        compresslevel=compresslevel,
+                                        encoding=encoding,
+                                        error=errors,
+                                        newline=newline)
+        else:
+            file_object = open(file,
+                               mode=mode,
+                               buffering=buffering,
+                               encoding=encoding,
+                               errors=errors,
+                               newline=newline,
+                               closefd=closefd,
+                               opener=opener)                               
+        return file_object
diff --git a/binarycpython/utils/ensemble.py b/binarycpython/utils/ensemble.py
index a1d85aa9f..2a5d9db42 100644
--- a/binarycpython/utils/ensemble.py
+++ b/binarycpython/utils/ensemble.py
@@ -91,8 +91,8 @@ def open_ensemble(filename,encoding='utf-8'):
 
 def ensemble_compression(filename):
     """
-    Return the compression type of the ensemble file, based on its filename extension.
-    """
+        Return the compression type of the ensemble file, based on its filename extension.
+            """
     if filename.endswith(".bz2"):
         return "bzip2"
     elif filename.endswith(".gz"):
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index b4a1ea4f5..4b1ec7011 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -170,7 +170,7 @@ class Population(analytics,
             os.path.join(self.grid_options["tmp_dir"],
                          "moe_distefano"), exist_ok=True
         )
-        with open(
+        with self.open(
             os.path.join(
                 os.path.join(self.grid_options["tmp_dir"],
                              "moe_distefano"),
@@ -611,16 +611,17 @@ class Population(analytics,
                         0
                     ]
 
-                settings_name = base_name + "_settings.json"
+                # save settings as gzipped JSON
+                settings_name = base_name + "_settings.json.gz"
 
                 # Check directory, make if necessary
                 os.makedirs(self.custom_options["data_dir"], exist_ok=True)
 
                 settings_fullname = os.path.join(
-                    self.custom_options["data_dir"], settings_name
+                    self.custom_options["data_dir"],
+                    settings_name
                 )
 
-
                 # open locked settings file, then output if we get the lock
                 (f,lock) = self.locked_open_for_write(settings_fullname)
 
@@ -657,7 +658,7 @@ class Population(analytics,
                 )
                 raise ValueError
 
-            with open(outfile, "w") as file:
+            with self.open(outfile, "w") as file:
                 json.dump(
                     all_info_cleaned,
                     file,
@@ -1529,7 +1530,7 @@ class Population(analytics,
             # that was on, we log each current system to a file (each thread has one).
             # Each new system overrides the previous
             if self.grid_options["log_args"]:
-                with open(
+                with self.open(
                         os.path.join(
                             self.grid_options["log_args_dir"],
                             "current_system",
@@ -1571,7 +1572,7 @@ class Population(analytics,
 
             # Debug line: logging all the lines
             if self.grid_options["log_runtime_systems"] == 1:
-                with open(
+                with self.open(
                         os.path.join(
                             self.grid_options["tmp_dir"],
                             "runtime_systems",
@@ -1779,7 +1780,7 @@ class Population(analytics,
             ],
             "zero_prob_stars_skipped": zero_prob_stars_skipped,
         }
-        with open(
+        with self.open(
                 os.path.join(
                     self.grid_options["tmp_dir"],
                     "process_summary",
@@ -2121,7 +2122,7 @@ class Population(analytics,
 
         # We can choose to perform a check on the source file, which checks if the lines start with 'binary_c'
         if check:
-            source_file_check_filehandle = open(
+            source_file_check_filehandle = self.open(
                 self.grid_options["source_file_filename"],
                 "r",
                 encoding='utf-8'
@@ -2138,7 +2139,7 @@ class Population(analytics,
                 )
                 raise ValueError
 
-        source_file_filehandle = open(self.grid_options["source_file_filename"],
+        source_file_filehandle = self.open(self.grid_options["source_file_filename"],
                                       "r",
                                       encoding='utf-8')
 
@@ -2281,7 +2282,7 @@ class Population(analytics,
                 else:
                     # Write arg lines to file
                     argstring = self._return_argline(system_dict)
-                    with open(
+                    with self.open(
                             os.path.join(
                                 self.grid_options["tmp_dir"],
                                 "failed_systems",
diff --git a/binarycpython/utils/grid_options_defaults.py b/binarycpython/utils/grid_options_defaults.py
index c4fa23815..742c2037c 100644
--- a/binarycpython/utils/grid_options_defaults.py
+++ b/binarycpython/utils/grid_options_defaults.py
@@ -200,7 +200,7 @@ class grid_options_defaults():
             ########################################
             "HPC_force_join" : 0, # if True, and the HPC variable ("slurm" or "condor") is 3, skip checking our own job and force the join
             "HPC_rebuild_joinlist": 0, # if True, ignore the joinlist we would usually use and rebuild it automatically
-
+            
             ########################################
             # Slurm stuff
             ########################################
@@ -407,7 +407,8 @@ class grid_options_defaults():
     # Grid options functions
 
     # Utility functions
-    def grid_options_help(option: str) -> dict:
+    def grid_options_help(self,
+                          option: str) -> dict:
         """
         Function that prints out the description of a grid option. Useful function for the user.
 
@@ -440,7 +441,8 @@ class grid_options_defaults():
                 return {option: grid_options_descriptions[option]}
 
 
-    def grid_options_description_checker(print_info: bool = True) -> int:
+    def grid_options_description_checker(self,
+                                         print_info: bool = True) -> int:
         """
         Function that checks which descriptions are missing
 
@@ -473,9 +475,10 @@ class grid_options_defaults():
         return len(undescribed_keys)
 
 
-    def write_grid_options_to_rst_file(output_file: str) -> None:
+    def write_grid_options_to_rst_file(self,
+                                       output_file: str) -> None:
         """
-        Function that writes the descriptions of the grid options to a rst file
+        Function that writes the descriptions of the grid options to an rst file
 
         Tasks:
             TODO: separate things into private and public options
@@ -500,7 +503,7 @@ class grid_options_defaults():
             print("Filename doesn't end with .rst, please provide a proper filename")
             return None
 
-        with open(output_file, "w") as f:
+        with self.open(output_file, "w") as f:
             print("Population grid code options", file=f)
             print("{}".format("=" * len("Population grid code options")), file=f)
             print(
diff --git a/binarycpython/utils/gridcode.py b/binarycpython/utils/gridcode.py
index 8ec0c40ca..7c9fb8207 100644
--- a/binarycpython/utils/gridcode.py
+++ b/binarycpython/utils/gridcode.py
@@ -644,7 +644,7 @@ class gridcode():
             1,
         )
 
-        with open(gridcode_filename, "w",encoding='utf-8') as file:
+        with self.open(gridcode_filename, "w",encoding='utf-8') as file:
             file.write(self.code_string)
 
         # perhaps create symlink
diff --git a/binarycpython/utils/slurm.py b/binarycpython/utils/slurm.py
index 5a6a15201..5dcd148b0 100644
--- a/binarycpython/utils/slurm.py
+++ b/binarycpython/utils/slurm.py
@@ -83,7 +83,7 @@ class slurm():
             dir = self.grid_options["slurm_dir"]
         idfile = os.path.join(dir,"jobid")
         if not os.path.exists(idfile):
-            with open(idfile,"w",encoding='utf-8') as fjobid:
+            with self.open(idfile,"w",encoding='utf-8') as fjobid:
                 fjobid.write("{jobid}\n".format(jobid=self.grid_options['slurm_jobid']))
                 fjobid.close()
                 self.NFS_flush_hach(idfile)
@@ -91,7 +91,7 @@ class slurm():
         # save slurm status
         file = self.slurm_status_file(dir=dir)
         if file:
-            with open(file,'w',encoding='utf-8') as f:
+            with self.open(file,'w',encoding='utf-8') as f:
                 f.write(string)
                 f.close()
                 self.NFS_fluch_hack(file)
@@ -254,7 +254,7 @@ class slurm():
             # make slurm script
             scriptpath = self.slurmpath('slurm_script')
             try:
-                script = open(scriptpath,'w',encoding='utf-8')
+                script = self.open(scriptpath,'w',encoding='utf-8')
             except IOError:
                 print("Could not open Slurm script at {path} for writing: please check you have set {slurm_dir} correctly (it is currently {slurm_dir} and can write to this directory.".format(path=scriptpath,
                                                                                                                                                                                                 slurm_dir = self.grid_options['slurm_dir']))
diff --git a/binarycpython/utils/spacing_functions.py b/binarycpython/utils/spacing_functions.py
index c1d368b8b..61f90e73c 100644
--- a/binarycpython/utils/spacing_functions.py
+++ b/binarycpython/utils/spacing_functions.py
@@ -295,7 +295,7 @@ class spacing_functions():
         # make it a wrapped function that just returns the
         # _const_dt function acting on its arguments
         def __dummy_decorator(func):
-            @wraps(func)
+            @functools.wraps(func)
             def wrapped(*args, **kwargs):
                 return func(*args, **kwargs)
             return wrapped
@@ -304,7 +304,7 @@ class spacing_functions():
         else:
             __decorator = __dummy_decorator
 
-        @__decorator()
+        @__decorator
         def _const_dt(cachedir=None,
                       num_cores=None,
                       bse_options_json=None, # JSON string
-- 
GitLab