diff --git a/binarycpython/utils/dicts.py b/binarycpython/utils/dicts.py
index c474b3dce0e8d94343bbd4a3841460b4724503d8..22b59c88537a8922d73cb949d7ccdcab93eda23a 100644
--- a/binarycpython/utils/dicts.py
+++ b/binarycpython/utils/dicts.py
@@ -3,13 +3,15 @@ Module containing functions that binary_c-python uses to modify
 dictionaries.
 """
 import astropy.units as u
-import copy
-import collections
-import math
-import numpy as np
 from collections import (
     OrderedDict,
 )
+import collections
+import copy
+import math
+from natsort import index_natsorted,natsorted
+import numpy as np
+import pandas as pd
 import sys
 
 # we need to convert keys to floats:
@@ -829,7 +831,7 @@ def normalize_dict(result_dict, verbosity=0):
         result_dict[key] = result_dict[key] / sum_result
     return result_dict
 
-def fill_2d_dict(data,fill):
+def fill_narrow_dict(data,fill):
     """
     Fill a 2D dict with whatever fill is set to when keys are missing.
     """
@@ -844,8 +846,9 @@ def fill_2d_dict(data,fill):
             for y in ys:
                 if not y in data[x]:
                     data[x][y] = fill
+    return data
 
-def format_2d_dict_keys(data,format):
+def format_narrow_dict_keys(data,format):
     """
     Reformat float keys in a 2d dict according to the given format
     """
@@ -866,10 +869,12 @@ def format_2d_dict_keys(data,format):
                 data[x][y] = z
     return data
 
-def keys_of_2d_dict(data,numpy=False):
+def tlist_keys_of_narrow_dict(data,
+                              numpy=False):
     """
     Given a 2D dict data[x][y] = z, return the list of keys as tuples, (x,y),
-    and equivalent z list.
+    and equivalent z list. This is useful for providing to plotting and numpy
+    routines.
 
     Args:
         data: the input 2D dict
@@ -886,6 +891,73 @@ def keys_of_2d_dict(data,numpy=False):
         zs = np.array(zs)
     return xys,zs
 
+def list_keys_of_narrow_dict(data,
+                             logspace=False,
+                             numpy=False,
+                             zmin=None,
+                             zmax=None,
+                             sortx=False,
+                             sorty=False):
+    """
+    Given a 2D dict data[x][y] = z, return a list containing vectors
+    [x,y,z] for all the points in the dict, lists of the x and y values,
+    and the min and max of the z values.
+
+    These data are very useful for plotting and numpy routines.
+
+    Args:
+        data: the input 2D dict (required)
+        numpy: if True convert returned data to numpy arrays (default False)
+        logspace : if True only use z>0 to calculate zmin and zmax (default False)
+        zmin : start minimum z checks at this value (default None)
+        zmax : start maximum z checks at this value (default None)
+        sortx : sorts x key list
+        sorty : sorts y key list
+    Returns:
+
+    ds,xs,ys,zmin,zmax
+
+    ds : list of key vectors (x,y)
+    xs : list of x values
+    ys : list of y values
+    zmin : minimum z
+    zmax : maximum z
+
+    """
+    model_space,model_weights = tlist_keys_of_narrow_dict(data,numpy)
+
+    ds = []
+    xs = {}
+    ys = {}
+    for x,z in zip(model_space,model_weights):
+        ds.append([x[0], x[1], z])
+        xs[x[0]] = 1
+        ys[x[1]] = 1
+        if not logspace or z > 0.0:
+            if zmin:
+                zmin = min(z,zmin)
+            else:
+                zmin = z
+            if zmax:
+                zmax = max(z,zmax)
+            else:
+                zmax = z
+
+    xs = list(xs.keys())
+    ys = list(ys.keys())
+
+    if sortx:
+        xs = sorted(xs)
+    if sorty:
+        ys = sorted(ys)
+
+    if numpy is True:
+        ds = np.array(ds)
+        xs = np.array(xs)
+        ys = np.array(ys)
+
+    return ds,xs,ys,zmin,zmax
+
 def _mindiff(list,tol=1e-10):
     """
     Given a list of unique values, within tol(=1e-7), find the
@@ -919,34 +991,72 @@ def _find_nearest_index(array,value):
     return idx
 
 
-def fillgrid(data,tol=1e-10,format=None):
+def pad_narrow_dict_grid(data,
+                         tol=1e-10,
+                         format=None,
+                         value=0.0,
+                         dx=None,
+                         dy=None):
     """
