"""
Unittests for the functions module
"""

import unittest
import tempfile
from binarycpython.utils.functions import *
from binarycpython.utils.custom_logging_functions import binary_c_log_code
from binarycpython.utils.run_system_wrapper import run_system

TMP_DIR = temp_dir("tests", "test_functions")

class dummy:
    """
    Dummy class to be used in the merge_dicts
    """

    def __init__(self, name):
        """
        init
        """
        self.name = name

    def __str__(self):
        """
        str returns self.name
        """
        return self.name


class test_verbose_print(unittest.TestCase):
    """
    Unittests for verbose_print
    """

    def test_print(self):
        with Capturing() as output:
            self._test_print()

    def _test_print(self):
        """
        Tests whether something gets printed
        """
        verbose_print("test1", 1, 0)

    def test_not_print(self):
        with Capturing() as output:
            self._test_not_print()

    def _test_not_print(self):
        """
        Tests whether nothing gets printed.
        """

        verbose_print("test1", 0, 1)


class test_remove_file(unittest.TestCase):
    """
    Unittests for remove_file
    """

    def test_remove_file(self):
        with Capturing() as output:
            self._test_remove_file()

    def _test_remove_file(self):
        """
        Test to remove a file
        """

        with open(
            os.path.join(TMP_DIR, "test_remove_file_file.txt"), "w"
        ) as f:
            f.write("test")

        remove_file(os.path.join(TMP_DIR, "test_remove_file_file.txt"))

    def test_remove_nonexisting_file(self):
        with Capturing() as output:
            self._test_remove_nonexisting_file()

    def _test_remove_nonexisting_file(self):
        """
        Test to try to remove a nonexistant file
        """

        file = os.path.join(TMP_DIR, "test_remove_nonexistingfile_file.txt")

        remove_file(file)


class test_temp_dir(unittest.TestCase):
    """
    Unittests for temp_dir
    """

    def test_create_temp_dir(self):
        with Capturing() as output:
            self._test_create_temp_dir()

    def _test_create_temp_dir(self):
        """
        Test making a temp directory and comparing that to what it should be
        """

        binary_c_temp_dir = temp_dir()
        general_temp_dir = tempfile.gettempdir()

        self.assertTrue(
            os.path.isdir(os.path.join(general_temp_dir, "binary_c_python"))
        )
        self.assertTrue(
            os.path.join(general_temp_dir, "binary_c_python")
        ) == binary_c_temp_dir


class test_create_hdf5(unittest.TestCase):
    """
    Unittests for create_hdf5
    """

    def test_1(self):
        with Capturing() as output:
            self._test_1()

    def _test_1(self):
        """
        Test that creates files, packs them in a hdf5 file and checks the contents
        """

        testdir = os.path.join(TMP_DIR, "test_create_hdf5")
        os.makedirs(testdir, exist_ok=True)

        # Create dummy settings file:
        settings_dict = {"settings_1": 1, "settings_2": [1, 2]}

        with open(os.path.join(testdir, "example_settings.json"), "w") as f:
            f.write(json.dumps(settings_dict))

        with open(os.path.join(testdir, "data1.dat"), "w") as f:
            f.write("time mass\n")
            f.write("1 10")

        create_hdf5(testdir, "testhdf5.hdf5")
        file = h5py.File(os.path.join(testdir, "testhdf5.hdf5"), "r")

        self.assertIn(b"time", file.get("data/data1_header")[()])
        self.assertIn(b"mass", file.get("data/data1_header")[()])

        self.assertIn("settings_1", json.loads(file.get("settings/used_settings")[()]))
        self.assertIn("settings_2", json.loads(file.get("settings/used_settings")[()]))


class test_return_binary_c_version_info(unittest.TestCase):
    """
    Unittests for return_binary_c_version_info
    """

    def test_not_parsed(self):
        with Capturing() as output:
            self._test_not_parsed()

    def _test_not_parsed(self):
        """
        Test for the raw version_info output
        """

        version_info = return_binary_c_version_info()

        self.assertTrue(isinstance(version_info, str))
        self.assertIn("Build", version_info)
        self.assertIn("REIMERS_ETA_DEFAULT", version_info)
        self.assertIn("SIGMA_THOMPSON", version_info)

    def test_parsed(self):
        with Capturing() as output:
            self._test_parsed()

    def _test_parsed(self):
        """
        Test for the parssed version_info
        """

        # also tests the parse_version_info indirectly
        version_info_parsed = return_binary_c_version_info(parsed=True)

        self.assertTrue(isinstance(version_info_parsed, dict))
        self.assertIn("isotopes", version_info_parsed.keys())
        self.assertIn("argpairs", version_info_parsed.keys())
        self.assertIn("ensembles", version_info_parsed.keys())
        self.assertIn("macros", version_info_parsed.keys())
        self.assertIn("elements", version_info_parsed.keys())
        self.assertIn("dt_limits", version_info_parsed.keys())
        self.assertIn("nucleosynthesis_sources", version_info_parsed.keys())
        self.assertIn("miscellaneous", version_info_parsed.keys())


