From 3fc53a63e5581ce32b998a82031901ad771d4a0f Mon Sep 17 00:00:00 2001
From: Robert Izzard <r.izzard@surrey.ac.uk>
Date: Tue, 9 Nov 2021 22:39:41 +0000
Subject: [PATCH] can now restart from saved snapshot caused by a SIGINT

doesn't work for slurm yet, but small steps

also fixed a bug in keys_to_floats()
---
 binarycpython/utils/dicts.py                 |  25 +-
 binarycpython/utils/grid.py                  | 326 ++++++++++++-------
 binarycpython/utils/grid_options_defaults.py |   3 +
 3 files changed, 235 insertions(+), 119 deletions(-)

diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py
index 60825dff2..36198a52d 100644
--- a/binarycpython/utils/dicts.py
+++ b/binarycpython/utils/dicts.py
@@ -28,24 +28,33 @@ def keys_to_floats(json_data):
     new_data = type(json_data)()
 
     for k, v in json_data.items():
+        # convert key to a float, if we can
+        # otherwise leave as is
+        try:
+            newkey = float(k)
+        except:
+            newkey = k
+
+        # act on value(s)
         if isinstance(v, list):
-            v = [
+            # list data
+            new_data[newkey] = [
                 keys_to_floats(item)
                 if isinstance(item, collections.abc.Mapping)
                 else item
                 for item in v
             ]
         elif isinstance(v, collections.abc.Mapping):
-            # dict, ordereddict, etc.
-            v = keys_to_floats(v)
-        try:
-            f = float(k)
-            new_data[f] = json_data[k]
-        except:
-            new_data[k] = v
+            # dict, ordereddict, etc. data
+            new_data[newkey] = keys_to_floats(v)
+        else:
+            # assume all other data are scalars
+            new_data[newkey] = v
+
     return new_data
 
 
+
 def recursive_change_key_to_float(input_dict):
     """
     Function to recursively change the key to float
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index 152a76610..18bd0dc84 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -26,6 +26,7 @@ import compress_pickle
 import copy
 import datasize
 import datetime
+import functools
 import json
 import gc
 import gzip
@@ -43,6 +44,7 @@ import py_rinterpolate
 import re
 import resource
 import setproctitle
+import signal
 import stat
 import strip_ansi
 import subprocess
@@ -104,9 +106,8 @@ from binarycpython.utils.dicts import (
     merge_dicts,
     multiply_float_values,
     multiply_values_dict,
-    recursive_change_key_to_float,
-    recursive_change_key_to_string,
     update_dicts,
+    keys_to_floats
 )
 
 # from binarycpython.utils.hpc_functions import (
@@ -153,6 +154,8 @@ class Population:
         self.special_params = [
             el for el in list(self.defaults.keys()) if el.endswith("%d")
         ]
+        self.preloaded_population = None
+        self.signal_count = {}
 
         # make the input dictionary
         self.bse_options = {}  # bse_options is just empty.
@@ -161,7 +164,10 @@ class Population:
         self.grid_options = copy.deepcopy(grid_options_defaults_dict)
 
         # Custom options
-        self.custom_options = {}
+        self.custom_options = {
+            'stop_queue' : False,
+            'save_snapshot' : False
+        }
 
         # grid code generation
         self.indent_depth = 0
@@ -224,6 +230,16 @@ class Population:
         # add metadata
         self.add_system_metadata()
 
+
+    def jobID(self):
+        # job ID
+        if self.grid_options['slurm'] > 0:
+            jobID = "{}.{}".format(self.grid_options['jobid'],
+                                   self.grid_options['slurm_jobarrayindex'])
+        else:
+            jobID = "{}".format(self.process_ID)
+        return jobID
+
     ###################################################
     # Argument functions
     ###################################################
@@ -1035,6 +1051,12 @@ class Population:
         # If num_cores <= 0, set automatically
         #
         # if num_cores is 0, we use as many as we have available
+
+        # backwards compatibility
+        if "amt_cores" in self.grid_options:
+            self.grid_options["num_processes"] = self.grid_options["amt_cores"]
+            self.grid_options["num_cores"] = self.grid_options["amt_cores"]
+
         if self.grid_options['num_cores'] == 0:
             # use all logical cores available to us
             self.grid_options['num_processes'] = max(1,psutil.cpu_count(logical=True))
@@ -1137,7 +1159,7 @@ class Population:
                 sys.exit(1)
 
         # Make sure the subdirs of the tmp dir exist
-        subdirs = ['failed_systems','current_system','process_summary','runtime_systems']
+        subdirs = ['failed_systems','current_system','process_summary','runtime_systems','snapshots']
         for subdir in subdirs:
             path = os.path.join(self.grid_options["tmp_dir"], subdir)
             os.makedirs(path,exist_ok=True)
@@ -1220,7 +1242,7 @@ class Population:
                 "zero_prob_stars_skipped": self.grid_options["_zero_prob_stars_skipped"],
             }
 
-            if 'metadata' in self.grid_ensemble_results:
+            if "metadata" in self.grid_ensemble_results:
                 # Add analytics dict to the metadata too:
                 self.grid_ensemble_results["metadata"].update(analytics_dict)
                 self.add_system_metadata()
@@ -1231,8 +1253,14 @@ class Population:
             except:
                 analytics_dict = {} # should never happen
 
+        if self.custom_options['save_snapshot']:
+            # we must save a snapshot, not the population object
+            self.grid_options['start_at'] = self.grid_options["_count"]
+            self.save_snapshot()
+            sys.exit()
+
         # Save object to a pickle file
-        if self.grid_options['save_population_object']:
+        elif self.grid_options['save_population_object']:
             self.save_population_object()
 
         # if we're running a slurm grid, exit here
@@ -1310,6 +1338,8 @@ class Population:
                 )
             )
 
+
+
         # finished!
         self.grid_options["_end_time_evolution"] = time.time()
 
@@ -1321,7 +1351,7 @@ class Population:
         )
         string2 = "It took a total of {dtsecs} to run {starcount} systems on {ncores} cores\n = {totaldtsecs} of CPU time.\nMaximum memory use {memuse:.3f} MB".format(
             dtsecs=timedelta(dtsecs),
-            starcount=self.grid_options["_total_starcount"],
+            starcount=self.grid_options["_count"], # not _total_count! we may have ended the run early...
             ncores=self.grid_options["num_processes"],
             totaldtsecs=timedelta(dtsecs * self.grid_options["num_processes"]),
             memuse=sum(self.shared_memory["max_memory_use_per_thread"]),
@@ -1404,16 +1434,24 @@ class Population:
             )
 
         # Continuously fill the queue
-
+        signal.signal(signal.SIGTERM,
+                      functools.partial(self._signal_handler))
+        signal.signal(signal.SIGINT,
+                      functools.partial(self._signal_handler))
 
         # start_at can be an expression : we should eval it
         # prior to running the loop
         self.grid_options['start_at'] = eval(str(self.grid_options['start_at']))
+        if self.grid_options['start_at'] > 0:
+            print("Starting at model {} ".format(self.grid_options['start_at']))
 
         for system_number, system_dict in enumerate(generator):
 
+            if self.custom_options['stop_queue']:
+                print("QUEUE DETECTED STOP")
+
             # skip systems before start_at
-            if system_number < self.grid_options["start_at"]:
+            elif system_number < self.grid_options["start_at"]:
                 verbose_print("skip system {n} because < start_at = {start}".format(
                     n=system_number,
                     start=self.grid_options["start_at"]),
@@ -1469,6 +1507,8 @@ class Population:
         When all the systems have been put in the queue we pass a STOP signal
         that will make the processes wrap up.
 
+        We then add any previous population
+
         We read out the information in the result queue and store them in the grid object
         """
 
@@ -1479,11 +1519,6 @@ class Population:
         manager = multiprocessing.Manager()
         job_queue = manager.Queue(maxsize=self.grid_options["max_queue_size"])
 
-        # backwards compatibility
-        if "amt_cores" in self.grid_options:
-            self.grid_options["num_processes"] = self.grid_options["amt_cores"]
-            self.grid_options["num_cores"] = self.grid_options["amt_cores"]
-
         result_queue = manager.Queue(maxsize=self.grid_options["num_processes"])
 
         # Create process instances
@@ -1507,13 +1542,47 @@ class Population:
         for p in processes:
             p.join()
 
+        keylist = ["_failed_count",
+                   "_failed_prob",
+                   "_errors_exceeded",
+                   "_errors_found",
+                   "_probtot",
+                   "_count",
+                   "_total_mass_run",
+                   "_total_probability_weighted_mass_run",
+                   "_zero_prob_stars_skipped"]
+        # todo: error codes
+
         # Handle the results by merging all the dictionaries. How that merging happens exactly is
         # described in the merge_dicts description.
-        combined_output_dict = OrderedDict()
+
+        if self.preloaded_population:
+            combined_output_dict = {
+                "ensemble_results" : keys_to_floats(self.preloaded_population.grid_ensemble_results),
+                "results": keys_to_floats(self.preloaded_population.grid_results)
+                }
+
+            for x in keylist:
+                try:
+                    combined_output_dict[x] = self.preloaded_population.grid_options[x]
+                except Exception as e:
+                    print("OOPS",e,x)
+            print("Pre-loaded data from {} stars".format(combined_output_dict["_count"]))
+        else:
+            combined_output_dict = OrderedDict()
 
         sentinel = object()
         for output_dict in iter(result_queue.get, sentinel):
-            combined_output_dict = merge_dicts(combined_output_dict, output_dict)
+
+            # don't let Xinit be added
+            if "ensemble_results" in combined_output_dict and \
+               "ensemble" in combined_output_dict["ensemble_results"] and \
+               "Xinit" in combined_output_dict["ensemble_results"]["ensemble"]:
+                del combined_output_dict["ensemble_results"]["ensemble"]["Xinit"]
+
+            # merge dicts
+            combined_output_dict = merge_dicts(combined_output_dict,
+                                               keys_to_floats(output_dict))
             if result_queue.empty():
                 break
 
@@ -1523,11 +1592,6 @@ class Population:
         )
         gc.collect()
 
