From 836b68adfda5da4f1d6861c4a953869b4b14e5eb Mon Sep 17 00:00:00 2001
From: David Hendriks <davidhendriks93@gmail.com>
Date: Wed, 5 May 2021 02:09:48 +0100
Subject: [PATCH] Working on the grid M&S

---
 binarycpython/utils/grid.py         | 200 ++++++++++++++++------------
 binarycpython/utils/useful_funcs.py |  40 ++++++
 2 files changed, 156 insertions(+), 84 deletions(-)

diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py
index 1821f16e8..6c2b67cdd 100644
--- a/binarycpython/utils/grid.py
+++ b/binarycpython/utils/grid.py
@@ -1724,84 +1724,7 @@ class Population:
             # whether this is the last loop.
             if loopnr == len(self.grid_options["_grid_variables"]) - 1:
 
-                code_string = self._write_gridcode_system_call(code_string, indent, depth, grid_variable, dry_run)
-
-def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable, dry_run):
-
-
-                #################################################################################
-                # Here are the calls to the queuing or other solution. this part is for every system
-                # Add comment
-                code_string += indent * (depth + 1) + "#" * 40 + "\n"
-                code_string += (
-                    indent * (depth + 1)
-                    + "# Code below will get evaluated for every generated system\n"
-                )
-
-                # Calculate value
-                code_string += (
-                    indent * (depth + 1)
-                    + 'probability = self.grid_options["weight"] * probabilities_list[{}]'.format(
-                        grid_variable["grid_variable_number"]
-                    )
-                    + "\n"
-                )
-                # TODO: ask rob if just replacing this with probability is enough
-                code_string += (
-                    indent * (depth + 1)
-                    # + 'repeat_probability = probability / self.grid_options["repeat"]'
-                    + 'probability = probability / self.grid_options["repeat"]'
-                    + "\n"
-                )
-
-                # For each repeat of the system this has to be done yes.
-                code_string += (
-                    indent * (depth + 1)
-                    + 'for _ in range(self.grid_options["repeat"]):'
-                    + "\n"
-                )
-
-                code_string += indent * (depth + 2) + "_total_starcount += 1\n"
-
-                # set probability and phasevol values
-                code_string += (
-                    indent * (depth + 2)
-                    + 'parameter_dict["{}"] = {}'.format("probability", "probability")
-                    + "\n"
-                )
-                code_string += (
-                    indent * (depth + 2)
-                    + 'parameter_dict["{}"] = {}'.format("phasevol", "phasevol")
-                    + "\n"
-                )
-
-                # Some prints. will be removed
-                # code_string += indent * (depth + 1) + "print(probabilities)\n"
-                # code_string += (
-                #     indent * (depth + 1) + 'print("_total_starcount: ", _total_starcount)\n'
-                # )
-
-                # code_string += indent * (depth + 1) + "print(probability)\n"
-
-                # Increment total probability
-                code_string += (
-                    indent * (depth + 2) + "self._increment_probtot(probability)\n"
-                )
-
-                if not dry_run:
-                    # Handling of what is returned, or what is not.
-                    # TODO: think of whether this is a good method
-                    code_string += indent * (depth + 2) + "yield(parameter_dict)\n"
-
-                    # The below solution might be a good one to add things to specific queues
-                    # $self->queue_evolution_code_run($self->{_flexigrid}->{thread_q},
-                    # $system);
-
-                # If its a dry run, dont do anything with it
-                else:
-                    code_string += indent * (depth + 2) + "pass\n"
-
-                code_string += indent * (depth + 1) + "#" * 40 + "\n"
+                code_string = self._write_gridcode_system_call(code_string, indent, depth, grid_variable, dry_run, grid_variable['branchpoint'])
 
             # increment depth
             depth += 1
@@ -1813,14 +1736,16 @@ def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable,
         # this has to go in a reverse order:
         # Here comes the stuff that is put after the deepest nested part that calls returns stuff.
         # Here we will have a 
