From 1326b21e7f67981182da9181dfe5ca7390ae76f2 Mon Sep 17 00:00:00 2001
From: David Hendriks <davidhendriks93@gmail.com>
Date: Tue, 8 Jun 2021 01:40:42 +0100
Subject: [PATCH] Fixed correct branchpointing

---
 binarycpython/utils/distribution_functions.py | 16 +++--
 binarycpython/utils/grid.py                   | 61 ++++++-------------
 2 files changed, 29 insertions(+), 48 deletions(-)

diff --git a/binarycpython/utils/distribution_functions.py b/binarycpython/utils/distribution_functions.py
index ee4a0e81a..93003faf4 100644
--- a/binarycpython/utils/distribution_functions.py
+++ b/binarycpython/utils/distribution_functions.py
@@ -1006,6 +1006,17 @@ def _poisson(lambda_val, n):
 
     return (lambda_val ** n) * np.exp(-lambda_val) / (1.0 * math.factorial(n))
 
+def get_max_multiplicity(multiplicity_array):
+    """
+    Function to get the maximum multiplicity
+    """
+
+    max_multiplicity = 0
+    for n in range(4):
+        if multiplicity_array[n] > 0:
+            max_multiplicity = n + 1
+    return max_multiplicity
+
 
 def Moe_de_Stefano_2017_multiplicity_fractions(options, verbosity=0):
     """
@@ -1123,10 +1134,7 @@ def Moe_de_Stefano_2017_multiplicity_fractions(options, verbosity=0):
 
         # TODO: ask rob about this part. Not sure if i understand it.
 
-        max_multiplicity = 0
-        for n in range(4):
-            if options["multiplicity_modulator"][n] > 0:
-                max_multiplicity = n + 1
+        max_multiplicity = get_max_multiplicity(options["multiplicity_modulator"])
 
         if max_multiplicity == 1:
             result[0] = 1.0
diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index bda3c3527..614c652a6 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -80,6 +80,7 @@ from binarycpython.utils.distribution_functions import (
     LOG_LN_CONVERTER,
     fill_data,
     Moe_de_Stefano_2017_pdf,
+    get_max_multiplicity,
 )
 
 from binarycpython.utils.spacing_functions import (
@@ -1639,6 +1640,7 @@ class Population:
         # TODO: import only the necessary packages/functions
         # TODO: Put all the masses, eccentricities and periods in there already
         # TODO: Put the certain blocks that are repeated in some subfunctions
+        # TODO: make sure running systems with multicplity 3+ is also possible.
 
         Results in a generated file that contains a system_generator function.
         """
@@ -1986,7 +1988,7 @@ class Population:
             # Check the branchpoint part here. The branchpoint makes sure that we can construct
             # a grid with several multiplicities and still can make the system calls for each
             # multiplicity without reconstructing the grid each time
-            if grid_variable["branchpoint"] == 1:
+            if grid_variable["branchpoint"] > 0:
 
                 # Add comment
                 code_string += (
@@ -1997,10 +1999,17 @@ class Population:
                     + "\n"
                 )
 
-                # Add condition check
+                # # Add condition check
+                # code_string += (
+                #     indent * (depth + 1)
+                #     + "if not {}:".format(grid_variable["condition"])
+                #     + "\n"
+                # )
+
+                # Add branchpoint
                 code_string += (
                     indent * (depth + 1)
-                    + "if not {}:".format(grid_variable["condition"])
+                    + "if multiplicity=={}:".format(grid_variable["branchpoint"])
                     + "\n"
                 )
 
@@ -3359,39 +3368,6 @@ class Population:
         with open(os.path.join(ms_tmp_dir, "moecache.json"), "w") as cache_filehandle:
             cache_filehandle.write(json.dumps(Moecache, indent=4))
 
