23 """A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an 
   26 The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds 
   27 a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a 
   28 data object that abstracts away how the N-d density data is actually represented. 
   30 For simple cases, users can just create a custom data class with an interface like that of 
   31 the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer 
   32 classes directly.  In more complicated cases, users may want to create their own Layer classes, 
   33 which may define their own relationship with the data object. 
   36 import collections.abc
 
   39 import matplotlib.pyplot
 
   40 import matplotlib.ticker
 
   42 __all__ = (
"HistogramLayer", 
"SurfaceLayer", 
"ScatterLayer", 
"CrossPointsLayer",
 
   43            "DensityPlot", 
"ExampleData", 
"demo")
 
   47     for label 
in axes.get_xticklabels():
 
   48         label.set_visible(
False)
 
   52     for label 
in axes.get_yticklabels():
 
   53         label.set_visible(
False)
 
   57     copy = defaults.copy()
 
   64     """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and 
   65     colormapped large-pixel images in 2-d. 
   67     Relies on two data object attributes: 
   69        values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the 
   72        weights ---- (optional) an array of weights with shape (M,); if not present, all weights will 
   75     The need for these data object attributes can be removed by subclassing HistogramLayer and overriding 
   76     the hist1d and hist2d methods. 
   79     defaults1d = dict(facecolor=
'b', alpha=0.5)
 
   80     defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation=
'nearest')
 
   82     def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=
None, kwds2d=
None):
 
   90         """Extract points from the data object and compute a 1-d histogram. 
   92         Return value should match that of numpy.histogram: a tuple of (hist, edges), 
   93         where hist is a 1-d array with size=bins1d, and edges is a 1-d array with 
   94         size=self.bins1d+1 giving the upper and lower edges of the bins. 
   96         i = data.dimensions.index(dim)
 
   97         if hasattr(data, 
"weights") 
and data.weights 
is not None:
 
   98             weights = data.weights
 
  101         return numpy.histogram(data.values[:, i], bins=self.
bins1d, weights=weights,
 
  102                                range=limits, normed=
True)
 
  104     def hist2d(self, data, xDim, yDim, xLimits, yLimits):
 
  105         """Extract points from the data object and compute a 1-d histogram. 
  107         Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges), 
  108         where hist is a 2-d array with shape=bins2d, xEdges is a 1-d array with size=bins2d[0]+1, 
  109         and yEdges is a 1-d array with size=bins2d[1]+1. 
  111         i = data.dimensions.index(yDim)
 
  112         j = data.dimensions.index(xDim)
 
  113         if hasattr(data, 
"weights") 
and data.weights 
is not None:
 
  114             weights = data.weights
 
  117         return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.
bins2d, weights=weights,
 
  118                                  range=(xLimits, yLimits), normed=
True)
 
  121         y, xEdge = self.
hist1d(data, dim, axes.get_xlim())
 
  122         xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
 
  123         width = xEdge[1:] - xEdge[:-1]
 
  124         return axes.bar(xCenter, y, width=width, align=
'center', **self.
kwds1d)
 
  127         x, yEdge = self.
hist1d(data, dim, axes.get_ylim())
 
  128         yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
 
  129         height = yEdge[1:] - yEdge[:-1]
 
  130         return axes.barh(yCenter, x, height=height, align=
'center', **self.
kwds1d)
 
  132     def plotXY(self, axes, data, xDim, yDim):
 
  133         z, xEdge, yEdge = self.
hist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
 
  134         return axes.imshow(z.transpose(), aspect=
'auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
 
  135                            origin=
'lower', **self.
kwds2d)
 
  139     """A Layer class that plots individual points in 2-d, and does nothing in 1-d. 
  141     Relies on two data object attributes: 
  143        values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the 
  144                     number of data points 
  146        weights ---- (optional) an array of weights with shape (M,); will be used to set the color of points 
  150     defaults = dict(linewidth=0, alpha=0.2)
 
  162     def plotXY(self, axes, data, xDim, yDim):
 
  163         i = data.dimensions.index(yDim)
 
  164         j = data.dimensions.index(xDim)
 
  165         if hasattr(data, 
"weights") 
and data.weights 
is not None:
 
  166             args = data.values[:, j], data.values[:, i], data.weights
 
  168             args = data.values[:, j], data.values[:, i]
 
  169         return axes.scatter(*args, **self.kwds)
 
  173     """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices. 
  175     The 2-d slices are drawn as contours, and the 1-d slices are drawn as simple curves. 
  177     Relies on eval1d and eval2d methods in the data object; this can be avoided by subclassing 
  178     SurfaceLayer and reimplementing its own eval1d and eval2d methods. 
  181     defaults1d = dict(linewidth=2, color=
'r')
 
  182     defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds)
 
  184     def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
 
  193         """Return analytic function values for the given values.""" 
  194         return data.eval1d(dim, x)
 
  196     def eval2d(self, data, xDim, yDim, x, y):
 
  197         """Return analytic function values for the given values.""" 
  198         return data.eval2d(xDim, yDim, x, y)
 
  201         xMin, xMax = axes.get_xlim()
 
  202         x = numpy.linspace(xMin, xMax, self.
steps1d)
 
  203         z = self.