-        # Take into account that we run this on multiple cores
-        combined_output_dict[
-            "_total_probability_weighted_mass_run"
-        ] = combined_output_dict["_total_probability_weighted_mass_run"]
-
         # Put the values back as object properties
         self.grid_results = combined_output_dict["results"]
 
@@ -1539,15 +1603,9 @@ class Population:
 
         # Add metadata
         self.grid_ensemble_results["metadata"] = {}
-        self.grid_ensemble_results["metadata"]["population_id"] = self.grid_options[
-            "_population_id"
-        ]
-        self.grid_ensemble_results["metadata"][
-            "total_probability_weighted_mass"
-        ] = combined_output_dict["_total_probability_weighted_mass_run"]
-        self.grid_ensemble_results["metadata"][
-            "factored_in_probability_weighted_mass"
-        ] = False
+        self.grid_ensemble_results["metadata"]["population_id"] = self.grid_options["_population_id"]
+        self.grid_ensemble_results["metadata"]["total_probability_weighted_mass"] = combined_output_dict["_total_probability_weighted_mass_run"]
+        self.grid_ensemble_results["metadata"]["factored_in_probability_weighted_mass"] = False
         if self.grid_options["ensemble_factor_in_probability_weighted_mass"]:
             multiply_values_dict(
                 self.grid_ensemble_results["ensemble"],
@@ -1556,9 +1614,7 @@ class Population:
                     "total_probability_weighted_mass"
                 ],
             )