class test_parse_binary_c_version_info(unittest.TestCase):
    """
    Unittests for function parse_binary_c_version_info
    """

    def test_1(self):
        with Capturing() as output:
            self._test_1()

    def _test_1(self):
        """
        Test for the parsed versio info, more detailed
        """

        info = return_binary_c_version_info()
        parsed_info = parse_binary_c_version_info(info)

        self.assertIn("isotopes", parsed_info.keys())
        self.assertIn("argpairs", parsed_info.keys())
        self.assertIn("ensembles", parsed_info.keys())
        self.assertIn("macros", parsed_info.keys())
        self.assertIn("elements", parsed_info.keys())
        self.assertIn("dt_limits", parsed_info.keys())
        self.assertIn("nucleosynthesis_sources", parsed_info.keys())
        self.assertIn("miscellaneous", parsed_info.keys())

        self.assertIsNotNone(parsed_info["argpairs"])
        self.assertIsNotNone(parsed_info["ensembles"])
        self.assertIsNotNone(parsed_info["macros"])
        self.assertIsNotNone(parsed_info["dt_limits"])
        self.assertIsNotNone(parsed_info["miscellaneous"])

        if parsed_info["macros"]["NUCSYN"] == "on":
            self.assertIsNotNone(parsed_info["isotopes"])

            if parsed_info["macros"]["NUCSYN_ID_SOURCES"] == "on":
                self.assertIsNotNone(parsed_info["nucleosynthesis_sources"])


class test_output_lines(unittest.TestCase):
    """
    Unittests for function output_lines
    """

    def test_1(self):
        with Capturing() as output:
            self._test_1()

    def _test_1(self):
        """
        Test to check if the shape and contents of output_lines is correct
        """

        example_text = "hallo\ntest\n123"
        output_1 = output_lines(example_text)

        self.assertTrue(isinstance(output_1, list))
        self.assertIn("hallo", output_1)
        self.assertIn("test", output_1)
        self.assertIn("123", output_1)


class test_example_parse_output(unittest.TestCase):
    """
    Unittests for function example_parse_output
    """

    def test_normal_output(self):
        with Capturing() as output:
            self._test_normal_output()

    def _test_normal_output(self):
        """
        Test checking if parsed output with a custom logging line works correctly
        """

        # generate logging lines. Here you can choose whatever you want to have logged, and with what header
        # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals.
        logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)'

        # Generate entire shared lib code around logging lines
        custom_logging_code = binary_c_log_code(logging_line)

        # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically.
        output = run_system(
            M_1=1,
            metallicity=0.002,
            M_2=0.1,
            separation=0,
            orbital_period=100000000000,
            custom_logging_code=custom_logging_code,
        )

        parsed_output = example_parse_output(output, "MY_STELLAR_DATA")

        self.assertIn("time", parsed_output)
        self.assertIn("mass", parsed_output)
        self.assertTrue(isinstance(parsed_output["time"], list))
        self.assertTrue(len(parsed_output["time"]) > 0)

    def test_mismatch_output(self):
        with Capturing() as output:
            self._test_mismatch_output()

    def _test_mismatch_output(self):
        """
        Test checking if parsed output with a mismatching headerline doesnt have any contents
        """

        # generate logging lines. Here you can choose whatever you want to have logged, and with what header
        # You can also decide to `write` your own logging_line, which allows you to write a more complex logging statement with conditionals.
        logging_line = 'Printf("MY_STELLAR_DATA time=%g mass=%g\\n", stardata->model.time, stardata->star[0].mass)'

        # Generate entire shared lib code around logging lines
        custom_logging_code = binary_c_log_code(logging_line)

        # Run system. all arguments can be given as optional arguments. the custom_logging_code is one of them and will be processed automatically.
        output = run_system(
            M_1=1,
            metallicity=0.002,
            M_2=0.1,
            separation=0,
            orbital_period=100000000000,
            custom_logging_code=custom_logging_code,
        )

        parsed_output = example_parse_output(output, "MY_STELLAR_DATA_MISMATCH")
        self.assertIsNone(parsed_output)


