From f90c861372ce3f97424c618d460e8847fb3bd344 Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Wed, 10 Nov 2021 21:28:07 +0000
Subject: [PATCH] code clean up

---
 binarycpython/utils/grid.py | 150 +++++++++++++++++++++---------------
 1 file changed, 87 insertions(+), 63 deletions(-)

diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index e8412d0f8..4ab3e5496 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -242,8 +242,12 @@ class Population:
         return jobID
 
     def exit(self,code=None,message=True,flush=True,stacktrace=False):
-        # wrapper for sys.exit() to return the correct exit code
-
+        """
+        Exit function: use this to exit from a Population object.
+        Really it's just a wrapper for sys.exit() to return the correct exit code,
+        but also to post a message (if message is True, default is True)
+        and perhaps a stacktrace (if stacktrace is True, default is False).
+        """
         # if we've been killed, set exit code to 1
         if self.grid_options['exit_code']==0 and self.grid_options['_killed']:
             self.grid_options['exit_code'] = 1
@@ -534,6 +538,13 @@ class Population:
         self,
         name: str,
     ) -> None:
+        """
+        Function to delete a grid variable with the given name.
+
+        Args:
+            name:
+                name of the grid variable to be deleted.
+        """
         try:
             del self.grid_options["_grid_variables"][name]
             verbose_print(
@@ -693,7 +704,6 @@ class Population:
             "parameter_name": parameter_name,
             "longname": longname,
             "valuerange": valuerange,
-            # "resolution": 0,
             "samplerfunc": samplerfunc,
             "precode": precode,
             "postcode": postcode,
@@ -1077,11 +1087,14 @@ class Population:
     # Evolution functions
     ###################################################
 
-    def _set_ncores(self):
-        # If num_cores <= 0, set automatically
-        #
-        # if num_cores is 0, we use as many as we have available
+    def _set_nprocesses(self):
+        """
+        Function to set the number of processes used in multiprocessing.
 
+        If grid_options['num_cores'] <= 0, set automatically
+
+        If grid_options['num_cores'] is 0, we use as many as we have available
+        """
         # backwards compatibility
         if "amt_cores" in self.grid_options:
             self.grid_options["num_processes"] = self.grid_options["amt_cores"]
@@ -1106,8 +1119,8 @@ class Population:
         self.grid_results = AutoVivificationDict()
         self.grid_ensemble_results = {}
 
-        # set number of cores we want to use
-        self._set_ncores()
+        # set number of processes/cores we want to use
+        self._set_nprocesses()
 
         # Reset the process ID (should not have a value initially, but can't hurt if it does)
         self.process_ID = 0
@@ -1202,12 +1215,7 @@ class Population:
         if self.grid_options['slurm'] == 1 and \
            self.grid_options['slurm_restart_dir'] and \
            self.grid_options['slurm_jobarrayindex'] != None:
-            f = open(os.path.join(self.grid_options['slurm_restart_dir'],'jobid'))
-            oldjobid = f.read().strip()
-            f.close()
-            if not oldjobid:
-                print("Error: could not find jobid in {}".format(self.grid_options['slurm_restart_dir']))
-                self.exit(code=1)
+            oldjobid = self.slurm_jobid_from_dir(self.grid_options['slurm_restart_dir'])
 
             print("Restart from dir {} which was jobid {}, we are jobarrayindex {}".format(
                 self.grid_options['slurm_restart_dir'],
@@ -1263,10 +1271,6 @@ class Population:
 
         Returns an dictionary containing the analytics of the run
         """
-#        signal.signal(signal.SIGTERM,
-#                      functools.partial(self._parent_signal_handler,{'where':'evolve'}))
-#        signal.signal(signal.SIGINT,
-#                      functools.partial(self._parent_signal_handler,{'where':'evolve'}))
 
         # Just to make sure we don't have stuff from a previous run hanging around
         self._pre_run_setup()
@@ -1282,6 +1286,7 @@ class Population:
         elif self.grid_options["slurm"] == 1:
             # Slurm setup grid
             self.slurm_grid()
+
             # and then exit
             print("Slurm jobs launched : exiting")
             self.exit(code=0)
@@ -1289,7 +1294,7 @@ class Population:
             # Execute population evolution subroutines
             self._evolve_population()
 
-        print("do analytics")
+        print("Do analytics")
 
         if self.grid_options['do_analytics']:
             # Put all interesting stuff in a variable and output that afterwards, as analytics of the run.
@@ -1330,8 +1335,8 @@ class Population:
             # we must save a snapshot, not the population object
             self.grid_options['start_at'] = self.grid_options["_count"]
             self.save_snapshot()
-            code = 1 if self.was_killed() else 0
-            self.exit(code=code)
+            exitcode = 1 if self.was_killed() else 0
+            self.exit(code=exitcode)
 
         # Save object to a pickle file
         elif self.grid_options['save_population_object']:
@@ -1503,7 +1508,7 @@ class Population:
         This will generate the systems until it is full, and then keeps trying to fill it.
         Will have to play with the size of this.
 
-        This function is called in the parent process.
+        This function is called as part of the parent process.
         """
 
         stream_logger = self._get_stream_logger()
@@ -1607,7 +1612,6 @@ class Population:
         We read out the information in the result queue and store them in the grid object
         """
 
-
         # Set process name
         setproctitle.setproctitle("binarycpython parent process")
 
@@ -1656,10 +1660,10 @@ class Population:
                                   num_processes=self.grid_options["num_processes"])
 
         # Join the processes
-        print("Do join")
+        print("Do join...")
         for p in processes:
             p.join()
-        print("Joined")
+        print("Joined.")
 
         keylist = ["_failed_count",
                    "_failed_prob",
@@ -1690,6 +1694,8 @@ class Population:
                 except Exception as e:
                     print("Tried to set combined_output_dict key",x,"from preloaded_popuation, but this failed:",e)
             print("Pre-loaded data from {} stars".format(combined_output_dict["_count"]))
+            self.preloaded_population = None
+            gc.collect()
         else:
             # new empty combined output
             combined_output_dict = OrderedDict()
@@ -1811,11 +1817,6 @@ class Population:
         Signal handling function for the parent process.
         """
 
-        if 'queue' in signal_data:
-            q = signal_data['queue']
-        else:
-            q = None
-
         # this function is called by both queues when they
         # catch a signal
         sigstring = signal.Signals(signum).name
@@ -1848,19 +1849,24 @@ class Population:
             self.custom_options['save_snapshot'] = True
             self.grid_options['_killed'] = True
 
-            return
         else:
             # what to do?
-            return
+
+        return
 
 
     def _child_signal_handler(self,signal_data,signum,frame):
+        """
+        Signal handler for child processes.
+        """
         sigstring = signal.Signals(signum).name
 
         if sigstring in self.signal_count:
             self.signal_count[sigstring] += 1
         else:
             self.signal_count[sigstring] = 1
+
+        # if we receive the signal three times, exit
         if self.signal_count[sigstring] > 3:
             print("caught > 3 times : exit")
             self.exit(code=2)
@@ -1872,6 +1878,8 @@ class Population:
             ','.join(signal_data.keys())
         ))
 
+        # SIGINT should stop the queue: this is
+        # what Slurm sends to end a process
         if signum == signal.SIGINT:
             self.grid_options['stop_queue'] = True
             self.grid_options['_killed'] = True
@@ -1960,9 +1968,7 @@ class Population:
             0  # counter for the actual amt of systems this thread ran
         )
         zero_prob_stars_skipped = 0