-            self.grid_ensemble_results["metadata"][
-                "factored_in_probability_weighted_mass"
-            ] = True
+            self.grid_ensemble_results["metadata"]["factored_in_probability_weighted_mass"] = True
 
         # Add settings of the populations
         all_info = self.return_all_info(
@@ -1573,22 +1629,10 @@ class Population:
 
         ##############################
         # Update grid options
-        self.grid_options["_failed_count"] = combined_output_dict["_failed_count"]
-        self.grid_options["_failed_prob"] = combined_output_dict["_failed_prob"]
-        self.grid_options["_failed_systems_error_codes"] = list(
-            set(combined_output_dict["_failed_systems_error_codes"])
-        )
-        self.grid_options["_errors_exceeded"] = combined_output_dict["_errors_exceeded"]
-        self.grid_options["_errors_found"] = combined_output_dict["_errors_found"]
-        self.grid_options["_probtot"] = combined_output_dict["_probtot"]
-        self.grid_options["_count"] = combined_output_dict["_count"]
-        self.grid_options["_total_mass_run"] = combined_output_dict["_total_mass_run"]
-        self.grid_options[
-            "_total_probability_weighted_mass_run"
-        ] = combined_output_dict["_total_probability_weighted_mass_run"]
-        self.grid_options["_zero_prob_stars_skipped"] = combined_output_dict[
-            "_zero_prob_stars_skipped"
-        ]
+        for x in keylist:
+            self.grid_options[x] = combined_output_dict[x]
+        self.grid_options["_failed_systems_error_codes"] = list(set(combined_output_dict["_failed_systems_error_codes"]))
+
 
     def _evolve_system_mp(self, full_system_dict):
         """
@@ -1628,6 +1672,35 @@ class Population:
             self.custom_options["parameter_dict"] = full_system_dict
             self.grid_options["parse_function"](self, out)
 
+    def _signal_handler(self,signum,frame):
+        """
+        Signal handling function.
+        """
+        sigstring = signal.Signals(signum).name
+
+        if sigstring in self.signal_count:
+            self.signal_count[sigstring] += 1
+        else:
+            self.signal_count[sigstring] = 1
+
+        # tell the user what has happened
+        print("Signal {} caught by process {} count {}".format(sigstring,
+                                                               self.jobID(),
+                                                               self.signal_count[sigstring]))
+
+        if signum == signal.SIGINT:
+            self.custom_options['stop_queue'] = True
+            self.custom_options['save_snapshot'] = True
+            if self.signal_count[sigstring] > 3:
+                print("caught > 3 times : exit")
+                sys.exit()
+
+            return
+
+        else:
+            # what to do?
+            return
+
     def _process_run_population_grid(self, job_queue, result_queue, ID):
         """
         Worker process that gets items from the job_queue and runs those systems.
@@ -1643,10 +1716,15 @@ class Population:
         # set start timer
         start_process_time = datetime.datetime.now()
 
-        #
-        self.process_ID = (
-            ID  # Store the ID as a object property again, lets see if that works.
-        )
+        # set the process ID
+        self.process_ID = ( ID )
+        print("Set process ID",self.process_ID)
+
+        # set handler to catch SIGINT and SIGTERM and exit gracefully
+        signal.signal(signal.SIGTERM,
+                      functools.partial(self._signal_handler))
+        signal.signal(signal.SIGINT,
+                      functools.partial(self._signal_handler))
 
         stream_logger = self._get_stream_logger()
         if self.grid_options["verbosity"] >= _LOGGER_VERBOSITY_LEVEL:
@@ -1656,7 +1734,6 @@ class Population:
         name = "binarycpython population thread {}".format(ID)
         name_proc = "binarycpython population process {}".format(ID)
         setproctitle.setproctitle(name_proc)
-        # setproctitle.setthreadtitle(name)
 
         # Set to starting up
         self.set_status("starting")
@@ -1751,16 +1828,6 @@ class Population:
                             )
                             raise ValueError(msg)
 