eval1d(data, dim, x)
 
  206         return axes.plot(x, z, **self.
kwds1d)
 
  209         yMin, yMax = axes.get_ylim()
 
  210         y = numpy.linspace(yMin, yMax, self.
steps1d)
 
  211         z = self.
eval1d(data, dim, y)
 
  214         return axes.plot(z, y, **self.
kwds1d)
 
  216     def plotXY(self, axes, data, xDim, yDim):
 
  217         xMin, xMax = axes.get_xlim()
 
  218         yMin, yMax = axes.get_ylim()
 
  219         xc = numpy.linspace(xMin, xMax, self.
steps2d)
 
  220         yc = numpy.linspace(yMin, yMax, self.
steps2d)
 
  221         xg, yg = numpy.meshgrid(xc, yc)
 
  222         z = self.
eval2d(data, xDim, yDim, xg, yg)
 
  226             return axes.contourf(xg, yg, z, 6, **self.
kwds2d)
 
  228             return axes.contour(xg, yg, z, 6, **self.
kwds2d)
 
  232     """A layer that marks a few points with axis-length vertical and horizontal lines. 
  234     This relies on a "points" data object attribute. 
  237     defaults = dict(alpha=0.8)
 
  239     def __init__(self, tag, colors=(
"y", 
"m", 
"c", 
"r", 
"g", 
"b"), **kwds):
 
  245         i = data.dimensions.index(dim)
 
  247         for n, point 
in enumerate(data.points):
 
  248             artists.append(axes.axvline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
 
  252         i = data.dimensions.index(dim)
 
  254         for n, point 
in enumerate(data.points):
 
  255             artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
 
  258     def plotXY(self, axes, data, xDim, yDim):
 
  259         i = data.dimensions.index(yDim)
 
  260         j = data.dimensions.index(xDim)
 
  262         for n, point 
in enumerate(data.points):
 
  263             artists.append(axes.axvline(point[j], color=self.
colors[n % len(self.
colors)], **self.
kwds))
 
  264             artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
 
  269     """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d 
  270     slices through an N-d density. 
  280             layer = self.
_dict.pop(name)
 
  281             self.
_parent._dropLayer(name, layer)
 
  285             self.
_dict[name] = layer
 
  286             self.
_parent._plotLayer(name, layer)
 
  289             return self.
_dict[name]
 
  295             return len(self.
_dict)
 
  298             return str(self.
_dict)
 
  301             return repr(self.
_dict)
 
  304             layer = self.
_dict[name]
 
  305             self.
_parent._dropLayer(name, layer)
 
  306             self.
_parent._plotLayer(name, layer)
 
  316         for v 
in self.
data.values():
 
  317             for dim 
in v.dimensions:
 
  318                 if dim 
not in active:
 
  320                     self.
_lower[dim] = v.lower[dim]
 
  321                     self.
_upper[dim] = v.upper[dim]
 
  327         self.
figure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
 
  331     def _dropLayer(self, name, layer):
 
  332         def removeArtist(*key):
 
  334                 self.
_objs.pop(key).remove()
 
  335             except AttributeError:
 
  342         for i, yDim 
in enumerate(self.
_active):
 
  343             removeArtist(
None, i, name)
 
  344             removeArtist(i, 
None, name)
 
  345             for j, xDim 
in enumerate(self.
_active):
 
  348                 removeArtist(i, j, name)
 
  350     def _plotLayer(self, name, layer):
 
  351         for i, yDim 
in enumerate(self.
_active):
 
  352             if yDim 
not in self.
data[layer.tag].dimensions:
 
  354             self.
