import matplotlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import json
import math

def amdahl(f,n):
    return 1.0/((1-f) + (f/n))

#################################
# Files
scaling_result_dir = 'scaling_results'
filenames = [
    'astro2_2500_systems.json',
    'astro2_3000_systems.json',
]
result_jsons = []
for filename in filenames:
    result_jsons.append(os.path.join(os.path.abspath(scaling_result_dir), filename))

#################################
# Plotting of the scaling results
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
for jsonfile in result_jsons:
    print(jsonfile)

    with open(jsonfile, 'r') as f:
        result_data = json.loads(f.read())

        # Get linear data
        linear_data = result_data['linear']
        linear_mean = np.mean(linear_data)
        linear_stdev = np.std(linear_data)

        cpus, speedups, efficiencies, stddev_speedups = [], [], [], []
        for amt_cpus in result_data['mp']:
            # Get mp data
            mp_data = result_data['mp'][amt_cpus]
            mp_mean = np.mean(mp_data)
            mp_stdev = np.std(mp_data)

            # Calc 
            amt_cpus = int(amt_cpus)
            speedup = linear_mean/mp_mean
            stddev_speedup = math.sqrt((linear_stdev/linear_mean)**2 + (mp_stdev/mp_mean)**2) * speedup
            efficiency = speedup/int(amt_cpus)

            # Add to lists
            cpus.append(amt_cpus)
            efficiencies.append(efficiency)
            speedups.append(speedup)
            stddev_speedups.append(stddev_speedup)

        # Plot
        ax1.errorbar(
            cpus,
            speedups,
            stddev_speedups,
            linestyle="None",
            marker="^",
            label="Speed up & efficiency of {} systems".format(result_data['amt_systems']),
        )

        ax2.plot(cpus, efficiencies, alpha=0.5)

        # x_position_shift += 0.1


# Do Amdahls law fitting
# cores = np.arange(1, 48, 0.1)
# values_list = [] 
# par_step = 0.005
# par_vals = np.arange(.95, 1, par_step)


# for par_val in par_vals:
#     values = amdahl(par_val, cores)
#     values_list.append(values)

# for i, values in enumerate(values_list):
#     ax1.plot(cores, values, label="par_val={}".format(par_vals[i]))



#################################
# Adding plot make up
ax1.set_title(
    "Speed up ratio vs amount of cores for different amounts of systems on {}".format(
        'name_testcase'
    )
)

# ax1.plot([1, max(cpus)], [1, max(cpus)], label='100% scaling')

ax1.set_xlabel("Amount of cores used")
ax1.set_ylabel("Speed up ratio (time_linear/time_parallel)")

# ax1.set_xlim(0, max(cpus) + 4)
# ax2.set_ylim(0, 1)










ax1.grid()
ax1.legend(loc=4)
ax1.set_xscale('log')
ax2.set_xscale('log')

# fig.savefig(os.path.join(img_dir, "speedup_scaling_{}.{}".format(name_testcase, "png")))
# fig.savefig(os.path.join(img_dir, "speedup_scaling_{}.{}".format(name_testcase, "pdf")))
# fig.savefig(os.path.join(img_dir, "speedup_scaling_{}.{}".format(name_testcase, "eps")))
plt.show()