-        # TODO: remove this and make the plotting of the multiplicity fractions more modular
-        # options['M1'] = 10
-
-        # # multiplicity_fractions = Moe_de_Stefano_2017_multiplicity_fractions(options)
-
-        # import matplotlib.pyplot as plt
-
-        # mass_range = range(1, 80)
-
-        # multiplicity_dict = {}
-        # for mass in mass_range:
-        #     options["M1"] = mass
-        #     multiplicity_fractions = Moe_de_Stefano_2017_multiplicity_fractions(options)
-        #     multiplicity_dict[mass] = multiplicity_fractions
-
-        # single_values = [multiplicity_dict[key][0] for key in multiplicity_dict.keys()]
-        # binary_values = [multiplicity_dict[key][1] for key in multiplicity_dict.keys()]
-        # triple_values = [multiplicity_dict[key][2] for key in multiplicity_dict.keys()]
-        # quad_values = [multiplicity_dict[key][3] for key in multiplicity_dict.keys()]
-
-        # print("mass = {}".format(list(mass_range)))
-        # print("single_values={}".format(single_values))
-        # print("binary_values={}".format(binary_values))
-        # print("triple_values={}".format(triple_values))
-        # print("quad_values={}".format(quad_values))
-
-        # plt.plot(mass_range, single_values, label="single")
-        # plt.plot(mass_range, binary_values, label="binary")
-        # plt.plot(mass_range, triple_values, label="triple")
-        # plt.plot(mass_range, quad_values, label="Quadruple")
-        # plt.legend()
-        # plt.xscale("log")
-        # plt.show()
 
         ############################################################
         # construct the grid here
@@ -3404,11 +3380,8 @@ class Population:
         ############################################################
         # first, the multiplicity, this is 1,2,3,4, ...
         # for singles, binaries, triples, quadruples, ...
-        max_multiplicity = 0
-        for i in range(1, 5):
-            mod = options["multiplicity_modulator"][i - 1]
-            if mod > 0:
-                max_multiplicity = i
+
+        max_multiplicity = get_max_multiplicity(options["multiplicity_modulator"])
         verbose_print(
             "\tMoe_de_Stefano_2017: Max multiplicity = {}".format(max_multiplicity),
             self.grid_options["verbosity"],
@@ -3475,7 +3448,7 @@ class Population:
                 resolution=options["resolutions"]["logP"][0],
                 probdist=1.0,
                 condition='(self.grid_options["multiplicity"] >= 2)',
-                # branchpoint=1,
+                branchpoint=1 if max_multiplicity > 1 else 0, # Signal here to put a branchpoint if we have a max multiplicity higher than 1. 
                 gridtype="centred",
                 dphasevol="({} * dlog10per)".format(LOG_LN_CONVERTER),
                 valuerange=[options["ranges"]["logP"][0], options["ranges"]["logP"][1]],
@@ -3556,7 +3529,7 @@ sep = calc_sep_from_period(M_1, M_2, orbital_period)
                     resolution=options["resolutions"]["logP"][1],
                     probdist=1.0,
                     condition='(self.grid_options["multiplicity"] >= 3)',
-                    branchpoint=1,
+                    branchpoint=2 if max_multiplicity > 2 else 0, # Signal here to put a branchpoint if we have a max multiplicity higher than 1. 
                     gridtype="centred",
                     dphasevol="({} * dlog10per2)".format(LOG_LN_CONVERTER),
                     valuerange=[
@@ -3642,7 +3615,7 @@ eccentricity2=0
                         resolution=options["resolutions"]["logP"][2],
                         probdist=1.0,
                         condition='(self.grid_options["multiplicity"] >= 4)',
-                        branchpoint=1,
+                        branchpoint=3 if max_multiplicity > 3 else 0, # Signal here to put a branchpoint if we have a max multiplicity higher than 1.
                         gridtype="centred",
                         dphasevol="({} * dlog10per3)".format(LOG_LN_CONVERTER),
                         valuerange=[
-- 
GitLab