class test_get_defaults(unittest.TestCase):
    """
    Unittests for function get_defaults
    """

    def test_no_filter(self):
        with Capturing() as output:
            self._test_no_filter()

    def _test_no_filter(self):
        """
        Test checking if the defaults without filtering contains non-filtered content
        """

        output_1 = get_defaults()

        self.assertTrue(isinstance(output_1, dict))
        self.assertIn("colour_log", output_1.keys())
        self.assertIn("M_1", output_1.keys())
        self.assertIn("list_args", output_1.keys())
        self.assertIn("use_fixed_timestep_%d", output_1.keys())

    def test_filter(self):
        with Capturing() as output:
            self._test_filter()

    def _test_filter(self):
        """
        Test checking filtering works correctly
        """

        # Also tests the filter_arg_dict indirectly
        output_1 = get_defaults(filter_values=True)

        self.assertTrue(isinstance(output_1, dict))
        self.assertIn("colour_log", output_1.keys())
        self.assertIn("M_1", output_1.keys())
        self.assertNotIn("list_args", output_1.keys())
        self.assertNotIn("use_fixed_timestep_%d", output_1.keys())


class test_get_arg_keys(unittest.TestCase):
    """
    Unittests for function get_arg_keys
    """

    def test_1(self):
        with Capturing() as output:
            self._test_1()

    def _test_1(self):
        """
        Test checking if some of the keys are indeed in the list
        """

        output_1 = get_arg_keys()

        self.assertTrue(isinstance(output_1, list))
        self.assertIn("colour_log", output_1)
        self.assertIn("M_1", output_1)
        self.assertIn("list_args", output_1)
        self.assertIn("use_fixed_timestep_%d", output_1)


class test_create_arg_string(unittest.TestCase):
    """
    Unittests for function create_arg_string
    """

    def test_default(self):
        with Capturing() as output:
            self._test_default()

    def _test_default(self):
        """
        Test checking if the argstring is correct
        """

        input_dict = {"separation": 40000, "M_1": 10}
        argstring = create_arg_string(input_dict)
        self.assertEqual(argstring, "separation 40000 M_1 10")

    def test_sort(self):
        with Capturing() as output:
            self._test_sort()

    def _test_sort(self):
        """
        Test checking if the argstring with a different ordered dict is also in a differnt order
        """

        input_dict = {"M_1": 10, "separation": 40000}
        argstring = create_arg_string(input_dict, sort=True)
        self.assertEqual(argstring, "M_1 10 separation 40000")

    def test_filtered(self):
        with Capturing() as output:
            self._test_filtered()

    def _test_filtered(self):
        """
        Test if filtering works
        """

        input_dict = {"M_1": 10, "separation": 40000, "list_args": "NULL"}
        argstring = create_arg_string(input_dict, filter_values=True)
        self.assertEqual(argstring, "M_1 10 separation 40000")


class test_get_help(unittest.TestCase):
    """
    Unit tests for function get_help
    """

    def test_input_normal(self):
        with Capturing() as output:
            self._test_input_normal()

    def _test_input_normal(self):
        """
        Function to test the get_help function
        """

        self.assertEqual(
            get_help("M_1", print_help=False)["parameter_name"],
            "M_1",
            msg="get_help('M_1') should return the correct parameter name",
        )

    def test_no_input(self):
        with Capturing() as output:
            self._test_no_input()

    def _test_no_input(self):
        """
        Test if the result is None if called without input
        """

        output = get_help()
        self.assertIsNone(output)

    def test_wrong_input(self):
        with Capturing() as output:
            self._test_wrong_input()

    def _test_wrong_input(self):
        """
        Test if the result is None if called with an unknown input
        """

        output = get_help("kaasblokjes")
        self.assertIsNone(output)

    # def test_print(self):
    #     output = get_help("M_1", print_help=True)


class test_get_help_all(unittest.TestCase):
    """
    Unit test for get_help_all
    """

    def test_all_output(self):
        with Capturing() as output:
            self._test_all_output()

    def _test_all_output(self):
        """
        Function to test the get_help_all function
        """

        get_help_all_output = get_help_all(print_help=False)
        get_help_all_keys = get_help_all_output.keys()

        self.assertIn("stars", get_help_all_keys, "missing section")
        self.assertIn("binary", get_help_all_keys, "missing section")
        self.assertIn("nucsyn", get_help_all_keys, "missing section")
        self.assertIn("output", get_help_all_keys, "missing section")
        self.assertIn("i/o", get_help_all_keys, "missing section")
        self.assertIn("algorithms", get_help_all_keys, "missing section")
        self.assertIn("misc", get_help_all_keys, "missing section")

    # def test_print(self):
    #     # test if stuff is printed
    #     get_help_all(print_help=True)


