# Script to generate the test script :P 
import socket, psutil
import numpy as np
import time
import json

from binarycpython.utils.grid import Population
from binarycpython.utils.functions import get_help_all, get_help, create_hdf5


import argparse

# Get some info
amount_of_cores = psutil.cpu_count(logical=False)
amount_of_cpus = psutil.cpu_count()
hostname = socket.gethostname()

# Generate list of cpu amounts to use
if amount_of_cpus <= 4:
    stepsize = 1
elif 4 < amount_of_cpus:
    stepsize = 2
# elif 24 < amount_of_cpus <= 48:
#     stepsize = 4

cpu_list = [1]
for i in range(1, int(amount_of_cpus/stepsize) + 1 ):
    cpu_amt = i * stepsize 
    if not cpu_amt in cpu_list:
        cpu_list.append(i * stepsize)


# set some info
amt_repeats = 5
resolution = {'M_1': 50, 'per': 60}
total_systems = int(np.prod([el for el in resolution.values()]))
result_dir = 'scaling_results'
testcase = 'linear vs MP batched'

# Create dictionairy in which to store all the results:
result_dict = {}

#
result_dict['amt_systems'] = total_systems
result_dict['hostname'] = hostname
result_dict['amt_logical_cores'] = amount_of_cpus
result_dict['amt_of_physical_cores'] = amount_of_cores
result_dict['testcase'] = testcase

################# 
# Configuring population
test_pop = Population()

test_pop.set(
    verbose=1, 
    binary=1,
)

test_pop.add_grid_variable(
    name="M_1",
    longname="log primary mass",
    valuerange=[1, 100],
    resolution="{}".format(resolution['M_1']),
    spacingfunc="const(1, 100, {})".format(resolution['M_1']),
    probdist="Kroupa2001(M_1)",
    # probdist='self.custom_options["extra_prob_function"](M_1)',
    dphasevol="dlnm1",
    parameter_name="M_1",
    condition="",
)

test_pop.add_grid_variable(
    name="period",
    longname="period",
    valuerange=["M_1", 20],
    resolution="{}".format(resolution['per']),
    spacingfunc="np.linspace(1, 10, {})".format(resolution['per']),
    precode="orbital_period = period**2",
    probdist="flat(orbital_period)",
    parameter_name="orbital_period",
    dphasevol="dper",
    condition='self.grid_options["binary"]==1',
)

#######################################################################################
# Execute grids

# Linear runs
linear_times = []
for repeat in range(amt_repeats):
    total_lin_start = time.time()
    evolve_lin_time = test_pop.test_evolve_population_lin()
    total_lin = time.time() - total_lin_start

    print("linear run with {} systems: {} of which {} spent on evolving the systems".format(total_systems, total_lin, evolve_lin_time))
    linear_times.append(total_lin)


result_dict['linear'] = linear_times

#######################################################################################
# MP runs
mp_dict = {}
for cpu_amt in cpu_list:

    mp_times = []

    test_pop.set(amt_cores=cpu_amt)

    # 
    for repeat in range(amt_repeats):
        total_mp_start = time.time()
        evolve_mp_time = test_pop.evolve_population_mp_chunks()
        total_mp = time.time() - total_mp_start

        print("MP ({} nodes) run with {} systems: {} of which {} spent on evolving the systems".format(cpu_amt, total_systems, total_mp, evolve_mp_time))
        mp_times.append(total_mp)

    mp_dict[cpu_amt] = mp_times

result_dict['mp'] = mp_dict

# Write to file
with open('scaling_results/{}_{}_systems.json'.format(hostname, total_systems), 'w') as f:
    f.write(json.dumps(result_dict))