-    Given a data[x][y] dict, fill it to the appropriate bin widths.
+    Given data on a grid in a narrow dict, data[x][y], fill it with value where
+    data is missing. We fill data according to the binwidths dx,dy which are
+    automatically calculated if set to None.
+
+    Args:
+
+    data : the input dictionary in the form data[x][y]
+    tol : floating point match tolerance when reconstructing the x locations (default 1e-10)
+    format : if not None reformat x keys with this format statement (default None)
+    dx : x bin width, usually found automatically (when dx=None) but you may need to set it manually if there is not much data. (default None)
+    value : the value to fill in gaps in the data (default 0.0)
+
     """
 
     data = ordered(data)
-    xdata = list(data.keys())
-    ydata = list(data.values())
-    dx = _mindiff(xdata)
 
-    if not dx:
+    ds,xs,ys,zmin,zmax = list_keys_of_narrow_dict(data,
+                                                  sortx=True,
+                                                  sorty=True)
+
+    # automatically get binwidths
+    if dx == None:
+        dx = _mindiff(xs)
+    if dy == None:
+        dy = _mindiff(ys)
+
+    if format:
+        dx = float(format.format(dx))
+        dy = float(format.format(dy))
+
+    if not dx or not dy:
         print("Warning: mindiff failed in fill() on {} data items. Returning original data.".format(len(xdata)))
         return data
 
-    maxx = xdata[-1]
-    x = xdata[0]
-    newdata = {}
-    index = 0
+    maxx = xs[-1]
+    maxy = ys[-1]
+    x = xs[0]
+    newdata = AutoVivificationDict()
     while x<=maxx:
-        matchi = _find_nearest_index(xdata,x)
-        matchx = xdata[matchi]
+        # get the index of the nearest x value
+        matchx = xs[_find_nearest_index(xs,x)]
+
+        y = ys[0]
+        while y<=maxy:
+            # get the index of the nearest y value
+            matchy = ys[_find_nearest_index(ys,y)]
+
+            if (x==0.0 and matchx==0.0) or abs(1.0-matchx/x)<=tol and \
+               (y==0.0 and matchy==0.0) or abs(1.0-matchy/y)<=tol:
+                newdata[x][y] = data[matchx][matchy]
+
+            else:
+                # missing data : fill with value
+                newdata[x][y] = value
+
+            y += dy
+            if format:
+                y = float(format.format(y))
 
-        if (x==0.0 and matchx==0.0) or abs(1.0-matchx/x)<=tol:
-            newdata[x] = ydata[index]
-            index += 1
-        else:
-            # missing data
-            newdata[x] = 0.0
         x += dx
         if format:
             x = float(format.format(x))
@@ -960,3 +1070,274 @@ def ordered(dict):
     return collections.OrderedDict(
         sorted(dict.items(),
                key = lambda x:(float(x[0]))))
+
+
+def dictdepth(d, n=0):
+    """
+    Return the depth of a nested dictionary
+    """
+    if not d or not isinstance(d, dict):
+        return n
+    return max(dictdepth(value, n+1) for value in d.values())
+
+def narrow_dict(data,
+                selectx=None,
+                excludex=None,
+                selecty=None,
+                excludey=None,
+                fill=None):
+    """
+    Given a dict in data, which is either "wide" (with depth 4)
+    or narrow (with depth 2), convert to narrow.
+
+    "wide" dict is {xkey}->{x}->{ykey}->{y}
+    "narrow" dict is {x}->{y}
+    """
+    n = dictdepth(data)
+    if n == 2:
+        # already have a narrow dict
+        return data
+    elif n == 4:
+        # we have a wide dict : convert
+        if selectx and not isinstance(selectx,list):
+            selectx = [selectx]
+        if selecty and not isinstance(selecty,list):
+            selecty = [selecty]
+        if excludex and not isinstance(excludex,list):
+            excludex = [excludex]
+        if excludey and not isinstance(excludey,list):
+            excludey = [excludey]
+
+        newdata = {}
+        for xkey in data:
+            if (not excludex or not xkey in excludex) and \
+               (not selectx or xkey in selectx):
+                newdata[x] = {}
+                for ykey in data[xkey][x]:
+
+                    if (not excludey or not ykey in excludey) and \
+                       (not selecty or ykey in selecty):
+                        newdata[x][y] = data[xkey][x][ykey][y]
+
+        if fill is not None:
+            newdata = fill_narrow_dict(newdata,fill)
+        return newdata
+    else:
+        # what to do? return None on error
+        return None
+
+def narrow_dict_from_keys(data,
+                          xkeys,
+                          ykeys,
+                          fill=None,
+                          pad = False,
+                          dx = None,
+                          dy = None,
+                          format = None,
+                          ):
+    """
+    Function to convert an ensemble dictionary like
+
+    ...->{xkeys}->{xvalue}->{ykeys}->{yvalue} = z
+
+    to a dictionary like
+
+    {xvalue}->{yvalue} = z
+
+    Note that xkeys and ykeys can be lists, in which case
+    we loop over several keys.
+
+    This function does no other filtering (for that, use narrow_dict() )
+
+    Args:
+        ensemble : the base of the ensemble data
+        xkeys : scalar or list of x keys
+        ykeys : scalar or list of y keys
+        fill : if not None, fill the data with whatever fill is set to
+        pad : if True, calls pad_narrow_dict_grid() to pad the data set
+        dx : x bin width, sent to pad_narrow_dict_grid()
+        dy : y bin width, sent to pad_narrow_dict_grid()
+        format : key formatter (also sent to pad_narrow_dict_grid())
+    """
+    newdata = AutoVivificationDict()
+    if not isinstance(xkeys,list):
+        xkeys = [xkeys]
+    if not isinstance(ykeys,list):
+        ykeys = [ykeys]
+    h = data
+
+    for xk in xkeys:
+        h = h[xk]
+    for x in h.keys():
+        h2 = h[x]
+        for yk in ykeys:
+            h2 = h2[yk]
+        for y in h2.keys():
+            if format:
+                newdata[float(format.format(x))][float(format.format(y))] = h2[y]
+            else:
+                newdata[x][y] += h2[y]
+
+
+    if fill is not None:
+        newdata = fill_narrow_dict(newdata,fill)
+    if pad:
+        newdata = pad_narrow_dict_grid(newdata,
+                                       dx=dx,
+                                       dy=dy,
+                                       format=format)
+    return newdata
+
+
+
+def wide_dict(data,
+              xkey=None,
+              ykey=None,
+              selectx=None,
+              excludex=None,
+              selecty=None,
+              excludey=None):
+    """
+    Function to convert data from wide or narrow form into wide form dict.
+
+    If data is narrow, data[x][y] we use xkey and ykey to convert to wide form,
+    data[xkey][x][ykey][y].
+
+    If data is wide, we don't need to convert, but we do use the optional
+    selectx, selecty arrays of keys to select only from these, and use
+    excludex, excludey to exclude from these.
+    """
+    n = dictdepth(data)
+    if n == 4:
+        # we already have a wide dict: select/exclude from it if requried
+        if not excludex and not excludey and not selectx and not selecty:
+            return data
+        else:
+            newdata = {}
+            for xkey in data:
+                if (not excludex or not xkey in excludex) and \
+                   (not selectx or xkey in selectx):
+                    newdata[x] = {}
+                    for ykey in data[xkey][x]:
+                        if (not excludey or not ykey in excludey) and \
+                           (not selecty or ykey in selecty):
+                            newdata[xkey][x][ykey][y] = data[xkey][x][ykey][y]
+            return newdata
+    elif n == 2:
+        # we have a narrow dict, and want to convert to a wide dict
+        # using xkey and ykey to nest the dict data
+        newdata = AutoVivificationDict()
+        for x in data:
+            for y in data[x]:
+                newdata[xkey][x][ykey][y] = data[x][y]
+        return newdata
+    else:
+        # not sure what to do
+        return None
+
+def dict_to_dataframe(data,
+                      xykeys = (None,None),
+                      index_name = None,
+                      row_name = None,
+                      dtype = float,
+                      sortx = True,
+                      sorty = True,
+                      invertx = False,
+                      inverty = False,
+                      logspace = False,
+                      weight_label = 'weight',
+                      numpy=False,
+                      zmin=None,
+                      zmax=None,
+                      seaborn_target = None):
+    """
+    Generic function to convert a dict to a pandas dataframe.
+
+    Args:
+        data : input dictionary in narrow form data[x][y] or wide form data[xtype][x][ytype][y]
+               (required)
+               if you want to use data in wide form, you must specify xykeys (see below)
+        xykeys : tuple (xkey,ykey) of keys to select when converting from a wide dataset to narrow  (default (None,None) )
+        ytype : y key (default None)
+        dtype : data type passed to pandas (default float)
+        seaborn_target = string containing the type of Seaborn plot for which we wish to convert
+                         one of: "heatmap", "histplot" (default None)
+    """
+
+    # convert data to data[x][y] format if possible
+
+    pdata = None # returned
+
+    if xykeys[0]:
+        selectx = [xykeys[0]]
+    else:
+        selectx = None
+
+    if xykeys[1]:
+        selecty = [xykeys[1]]
+    else:
+        selecty = None
+
+    ############################################################
+    # data should be in narrow form
+    data = narrow_dict(data,
+                       selectx=selectx,
+                       selecty=selecty)
+
+    ds,xs,ys,zmin,zmax = list_keys_of_narrow_dict(data,
+                                                  numpy=numpy,
+                                                  zmin=zmin,
+                                                  zmax=zmax,
+                                                  logspace=logspace)
+
+    if seaborn_target == "histplot":
+        ############################################################
+        # Make data for a histplot
+        ############################################################
+
+        columns = [index_name,row_name,weight_label]
+        if columns[0] is None:
+            columns[0] = xykeys[0]
+        if columns[1] is None:
+            columns[1] = xykeys[1]
+
+        pdata = pd.DataFrame(ds,
+                             columns=columns)
+        extra = ds,xs,ys,zmin,zmax
+
+    elif seaborn_target == "heatmap":
+        ############################################################
+        # Make data for a heatmap
+        ############################################################
+
+        pdata = pd.DataFrame.from_dict(
+            data=data,
+            dtype=dtype
+        )
+
+        # rename index
+        if index_name:
+            pdata.index.rename(index_name,inplace=True)
+
+        extra = ds,xs,ys,zmin,zmax
+    else:
+        pdata = None
+        extra = []
+
+    # sort x axis?
+    if sortx:
+        pdata = pdata.reindex(natsorted(pdata.columns),axis=1)
+
+    # sort y axis?
+    if sorty:
+        pdata = pdata.reindex(index=natsorted(pdata.index))
+
+    # invert y axis?
+    if inverty:
+        pdata = pdata.reindex(index=pdata.index[::-1])
+
+    # invert x axis?
+    if invertx:
+        pdata = pdata[pdata.columns[::-1]]
+
+    return pdata,extra
diff --git a/binarycpython/utils/ensemble.py b/binarycpython/utils/ensemble.py
index d7f4fbcadb1f412c0f1a3bb8373865e69a8b83a7..0d018673a2bd24691a36fb04d796ff473298210f 100644
--- a/binarycpython/utils/ensemble.py
+++ b/binarycpython/utils/ensemble.py
@@ -6,7 +6,8 @@ population ensemble using the binarycpython package
 from binarycpython.utils.dicts import (
     AutoVivificationDict,
     custom_sort_dict,
-    fill_2d_dict,
+    dict_to_dataframe,
+    fill_narrow_dict,
     keys_to_floats,
     merge_dicts,
     recursive_change_key_to_float,
@@ -20,9 +21,11 @@ import gzip
 from halo import Halo
 import inspect
 import json
+from matplotlib.colors import LogNorm,Normalize
 import msgpack
 # import orjson # not required any more?
 import py_rinterpolate
+import seaborn as sns
 import simplejson
 import sys
 import time
@@ -370,46 +373,10 @@ def format_ensemble_results(ensemble_dictionary):
     # Put back in the dictionary
     return reformatted_ensemble_results
 
-def ensemble_dist_to_2D_dict(ensemble,xkeys,ykeys,fill=None,format=None):
-    """
-    Function to convert an ensemble dictionary like
-
-    ...->{xkeys}->{xvalue}->{ykeys}->{yvalue} = z
-
-    to a dictionary like
-
-    {xvalue}->{yvalue} = z
-
-    Note that xkeys and ykeys can be lists, in which case
-    we loop over several keys
-
-    Args:
-        ensemble : the base of the ensemble data
-        xkeys : scalar or list of x keys
-        ykeys : scalar or list of y keys
-        fill : if not None, fill out the data with whatever fill is set to
 