_objs[
None, i, name] = layer.plotX(self.
_axes[
None, i], self.
data[layer.tag], yDim)
 
  355             self.
_objs[i, 
None, name] = layer.plotY(self.
_axes[i, 
None], self.
data[layer.tag], yDim)
 
  356             for j, xDim 
in enumerate(self.
_active):
 
  357                 if xDim 
not in self.
data[layer.tag].dimensions:
 
  361                 self.
_objs[i, j, name] = layer.plotXY(self.
_axes[i, j], self.
data[layer.tag], xDim, yDim)
 
  362             self.
_axes[
None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
 
  363             self.
_axes[i, 
None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
 
  364             self.
_axes[
None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
 
  365             self.
_axes[i, 
None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
 
  367     def _get_active(self):
 
  370     def _set_active(self, active):
 
  372         if len(s) != len(active):
 
  373             raise ValueError(
"Active set contains duplicates")
 
  375             raise ValueError(
"Invalid values in active set")
 
  378     active = property(_get_active, _set_active, doc=
"sequence of active dimensions to plot (sequence of str)")
 
  387     def _build_axes(self):
 
  398             axesX = self.
_axes[
None, j] = self.
figure.add_subplot(n+1, n+1, jStart+j*jStride)
 
  399             axesX.autoscale(
False, axis=
'x')
 
  400             axesX.xaxis.tick_top()
 
  403             bbox = axesX.get_position()
 
  405             axesX.set_position(bbox)
 
  406             axesY = self.
_axes[i, 
None] = self.
figure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
 
  407             axesY.autoscale(
False, axis=
'y')
 
  408             axesY.yaxis.tick_right()
 
  411             bbox = axesY.get_position()
 
  413             axesY.set_position(bbox)
 
  416                 axesXY = self.
_axes[i, j] = self.
figure.add_subplot(
 
  417                     n+1, n+1, iStart+i*iStride + jStart+j*jStride,
 
  418                     sharex=self.
_axes[
None, j],
 
  419                     sharey=self.
_axes[i, 
None]
 
  421                 axesXY.autoscale(
False)
 
  428             xbox = self.
_axes[
None, j].get_position()
 
  429             ybox = self.
_axes[i, 
None].get_position()
 
  430             self.
figure.text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.
active[i],
 
  431                              ha=
'center', va=
'center', weight=
'bold')
 
  432             self.
_axes[i, j].get_frame().set_facecolor(
'none')
 
  439     """An example data object for DensityPlot, demonstrating the necessarity interface. 
  441     There are two levels of requirements for a data object.  First are the attributes 
  442     required by the DensityPlot object itself; these must be present on every data object: 
  444        dimensions ------ a sequence of strings that provide names for the dimensions 
  446        lower ----------- a dictionary of {dimension-name: lower-bound} 
  448        upper ----------- a dictionary of {dimension-name: upper-bound} 
  450     The second level of requirements are those of the Layer objects provided here.  These 
  451     may be absent if the associated Layer is not used or is subclassed to reimplement the 
  452     Layer method that calls the data object method.  Currently, these include: 
  454        eval1d, eval2d -- methods used by the SurfaceLayer class; see their docs for more info 
  456        values ---------- attribute used by the HistogramLayer and ScatterLayer classes, an array 
  457                          with shape (M,N), where N is the number of dimension and M is the number 
  460        weights --------- optional attribute used by the HistogramLayer and ScatterLayer classes, 
  461                          a 1-d array with size=M that provides weights for each data point 
  466         self.
mu = numpy.array([-10.0, 0.0, 10.0])
 
  467         self.
sigma = numpy.array([3.0, 2.0, 1.0])
 
  470         self.
values = numpy.random.randn(2000, 3) * self.
sigma[numpy.newaxis, :] + self.
mu[numpy.newaxis, :]
 
  473         """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array; 
  474         this method must be numpy-vectorized). 
  477         return numpy.exp(-0.5*((x-self.
mu[i])/self.
sigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.
sigma[i])
 
  480         """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y 
  481         (2-d numpy arrays with the same shape; this method must be numpy-vectorized). 
  485         return (numpy.exp(-0.5*(((x-self.
mu[j])/self.
sigma[j])**2 + ((y-self.
mu[i])/self.
sigma[i])**2))
 
  486                 / (2.0*numpy.pi * self.
sigma[j]*self.
sigma[i]))
 
  490     """Create and return a DensityPlot with example data.""" 
  491     fig = matplotlib.pyplot.figure()