-
         total_time_calling_binary_c = 0
-
         total_mass_run = 0
         total_probability_weighted_mass_run = 0
 
@@ -2166,18 +2172,13 @@ class Population:
                 print("Child: Stop queue at system {n}".format(n=number_of_systems_run))
                 break
 
-
         if self.grid_options['stop_queue']:
-            print("Child: FLUSH JOB QUEUE")
             # any remaining jobs should be ignored
             try:
                 while True:
                     job_queue.get_nowait()
             except queue.Empty:
                 pass
-            print("Child: FLUSHED JOB QUEUE")
-
-        print("Child : Q finished",flush=True)
 
         # Set status to finishing
         self.set_status("finishing")
@@ -2610,16 +2611,12 @@ class Population:
         """
 
         # Reset values
-        self.grid_options["_count"] = 0
-        self.grid_options["_probtot"] = 0
+        for x in ["_count","_probtot","_failed_count","_failed_prob","_total_mass_run","_total_probability_weighted_mass_run"]:
+            self.grid_options[x] = 0
+        for x in ["_errors_found","_errors_exceeded"]:
+            self.grid_options[x] = False
         self.grid_options["_system_generator"] = None
-        self.grid_options["_failed_count"] = 0
-        self.grid_options["_failed_prob"] = 0
-        self.grid_options["_errors_found"] = False
-        self.grid_options["_errors_exceeded"] = False
         self.grid_options["_failed_systems_error_codes"] = []
-        self.grid_options["_total_mass_run"] = 0
-        self.grid_options["_total_probability_weighted_mass_run"] = 0
 
         # Remove files
         # TODO: remove files
@@ -2637,6 +2634,9 @@ class Population:
     # a variable grid
     ###################################################
     def _gridcode_filename(self):
+        """
+        Returns a filename for the gridcode.
+        """
         if self.grid_options['slurm'] > 0:
             filename = os.path.join(
                 self.grid_options["tmp_dir"],
@@ -2675,14 +2675,12 @@ class Population:
         """
         return an indent block, with n extra blocks in it
         """
-
         return (self.indent_depth + n) * self.indent_string
 
     def _increment_indent_depth(self, delta):
         """
         increment the indent indent_depth by delta
         """
-
         self.indent_depth += delta
 
     def _generate_grid_code(self, dry_run=False):
@@ -2706,7 +2704,6 @@ class Population:
 
         Results in a generated file that contains a system_generator function.
         """
