From 2d50457d852db61451ef8ccf3e708a7a5ed58d1d Mon Sep 17 00:00:00 2001
From: David Hendriks <davidhendriks93@gmail.com>
Date: Sat, 9 Jan 2021 13:38:53 +0000
Subject: [PATCH] added new test for distributions and stellar types

---
 binarycpython/tests/main.py               |   1 +
 binarycpython/tests/test_distributions.py | 102 ++++++++++++++++++++++
 binarycpython/tests/test_stellar_types.py |   3 +
 3 files changed, 106 insertions(+)
 create mode 100644 binarycpython/tests/test_stellar_types.py

diff --git a/binarycpython/tests/main.py b/binarycpython/tests/main.py
index 5a00d810a..81fff0a97 100644
--- a/binarycpython/tests/main.py
+++ b/binarycpython/tests/main.py
@@ -12,6 +12,7 @@ from binarycpython.tests.test_run_system_wrapper import *
 from binarycpython.tests.test_spacing_functions import *
 from binarycpython.tests.test_useful_funcs import *
 from binarycpython.tests.test_grid_options_defaults import *
+from binarycpython.tests.test_stellar_types import *
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/binarycpython/tests/test_distributions.py b/binarycpython/tests/test_distributions.py
index bcb038cdf..a78d8e70d 100644
--- a/binarycpython/tests/test_distributions.py
+++ b/binarycpython/tests/test_distributions.py
@@ -25,6 +25,39 @@ class TestDistributions(unittest.TestCase):
 
         self.tolerance = 1e-5
 
+    def test_flat(self):
+        """
+        Unittest for the function flat
+        """
+
+        output_1 = flat()
+
+        self.assertTrue(isinstance(output_1, float))
+        self.assertEqual(output_1, 1.0)
+
+    def test_number(self):
+        """
+        Unittest for function number
+        """
+
+        input_1 = 1.0
+        output_1 = number(input_1)
+
+        self.assertEqual(input_1, output_1)
+
+    def test_const(self):
+        """
+        Unittest for function const
+        """
+
+        output_1 = const(min_bound=0, max_bound=2)
+        self.assertEqual(output_1, 0.5, msg="Value should be 0.5, but is {}".format(output_1))
+
+
+        output_2 = const(min_bound=0, max_bound=2, val=3)
+        self.assertEqual(output_2, 0, msg="Value should be 0, but is {}".format(output_2))
+
+
     def test_powerlaw(self):
         """
         unittest for the powerlaw test
@@ -47,6 +80,9 @@ class TestDistributions(unittest.TestCase):
         for i in range(len(python_results)):
             self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
 
+        # extra test for k = -1
+        self.assertRaises(ValueError, powerlaw, 1, 100, -1, 10)
+
     def test_three_part_power_law(self):
         """
         unittest for three_part_power_law
@@ -71,6 +107,11 @@ class TestDistributions(unittest.TestCase):
         for i in range(len(python_results)):
             self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
 
+        # Extra test:
+        # M < M0
+        self.assertTrue(three_part_powerlaw(0.05, 0.08, 0.1, 1, 300, -1.3, -2.3, -2.3)==0, msg="Probability should be zero as M < M0")
+
+
     def test_Kroupa2001(self):
         """
         unittest for three_part_power_law
@@ -92,6 +133,9 @@ class TestDistributions(unittest.TestCase):
         # GO over the results and check whether they are equal (within tolerance)
         for i in range(len(python_results)):
             self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
+        # Extra tests:
+        self.assertEqual(Kroupa2001(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.3, -2.3))
+
 
     def test_ktg93(self):
         """
@@ -114,6 +158,61 @@ class TestDistributions(unittest.TestCase):
         # GO over the results and check whether they are equal (within tolerance)
         for i in range(len(python_results)):
             self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
+        # extra test:
+        self.assertEqual(ktg93(10, newopts={'mmax': 300}), three_part_powerlaw(10, 0.1, 0.5, 1, 300, -1.3, -2.2, -2.7))
+
+
+    def test_imf_tinsley1980(self):
+        """
+        Unittest for function imf_tinsley1980
+        """
+
+        m = 1.2
+        self.assertEqual(imf_tinsley1980(m), three_part_powerlaw(m, 0.1, 2.0, 10.0, 80.0, -2.0, -2.3, -3.3))
+
+    def test_imf_scalo1986(self):
+        """
+        Unittest for function imf_scalo1986
+        """
+
+        m = 1.2
+        self.assertEqual(imf_scalo1986(m), three_part_powerlaw(m, 0.1, 1.0, 2.0, 80.0, -2.35, -2.35, -2.70))
+
+
+    def test_imf_scalo1998(self):
+        """
+        Unittest for function imf_scalo1986
+        """
+
+        m = 1.2
+        self.assertEqual(imf_scalo1998(m), three_part_powerlaw(m, 0.1, 1.0, 10.0, 80.0, -1.2, -2.7, -2.3))
+
+
+    def test_imf_chabrier2003(self):
+        """
+        Unittest for function imf_chabrier2003
+        """
+
+        input_1 = 0
+        self.assertRaises(ValueError, imf_chabrier2003, input_1)
+
+        # for m=0.5
+        m = 0.5
+        self.assertLess(np.abs(imf_chabrier(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance")
+
+        # For m = 2
+        m = 2
+        self.assertLess(np.abs(imf_chabrier(m)-0.581457346702825), self.tolerance, msg="Difference is bigger than the tolerance")
+
+
+
+    def test_duquennoy1991(self):
+        """
+        Unittest for function duquennoy1991
+        """
+
+        self.assertEqual(duquennoy1991(4.2), gaussian(4.2, 4.8, 2.3, -2, 12))
+
 
     def test_gaussian(self):
         """
@@ -137,6 +236,9 @@ class TestDistributions(unittest.TestCase):
         for i in range(len(python_results)):
             self.assertLess(np.abs(python_results[i] - perl_results[i]), self.tolerance)
 
+        # Extra test:
+        self.assertTrue(gaussian(15, 4.8, 2.3, -2.0, 12.0)==0, msg="Probability should be 0 because the input period is out of bounds")
+
     def test_Arenou2010_binary_fraction(self):
         """
         unittest for three_part_power_law
diff --git a/binarycpython/tests/test_stellar_types.py b/binarycpython/tests/test_stellar_types.py
new file mode 100644
index 000000000..6bc033c6a
--- /dev/null
+++ b/binarycpython/tests/test_stellar_types.py
@@ -0,0 +1,3 @@
+import unittest
+
+from binarycpython.utils.stellar_types import *
\ No newline at end of file
-- 
GitLab