"""
    Binary_c-python's data input-output (IO) functions
"""

import bz2
import compress_pickle
import copy
import datetime
import gzip
import json
import msgpack
import os
import pathlib
from typing import Union, Any

from binarycpython.utils.ensemble import (
    binaryc_json_serializer,
    ensemble_compression,
    ensemble_file_type,
    extract_ensemble_json_from_string,
    format_ensemble_results,
)
from binarycpython.utils.dicts import (
    merge_dicts,
)

class dataIO():

    def __init__(self, **kwargs):
        # don't do anything: we just inherit from this class
        return

    def dir_ok(self,dir):
        """
        Function to test if we can read and write to a dir that must exist. Return True if all is ok, False otherwise.
        """
        return os.access(dir, os.F_OK) and os.access(dir, os.R_OK | os.W_OK)

    def save_population_object(self,object=None,filename=None,confirmation=True,compression='gzip'):
        """
        Save pickled Population object to file at filename or, if filename is None, whatever is set at self.grid_options['save_population_object']

        Args:
            object : the object to be saved to the file. If object is None, use self.
            filename : the name of the file to be saved. If not set, use self.grid_options['save_population_object']
            confirmation : if True, a file "filename.saved" is touched just after the dump, so we know it is finished.

        Compression is performed according to the filename, as stated in the
        compress_pickle documentation at
        https://lucianopaz.github.io/compress_pickle/html/

        Shared memory, stored in the object.shared_memory dict, is not saved.

        """
        if object is None:
            # default to using self
            object = self

        if filename is None:
            # get filename from self
            filename = self.grid_options['save_population_object']

        if filename:

            print("Save population {id}, probtot {probtot} to pickle in {filename}".format(
                id=self.grid_options["_population_id"],
                probtot=object.grid_options['_probtot'],
                filename=filename))


            # Some parts of the object cannot be pickled:
            # remove them, and restore them after pickling

            # remove shared memory
            shared_memory = object.shared_memory
            object.shared_memory = None

            # delete system generator
            system_generator = object.grid_options["_system_generator"]
            object.grid_options["_system_generator"] = None

            # delete _store_memaddr
            _store_memaddr = object.grid_options['_store_memaddr']
            object.grid_options['_store_memaddr'] = None

            # delete persistent_data_memory_dict
            persistent_data_memory_dict = object.persistent_data_memory_dict
            object.persistent_data_memory_dict = None

            # add metadata if it doesn't exist
            if not "metadata" in object.grid_ensemble_results:
                object.grid_ensemble_results["metadata"] = {}

            # add datestamp
            object.grid_ensemble_results["metadata"]['save_population_time'] = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")

            # add extra metadata
            object.add_system_metadata()

            # add max memory use
            try:
                self.grid_ensemble_results["metadata"]['max_memory_use'] = copy.deepcopy(sum(shared_memory["max_memory_use_per_thread"]))
            except Exception as e:
                print("save_population_object : Error: ",e)
                pass

            # dump pickle file
            compress_pickle.dump(object,
                                 filename,
                                 pickler_method='dill')

            # restore data
            object.shared_memory = shared_memory
            object.grid_options["_system_generator"] = system_generator
            del object.grid_ensemble_results["metadata"]['save_population_time']
            object.grid_options['store_memaddr'] = _store_memaddr
            object.persistent_data_memory_dict = persistent_data_memory_dict

            # touch 'saved' file
            pathlib.Path(filename + '.saved').touch(exist_ok=True)

        return

    def load_population_object(self,filename):
        """
        returns the Population object loaded from filename
        """
        if filename is None:
            obj = None
        else:
            try:
                obj = compress_pickle.load(filename,
                                           pickler_method='dill')
            except Exception as e:
                obj = None

        return obj

    def merge_populations(self,refpop,newpop):
        """
        merge newpop's results data into refpop's results data

        Args:
            refpop : the original "reference" Population object to be added to
            newpop : Population object containing the new data

        Returns:
            nothing

        Note:
            The file should be saved using save_population_object()
        """

        # combine data
        try:
            refpop.grid_results = merge_dicts(refpop.grid_results,
                                              newpop.grid_results)
        except Exception as e:
            print("Error merging grid_results:",e)

        # special cases
        try:
            maxmem = max(refpop.grid_ensemble_results["metadata"]['max_memory_use'],
                         newpop.grid_ensemble_results["metadata"]['max_memory_use'])
        except:
            maxmem = 0

        try:
            # special cases:
            # copy the settings and Xinit: these should just be overridden
            try:
                settings = copy.deepcopy(newpop.grid_ensemble_results["metadata"]['settings'])
            except:
                settings = None
            try:
                Xinit = copy.deepcopy(newpop.grid_ensemble_results["ensemble"]["Xinit"])
            except:
                Xinit = None

            # merge the ensemble dicts
            refpop.grid_ensemble_results = merge_dicts(refpop.grid_ensemble_results,
                                                       newpop.grid_ensemble_results)

            # set special cases
            try:
                refpop.grid_ensemble_results["metadata"]['max_memory_use'] = maxmem
                if settings:
                    refpop.grid_ensemble_results["metadata"]['settings'] = settings
                if Xinit:
                    refpop.grid_ensemble_results["ensemble"]["Xinit"] = Xinit
            except:
                pass

        except Exception as e:
            print("Error merging grid_ensemble_results:",e)

        for key in ["_probtot"]:
            refpop.grid_options[key] += newpop.grid_options[key]

        refpop.grid_options['_killed'] |= newpop.grid_options['_killed']

        return

    def merge_populations_from_file(self,refpop,filename):
        """
         Wrapper for merge_populations so it can be done directly
         from a file.

        Args:
            refpop : the original "reference" Population object to be added to
            filename : file containing the Population object containing the new data

        Note:
            The file should be saved using save_population_object()
        """

        newpop = self.load_population_object(filename)


        # merge with refpop
        try:
            self.merge_populations(refpop,
                                   newpop)
        except Exception as e:
            print("merge_populations gave exception",e)

        return

    def snapshot_filename(self):
        """
        Automatically choose the snapshot filename.
        """
        if self.grid_options['slurm'] > 0:
            file = os.path.join(self.grid_options['slurm_dir'],
                                'snapshots',
                                self.jobID() + '.gz')
        else:
            file = os.path.join(self.grid_options['tmp_dir'],
                                'snapshot.gz')
        return file

    def load_snapshot(self,file):
        """
        Load a snapshot from file and set it in the preloaded_population placeholder.
        """
        newpop = self.load_population_object(file)
        self.preloaded_population = newpop
        self.grid_options['start_at'] = newpop.grid_options['start_at']
        print("Loaded from snapshot at {file} : start at star {n}".format(
            file=file,
            n=self.grid_options['start_at']))
        return

    def save_snapshot(self,file=None):
        """
        Save the population object to a snapshot file, automatically choosing the filename if none is given.
        """

        if file == None:
            file = self.snapshot_filename()
        try:
            n = self.grid_options['_count']
        except:
            n = '?'

        print("Saving snapshot containing {} stars to {}".format(n,file))
        self.save_population_object(object=self,
                                    filename=file)

        return

    def write_ensemble(self, output_file, data=None, sort_keys=True, indent=4, encoding='utf-8', ensure_ascii=False):
        """
            write_ensemble : Write ensemble results to a file.

        Args:
            output_file : the output filename.

                          If the filename has an extension that we recognise,
                          e.g. .gz or .bz2, we compress the output appropriately.

                          The filename should contain .json or .msgpack, the two
                          currently-supported formats.

                          Usually you'll want to output to JSON, but we can
                          also output to msgpack.

            data :   the data dictionary to be converted and written to the file.
                     If not set, this defaults to self.grid_ensemble_results.

            sort_keys : if True, and output is to JSON, the keys will be sorted.
                        (default: True, passed to json.dumps)

            indent : number of space characters used in the JSON indent. (Default: 4,
                     passed to json.dumps)

            encoding : file encoding method, usually defaults to 'utf-8'

            ensure_ascii : the ensure_ascii flag passed to json.dump and/or json.dumps
                           (Default: False)
        """

        # get the file type
        file_type = ensemble_file_type(output_file)

        # choose compression algorithm based on file extension
        compression = ensemble_compression(output_file)

        # default to using grid_ensemble_results if no data is given
        if data is None:
            data = self.grid_ensemble_results

        if not file_type:
            print(
                "Unable to determine file type from ensemble filename {} : it should be .json or .msgpack."
            ).format(output_file)
            self.exit(code=1)
        elif file_type == "JSON":
            # JSON output
            if compression == "gzip":
                # gzip
                f = gzip.open(output_file, "wt", encoding=encoding)
            elif compression == "bzip2":
                # bzip2
                f = bz2.open(output_file, "wt", encoding=encoding)
            else:
                # raw output (not compressed)
                f = open(output_file, "wt", encoding=encoding)
            f.write(json.dumps(data,
                               sort_keys=sort_keys,
                               indent=indent,
                               ensure_ascii=ensure_ascii))

        elif file_type == "msgpack":
            # msgpack output
            if compression == "gzip":
                f = gzip.open(output_file, "wb", encoding=encoding)
            elif compression == "bzip2":
                f = bz2.open(output_file, "wb", encoding=encoding)
            else:
                f = open(output_file, "wb", encoding=encoding)
            msgpack.dump(data, f)
        f.close()

        print(
            "Thread {thread}: Wrote ensemble results to file: {colour}{file}{reset} (file type {file_type}, compression {compression})".format(
                thread=self.process_ID,
                file=output_file,
                colour=self.ANSI_colours["green"],
                reset=self.ANSI_colours["reset"],
                file_type=file_type,
                compression=compression,
            )
        )
    def write_binary_c_calls_to_file(
            self,
            output_dir: Union[str, None] = None,
            output_filename: Union[str, None] = None,
            include_defaults: bool = False,
            encoding='utf-8'
    ) -> None:
        """
        Function that loops over the grid code and writes the generated parameters to a file.
        In the form of a command line call

        Only useful when you have a variable grid as system_generator. MC wouldn't be that useful

        Also, make sure that in this export there are the basic parameters
        like m1,m2,sep, orb-per, ecc, probability etc.

        On default this will write to the datadir, if it exists

        Tasks:
            - TODO: test this function
            - TODO: make sure the binary_c_python .. output file has a unique name

        Args:
            output_dir: (optional, default = None) directory where to write the file to. If custom_options['data_dir'] is present, then that one will be used first, and then the output_dir
            output_filename: (optional, default = None) filename of the output. If not set it will be called "binary_c_calls.txt"
            include_defaults: (optional, default = None) whether to include the defaults of binary_c in the lines that are written. Beware that this will result in very long lines, and it might be better to just export the binary_c defaults and keep them in a separate file.

        Returns:
            filename: filename that was used to write the calls to
        """

        # Check if there is no compiled grid yet. If not, lets try to build it first.
        if not self.grid_options["_system_generator"]:

            ## check the settings:
            if self.bse_options.get("ensemble", None):
                if self.bse_options["ensemble"] == 1:
                    if not self.bse_options.get("ensemble_defer", 0) == 1:
                        verbose_print(
                            "Error, if you want to run an ensemble in a population, the output needs to be deferred",
                            self.grid_options["verbosity"],
                            0,
                        )
                        raise ValueError

            # Put in check
            if len(self.grid_options["_grid_variables"]) == 0:
                print("Error: you haven't defined any grid variables! Aborting")
                raise ValueError

            #
            self._generate_grid_code(dry_run=False)

            #
            self._load_grid_function()

        # then if the _system_generator is present, we go through it
        if self.grid_options["_system_generator"]:
            # Check if there is an output dir configured
            if self.custom_options.get("data_dir", None):
                binary_c_calls_output_dir = self.custom_options["data_dir"]
                # otherwise check if there's one passed to the function
            else:
                if not output_dir:
                    print(
                        "Error. No data_dir configured and you gave no output_dir. Aborting"
                    )
                    raise ValueError
                binary_c_calls_output_dir = output_dir

            # check if there's a filename passed to the function
            if output_filename:
                binary_c_calls_filename = output_filename
                # otherwise use default value
            else:
                binary_c_calls_filename = "binary_c_calls.txt"

            binary_c_calls_full_filename = os.path.join(
                binary_c_calls_output_dir, binary_c_calls_filename
            )
            print("Writing binary_c calls to {}".format(binary_c_calls_full_filename))

            # Write to file
            with open(binary_c_calls_full_filename, "w", encoding=encoding) as file:
                # Get defaults and clean them, then overwrite them with the set values.
                if include_defaults:
                    # TODO: make sure that the defaults here are cleaned up properly
                    cleaned_up_defaults = self.cleaned_up_defaults
                    full_system_dict = cleaned_up_defaults.copy()
                    full_system_dict.update(self.bse_options.copy())
                else:
                    full_system_dict = self.bse_options.copy()

                for system in self.grid_options["_system_generator"](self):
                    # update values with current system values
                    full_system_dict.update(system)

                    binary_cmdline_string = self._return_argline(full_system_dict)
                    file.write(binary_cmdline_string + "\n")
        else:
            print("Error. No grid function found!")
            raise ValueError

        return binary_c_calls_full_filename

    def set_status(self,
                   string,
                   format_statment="process_{}.txt",
                   ID=None,
                   slurm=True,
                   condor=True):
        """
        function to set the status string in its appropriate file
        """

        if ID is None:
            ID = self.process_ID

        if self.grid_options['status_dir']:
            with open(
                    os.path.join(
                        self.grid_options["status_dir"],
                        format_statment.format(ID),
                    ),
                    "w",
                    encoding='utf-8'
            ) as f:
                f.write(string)
                f.close()

        # custom logging functions
        if slurm and self.grid_options['slurm'] >= 1:
            self.set_slurm_status(string)
#        if self.grid_options['condor']==1:
#            self.set_condor_status(string)