class test_get_help_super(unittest.TestCase):
    """
    Unit test for get_help_super
    """

    def test_all_output(self):
        with Capturing() as output:
            self._test_all_output()

    def _test_all_output(self):
        """
        Function to test the get_help_super function
        """

        get_help_super_output = get_help_super()
        get_help_super_keys = get_help_super_output.keys()

        self.assertIn("stars", get_help_super_keys, "missing section")
        self.assertIn("binary", get_help_super_keys, "missing section")
        self.assertIn("nucsyn", get_help_super_keys, "missing section")
        self.assertIn("output", get_help_super_keys, "missing section")
        self.assertIn("i/o", get_help_super_keys, "missing section")
        self.assertIn("algorithms", get_help_super_keys, "missing section")
        self.assertIn("misc", get_help_super_keys, "missing section")

    # def test_print(self):
    #     # test to see if stuff is printed.
    #     get_help_super(print_help=True)


class test_make_build_text(unittest.TestCase):
    """
    Unittests for function
    """

    def test_output(self):
        with Capturing() as output:
            self._test_output()

    def _test_output(self):
        """
        Test checking the contents of the build_text
        """

        build_text = make_build_text()

        # Remove the things
        build_text = build_text.replace("**binary_c git branch**:", ";")
        build_text = build_text.replace("**binary_c git revision**:", ";")
        build_text = build_text.replace("**Built on**:", ";")

        # Split up
        split_text = build_text.split(";")

        # Check whether the contents are actually there
        self.assertNotEqual(split_text[1].strip(), "second")
        self.assertNotEqual(split_text[2].strip(), "second")
        self.assertNotEqual(split_text[3].strip(), "second")


class test_write_binary_c_parameter_descriptions_to_rst_file(unittest.TestCase):
    """
    Unittests for function write_binary_c_parameter_descriptions_to_rst_file
    """

    def test_bad_outputname(self):
        with Capturing() as output:
            self._test_bad_outputname()

    def _test_bad_outputname(self):
        """
        Test checking if None is returned when a bad input name is provided
        """

        output_name = os.path.join(
            TMP_DIR,
            "test_write_binary_c_parameter_descriptions_to_rst_file_test_1.txt",
        )
        output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name)
        self.assertIsNone(output_1)

    def test_checkfile(self):
        with Capturing() as output:
            self._test_checkfile()

    def _test_checkfile(self):
        """
        Test checking if the file is created correctly
        """

        output_name = os.path.join(
            TMP_DIR,
            "test_write_binary_c_parameter_descriptions_to_rst_file_test_1.rst",
        )
        output_1 = write_binary_c_parameter_descriptions_to_rst_file(output_name)
        self.assertTrue(os.path.isfile(output_name))


class test_inspect_dict(unittest.TestCase):
    """
    Unittests for function inspect_dict
    """

    def test_compare_dict(self):
        with Capturing() as output:
            self._test_compare_dict()

    def _test_compare_dict(self):
        """
        Test checking if inspect_dict returns the correct structure by comparing it to known value
        """

        input_dict = {
            "int": 1,
            "float": 1.2,
            "list": [1, 2, 3],
            "function": os.path.isfile,
            "dict": {"int": 1, "float": 1.2},
        }
        output_dict = inspect_dict(input_dict)
        compare_dict = {
            "int": int,
            "float": float,
            "list": list,
            "function": os.path.isfile.__class__,
            "dict": {"int": int, "float": float},
        }
        self.assertTrue(compare_dict == output_dict)

    def test_compare_dict_with_print(self):
        with Capturing() as output:
            self._test_compare_dict_with_print()

    def _test_compare_dict_with_print(self):
        """
        Test checking output is printed
        """

        input_dict = {
            "int": 1,
            "float": 1.2,
            "list": [1, 2, 3],
            "function": os.path.isfile,
            "dict": {"int": 1, "float": 1.2},
        }
        output_dict = inspect_dict(input_dict, print_structure=True)