-            # self._print_info(
-            #     i + 1, self.grid_options["_total_starcount"], full_system_dict
-            # )
-
-            # verbose_print(
-            #     "Process {} is handling system {}".format(ID, system_number),
-            #     self.grid_options["verbosity"],
-            #     1,
-            # )
-
             ######################
             # Print status of runs
             # save the current time (used often)
@@ -1906,6 +1973,10 @@ class Population:
                 total_mass_system * full_system_dict.get("probability", 1)
             )
 
+            if self.custom_options['stop_queue']:
+                print("Stop queue at system {n}".format(n=number_of_systems_run))
+                break
+
         # Set status to finishing
         self.set_status("finishing")
 
@@ -1938,6 +2009,10 @@ class Population:
                     self.grid_options["verbosity"],
                     1,
                 )
+                ensemble_output = None
+            else:
+                # convert ensemble_raw_output to a dictionary
+                ensemble_output = extract_ensemble_json_from_string(ensemble_raw_output)
 
             # save the ensemble chunk to a file
             if (
@@ -1959,8 +2034,6 @@ class Population:
                     1,
                 )
 
-                ensemble_output = extract_ensemble_json_from_string(ensemble_raw_output)
-
                 self.write_ensemble(output_file, ensemble_output)
 
             # combine ensemble chunks
