From f445490ef0e4feceb02f4c81287cd714d9166bdd Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Sat, 27 Nov 2021 14:01:18 +0000
Subject: [PATCH] a few cleanups and attempts to paraellelize loops (that
 failed)

---
 binarycpython/utils/grid.py                  | 19 +++-
 binarycpython/utils/grid_logging.py          | 84 ++++++++++-------
 binarycpython/utils/grid_options_defaults.py |  3 +
 binarycpython/utils/gridcode.py              | 95 ++++++++++++++------
 binarycpython/utils/spacing_functions.py     | 20 ++++-
 5 files changed, 157 insertions(+), 64 deletions(-)

diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index c604de20f..b4a1ea4f5 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -208,7 +208,8 @@ class Population(analytics,
         self.process_ID = 0
 
         # Create location to store results. Users should write to this dictionary.
-        # The AutoVivificationDict allows for perls method of accessing possibly non-existant subdicts
+        # The AutoVivificationDict allows for Perl-like addition of possibly
+        # non-existant subdicts.
         self.grid_results = AutoVivificationDict()
 
         # Create location where ensemble results are written to
@@ -835,7 +836,7 @@ class Population(analytics,
 
         if self.HPC_job():
             self.HPC_dump_status("HPC grid after analytics")
-        
+
         if self.custom_options['save_snapshot']:
             # we must save a snapshot, not the population object
             # ... also save the new starting point: this has to take into
@@ -850,7 +851,7 @@ class Population(analytics,
         # Save object to a pickle file
         elif self.grid_options['save_population_object']:
             self.save_population_object()
-            
+
         # if we're running an HPC grid, exit here
         # unless we're joining
         if self.HPC_job() and \
@@ -882,7 +883,8 @@ class Population(analytics,
         ############################################################
         # Prepare code/initialise grid.
         # set custom logging, set up store_memaddr, build grid code. dry run grid code.
-        self._setup()
+        if self._setup() is False:
+            return
 
         ############################################################
         # Evolve systems
@@ -1876,6 +1878,10 @@ class Population(analytics,
         Since we have different methods of running a population, this setup function
         will do different things depending on different settings
 
+        Returns:
+        True if we want to continue.
+        False if we should return to the original calling script.
+
         Tasks:
             TODO: Make other kinds of populations possible. i.e, read out type of grid,
                 and set up accordingly
@@ -1979,6 +1985,9 @@ class Population(analytics,
                 if self.grid_options["exit_after_dry_run"]:
                     print("Exiting after dry run {}".format(self.grid_options["exit_after_dry_run"]))
                     self.exit(code=0)
+                elif self.grid_options["return_after_dry_run"]:
+                    print("Returning after dry run {}".format(self.grid_options["exit_after_dry_run"]))
+                    return False
 
             #######################
             # Reset values and prepare the grid function
@@ -2041,6 +2050,8 @@ class Population(analytics,
             "_probtot"
         ] = 0  # To make sure that the values are reset. TODO: fix this in a cleaner way
 
+        return True
+
     def _cleanup(self):
         """
         Function that handles all the cleaning up after the grid has been generated and/or run
diff --git a/binarycpython/utils/grid_logging.py b/binarycpython/utils/grid_logging.py
index c97a91484..b695da149 100644
--- a/binarycpython/utils/grid_logging.py
+++ b/binarycpython/utils/grid_logging.py
@@ -247,37 +247,59 @@ class grid_logging():
             system_string += "P=" + format_number(system_dict["orbital_period"])
 
         # do the print
-        self.verbose_print(
-            "{opening_colour}{system_number}/{total_starcount}{modulo} {pc_colour}{pc_complete:5.1f}% complete {time_colour}{hours:02d}:{minutes:02d}:{seconds:02d} {ETA_colour}ETA={ETA:7.1f}{units} tpr={tpr:2.2e} {ETF_colour}ETF={ETF} {mem_use_colour}mem:{mem_use:.1f}MB {system_string_colour}{system_string}{closing_colour}".format(
-                opening_colour=self.ANSI_colours["reset"]
-                + self.ANSI_colours["yellow on black"],
-                system_number=system_number,
-                total_starcount=self.grid_options["_total_starcount"],
-                modulo=modulo,
-                pc_colour=self.ANSI_colours["blue on black"],
-                pc_complete=(100.0 * system_number)
-                / (1.0 * self.grid_options["_total_starcount"])
-                if self.grid_options["_total_starcount"]
-                else -1,
-                time_colour=self.ANSI_colours["green on black"],
-                hours=localtime.tm_hour,
-                minutes=localtime.tm_min,
-                seconds=localtime.tm_sec,
-                ETA_colour=self.ANSI_colours["red on black"],
-                ETA=eta,
-                units=units,
-                tpr=tpr,
-                ETF_colour=self.ANSI_colours["blue"],
-                ETF=etf,
-                mem_use_colour=self.ANSI_colours["magenta"],
-                mem_use=total_mem_use,
-                system_string_colour=self.ANSI_colours["yellow"],
-                system_string=system_string,
-                closing_colour=self.ANSI_colours["reset"],
-            ),
-            self.grid_options["verbosity"],
-            1,
-        )
+        if self.grid_options["_total_starcount"] > 0:
+            self.verbose_print(
+                "{opening_colour}{system_number}/{total_starcount}{modulo} {pc_colour}{pc_complete:5.1f}% complete {time_colour}{hours:02d}:{minutes:02d}:{seconds:02d} {ETA_colour}ETA={ETA:7.1f}{units} tpr={tpr:2.2e} {ETF_colour}ETF={ETF} {mem_use_colour}mem:{mem_use:.1f}MB {system_string_colour}{system_string}{closing_colour}".format(
+                    opening_colour=self.ANSI_colours["reset"]
+                    + self.ANSI_colours["yellow on black"],
+                    system_number=system_number,
+                    total_starcount=self.grid_options["_total_starcount"],
+                    modulo=modulo,
+                    pc_colour=self.ANSI_colours["blue on black"],
+                    pc_complete=(100.0 * system_number)
+                    / (1.0 * self.grid_options["_total_starcount"])
+                    if self.grid_options["_total_starcount"]
+                    else -1,
+                    time_colour=self.ANSI_colours["green on black"],
+                    hours=localtime.tm_hour,
+                    minutes=localtime.tm_min,
+                    seconds=localtime.tm_sec,
+                    ETA_colour=self.ANSI_colours["red on black"],
+                    ETA=eta,
+                    units=units,
+                    tpr=tpr,
+                    ETF_colour=self.ANSI_colours["blue"],
+                    ETF=etf,
+                    mem_use_colour=self.ANSI_colours["magenta"],
+                    mem_use=total_mem_use,
+                    system_string_colour=self.ANSI_colours["yellow"],
+                    system_string=system_string,
+                    closing_colour=self.ANSI_colours["reset"],
+                ),
+                self.grid_options["verbosity"],
+                1,
+            )
+        else:
+            self.verbose_print(
+                "{opening_colour}{system_number}{modulo} {time_colour}{hours:02d}:{minutes:02d}:{seconds:02d} tpr={tpr:2.2e} {mem_use_colour}mem:{mem_use:.1f}MB {system_string_colour}{system_string}{closing_colour}".format(
+                    opening_colour=self.ANSI_colours["reset"]
+                    + self.ANSI_colours["yellow on black"],
+                    system_number=system_number,
+                    modulo=modulo,
+                    time_colour=self.ANSI_colours["green on black"],
+                    hours=localtime.tm_hour,
+                    minutes=localtime.tm_min,
+                    seconds=localtime.tm_sec,
+                    tpr=tpr,
+                    mem_use_colour=self.ANSI_colours["magenta"],
+                    mem_use=total_mem_use,
+                    system_string_colour=self.ANSI_colours["yellow"],
+                    system_string=system_string,
+                    closing_colour=self.ANSI_colours["reset"],
+                ),
+                self.grid_options["verbosity"],
+                1,
+            )
 
     def vb2print(self, system_dict, cmdline_string):
         print(
diff --git a/binarycpython/utils/grid_options_defaults.py b/binarycpython/utils/grid_options_defaults.py
index 9b242f985..c4fa23815 100644
--- a/binarycpython/utils/grid_options_defaults.py
+++ b/binarycpython/utils/grid_options_defaults.py
@@ -55,8 +55,10 @@ class grid_options_defaults():
             "_zero_prob_stars_skipped": 0,
             "ensemble_factor_in_probability_weighted_mass": False,  # Whether to multiply the ensemble results by 1/probability_weighted_mass
             "do_dry_run": True,  # Whether to do a dry run to calculate the total probability for this run
+            "dry_run_num_cores" : 1, # number of parallel processes for the dry run (outer loop)
             "dry_run_hook" : None, # Function hook for the dry run: this function is called, if not None, for every star in the dry run. Useful for checking initial distributions.
             "custom_generator": None,  # Place for the custom system generator
+            "return_after_dry_run" : False,# Return immediately after a dry run?
             "exit_after_dry_run": False,  # Exit after dry run?
             "print_stack_on_exit" : False, # print the stack trace on exit calls?
 
@@ -394,6 +396,7 @@ class grid_options_defaults():
             "_loaded_Moe2017_data": "Internal variable storing whether the Moe and di Stefano (2017) data has been loaded into memory",
             "do_dry_run": "Whether to do a dry run to calculate the total probability for this run",
             "dry_run_hook" : "Function hook to be called for every system in a dry run. The function is passed a dict of the system parameters. Does nothing if None (the default).",
+            "return_after_dry_run" : "If True, return immediately after a dry run (and don't run actual stars). Default is False.",
             "exit_after_dry_run": "If True, exits after a dry run. Default is False.",
             "print_stack_on_exit" : "If True, prints a stack trace when the population's exit method is called.",
             "_Moe2017_JSON_data": "Location to store the loaded Moe&diStefano2017 dataset",  # Stores the data
diff --git a/binarycpython/utils/gridcode.py b/binarycpython/utils/gridcode.py
index 2e14da981..8ec0c40ca 100644
--- a/binarycpython/utils/gridcode.py
+++ b/binarycpython/utils/gridcode.py
@@ -9,6 +9,7 @@ from typing import Union, Any
 
 
 _count = 0 # used for file symlinking (for testing only)
+_numba = False # activate experimental numba code?
 
 class gridcode():
 
@@ -104,6 +105,7 @@ class gridcode():
             "from binarycpython.utils.distribution_functions import *\n",
             "from binarycpython.utils.spacing_functions import *\n",
             "from binarycpython.utils.useful_funcs import *\n",
+            "import numba" if _numba else "",
             "\n\n",
             # Make the function
             "def grid_code(self, print_results=True):\n",
@@ -233,16 +235,44 @@ class gridcode():
                         gridtype=grid_variable['gridtype'],
                     )
                 )
-            self._add_code(
-                "for {name}_sample_number in range({start},len(sampled_values_{name})+{offset}):".format(
-                    name=grid_variable["name"],
-                    offset=offset,
-                    start=start
-                )
-                + "\n"
+
+            stop = "len(sampled_values_{name})+{offset}".format(
+                name=grid_variable["name"],
+                offset=offset
             )
 
-            self._increment_indent_depth(+1)
+            if _numba and grid_variable["dry_parallel"]:
+                # Parallel outer loop
+                self._add_code(
+                    "@numba.jit(parallel=True)\n"
+                    )
+                self._add_code(
+                    "def __parallel_func(phasevol,_total_starcount):\n"
+                    )
+                self._increment_indent_depth(+1)
+                self._add_code(
+                    "for {name}_sample_number in numba.prange({stop}):\n".format(
+                        name=grid_variable["name"],
+                        stop = stop,
+                    ))
+                self._increment_indent_depth(+1)
+                if start > 0:
+                    self._add_code(
+                        "if {name}_sample_number < {start}:\n".format(
+                            name=grid_variable["name"],
+                            start = start,
+                        ))
+                    self._add_code(
+                        "continue\n",indent=1
+                    )
+            else:
+                self._add_code(
+                    "for {name}_sample_number in range({start},{stop}):\n".format(
+                        name=grid_variable["name"],
+                        start = start,
+                        stop = stop,
+                    ))
+                self._increment_indent_depth(+1)
 
             # {}_this_index is this grid point's index
             # {}_prev_index and {}_next_index are the previous and next grid points,
@@ -538,6 +568,10 @@ class gridcode():
 
             self._increment_indent_depth(-2)
 
+            if _numba and grid_variable["dry_parallel"]:
+                self._add_code("__parallel_func(phasevol,_total_starcount)\n")
+                self._increment_indent_depth(-1)
+
             # Check the branchpoint part here. The branchpoint makes sure that we can construct
             # a grid with several multiplicities and still can make the system calls for each
             # multiplicity without reconstructing the grid each time
@@ -701,9 +735,12 @@ class gridcode():
 
         # If its a dry run, dont do anything with it
         else:
+            # run the hook function, only if given
             if self.grid_options['dry_run_hook']:
-                self._add_code("self.grid_options['dry_run_hook'](parameter_dict)\n",indent=1)
-            self._add_code("pass\n", indent=1)
+                self._add_code("self.grid_options['dry_run_hook'](self,parameter_dict)\n",indent=1)
+            else:
+                # or pass
+                self._add_code("pass\n", indent=1)
 
         self._add_code("#" * 40 + "\n")
 
@@ -842,22 +879,24 @@ class gridcode():
             raise ValueError(msg)
 
     def add_grid_variable(
-        self,
-        name: str,
-        parameter_name: str,
-        longname: str,
-        valuerange: Union[list, str],
-        samplerfunc: str,
-        probdist: str,
-        dphasevol: Union[str, int] = -1,
-        gridtype: str = "centred",
-        branchpoint: int = 0,
-        branchcode: Union[str, None] = None,
-        precode: Union[str, None] = None,
-        postcode: Union[str, None] = None,
-        topcode: Union[str, None] = None,
-        bottomcode: Union[str, None] = None,
-        condition: Union[str, None] = None,
+            self,
+            name: str,
+            parameter_name: str,
+            longname: str,
+            valuerange: Union[list, str],
+            samplerfunc: str,
+            probdist: str,
+            dphasevol: Union[str, int] = -1,
+            gridtype: str = "centred",
+            branchpoint: int = 0,
+            branchcode: Union[str, None] = None,
+            precode: Union[str, None] = None,
+            postcode: Union[str, None] = None,
+            topcode: Union[str, None] = None,
+            bottomcode: Union[str, None] = None,
+            condition: Union[str, None] = None,
+            index: Union[int, None] = None,
+            dry_parallel: Union[bool, None] = False,
     ) -> None:
         """
         Function to add grid variables to the grid_options.
@@ -937,6 +976,9 @@ class gridcode():
                 the lower edge of the value range) or 'centred'
                 (steps starting at lower edge + 0.5 * stepsize).
 
+            dry_parallel:
+                If True, try to parallelize this variable in dry runs.
+
             topcode:
                 Code added at the very top of the block.
 
@@ -968,6 +1010,7 @@ class gridcode():
             "topcode": topcode,
             "bottomcode": bottomcode,
             "grid_variable_number": len(self.grid_options["_grid_variables"]),
+            "dry_parallel": dry_parallel
         }
 
         # Check for gridtype input
diff --git a/binarycpython/utils/spacing_functions.py b/binarycpython/utils/spacing_functions.py
index e07fcdeb7..c1d368b8b 100644
--- a/binarycpython/utils/spacing_functions.py
+++ b/binarycpython/utils/spacing_functions.py
@@ -223,7 +223,7 @@ class spacing_functions():
 
         """
         print("Cache dir {}".format(self.grid_options['cache_dir']))
-        if cachedir == None:
+        if cachedir is not None:
             cachedir = self.grid_options['cache_dir'] + '/const_dt_cache'
             cache = diskcache.Cache(cachedir)
         else:
@@ -290,8 +290,21 @@ class spacing_functions():
                              showtable=showtable,
                              usecache=usecache)
 
+        # if we want to use the cache, set the __decorator
+        # to just be the cache.memoize function, otherwise
+        # make it a wrapped function that just returns the
+        # _const_dt function acting on its arguments
+        def __dummy_decorator(func):
+            @wraps(func)
+            def wrapped(*args, **kwargs):
+                return func(*args, **kwargs)
+            return wrapped
         if cache:
-            eval('@cache.memoize()') # memoize to disc
+            __decorator = cache.memoize
+        else:
+            __decorator = __dummy_decorator
+
+        @__decorator()
         def _const_dt(cachedir=None,
                       num_cores=None,
                       bse_options_json=None, # JSON string
@@ -562,7 +575,8 @@ class spacing_functions():
                                       bse_options=self.bse_options,
                                       **kwargs
         )
-        cache.close()
+        if cache:
+            cache.close()
 
         if kwargs.get('showlist',True):
             print("const_dt mass list ({} masses)\n".format(len(mass_list)), mass_list)
-- 
GitLab