diff --git a/binarycpython/utils/grid.py b/binarycpython/utils/grid.py index 293215d0431e193b7353d30f29531c329946e4f4..a487989ad17c9a5e94df01783f1fa343a61f790e 100644 --- a/binarycpython/utils/grid.py +++ b/binarycpython/utils/grid.py @@ -358,6 +358,7 @@ class Population: == number - 1 ): return grid_variable + def add_grid_variable( self, name: str, @@ -368,7 +369,7 @@ class Population: probdist: str, dphasevol: Union[str, int], parameter_name: str, - gridtype: str = "edge", + gridtype: str = "centred", branchpoint: int = 0, precode: Union[str, None] = None, condition: Union[str, None] = None, @@ -391,25 +392,25 @@ class Population: name: name of parameter. This is evaluated as a parameter and you can use it throughout the rest of the function - + Examples: name = 'lnm1' longname: Long name of parameter - + Examples: longname = 'Primary mass' range: Range of values to take. Does not get used really, the spacingfunction is used to get the values from - + Examples: range = [math.log(m_min), math.log(m_max)] resolution: Resolution of the sampled range (amount of samples). TODO: check if this is used anywhere - Examples: + Examples: resolution = resolution["M_1"] spacingfunction: Function determining how the range is sampled. You can either use a real function, @@ -422,12 +423,12 @@ class Population: precode: Extra room for some code. This code will be evaluated within the loop of the sampling function (i.e. a value for lnm1 is chosen already) - + Examples: precode = 'M_1=math.exp(lnm1);' probdist: Function determining the probability that gets assigned to the sampled parameter - + Examples: probdist = 'Kroupa2001(M_1)*M_1' dphasevol: @@ -441,7 +442,7 @@ class Population: condition = 'self.grid_options['binary']==1' gridtype: Method on how the value range is sampled. Can be either 'edge' (steps starting at - the lower edge of the value range) or 'center' + the lower edge of the value range) or 'centred' (steps starting at lower edge + 0.5 * stepsize). """ @@ -464,6 +465,11 @@ class Population: "grid_variable_number": len(self.grid_options["_grid_variables"]), } + # Check for gridtype input + if not gridtype in ['edge', 'centred']: + msg = "Unknown gridtype value. Please start another one" + raise ValueError(msg) + # Load it into the grid_options self.grid_options["_grid_variables"][grid_variable["name"]] = grid_variable verbose_print( @@ -1920,34 +1926,40 @@ class Population: # TODO: make sure this works # Adding for loop structure - code_string += ( - indent * depth - + "for {} in sampled_values_{}:".format( - grid_variable["name"], grid_variable["name"] - ) - + "\n" - ) - # code_string += ( # indent * depth - # + "for {}_sample_number in range({}):".format( - # grid_variable["name"], grid_variable["resolution"] + # + "for {} in sampled_values_{}:".format( + # grid_variable["name"], grid_variable["name"] # ) # + "\n" # ) - # code_string += ( - # indent * (depth+1) - # + "{} = sampled_values_{}[0] + ((sampled_values_{}[-1]-sampled_values_{}[0])/{}) * {}_sample_number".format( - # grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["resolution"], grid_variable["name"] - # ) - # + "\n" - # ) - - - - - + code_string += ( + indent * depth + + "for {}_sample_number in range({}):".format( + grid_variable["name"], grid_variable["resolution"] + ) + + "\n" + ) + if grid_variable['gridtype'] == 'edge': + code_string += ( + indent * (depth+1) + + "{} = sampled_values_{}[0] + (sampled_values_{}[1]-sampled_values_{}[0]) * {}_sample_number".format( + grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"] + ) + + "\n" + ) + elif grid_variable['gridtype'] == 'centred': + code_string += ( + indent * (depth+1) + + "{} = sampled_values_{}[0] + 0.5 * (sampled_values_{}[1]-sampled_values_{}[0]) + (sampled_values_{}[1]-sampled_values_{}[0]) * {}_sample_number".format( + grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"], grid_variable["name"] + ) + + "\n" + ) + else: + msg = "Unknown gridtype value. PLease choose a different one" + raise ValueError(msg) ################################################################################# # Check condition and generate for loop diff --git a/binarycpython/utils/spacing_functions.py b/binarycpython/utils/spacing_functions.py index 5cdf3522d857dce357ed8202e77f1e015819d95c..2284267736e45386d9db541e94081bb8dbce748d 100644 --- a/binarycpython/utils/spacing_functions.py +++ b/binarycpython/utils/spacing_functions.py @@ -21,7 +21,7 @@ def const( steps: amount of segments between min_bound and max_bound Returns: - np.linspace(min_bound, max_bound, steps) + np.linspace(min_bound, max_bound, steps+1) """ - return np.linspace(min_bound, max_bound, steps) + return np.linspace(min_bound, max_bound, steps+1)