-        for loopnr, grid_variable_el in enumerate(
-            sorted(
+        reverse_sorted_grid_variables = sorted(
                 self.grid_options["_grid_variables"].items(),
                 key=lambda x: x[1]["grid_variable_number"],
                 reverse=True,
-            )
+        )
+        for loopnr, grid_variable_el in enumerate(
+            reverse_sorted_grid_variables
         ):
             grid_variable = grid_variable_el[1]
+
             code_string += indent * (depth + 1) + "#" * 40 + "\n"
             code_string += (
                 indent * (depth + 1)
@@ -1836,6 +1761,28 @@ def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable,
 
             depth -= 1
 
+            # Check the branchpoint part here. The branchpoint makes sure that we can construct 
+            # a grid with several multiplicities and still can make the system calls for each 
+            # multiplicity without reconstructing the grid each time
+            if grid_variable['branchpoint'] == 1:
+
+                # Add comment
+                code_string += (
+                    indent * (depth + 1)
+                    + "# Condition for branchpoint at {}".format(reverse_sorted_grid_variables[loopnr+1][1]["parameter_name"])
+                    + "\n"
+                )
+
+                # Add condition check
+                code_string += (
+                    indent * (depth + 1)
+                    + "if not {}:".format(grid_variable["condition"])
+                    + "\n"
+                )
+
+                code_string = self._write_gridcode_system_call(code_string, indent, depth+1, reverse_sorted_grid_variables[loopnr+1][1], dry_run, grid_variable['branchpoint'])
+                code_string += "\n"
+
         ################
         # Finalising print statements
         #
@@ -1883,6 +1830,91 @@ def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable,
         with open(gridcode_filename, "w") as file:
             file.write(code_string)
 
+
+    def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable, dry_run, branchpoint):
+        #################################################################################
+        # Here are the calls to the queuing or other solution. this part is for every system
+        # Add comment
+        code_string += indent * (depth + 1) + "#" * 40 + "\n"
+
+        if branchpoint:
+            code_string += (
+                indent * (depth + 1)
+                + "# Code below will get evaluated for every system at this level of multiplicity (last one of that being {})\n"
+            ).format(grid_variable["name"])
+        else:
+            code_string += (
+                indent * (depth + 1)
+                + "# Code below will get evaluated for every generated system\n"
+            )
+
+        # Calculate value
+        code_string += (
+            indent * (depth + 1)
+            + 'probability = self.grid_options["weight"] * probabilities_list[{}]'.format(
+                grid_variable["grid_variable_number"]
+            )
+            + "\n"
+        )
+        # TODO: ask rob if just replacing this with probability is enough
+        code_string += (
+            indent * (depth + 1)
+            # + 'repeat_probability = probability / self.grid_options["repeat"]'
+            + 'probability = probability / self.grid_options["repeat"]'
+            + "\n"
+        )
+
+        # For each repeat of the system this has to be done yes.
+        code_string += (
+            indent * (depth + 1)
+            + 'for _ in range(self.grid_options["repeat"]):'
+            + "\n"
+        )
+
+        code_string += indent * (depth + 2) + "_total_starcount += 1\n"
+
+        # set probability and phasevol values
+        code_string += (
+            indent * (depth + 2)
+            + 'parameter_dict["{}"] = {}'.format("probability", "probability")
+            + "\n"
+        )
+        code_string += (
+            indent * (depth + 2)
+            + 'parameter_dict["{}"] = {}'.format("phasevol", "phasevol")
+            + "\n"
+        )
+
+        # Some prints. will be removed
+        # code_string += indent * (depth + 1) + "print(probabilities)\n"
+        # code_string += (
+        #     indent * (depth + 1) + 'print("_total_starcount: ", _total_starcount)\n'
+        # )
+
+        # code_string += indent * (depth + 1) + "print(probability)\n"
+
+        # Increment total probability
+        code_string += (
+            indent * (depth + 2) + "self._increment_probtot(probability)\n"
+        )
+
+        if not dry_run:
+            # Handling of what is returned, or what is not.
+            # TODO: think of whether this is a good method
+            code_string += indent * (depth + 2) + "yield(parameter_dict)\n"
+
+            # The below solution might be a good one to add things to specific queues
+            # $self->queue_evolution_code_run($self->{_flexigrid}->{thread_q},
+            # $system);
+
+        # If its a dry run, dont do anything with it
+        else:
+            code_string += indent * (depth + 2) + "pass\n"
+
+        code_string += indent * (depth + 1) + "#" * 40 + "\n"
+
+        return code_string
+
     def _load_grid_function(self):
         """
         Functon that loads the script containing the grid code.
@@ -3180,7 +3212,7 @@ def _write_gridcode_system_call(self, code_string, indent, depth, grid_variable,
                 ),
                 precode="""orbital_period = 10.0**log10per
 qmin={}/M_1
-qmax=self.maximum_mass_ratio_for_RLOF(M_1, orbital_period)
+qmax=maximum_mass_ratio_for_RLOF(M_1, orbital_period)
     """.format(
                     options.get("Mmin", 0.07)
                 ),
@@ -3264,7 +3296,7 @@ sep = calc_sep_from_period(M_1, M_2, orbital_period)
                     ),
                     precode="""orbital_period_triple = 10.0**log10per2
 q2min={}/(M_1+M_2)
-q2max=self.maximum_mass_ratio_for_RLOF(M_1+M_2, orbital_period_triple)
+q2max=maximum_mass_ratio_for_RLOF(M_1+M_2, orbital_period_triple)
     """.format(
                         options.get("Mmin", 0.07)
                     ),
@@ -3350,7 +3382,7 @@ eccentricity2=0
                         ),
                         precode="""orbital_period_quadruple = 10.0**log10per3
 q3min={}/(M_3)
-q3max=self.maximum_mass_ratio_for_RLOF(M_3, orbital_period_quadruple)
+q3max=maximum_mass_ratio_for_RLOF(M_3, orbital_period_quadruple)
     """.format(
                             options.get("Mmin", 0.07)
                         ),
diff --git a/binarycpython/utils/useful_funcs.py b/binarycpython/utils/useful_funcs.py
index acf28d6a5..53d1abede 100644
--- a/binarycpython/utils/useful_funcs.py
+++ b/binarycpython/utils/useful_funcs.py
@@ -13,6 +13,7 @@ Functions:
 
 Tasks:
     - TODO: check whether these functions are correct
+    - TODO: add unittest for maximum_mass_ratio_for_RLOF
 """
 
 import math
@@ -99,6 +100,45 @@ def minimum_separation_for_RLOF(M1, M2, metallicity, store_memaddr=-1):
 # print(minimum_separation_for_RLOF(0.08, 0.08, 0.00002))
 # print(minimum_separation_for_RLOF(10, 2, 0.02))
 
+def maximum_mass_ratio_for_RLOF(M1, orbital_period, metallicity=0.02, store_memaddr=None):
+    """
+    Wrapper function for _binary_c_bindings.return_maximum_mass_ratio_for_RLOF
+
+    Handles the output and returns the maximum mass ratio at which RLOF just does not occur at ZAMS
+
+    Args:
+        M1: Primary mass in solar mass
+        orbital_period: orbital period in days
+        metallicity: metallicity
+        store_memaddr (optional): store memory adress
+    Returns:
+        maximum mass ratio that just does not cause a RLOF at ZAMS
+    """
+
+    # Convert to orbital period in years 
+    orbital_period = orbital_period/3.651995478818308811241877265275e+02
+
+    bse_dict = {
+        "M_1": M1,
+        "M_2": 0.01,
+        "separation": 0,
+        "orbital_period": orbital_period,
+        "metallicity": metallicity,
+        "maximum_mass_ratio_for_instant_RLOF": 1,
+    }
+
+    argstring = "binary_c " + create_arg_string(bse_dict)
+    output = _binary_c_bindings.return_maximum_mass_ratio_for_RLOF(argstring, store_memaddr)
+    stripped = output.strip()
+
+    if stripped == "NO MAXIMUM MASS RATIO < 1":
+        maximum_mass_ratio = 1
+    else:
+        maximum_mass_ratio = float(stripped.split()[-1])
+    return maximum_mass_ratio
+
+# print(maximum_mass_ratio_for_RLOF(4, 0.1, 0.002))
+# print(maximum_mass_ratio_for_RLOF(4, 1, 0.002))
 
 def calc_period_from_sep(
     M1: Union[int, float], M2: Union[int, float], sep: Union[int, float]
-- 
GitLab