23"""A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an
26The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds
27a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a
28data object that abstracts away how the N-d density data is actually represented.
30For simple cases, users can just create a custom data class with an interface like that of
31the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer
32classes directly. In more complicated cases, users may want to create their own Layer classes,
33which may define their own relationship with the data object.
39import matplotlib.pyplot
40import matplotlib.ticker
42__all__ = ("HistogramLayer", "SurfaceLayer", "ScatterLayer", "CrossPointsLayer",
43 "DensityPlot",
"ExampleData",
"demo")
47 for label
in axes.get_xticklabels():
48 label.set_visible(
False)
52 for label
in axes.get_yticklabels():
53 label.set_visible(
False)
57 copy = defaults.copy()
64 """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and
65 colormapped large-pixel images in 2-d.
67 Relies on two data object attributes:
69 values ----- a (M,N) array of data points, where N
is the dimension of the dataset
and M
is the
72 weights ---- (optional) an array of weights
with shape (M,);
if not present, all weights will
75 The need
for these data object attributes can be removed by subclassing HistogramLayer
and overriding
76 the hist1d
and hist2d methods.
79 defaults1d = dict(facecolor='b', alpha=0.5)
80 defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation=
'nearest')
82 def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=
None, kwds2d=
None):
90 """Extract points from the data object and compute a 1-d histogram.
92 Return value should match that of numpy.histogram: a tuple of (hist, edges),
93 where hist is a 1-d array
with size=bins1d,
and edges
is a 1-d array
with
94 size=self.
bins1dbins1d+1 giving the upper
and lower edges of the bins.
96 i = data.dimensions.index(dim)
97 if hasattr(data,
"weights")
and data.weights
is not None:
98 weights = data.weights
101 return numpy.histogram(data.values[:, i], bins=self.
bins1dbins1d, weights=weights,
102 range=limits, normed=
True)
104 def hist2d(self, data, xDim, yDim, xLimits, yLimits):
105 """Extract points from the data object and compute a 1-d histogram.
107 Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges),
108 where hist is a 2-d array
with shape=bins2d, xEdges
is a 1-d array
with size=bins2d[0]+1,
109 and yEdges
is a 1-d array
with size=bins2d[1]+1.
111 i = data.dimensions.index(yDim)
112 j = data.dimensions.index(xDim)
113 if hasattr(data,
"weights")
and data.weights
is not None:
114 weights = data.weights
117 return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.
bins2dbins2d, weights=weights,
118 range=(xLimits, yLimits), normed=
True)
121 y, xEdge = self.
hist1dhist1d(data, dim, axes.get_xlim())
122 xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
123 width = xEdge[1:] - xEdge[:-1]
124 return axes.bar(xCenter, y, width=width, align=
'center', **self.
kwds1dkwds1d)
127 x, yEdge = self.
hist1dhist1d(data, dim, axes.get_ylim())
128 yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
129 height = yEdge[1:] - yEdge[:-1]
130 return axes.barh(yCenter, x, height=height, align=
'center', **self.
kwds1dkwds1d)
132 def plotXY(self, axes, data, xDim, yDim):
133 z, xEdge, yEdge = self.
hist2dhist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
134 return axes.imshow(z.transpose(), aspect=
'auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
135 origin=
'lower', **self.
kwds2dkwds2d)
139 """A Layer class that plots individual points in 2-d, and does nothing in 1-d.
141 Relies on two data object attributes:
143 values ----- a (M,N) array of data points, where N is the dimension of the dataset
and M
is the
144 number of data points
146 weights ---- (optional) an array of weights
with shape (M,); will be used to set the color of points
150 defaults = dict(linewidth=0, alpha=0.2)
162 def plotXY(self, axes, data, xDim, yDim):
163 i = data.dimensions.index(yDim)
164 j = data.dimensions.index(xDim)
165 if hasattr(data,
"weights")
and data.weights
is not None:
166 args = data.values[:, j], data.values[:, i], data.weights
168 args = data.values[:, j], data.values[:, i]
169 return axes.scatter(*args, **self.kwds)
173 """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices.
175 The 2-d slices are drawn as contours,
and the 1-d slices are drawn
as simple curves.
177 Relies on eval1d
and eval2d methods
in the data object; this can be avoided by subclassing
178 SurfaceLayer
and reimplementing its own eval1d
and eval2d methods.
181 defaults1d = dict(linewidth=2, color='r')
182 defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds)
184 def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
193 """Return analytic function values for the given values."""
194 return data.eval1d(dim, x)
196 def eval2d(self, data, xDim, yDim, x, y):
197 """Return analytic function values for the given values."""
198 return data.eval2d(xDim, yDim, x, y)
201 xMin, xMax = axes.get_xlim()
202 x = numpy.linspace(xMin, xMax, self.
steps1dsteps1d)
203 z = self.
eval1deval1d(data, dim, x)
206 return axes.plot(x, z, **self.
kwds1dkwds1d)
209 yMin, yMax = axes.get_ylim()
210 y = numpy.linspace(yMin, yMax, self.
steps1dsteps1d)
211 z = self.
eval1deval1d(data, dim, y)
214 return axes.plot(z, y, **self.
kwds1dkwds1d)
216 def plotXY(self, axes, data, xDim, yDim):
217 xMin, xMax = axes.get_xlim()
218 yMin, yMax = axes.get_ylim()
219 xc = numpy.linspace(xMin, xMax, self.
steps2dsteps2d)
220 yc = numpy.linspace(yMin, yMax, self.
steps2dsteps2d)
221 xg, yg = numpy.meshgrid(xc, yc)
222 z = self.
eval2deval2d(data, xDim, yDim, xg, yg)
226 return axes.contourf(xg, yg, z, 6, **self.
kwds2dkwds2d)
228 return axes.contour(xg, yg, z, 6, **self.
kwds2dkwds2d)
232 """A layer that marks a few points with axis-length vertical and horizontal lines.
234 This relies on a "points" data object attribute.
237 defaults = dict(alpha=0.8)
239 def __init__(self, tag, colors=(
"y",
"m",
"c",
"r",
"g",
"b"), **kwds):
245 i = data.dimensions.index(dim)
247 for n, point
in enumerate(data.points):
248 artists.append(axes.axvline(point[i], color=self.
colorscolors[n % len(self.
colorscolors)], **self.
kwdskwds))
252 i = data.dimensions.index(dim)
254 for n, point
in enumerate(data.points):
255 artists.append(axes.axhline(point[i], color=self.
colorscolors[n % len(self.
colorscolors)], **self.
kwdskwds))
258 def plotXY(self, axes, data, xDim, yDim):
259 i = data.dimensions.index(yDim)
260 j = data.dimensions.index(xDim)
262 for n, point
in enumerate(data.points):
263 artists.append(axes.axvline(point[j], color=self.
colorscolors[n % len(self.
colorscolors)], **self.
kwdskwds))
264 artists.append(axes.axhline(point[i], color=self.
colorscolors[n % len(self.
colorscolors)], **self.
kwdskwds))
269 """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d
270 slices through an N-d density.
276 self.
_dict_dict = dict()
280 layer = self.
_dict_dict.pop(name)
281 self.
_parent_parent._dropLayer(name, layer)
285 self.
_dict_dict[name] = layer
286 self.
_parent_parent._plotLayer(name, layer)
289 return self.
_dict_dict[name]
295 return len(self.
_dict_dict)
301 return repr(self.
_dict_dict)
304 layer = self.
_dict_dict[name]
305 self.
_parent_parent._dropLayer(name, layer)
306 self.
_parent_parent._plotLayer(name, layer)
312 self.
_lower_lower = dict()
313 self.
_upper_upper = dict()
316 for v
in self.
datadata.values():
317 for dim
in v.dimensions:
318 if dim
not in active:
320 self.
_lower_lower[dim] = v.lower[dim]
321 self.
_upper_upper[dim] = v.upper[dim]
325 self.
_active_active = tuple(active)
327 self.
figurefigure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
331 def _dropLayer(self, name, layer):
332 def removeArtist(*key):
334 self.
_objs_objs.pop(key).remove()
335 except AttributeError:
342 for i, yDim
in enumerate(self.
_active_active):
343 removeArtist(
None, i, name)
344 removeArtist(i,
None, name)
345 for j, xDim
in enumerate(self.
_active_active):
348 removeArtist(i, j, name)
350 def _plotLayer(self, name, layer):
351 for i, yDim
in enumerate(self.
_active_active):
352 if yDim
not in self.
datadata[layer.tag].dimensions:
354 self.
_objs_objs[
None, i, name] = layer.plotX(self.
_axes_axes[
None, i], self.
datadata[layer.tag], yDim)
355 self.
_objs_objs[i,
None, name] = layer.plotY(self.
_axes_axes[i,
None], self.
datadata[layer.tag], yDim)
356 for j, xDim
in enumerate(self.
_active_active):
357 if xDim
not in self.
datadata[layer.tag].dimensions:
361 self.
_objs_objs[i, j, name] = layer.plotXY(self.
_axes_axes[i, j], self.
datadata[layer.tag], xDim, yDim)
362 self.
_axes_axes[
None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
363 self.
_axes_axes[i,
None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
364 self.
_axes_axes[
None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
365 self.
_axes_axes[i,
None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
367 def _get_active(self):
370 def _set_active(self, active):
372 if len(s) != len(active):
373 raise ValueError(
"Active set contains duplicates")
374 if not self.
_all_dims_all_dims.issuperset(s):
375 raise ValueError(
"Invalid values in active set")
376 self.
_active_active = tuple(active)
378 active = property(_get_active, _set_active, doc=
"sequence of active dimensions to plot (sequence of str)")
387 def _build_axes(self):
389 self.
_axes_axes = dict()
390 self.
_objs_objs = dict()
398 axesX = self.
_axes_axes[
None, j] = self.
figurefigure.add_subplot(n+1, n+1, jStart+j*jStride)
399 axesX.autoscale(
False, axis=
'x')
400 axesX.xaxis.tick_top()
403 bbox = axesX.get_position()
405 axesX.set_position(bbox)
406 axesY = self.
_axes_axes[i,
None] = self.
figurefigure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
407 axesY.autoscale(
False, axis=
'y')
408 axesY.yaxis.tick_right()
411 bbox = axesY.get_position()
413 axesY.set_position(bbox)
416 axesXY = self.
_axes_axes[i, j] = self.
figurefigure.add_subplot(
417 n+1, n+1, iStart+i*iStride + jStart+j*jStride,
418 sharex=self.
_axes_axes[
None, j],
419 sharey=self.
_axes_axes[i,
None]
421 axesXY.autoscale(
False)
428 xbox = self.
_axes_axes[
None, j].get_position()
429 ybox = self.
_axes_axes[i,
None].get_position()
430 self.
figurefigure.
text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.
activeactive[i],
431 ha=
'center', va=
'center', weight=
'bold')
432 self.
_axes_axes[i, j].get_frame().set_facecolor(
'none')
435 self.
figurefigure.canvas.draw()
439 """An example data object for DensityPlot, demonstrating the necessarity interface.
441 There are two levels of requirements for a data object. First are the attributes
442 required by the DensityPlot object itself; these must be present on every data object:
444 dimensions ------ a sequence of strings that provide names
for the dimensions
446 lower ----------- a dictionary of {dimension-name: lower-bound}
448 upper ----------- a dictionary of {dimension-name: upper-bound}
450 The second level of requirements are those of the Layer objects provided here. These
451 may be absent
if the associated Layer
is not used
or is subclassed to reimplement the
452 Layer method that calls the data object method. Currently, these include:
454 eval1d, eval2d -- methods used by the SurfaceLayer
class; see their docs
for more info
456 values ---------- attribute used by the HistogramLayer
and ScatterLayer classes, an array
457 with shape (M,N), where N
is the number of dimension
and M
is the number
460 weights --------- optional attribute used by the HistogramLayer
and ScatterLayer classes,
461 a 1-d array
with size=M that provides weights
for each data point
466 self.
mumu = numpy.array([-10.0, 0.0, 10.0])
467 self.
sigmasigma = numpy.array([3.0, 2.0, 1.0])
470 self.
valuesvalues = numpy.random.randn(2000, 3) * self.
sigmasigma[numpy.newaxis, :] + self.
mumu[numpy.newaxis, :]
473 """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array;
474 this method must be numpy-vectorized).
477 return numpy.exp(-0.5*((x-self.
mumu[i])/self.
sigmasigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.
sigmasigma[i])
480 """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y
481 (2-d numpy arrays with the same shape; this method must be numpy-vectorized).
485 return (numpy.exp(-0.5*(((x-self.
mumu[j])/self.
sigmasigma[j])**2 + ((y-self.
mumu[i])/self.
sigmasigma[i])**2))
486 / (2.0*numpy.pi * self.
sigmasigma[j]*self.
sigmasigma[i]))
490 """Create and return a DensityPlot with example data."""
491 fig = matplotlib.pyplot.figure()
std::vector< SchemaItem< Flag > > * items
def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), **kwds)
def plotXY(self, axes, data, xDim, yDim)
def plotX(self, axes, data, dim)
def plotY(self, axes, data, dim)
def __setitem__(self, name, layer)
def __delitem__(self, name)
def __getitem__(self, name)
def __init__(self, parent)
def __init__(self, figure, **kwds)
def _plotLayer(self, name, layer)
def eval2d(self, xDim, yDim, x, y)
def hist2d(self, data, xDim, yDim, xLimits, yLimits)
def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None)
def plotXY(self, axes, data, xDim, yDim)
def hist1d(self, data, dim, limits)
def plotY(self, axes, data, dim)
def plotX(self, axes, data, dim)
def plotY(self, axes, data, dim)
def plotXY(self, axes, data, xDim, yDim)
def __init__(self, tag, **kwds)
def plotX(self, axes, data, dim)
def eval1d(self, data, dim, x)
def plotX(self, axes, data, dim)
def eval2d(self, data, xDim, yDim, x, y)
def plotY(self, axes, data, dim)
def plotXY(self, axes, data, xDim, yDim)
def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None)
daf::base::PropertySet * set
def hide_xticklabels(axes)
def mergeDefaults(kwds, defaults)
def hide_yticklabels(axes)