@@ -1970,10 +2043,7 @@ class Population:
                     self.grid_options["verbosity"],
                     1,
                 )
-
-                ensemble_json["ensemble"] = extract_ensemble_json_from_string(
-                    ensemble_raw_output
-                )  # Load this into a dict so that we can combine it later
+                ensemble_json["ensemble"] = ensemble_output
 
         ##########################
         # Clean up and return
@@ -2077,15 +2147,12 @@ class Population:
             self.grid_options["verbosity"],
             1,
         )
+
         result_queue.put(output_dict)
 
         if self.grid_options["verbosity"] >= _LOGGER_VERBOSITY_LEVEL:
             stream_logger.debug(f"Process-{self.process_ID} is finished.")
 
-        # Don't do this : Clean up the interpolators if they exist
-
-        # TODO: make a cleanup function for the individual threads
-        # TODO: make sure this is necessary. Actually its probably not, because we have a centralised queue
         verbose_print(
             "process {} return ".format(ID),
             self.grid_options["verbosity"],
@@ -2161,6 +2228,9 @@ class Population:
                 function
         """
 
+        # Check for restore
+        if self.grid_options['restore_from_snapshot_file']:
+            self.load_snapshot(self.grid_options['restore_from_snapshot_file'])
 
         # Check for parse function
         if not self.grid_options["parse_function"]:
@@ -2339,12 +2409,6 @@ class Population:
         self.grid_options["_total_mass_run"] = 0
         self.grid_options["_total_probability_weighted_mass_run"] = 0
 
-        # Xinit is overcounted
-        if 'Xinit' in self.grid_ensemble_results['ensemble']:
-            multiply_float_values(self.grid_ensemble_results['ensemble']['Xinit'],
-                                  1.0/float(self.grid_options['num_processes']))
-
-
         # Remove files
         # TODO: remove files
 
@@ -5017,7 +5081,7 @@ eccentricity3=0
         # make a list of directories, these contain the various slurm
         # output, status files, etc.
         dirs = []
-        for dir in ['stdout','stderr','results','status']:
+        for dir in ['stdout','stderr','results','status','snapshots']:
             dirs.append(self.slurmpath(dir))
 
         # make the directories: we do not allow these to already exist
@@ -5253,6 +5317,10 @@ eccentricity3=0
                 probtot=object.grid_options['_probtot'],
                 filename=filename))
 
+
+            # Some parts of the object cannot be pickled:
+            # remove them, and restore them after pickling
+
             # remove shared memory
             shared_memory = object.shared_memory
             object.shared_memory = None
@@ -5261,31 +5329,42 @@ eccentricity3=0
             system_generator = object.grid_options["_system_generator"]
             object.grid_options["_system_generator"] = None
 
+            # delete _store_memaddr
+            _store_memaddr = object.grid_options['_store_memaddr']
+            object.grid_options['_store_memaddr'] = None
+
+            # delete persistent_data_memory_dict
+            persistent_data_memory_dict = object.persistent_data_memory_dict
+            object.persistent_data_memory_dict = None
+
             # add metadata if it doesn't exist
-            if not 'metadata' in object.grid_ensemble_results:
-                object.grid_ensemble_results['metadata'] = {}
+            if not "metadata" in object.grid_ensemble_results:
+                object.grid_ensemble_results["metadata"] = {}
 
             # add datestamp
-            object.grid_ensemble_results['metadata']['save_population_time'] = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
+            object.grid_ensemble_results["metadata"]['save_population_time'] = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
 
             # add extra metadata
             object.add_system_metadata()
 
             # add max memory use
             try:
-                self.grid_ensemble_results['metadata']['max_memory_use'] = copy.deepcopy(sum(shared_memory["max_memory_use_per_thread"]))
+                self.grid_ensemble_results["metadata"]['max_memory_use'] = copy.deepcopy(sum(shared_memory["max_memory_use_per_thread"]))
             except Exception as e:
                 print("save_population_object : Error: ",e)
                 pass
 
             # dump pickle file
             compress_pickle.dump(object,
-                                 filename)
+                                 filename,
+                                 pickler_method='dill')
 
             # restore data
             object.shared_memory = shared_memory
             object.grid_options["_system_generator"] = system_generator
-            del object.grid_ensemble_results['metadata']['save_population_time']
+            del object.grid_ensemble_results["metadata"]['save_population_time']
+            object.grid_options['store_memaddr'] = _store_memaddr
+            object.persistent_data_memory_dict = persistent_data_memory_dict
 
             # touch 'saved' file
             pathlib.Path(filename + '.saved').touch(exist_ok=True)
@@ -5298,12 +5377,15 @@ eccentricity3=0
             obj = None
         else:
             try:
-                obj = compress_pickle.load(filename)
+                obj = compress_pickle.load(filename,
+                                           pickler_method='dill')
             except Exception as e:
                 obj = None
 
         return obj
 
+
+
     def merge_populations(self,refpop,newpop):
         """
         merge newpop's results data into refpop's results data
@@ -5328,8 +5410,8 @@ eccentricity3=0
 
         # special cases
         try:
-            maxmem = max(refpop.grid_ensemble_results['metadata']['max_memory_use'],
-                         newpop.grid_ensemble_results['metadata']['max_memory_use'])
+            maxmem = max(refpop.grid_ensemble_results["metadata"]['max_memory_use'],
+                         newpop.grid_ensemble_results["metadata"]['max_memory_use'])
         except:
             maxmem = 0
 
@@ -5337,11 +5419,11 @@ eccentricity3=0
             # special cases:
             # copy the settings and Xinit: these should just be overridden
             try:
-                settings = copy.deepcopy(newpop.grid_ensemble_results['metadata']['settings'])
+                settings = copy.deepcopy(newpop.grid_ensemble_results["metadata"]['settings'])
             except:
                 settings = None
             try:
-                Xinit = copy.deepcopy(newpop.grid_ensemble_results['ensemble']['Xinit'])
+                Xinit = copy.deepcopy(newpop.grid_ensemble_results["ensemble"]["Xinit"])
             except:
                 Xinit = None
 
@@ -5351,11 +5433,11 @@ eccentricity3=0
 
             # set special cases
             try:
-                refpop.grid_ensemble_results['metadata']['max_memory_use'] = maxmem
+                refpop.grid_ensemble_results["metadata"]['max_memory_use'] = maxmem
                 if settings:
-                    refpop.grid_ensemble_results['metadata']['settings'] = settings
+                    refpop.grid_ensemble_results["metadata"]['settings'] = settings
                 if Xinit:
-                    refpop.grid_ensemble_results['ensemble']['Xinit'] = Xinit
+                    refpop.grid_ensemble_results["ensemble"]["Xinit"] = Xinit
             except:
                 pass
 
@@ -5454,40 +5536,62 @@ eccentricity3=0
     def add_system_metadata(self):
 
         # add metadata if it doesn't exist
-        if not 'metadata' in self.grid_ensemble_results:
-            self.grid_ensemble_results['metadata'] = {}
+        if not "metadata" in self.grid_ensemble_results:
+            self.grid_ensemble_results["metadata"] = {}
 
         # add date
-        self.grid_ensemble_results['metadata']['date'] = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
+        self.grid_ensemble_results["metadata"]['date'] = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
 
         # add platform and build information
-        print("Try to write platform")
         try:
-            self.grid_ensemble_results['metadata']['platform'] = platform.platform()
-            self.grid_ensemble_results['metadata']['platform_uname'] = list(platform.uname())
-            self.grid_ensemble_results['metadata']['platform_machine'] = platform.machine()
-            self.grid_ensemble_results['metadata']['platform_node'] = platform.node()
-            self.grid_ensemble_results['metadata']['platform_release'] = platform.release()
-            self.grid_ensemble_results['metadata']['platform_version'] = platform.version()
-            self.grid_ensemble_results['metadata']['platform_processor'] = platform.processor()
-            self.grid_ensemble_results['metadata']['platform_python_build'] = ' '.join(platform.python_build())
-            self.grid_ensemble_results['metadata']['platform_python_version'] = platform.python_version()
+            self.grid_ensemble_results["metadata"]['platform'] = platform.platform()
+            self.grid_ensemble_results["metadata"]['platform_uname'] = list(platform.uname())
+            self.grid_ensemble_results["metadata"]['platform_machine'] = platform.machine()
+            self.grid_ensemble_results["metadata"]['platform_node'] = platform.node()
+            self.grid_ensemble_results["metadata"]['platform_release'] = platform.release()
+            self.grid_ensemble_results["metadata"]['platform_version'] = platform.version()
+            self.grid_ensemble_results["metadata"]['platform_processor'] = platform.processor()
+            self.grid_ensemble_results["metadata"]['platform_python_build'] = ' '.join(platform.python_build())
+            self.grid_ensemble_results["metadata"]['platform_python_version'] = platform.python_version()
         except Exception as e:
             print("platform call failed:",e)
             pass
 
         try:
-            self.grid_ensemble_results['metadata']['hostname'] = platform.uname()[1]
+            self.grid_ensemble_results["metadata"]['hostname'] = platform.uname()[1]
         except Exception as e:
             print("platform call failed:",e)
             pass
 
         try:
-            self.grid_ensemble_results['metadata']['duration'] = self.time_elapsed()
-            self.grid_ensemble_results['metadata']['CPU_time'] = self.CPU_time()
+            self.grid_ensemble_results["metadata"]['duration'] = self.time_elapsed()
+            self.grid_ensemble_results["metadata"]['CPU_time'] = self.CPU_time()
         except Exception as e:
             print("Failure to calculate time elapsed and/or CPU time consumed")
             pass
 
+        return
+
+    def load_snapshot(self,file):
+        newpop = self.load_population_object(file)
+        self.preloaded_population = newpop
+        self.grid_options['start_at'] = newpop.grid_options['start_at']
+        print("Loaded from snapshot at {file} : start at star {n}".format(
+            file=file,
+            n=self.grid_options['start_at']))
+        return
+
+    def save_snapshot(self):
+
+        if self.grid_options['slurm'] > 0:
+            file = os.path.join(self.grid_options['slurm_dir'],
+                                'snapshots',
+                                self.jobID() + '.gz')
+        else:
+            file = os.path.join(self.grid_options['tmp_dir'],
+                                'snapshot.gz')
+
+        self.save_population_object(object=self,
+                                    filename=file)
 
         return
diff --git a/binarycpython/utils/grid_options_defaults.py b/binarycpython/utils/grid_options_defaults.py
index 83d207062..a564ba651 100644
--- a/binarycpython/utils/grid_options_defaults.py
+++ b/binarycpython/utils/grid_options_defaults.py
@@ -147,6 +147,9 @@ grid_options_defaults_dict = {
     "save_population_object" : None, # filename to which we should save a pickled grid object as the final thing we do
     'joinlist' : None,
     'do_analytics' : True, # if True, calculate analytics prior to return
+    'save_snapshots' : True, # if True, save snapshots on SIGINT
+    'restore_from_snapshot_file' : None, # file to restore from
+    'restore_from_snapshot_dir' : None, # dir to restore from
     ## Monte carlo type evolution
     # TODO: make MC options
     ## Evolution from source file
-- 
GitLab