-
         verbose_print("Generating grid code", self.grid_options["verbosity"], 1)
 
         total_grid_variables = len(self.grid_options["_grid_variables"])
@@ -3409,12 +3406,9 @@ class Population:
         """
         Function to go through the source_file and count the number of lines and the total probability
         """
-
         system_generator = self.grid_options["_system_generator"]
-
         total_starcount = 0
         total_probability = 0
-
         contains_probability = False
 
         for line in system_generator:
@@ -3466,7 +3460,6 @@ class Population:
         """
         Function that creates a dict from a binary_c arg line
         """
-
         if line.startswith("binary_c "):
             line = line.replace("binary_c ", "")
 
@@ -5231,6 +5224,7 @@ eccentricity3=0
             with open(file,'w') as f:
                 f.write(string)
                 f.close()
+        return
 
     def get_slurm_status(self,
                          jobid=None,
@@ -5426,6 +5420,8 @@ eccentricity3=0
                                                                                                                     jobarrayindex=jobarrayindex,
                                                                                                                     joinfile=joinfile),
                 "\n# run grid of stars and, if this returns 0, set status to finished\n",
+
+                # note: the next line ends in &&
                 "{grid_command} evolution_type=grid slurm_jobid={jobid} slurm_jobarrayindex={jobarrayindex} save_population_object={slurm_dir}/results/{jobid}.{jobarrayindex}.gz && echo -n \"finished\" > {slurm_dir}/status/{jobid}.{jobarrayindex} && \\\n".format(
                     slurm_dir=self.grid_options['slurm_dir'],
                     jobid=jobid,
@@ -5435,6 +5431,8 @@ eccentricity3=0
 
             if not self.grid_options['slurm_postpone_join']:
                 lines += [
+                    # the following line also ends in && so that if one fails, the rest
+                    # also fail
                     "echo && echo \"Checking if we can join...\" && echo && \\\n",
                     "{grid_command} slurm=2 evolution_type=join joinlist={joinfile} slurm_jobid={jobid} slurm_jobarrayindex={jobarrayindex}\n\n".format(
                         grid_command=grid_command,
@@ -5564,6 +5562,8 @@ eccentricity3=0
             # touch 'saved' file
             pathlib.Path(filename + '.saved').touch(exist_ok=True)
 
+        return
+
     def load_population_object(self,filename):
         """
         returns the Population object loaded from filename
@@ -5673,7 +5673,7 @@ eccentricity3=0
 
     def joinfiles(self):
         """
-        Function to load in the joinlist to an array
+        Function to load in the joinlist to an array and return it.
         """
         f = open(self.grid_options['joinlist'],'r')
         list = f.read().splitlines()
@@ -5681,18 +5681,21 @@ eccentricity3=0
         return list
 
     def join_from_files(self,newobj,joinfiles):
-        # merge the results from many object files
-        # into newobj
+        """
+        Merge the results from the list joinfiles into newobj.
+        """
         for file in joinfiles:
-            print("join data in",file)
+            print("Join data in",file)
             self.merge_populations_from_file(newobj,
                                              file)
         return newobj
 
     def can_join(self,joinfiles,joiningfile,vb=False):
-        # check the joinfiles to make sure they all exist
-        # and their .saved equivalents also exist
-        vb= True
+        """
+        Check the joinfiles to make sure they all exist
+        and their .saved equivalents also exist
+        """
+        vb = False # for debugging set this to True
         if os.path.exists(joiningfile):
             if vb:
                 print("cannot join: joiningfile exists at {}".format(joiningfile))
@@ -5731,7 +5734,10 @@ eccentricity3=0
         return True
 
     def add_system_metadata(self):
-
+        """
+        Add system's metadata to the grid_ensemble_results, and
+        add some system information to metadata.
+        """
         # add metadata if it doesn't exist
         if not "metadata" in self.grid_ensemble_results:
             self.grid_ensemble_results["metadata"] = {}
@@ -5765,6 +5771,7 @@ eccentricity3=0
         except Exception as e:
             print("Failure to calculate time elapsed:",e)
             pass
+
         try:
             self.grid_ensemble_results["metadata"]['CPU_time'] = self.CPU_time()
         except Exception as e:
@@ -5828,3 +5835,20 @@ eccentricity3=0
         except:
             pass
         return killed
+
+    def slurm_jobid_from_dir(self,dir):
+        """
+        Return the Slurm jobid from a slurm directory, passed in
+        """
+        file = os.path.join(dir,'jobid')
+        f = open(file,"r")
+        if not f:
+            print("Error: could not open {} to read the Slurm jobid of the directory {}".format(file,dir))
+            sys.exit(code=1)
+        oldjobid = f.read().strip()
+        f.close()
+        if not oldjobid:
+            print("Error: could not find jobid in {}".format(self.grid_options['slurm_restart_dir']))
+            self.exit(code=1)
+
+        return oldjobid
-- 
GitLab