diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py index c474b3dce0e8d94343bbd4a3841460b4724503d8..22b59c88537a8922d73cb949d7ccdcab93eda23a 100644 --- a/binarycpython/utils/dicts.py +++ b/binarycpython/utils/dicts.py @@ -3,13 +3,15 @@ Module containing functions that binary_c-python uses to modify dictionaries. """ import astropy.units as u -import copy -import collections -import math -import numpy as np from collections import ( OrderedDict, ) +import collections +import copy +import math +from natsort import index_natsorted,natsorted +import numpy as np +import pandas as pd import sys # we need to convert keys to floats: @@ -829,7 +831,7 @@ def normalize_dict(result_dict, verbosity=0): result_dict[key] = result_dict[key] / sum_result return result_dict -def fill_2d_dict(data,fill): +def fill_narrow_dict(data,fill): """ Fill a 2D dict with whatever fill is set to when keys are missing. """ @@ -844,8 +846,9 @@ def fill_2d_dict(data,fill): for y in ys: if not y in data[x]: data[x][y] = fill + return data -def format_2d_dict_keys(data,format): +def format_narrow_dict_keys(data,format): """ Reformat float keys in a 2d dict according to the given format """ @@ -866,10 +869,12 @@ def format_2d_dict_keys(data,format): data[x][y] = z return data -def keys_of_2d_dict(data,numpy=False): +def tlist_keys_of_narrow_dict(data, + numpy=False): """ Given a 2D dict data[x][y] = z, return the list of keys as tuples, (x,y), - and equivalent z list. + and equivalent z list. This is useful for providing to plotting and numpy + routines. Args: data: the input 2D dict @@ -886,6 +891,73 @@ def keys_of_2d_dict(data,numpy=False): zs = np.array(zs) return xys,zs +def list_keys_of_narrow_dict(data, + logspace=False, + numpy=False, + zmin=None, + zmax=None, + sortx=False, + sorty=False): + """ + Given a 2D dict data[x][y] = z, return a list containing vectors + [x,y,z] for all the points in the dict, lists of the x and y values, + and the min and max of the z values. + + These data are very useful for plotting and numpy routines. + + Args: + data: the input 2D dict (required) + numpy: if True convert returned data to numpy arrays (default False) + logspace : if True only use z>0 to calculate zmin and zmax (default False) + zmin : start minimum z checks at this value (default None) + zmax : start maximum z checks at this value (default None) + sortx : sorts x key list + sorty : sorts y key list + Returns: + + ds,xs,ys,zmin,zmax + + ds : list of key vectors (x,y) + xs : list of x values + ys : list of y values + zmin : minimum z + zmax : maximum z + + """ + model_space,model_weights = tlist_keys_of_narrow_dict(data,numpy) + + ds = [] + xs = {} + ys = {} + for x,z in zip(model_space,model_weights): + ds.append([x[0], x[1], z]) + xs[x[0]] = 1 + ys[x[1]] = 1 + if not logspace or z > 0.0: + if zmin: + zmin = min(z,zmin) + else: + zmin = z + if zmax: + zmax = max(z,zmax) + else: + zmax = z + + xs = list(xs.keys()) + ys = list(ys.keys()) + + if sortx: + xs = sorted(xs) + if sorty: + ys = sorted(ys) + + if numpy is True: + ds = np.array(ds) + xs = np.array(xs) + ys = np.array(ys) + + return ds,xs,ys,zmin,zmax + def _mindiff(list,tol=1e-10): """ Given a list of unique values, within tol(=1e-7), find the @@ -919,34 +991,72 @@ def _find_nearest_index(array,value): return idx -def fillgrid(data,tol=1e-10,format=None): +def pad_narrow_dict_grid(data, + tol=1e-10, + format=None, + value=0.0, + dx=None, + dy=None): """ - Given a data[x][y] dict, fill it to the appropriate bin widths. + Given data on a grid in a narrow dict, data[x][y], fill it with value where + data is missing. We fill data according to the binwidths dx,dy which are + automatically calculated if set to None. + + Args: + + data : the input dictionary in the form data[x][y] + tol : floating point match tolerance when reconstructing the x locations (default 1e-10) + format : if not None reformat x keys with this format statement (default None) + dx : x bin width, usually found automatically (when dx=None) but you may need to set it manually if there is not much data. (default None) + value : the value to fill in gaps in the data (default 0.0) + """ data = ordered(data) - xdata = list(data.keys()) - ydata = list(data.values()) - dx = _mindiff(xdata) - if not dx: + ds,xs,ys,zmin,zmax = list_keys_of_narrow_dict(data, + sortx=True, + sorty=True) + + # automatically get binwidths + if dx == None: + dx = _mindiff(xs) + if dy == None: + dy = _mindiff(ys) + + if format: + dx = float(format.format(dx)) + dy = float(format.format(dy)) + + if not dx or not dy: print("Warning: mindiff failed in fill() on {} data items. Returning original data.".format(len(xdata))) return data - maxx = xdata[-1] - x = xdata[0] - newdata = {} - index = 0 + maxx = xs[-1] + maxy = ys[-1] + x = xs[0] + newdata = AutoVivificationDict() while x<=maxx: - matchi = _find_nearest_index(xdata,x) - matchx = xdata[matchi] + # get the index of the nearest x value + matchx = xs[_find_nearest_index(xs,x)] + + y = ys[0] + while y<=maxy: + # get the index of the nearest y value + matchy = ys[_find_nearest_index(ys,y)] + + if (x==0.0 and matchx==0.0) or abs(1.0-matchx/x)<=tol and \ + (y==0.0 and matchy==0.0) or abs(1.0-matchy/y)<=tol: + newdata[x][y] = data[matchx][matchy] + + else: + # missing data : fill with value + newdata[x][y] = value + + y += dy + if format: + y = float(format.format(y)) - if (x==0.0 and matchx==0.0) or abs(1.0-matchx/x)<=tol: - newdata[x] = ydata[index] - index += 1 - else: - # missing data - newdata[x] = 0.0 x += dx if format: x = float(format.format(x)) @@ -960,3 +1070,274 @@ def ordered(dict): return collections.OrderedDict( sorted(dict.items(), key = lambda x:(float(x[0])))) + + +def dictdepth(d, n=0): + """ + Return the depth of a nested dictionary + """ + if not d or not isinstance(d, dict): + return n + return max(dictdepth(value, n+1) for value in d.values()) + +def narrow_dict(data, + selectx=None, + excludex=None, + selecty=None, + excludey=None, + fill=None): + """ + Given a dict in data, which is either "wide" (with depth 4) + or narrow (with depth 2), convert to narrow. + + "wide" dict is {xkey}->{x}->{ykey}->{y} + "narrow" dict is {x}->{y} + """ + n = dictdepth(data) + if n == 2: + # already have a narrow dict + return data + elif n == 4: + # we have a wide dict : convert + if selectx and not isinstance(selectx,list): + selectx = [selectx] + if selecty and not isinstance(selecty,list): + selecty = [selecty] + if excludex and not isinstance(excludex,list): + excludex = [excludex] + if excludey and not isinstance(excludey,list): + excludey = [excludey] + + newdata = {} + for xkey in data: + if (not excludex or not xkey in excludex) and \ + (not selectx or xkey in selectx): + newdata[x] = {} + for ykey in data[xkey][x]: + + if (not excludey or not ykey in excludey) and \ + (not selecty or ykey in selecty): + newdata[x][y] = data[xkey][x][ykey][y] + + if fill is not None: + newdata = fill_narrow_dict(newdata,fill) + return newdata + else: + # what to do? return None on error + return None + +def narrow_dict_from_keys(data, + xkeys, + ykeys, + fill=None, + pad = False, + dx = None, + dy = None, + format = None, + ): + """ + Function to convert an ensemble dictionary like + + ...->{xkeys}->{xvalue}->{ykeys}->{yvalue} = z + + to a dictionary like + + {xvalue}->{yvalue} = z + + Note that xkeys and ykeys can be lists, in which case + we loop over several keys. + + This function does no other filtering (for that, use narrow_dict() ) + + Args: + ensemble : the base of the ensemble data + xkeys : scalar or list of x keys + ykeys : scalar or list of y keys + fill : if not None, fill the data with whatever fill is set to + pad : if True, calls pad_narrow_dict_grid() to pad the data set + dx : x bin width, sent to pad_narrow_dict_grid() + dy : y bin width, sent to pad_narrow_dict_grid() + format : key formatter (also sent to pad_narrow_dict_grid()) + """ + newdata = AutoVivificationDict() + if not isinstance(xkeys,list): + xkeys = [xkeys] + if not isinstance(ykeys,list): + ykeys = [ykeys] + h = data + + for xk in xkeys: + h = h[xk] + for x in h.keys(): + h2 = h[x] + for yk in ykeys: + h2 = h2[yk] + for y in h2.keys(): + if format: + newdata[float(format.format(x))][float(format.format(y))] = h2[y] + else: + newdata[x][y] += h2[y] + + + if fill is not None: + newdata = fill_narrow_dict(newdata,fill) + if pad: + newdata = pad_narrow_dict_grid(newdata, + dx=dx, + dy=dy, + format=format) + return newdata + + + +def wide_dict(data, + xkey=None, + ykey=None, + selectx=None, + excludex=None, + selecty=None, + excludey=None): + """ + Function to convert data from wide or narrow form into wide form dict. + + If data is narrow, data[x][y] we use xkey and ykey to convert to wide form, + data[xkey][x][ykey][y]. + + If data is wide, we don't need to convert, but we do use the optional + selectx, selecty arrays of keys to select only from these, and use + excludex, excludey to exclude from these. + """ + n = dictdepth(data) + if n == 4: + # we already have a wide dict: select/exclude from it if requried + if not excludex and not excludey and not selectx and not selecty: + return data + else: + newdata = {} + for xkey in data: + if (not excludex or not xkey in excludex) and \ + (not selectx or xkey in selectx): + newdata[x] = {} + for ykey in data[xkey][x]: + if (not excludey or not ykey in excludey) and \ + (not selecty or ykey in selecty): + newdata[xkey][x][ykey][y] = data[xkey][x][ykey][y] + return newdata + elif n == 2: + # we have a narrow dict, and want to convert to a wide dict + # using xkey and ykey to nest the dict data + newdata = AutoVivificationDict() + for x in data: + for y in data[x]: + newdata[xkey][x][ykey][y] = data[x][y] + return newdata + else: + # not sure what to do + return None + +def dict_to_dataframe(data, + xykeys = (None,None), + index_name = None, + row_name = None, + dtype = float, + sortx = True, + sorty = True, + invertx = False, + inverty = False, + logspace = False, + weight_label = 'weight', + numpy=False, + zmin=None, + zmax=None, + seaborn_target = None): + """ + Generic function to convert a dict to a pandas dataframe. + + Args: + data : input dictionary in narrow form data[x][y] or wide form data[xtype][x][ytype][y] + (required) + if you want to use data in wide form, you must specify xykeys (see below) + xykeys : tuple (xkey,ykey) of keys to select when converting from a wide dataset to narrow (default (None,None) ) + ytype : y key (default None) + dtype : data type passed to pandas (default float) + seaborn_target = string containing the type of Seaborn plot for which we wish to convert + one of: "heatmap", "histplot" (default None) + """ + + # convert data to data[x][y] format if possible + + pdata = None # returned + + if xykeys[0]: + selectx = [xykeys[0]] + else: + selectx = None + + if xykeys[1]: + selecty = [xykeys[1]] + else: + selecty = None + + ############################################################ + # data should be in narrow form + data = narrow_dict(data, + selectx=selectx, + selecty=selecty) + + ds,xs,ys,zmin,zmax = list_keys_of_narrow_dict(data, + numpy=numpy, + zmin=zmin, + zmax=zmax, + logspace=logspace) + + if seaborn_target == "histplot": + ############################################################ + # Make data for a histplot + ############################################################ + + columns = [index_name,row_name,weight_label] + if columns[0] is None: + columns[0] = xykeys[0] + if columns[1] is None: + columns[1] = xykeys[1] + + pdata = pd.DataFrame(ds, + columns=columns) + extra = ds,xs,ys,zmin,zmax + + elif seaborn_target == "heatmap": + ############################################################ + # Make data for a heatmap + ############################################################ + + pdata = pd.DataFrame.from_dict( + data=data, + dtype=dtype + ) + + # rename index + if index_name: + pdata.index.rename(index_name,inplace=True) + + extra = ds,xs,ys,zmin,zmax + else: + pdata = None + extra = [] + + # sort x axis? + if sortx: + pdata = pdata.reindex(natsorted(pdata.columns),axis=1) + + # sort y axis? + if sorty: + pdata = pdata.reindex(index=natsorted(pdata.index)) + + # invert y axis? + if inverty: + pdata = pdata.reindex(index=pdata.index[::-1]) + + # invert x axis? + if invertx: + pdata = pdata[pdata.columns[::-1]] + + return pdata,extra diff --git a/binarycpython/utils/ensemble.py b/binarycpython/utils/ensemble.py index d7f4fbcadb1f412c0f1a3bb8373865e69a8b83a7..0d018673a2bd24691a36fb04d796ff473298210f 100644 --- a/binarycpython/utils/ensemble.py +++ b/binarycpython/utils/ensemble.py @@ -6,7 +6,8 @@ population ensemble using the binarycpython package from binarycpython.utils.dicts import ( AutoVivificationDict, custom_sort_dict, - fill_2d_dict, + dict_to_dataframe, + fill_narrow_dict, keys_to_floats, merge_dicts, recursive_change_key_to_float, @@ -20,9 +21,11 @@ import gzip from halo import Halo import inspect import json +from matplotlib.colors import LogNorm,Normalize import msgpack # import orjson # not required any more? import py_rinterpolate +import seaborn as sns import simplejson import sys import time @@ -370,46 +373,10 @@ def format_ensemble_results(ensemble_dictionary): # Put back in the dictionary return reformatted_ensemble_results -def ensemble_dist_to_2D_dict(ensemble,xkeys,ykeys,fill=None,format=None): - """ - Function to convert an ensemble dictionary like - - ...->{xkeys}->{xvalue}->{ykeys}->{yvalue} = z - - to a dictionary like - - {xvalue}->{yvalue} = z - - Note that xkeys and ykeys can be lists, in which case - we loop over several keys - - Args: - ensemble : the base of the ensemble data - xkeys : scalar or list of x keys - ykeys : scalar or list of y keys - fill : if not None, fill out the data with whatever fill is set to - """ - data = AutoVivificationDict() - if not isinstance(xkeys,list): - xkeys = [xkeys] - if not isinstance(ykeys,list): - ykeys = [ykeys] - h = ensemble - for xk in xkeys: - h = h[xk] - for x in h.keys(): - h2 = h[x] - for yk in ykeys: - h2 = h2[yk] - for y in h2.keys(): - data[x][y] += h2[y] - if fill is not None: - fill_2d_dict(data,fill) - - return data - -def ensemble_flatten(ensemble): +def ensemble_flatten(ensemble, + include_list=[], + exclude_list=[]): """ Given a piece of an ensemble which is like this @@ -419,11 +386,139 @@ def ensemble_flatten(ensemble): flatten the key list [key1, key2, ...] by merging all data that lies deeper. + Args: + ensemble : the ensemble data to flatten + include_list : the list of keys to include, ignored if empty + exclude_list : the list of keys to exclude, ignored if empty + Returns: {more data1 + more data2 + ...} """ data = {} for k in ensemble: - data = merge_dicts(data,ensemble[k]) + if k not in exclude_list and \ + (len(include_list)==0 or k in include_list): + data = merge_dicts(data,ensemble[k]) return data + + +def seaborn_histplot(data, + x_type=None, + y_type=None, + invertx=False, + inverty=False, + binwidth=None, + cbar = True, + logspace=False, + logscale=(False,False), + cmap = 'cubehelix_r', + cbar_label = None, + shading=None, + xlabel=None, + ylabel=None, + ): + """ + Function to return a Seaborn histplot of the given data + + Data should be in narrow form dict, i.e. data[x][y]. + + Returns the axis object, ax, that is returned by Seaborn's histplot() + + """ + p,[ds,xs,ys,zmin,zmax] = dict_to_dataframe(data, + xykeys = (x_type,y_type), + numpy = True, + logspace = True, + weight_label = 'weight', + seaborn_target='histplot') + + if logspace: + norm = LogNorm(vmin=zmin,vmax=zmax) + else: + norm = None + + ax = sns.histplot( + data=p, + x = x_type, + y = y_type, + weights='weight', + bins=[len(xs),len(ys)], + binwidth = binwidth, + cbar = cbar, + log_scale=logscale, + # these are passed to matplotlib.Axes.pcolormesh + cmap = cmap, + norm = norm, + vmin=None, + vmax=None, + shading=shading, + cbar_kws={'label':cbar_label}, + ) + + if inverty: + ax.invert_yaxis() + if invertx: + ax.invert_xaxis() + if xlabel: + ax.set_xlabel(xlabel) + if ylabel: + ax.set_ylabel(ylabel) + + return ax + +def seaborn_heatmap(data, + x_type=None, + y_type=None, + invertx=False, + inverty=False, + binwidth=None, + cbar = True, + logspace=False, + logscale=(False,False), + cmap = 'cubehelix_r', + cbar_label = None, + xlabel=None, + ylabel=None, + ): + """ + Function to return a Seaborn heatmap of the given data + + Data should be in narrow form dict, i.e. data[x][y]. + + Returns the axis object, ax, that is returned by Seaborn's heatmap() + + """ + p,[ds,xs,ys,zmin,zmax] = dict_to_dataframe(data, + xykeys = (x_type,y_type), + numpy = True, + logspace = True, + row_name = x_type, + index_name = y_type, + sorty = True, + inverty = False, + sortx = True, + invertx = False, + seaborn_target='heatmap') + + if logspace: + norm = LogNorm(vmin=zmin,vmax=zmax) + else: + norm = None + ax = sns.heatmap(data=p, + cmap = cmap, + linewidths=0, + norm=norm, + cbar_kws={'label':cbar_label}, + annot=False) + + if inverty: + ax.invert_yaxis() + if invertx: + ax.invert_xaxis() + if xlabel: + ax.set_xlabel(xlabel) + if ylabel: + ax.set_ylabel(ylabel) + + return ax