"""
Module containing tests regarding the persistent_data memory and the ensemble output
"""

import os
import sys
import time
import json
import textwrap
import binary_c_python_api

from binarycpython.utils.functions import binarycDecoder, temp_dir

TMP_DIR = temp_dir()
os.makedirs(os.path.join(TMP_DIR, "test"),  exist_ok=True)


####
def return_argstring(m1=15.0, m2=14.0, separation=0, orbital_period=453000000000, eccentricity=0.0, metallicity=0.02, max_evolution_time=15000, defer_ensemble=0, ensemble_filters_off=1, ensemble_filter='SUPERNOVAE'):
    """
    Function to make a argstring that we can use in these tests
    """

    # Make the argstrings
    argstring_template = "binary_c M_1 {0:g} M_2 {1:g} separation {2:g} orbital_period {3:g} \
eccentricity {4:g} metallicity {5:g} max_evolution_time {6:g} ensemble 1 ensemble_defer {7} \
ensemble_filters_off {8} ensemble_filter_{9} 1 probability 0.1"

    argstring = argstring_template.format(m1, m2, separation, orbital_period, eccentricity,
        metallicity, max_evolution_time, defer_ensemble, ensemble_filters_off, ensemble_filter)

    return argstring

def test_return_persistent_data_memaddr():
    output = binary_c_python_api.return_persistent_data_memaddr()

    print("function: test_run_system")
    print("Binary_c output:")
    print(textwrap.indent(str(output), "\t"))

def test_passing_persistent_data_to_run_system():
    # Function to test the passing of the persistent data memoery adress, and having ensemble_defer = True
    # We should see that the results of multiple systems have been added to the one output json

    # Make argstrings
    argstring_1 = return_argstring(defer_ensemble=0)
    argstring_1_deferred = return_argstring(defer_ensemble=1)
    argstring_2 = return_argstring(defer_ensemble=0)

    # 
    persistent_data_memaddr = binary_c_python_api.return_persistent_data_memaddr()

    output_1 = binary_c_python_api.run_system(argstring=argstring_1)
    ensemble_jsons_1 = [line for line in output_1.splitlines() if line.startswith("ENSEMBLE_JSON")]
    json_1 = json.loads(ensemble_jsons_1[0][len("ENSEMBLE_JSON "):])

    # Doing 2 systems in a row.
    output_1_deferred = binary_c_python_api.run_system(argstring=argstring_1_deferred, persistent_data_memaddr=persistent_data_memaddr)
    output_2 = binary_c_python_api.run_system(argstring=argstring_2, persistent_data_memaddr=persistent_data_memaddr)
    ensemble_jsons_2 = [line for line in output_2.splitlines() if line.startswith("ENSEMBLE_JSON")]
    json_2 = json.loads(ensemble_jsons_2[0][len("ENSEMBLE_JSON "):])

    # Doing system one again.
    output_1_again = binary_c_python_api.run_system(argstring=argstring_1)
    ensemble_jsons_1 = [line for line in output_1_again.splitlines() if line.startswith("ENSEMBLE_JSON")]
    json_1_again = json.loads(ensemble_jsons_1[0][len("ENSEMBLE_JSON "):])

    assert json_1 == json_1_again, "The system with the same initial settings did not give the same output"
    assert json_1 != json_2, "The output of the deferred two systems should not be the same as the first undeferred output"

def test_full_ensemble_output():
    """
    Function to just output the whole ensemble
    """

    argstring_1 = return_argstring(defer_ensemble=0, ensemble_filters_off=0)
    output_1 = binary_c_python_api.run_system(argstring=argstring_1)
    ensemble_jsons_1 = [line for line in output_1.splitlines() if line.startswith("ENSEMBLE_JSON")]

    start = time.time()
    json_1 = json.loads(ensemble_jsons_1[0][len("ENSEMBLE_JSON "):], cls=binarycDecoder)
    stop = time.time()

    with open(os.path.join(TMP_DIR, "test", "json_full_ensemble.json"), 'w') as f:
        f.write(json.dumps(json_1, indent=4))

    print("took {}s to decode".format(stop-start))
    print("Size of the json in memory: {}".format(sys.getsizeof(json_1)))

    # assert statements:
    assert "number_counts" in json_1.keys()
    assert "HRD" in json_1.keys()
    assert "HRD(t)" in json_1.keys()
    assert "Xyield" in json_1.keys()
    assert "distributions" in json_1.keys()
    assert "scalars" in json_1.keys()