-    """
-    data = AutoVivificationDict()
-    if not isinstance(xkeys,list):
-        xkeys = [xkeys]
-    if not isinstance(ykeys,list):
-        ykeys = [ykeys]
-    h = ensemble
-    for xk in xkeys:
-        h = h[xk]
-    for x in h.keys():
-        h2 = h[x]
-        for yk in ykeys:
-            h2 = h2[yk]
-        for y in h2.keys():
-            data[x][y] += h2[y]
-    if fill is not None:
-        fill_2d_dict(data,fill)
-
-    return data
-
-def ensemble_flatten(ensemble):
+def ensemble_flatten(ensemble,
+                     include_list=[],
+                     exclude_list=[]):
     """
     Given a piece of an ensemble which is like this
 
@@ -419,11 +386,139 @@ def ensemble_flatten(ensemble):
 
     flatten the key list [key1, key2, ...] by merging all data that lies deeper.
 
+    Args:
+        ensemble : the ensemble data to flatten
+        include_list : the list of keys to include, ignored if empty
+        exclude_list : the list of keys to exclude, ignored if empty
+
     Returns:
 
     {more data1 + more data2 + ...}
     """
     data = {}
     for k in ensemble:
-        data = merge_dicts(data,ensemble[k])
+        if k not in exclude_list and \
+           (len(include_list)==0 or k in include_list):
+            data = merge_dicts(data,ensemble[k])
     return data
+
+
+def seaborn_histplot(data,
+                     x_type=None,
+                     y_type=None,
+                     invertx=False,
+                     inverty=False,
+                     binwidth=None,
+                     cbar = True,
+                     logspace=False,
+                     logscale=(False,False),
+                     cmap = 'cubehelix_r',
+                     cbar_label = None,
+                     shading=None,
+                     xlabel=None,
+                     ylabel=None,
+                     ):
+    """
+    Function to return a Seaborn histplot of the given data
+
+    Data should be in narrow form dict, i.e. data[x][y].
+
+    Returns the axis object, ax, that is returned by Seaborn's histplot()
+
+    """
+    p,[ds,xs,ys,zmin,zmax] = dict_to_dataframe(data,
+                                               xykeys = (x_type,y_type),
+                                               numpy = True,
+                                               logspace = True,
+                                               weight_label = 'weight',
+                                               seaborn_target='histplot')
+
+    if logspace:
+        norm = LogNorm(vmin=zmin,vmax=zmax)
+    else:
+        norm = None
+
+    ax = sns.histplot(
+        data=p,
+        x = x_type,
+        y = y_type,
+        weights='weight',
+        bins=[len(xs),len(ys)],
+        binwidth = binwidth,
+        cbar = cbar,
+        log_scale=logscale,
+        # these are passed to matplotlib.Axes.pcolormesh
+        cmap = cmap,
+        norm = norm,
+        vmin=None,
+        vmax=None,
+        shading=shading,
+        cbar_kws={'label':cbar_label},
+    )
+
+    if inverty:
+        ax.invert_yaxis()
+    if invertx:
+        ax.invert_xaxis()
+    if xlabel:
+        ax.set_xlabel(xlabel)
+    if ylabel:
+        ax.set_ylabel(ylabel)
+
+    return ax
+
+def seaborn_heatmap(data,
+                    x_type=None,
+                    y_type=None,
+                    invertx=False,
+                    inverty=False,
+                    binwidth=None,
+                    cbar = True,
+                    logspace=False,
+                    logscale=(False,False),
+                    cmap = 'cubehelix_r',
+                    cbar_label = None,
+                    xlabel=None,
+                    ylabel=None,
+                    ):
+    """
+    Function to return a Seaborn heatmap of the given data
+
+    Data should be in narrow form dict, i.e. data[x][y].
+
+    Returns the axis object, ax, that is returned by Seaborn's heatmap()
+
+    """
+    p,[ds,xs,ys,zmin,zmax] = dict_to_dataframe(data,
+                                               xykeys = (x_type,y_type),
+                                               numpy = True,
+                                               logspace = True,
+                                               row_name = x_type,
+                                               index_name = y_type,
+                                               sorty = True,
+                                               inverty = False,
+                                               sortx = True,
+                                               invertx = False,
+                                               seaborn_target='heatmap')
+
+    if logspace:
+        norm = LogNorm(vmin=zmin,vmax=zmax)
+    else:
+        norm = None
+    ax = sns.heatmap(data=p,
+                     cmap = cmap,
+                     linewidths=0,
+                     norm=norm,
+                     cbar_kws={'label':cbar_label},
+                     annot=False)
+
+    if inverty:
+        ax.invert_yaxis()
+    if invertx:
+        ax.invert_xaxis()
+    if xlabel:
+        ax.set_xlabel(xlabel)
+    if ylabel:
+        ax.set_ylabel(ylabel)
+
+    return ax