class test_merge_dicts(unittest.TestCase):
    """
    Unittests for function merge_dicts
    """

    def test_empty(self):
        with Capturing() as output:
            self._test_empty()

    def _test_empty(self):
        """
        Test merging an empty dict
        """

        input_dict = {
            "int": 1,
            "float": 1.2,
            "list": [1, 2, 3],
            "function": os.path.isfile,
            "dict": {"int": 1, "float": 1.2},
        }
        dict_2 = {}
        output_dict = merge_dicts(input_dict, dict_2)
        self.assertTrue(output_dict == input_dict)

    def test_unequal_types(self):
        with Capturing() as output:
            self._test_unequal_types()

    def _test_unequal_types(self):
        """
        Test merging unequal types: should raise valueError
        """

        dict_1 = {"input": 10}
        dict_2 = {"input": "hello"}

        self.assertRaises(ValueError, merge_dicts, dict_1, dict_2)

    def test_bools(self):
        with Capturing() as output:
            self._test_bools()

    def _test_bools(self):
        """
        Test merging dict with booleans
        """

        dict_1 = {"bool": True}
        dict_2 = {"bool": False}
        output_dict = merge_dicts(dict_1, dict_2)

        self.assertTrue(isinstance(output_dict["bool"], bool))
        self.assertTrue(output_dict["bool"])

    def test_ints(self):
        with Capturing() as output:
            self._test_ints()

    def _test_ints(self):
        """
        Test merging dict with ints
        """

        dict_1 = {"int": 2}
        dict_2 = {"int": 1}
        output_dict = merge_dicts(dict_1, dict_2)

        self.assertTrue(isinstance(output_dict["int"], int))
        self.assertEqual(output_dict["int"], 3)

    def test_floats(self):
        with Capturing() as output:
            self._test_floats()

    def _test_floats(self):
        """
        Test merging dict with floats
        """

        dict_1 = {"float": 4.5}
        dict_2 = {"float": 4.6}
        output_dict = merge_dicts(dict_1, dict_2)

        self.assertTrue(isinstance(output_dict["float"], float))
        self.assertEqual(output_dict["float"], 9.1)

    def test_lists(self):
        with Capturing() as output:
            self._test_lists()

    def test_lists(self):
        """
        Test merging dict with lists
        """

        dict_1 = {"list": [1, 2]}
        dict_2 = {"list": [3, 4]}
        output_dict = merge_dicts(dict_1, dict_2)

        self.assertTrue(isinstance(output_dict["list"], list))
        self.assertEqual(output_dict["list"], [1, 2, 3, 4])

    def test_dicts(self):
        with Capturing() as output:
            self._test_dicts()

    def _test_dicts(self):
        """
        Test merging dict with dicts
        """

        dict_1 = {"dict": {"same": 1, "other_1": 2.0}}
        dict_2 = {"dict": {"same": 2, "other_2": [4.0]}}
        output_dict = merge_dicts(dict_1, dict_2)

        self.assertTrue(isinstance(output_dict["dict"], dict))
        self.assertEqual(
            output_dict["dict"], {"same": 3, "other_1": 2.0, "other_2": [4.0]}
        )

    def test_unsupported(self):
        with Capturing() as output:
            self._test_unsupported()

    def _test_unsupported(self):
        """
        Test merging dict with unsupported types. should raise ValueError
        """

        dict_1 = {"new": dummy("david")}
        dict_2 = {"new": dummy("gio")}

        # output_dict = merge_dicts(dict_1, dict_2)
        self.assertRaises(ValueError, merge_dicts, dict_1, dict_2)


class test_binaryc_json_serializer(unittest.TestCase):
    """
    Unittests for function binaryc_json_serializer
    """

    def test_not_function(self):
        with Capturing() as output:
            self._test_not_function()

    def _test_not_function(self):
        """
        Test passing an object that doesnt get turned in to a string
        """

        stringo = "hello"
        output = binaryc_json_serializer(stringo)
        self.assertTrue(stringo == output)

    def test_function(self):
        with Capturing() as output:
            self._test_function()

    def _test_function(self):
        """
        Test passing an object that gets turned in to a string: a function
        """

        string_of_function = str(os.path.isfile)
        output = binaryc_json_serializer(os.path.isfile)
        self.assertTrue(string_of_function == output)


class test_handle_ensemble_string_to_json(unittest.TestCase):
    """
    Unittests for function handle_ensemble_string_to_json
    """

    def test_1(self):
        with Capturing() as output:
            self._test_1()

    def _test_1(self):
        """
        Test passing string representation of a dictionary.
        """

        string_of_function = str(os.path.isfile)
        input_string = '{"ding": 10, "list_example": [1,2,3]}'
        output_dict = handle_ensemble_string_to_json(input_string)

        self.assertTrue(isinstance(output_dict, dict))
        self.assertTrue(output_dict["ding"] == 10)
        self.assertTrue(output_dict["list_example"] == [1, 2, 3])


if __name__ == "__main__":
    unittest.main()