def test_adding_ensemble_output():
    """
    Function that adds the output of 2 ensembles and compares it to the output that we get by deferring the first output
    """

    m1 = 2  # Msun
    m2 = 0.1  # Msun

    #############################################################################################
    # The 2 runs below use the ensemble but do not defer the output to anything else, so that the
    # results are returned directly after the run

    # Direct output commands
    argstring_1 = return_argstring(m1=m1, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=0)
    argstring_2 = return_argstring(m1=m1+1, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=0)

    # Get outputs
    output_1 = binary_c_python_api.run_system(argstring=argstring_1)
    output_2 = binary_c_python_api.run_system(argstring=argstring_2)

    ensemble_jsons_1 = [line for line in output_1.splitlines() if line.startswith("ENSEMBLE_JSON")]
    ensemble_jsons_2 = [line for line in output_2.splitlines() if line.startswith("ENSEMBLE_JSON")]

    json_1 = json.loads(ensemble_jsons_1[0][len("ENSEMBLE_JSON "):], cls=binarycDecoder)
    json_2 = json.loads(ensemble_jsons_2[0][len("ENSEMBLE_JSON "):], cls=binarycDecoder)

    # test_1_total_dict = SumDict(json_1)
    # test_1_total_dict.merge(json_2)

    with open(os.path.join(TMP_DIR, "test", "adding_json_1.json"), 'w') as f:
        f.write(json.dumps(json_1, indent=4))
    with open(os.path.join(TMP_DIR, "test", "adding_json_2.json"), 'w') as f:
        f.write(json.dumps(json_2, indent=4))

    print("Single runs done\n")

    #############################################################################################
    # The 2 runs below use the ensemble and both defer the output so that after they are finished
    # nothing is printed. After that we explicitly free the memory of the persistent_data and
    # have the output returned in that way

    # Deferred commands
    argstring_1_deferred = return_argstring(m1=m1, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1)
    argstring_2_deferred = return_argstring(m1=m1+1, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1)

    # Get a memory location
    persistent_data_memaddr = binary_c_python_api.return_persistent_data_memaddr()

    # Run the systems and defer the output each time
    output_1_deferred = binary_c_python_api.run_system(
        argstring=argstring_1_deferred,
        persistent_data_memaddr=persistent_data_memaddr
    )
    output_2_deferred = binary_c_python_api.run_system(
        argstring=argstring_2_deferred,
        persistent_data_memaddr=persistent_data_memaddr
    )

    # Have the persistent_memory adress be released and have the json outputted
    output_total_deferred = binary_c_python_api.free_persistent_data_memaddr_and_return_json_output(persistent_data_memaddr)

    ensemble_jsons_deferred = [line for line in output_total_deferred.splitlines() if line.startswith("ENSEMBLE_JSON")]

    json_deferred = json.loads(ensemble_jsons_deferred[0][len("ENSEMBLE_JSON "):], cls=binarycDecoder)

    with open(os.path.join(TMP_DIR, "test", "adding_json_deferred.json"), 'w') as f:
        f.write(json.dumps(json_deferred, indent=4))

    print("Double deferred done\n")

    #############################################################################################
    # The 2 runs below use the ensemble and the first one defers the output to the memory,
    # Then the second one uses that memory to combine its results with, but doesn't defer the
    # data after that, so it will print it after the second run is done

    persistent_data_memaddr_2 = binary_c_python_api.return_persistent_data_memaddr()

    # Run the systems and defer the output once and the second time not, so that the second run automatically prints out the results
    output_1_deferred = binary_c_python_api.run_system(
        argstring=argstring_1_deferred,
        persistent_data_memaddr=persistent_data_memaddr_2
    )
    output_2_deferred_and_output = binary_c_python_api.run_system(
        argstring=argstring_2,
        persistent_data_memaddr=persistent_data_memaddr_2
    )

    ensemble_jsons_deferred_and_output = [line for line in output_2_deferred_and_output.splitlines() if line.startswith("ENSEMBLE_JSON")]

    json_deferred_and_output = json.loads(ensemble_jsons_deferred_and_output[0][len("ENSEMBLE_JSON "):], cls=binarycDecoder)

    with open(os.path.join(TMP_DIR, "test", "adding_json_deferred_and_output.json"), 'w') as f:
        f.write(json.dumps(json_deferred_and_output, indent=4))

    print("Single deferred done\n")

def test_free_and_json_output():
    """
    Function that tests the freeing of the memory adress and the output of the json
    """

    m1 = 2  # Msun
    m2 = 0.1  # Msun

    # Get argstring:
    argstring_1 = return_argstring(m1=m2, m2=m2, ensemble_filter="STELLAR_TYPE_COUNTS", defer_ensemble=1)

    # Get a memory adress:
    persistent_data_memaddr = binary_c_python_api.return_persistent_data_memaddr("")

    # Evolve and defer output
    print("evolving")
    output_1_deferred = binary_c_python_api.run_system(argstring=argstring_1, persistent_data_memaddr=persistent_data_memaddr)
    print("Evolved")
    print("Output:")
    print(textwrap.indent(str(output_1_deferred), "\t"))

    # Free memory adress
    print("freeing")
    json_output_by_freeing = binary_c_python_api.free_persistent_data_memaddr_and_return_json_output(persistent_data_memaddr)
    print("Freed")
    print("Output:")
    print(textwrap.indent(str(json_output_by_freeing), "\t"))

####
if __name__ == "__main__":
    test_return_persistent_data_memaddr()
    # test_passing_persistent_data_to_run_system()
    # test_full_ensemble_output()
    # test_adding_ensemble_output()
    # test_free_and_json_output()