23 """A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an
26 The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds
27 a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a
28 data object that abstracts away how the N-d density data is actually represented.
30 For simple cases, users can just create a custom data class with an interface like that of
31 the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer
32 classes directly. In more complicated cases, users may want to create their own Layer classes,
33 which may define their own relationship with the data object.
36 import collections.abc
39 import matplotlib.pyplot
40 import matplotlib.ticker
42 __all__ = (
"HistogramLayer",
"SurfaceLayer",
"ScatterLayer",
"CrossPointsLayer",
43 "DensityPlot",
"ExampleData",
"demo")
47 for label
in axes.get_xticklabels():
48 label.set_visible(
False)
52 for label
in axes.get_yticklabels():
53 label.set_visible(
False)
57 copy = defaults.copy()
64 """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and
65 colormapped large-pixel images in 2-d.
67 Relies on two data object attributes:
69 values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
72 weights ---- (optional) an array of weights with shape (M,); if not present, all weights will
75 The need for these data object attributes can be removed by subclassing HistogramLayer and overriding
76 the hist1d and hist2d methods.
79 defaults1d = dict(facecolor=
'b', alpha=0.5)
80 defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation=
'nearest')
82 def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=
None, kwds2d=
None):
90 """Extract points from the data object and compute a 1-d histogram.
92 Return value should match that of numpy.histogram: a tuple of (hist, edges),
93 where hist is a 1-d array with size=bins1d, and edges is a 1-d array with
94 size=self.bins1d+1 giving the upper and lower edges of the bins.
96 i = data.dimensions.index(dim)
97 if hasattr(data,
"weights")
and data.weights
is not None:
98 weights = data.weights
101 return numpy.histogram(data.values[:, i], bins=self.
bins1d, weights=weights,
102 range=limits, normed=
True)
104 def hist2d(self, data, xDim, yDim, xLimits, yLimits):
105 """Extract points from the data object and compute a 1-d histogram.
107 Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges),
108 where hist is a 2-d array with shape=bins2d, xEdges is a 1-d array with size=bins2d[0]+1,
109 and yEdges is a 1-d array with size=bins2d[1]+1.
111 i = data.dimensions.index(yDim)
112 j = data.dimensions.index(xDim)
113 if hasattr(data,
"weights")
and data.weights
is not None:
114 weights = data.weights
117 return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.
bins2d, weights=weights,
118 range=(xLimits, yLimits), normed=
True)
121 y, xEdge = self.
hist1d(data, dim, axes.get_xlim())
122 xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
123 width = xEdge[1:] - xEdge[:-1]
124 return axes.bar(xCenter, y, width=width, align=
'center', **self.
kwds1d)
127 x, yEdge = self.
hist1d(data, dim, axes.get_ylim())
128 yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
129 height = yEdge[1:] - yEdge[:-1]
130 return axes.barh(yCenter, x, height=height, align=
'center', **self.
kwds1d)
132 def plotXY(self, axes, data, xDim, yDim):
133 z, xEdge, yEdge = self.
hist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
134 return axes.imshow(z.transpose(), aspect=
'auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
135 origin=
'lower', **self.
kwds2d)
139 """A Layer class that plots individual points in 2-d, and does nothing in 1-d.
141 Relies on two data object attributes:
143 values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
144 number of data points
146 weights ---- (optional) an array of weights with shape (M,); will be used to set the color of points
150 defaults = dict(linewidth=0, alpha=0.2)
162 def plotXY(self, axes, data, xDim, yDim):
163 i = data.dimensions.index(yDim)
164 j = data.dimensions.index(xDim)
165 if hasattr(data,
"weights")
and data.weights
is not None:
166 args = data.values[:, j], data.values[:, i], data.weights
168 args = data.values[:, j], data.values[:, i]
169 return axes.scatter(*args, **self.kwds)
173 """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices.
175 The 2-d slices are drawn as contours, and the 1-d slices are drawn as simple curves.
177 Relies on eval1d and eval2d methods in the data object; this can be avoided by subclassing
178 SurfaceLayer and reimplementing its own eval1d and eval2d methods.
181 defaults1d = dict(linewidth=2, color=
'r')
182 defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds)
184 def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
193 """Return analytic function values for the given values."""
194 return data.eval1d(dim, x)
196 def eval2d(self, data, xDim, yDim, x, y):
197 """Return analytic function values for the given values."""
198 return data.eval2d(xDim, yDim, x, y)
201 xMin, xMax = axes.get_xlim()
202 x = numpy.linspace(xMin, xMax, self.
steps1d)
203 z = self.
eval1d(data, dim, x)
206 return axes.plot(x, z, **self.
kwds1d)
209 yMin, yMax = axes.get_ylim()
210 y = numpy.linspace(yMin, yMax, self.
steps1d)
211 z = self.
eval1d(data, dim, y)
214 return axes.plot(z, y, **self.
kwds1d)
216 def plotXY(self, axes, data, xDim, yDim):
217 xMin, xMax = axes.get_xlim()
218 yMin, yMax = axes.get_ylim()
219 xc = numpy.linspace(xMin, xMax, self.
steps2d)
220 yc = numpy.linspace(yMin, yMax, self.
steps2d)
221 xg, yg = numpy.meshgrid(xc, yc)
222 z = self.
eval2d(data, xDim, yDim, xg, yg)
226 return axes.contourf(xg, yg, z, 6, **self.
kwds2d)
228 return axes.contour(xg, yg, z, 6, **self.
kwds2d)
232 """A layer that marks a few points with axis-length vertical and horizontal lines.
234 This relies on a "points" data object attribute.
237 defaults = dict(alpha=0.8)
239 def __init__(self, tag, colors=(
"y",
"m",
"c",
"r",
"g",
"b"), **kwds):
245 i = data.dimensions.index(dim)
247 for n, point
in enumerate(data.points):
248 artists.append(axes.axvline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
252 i = data.dimensions.index(dim)
254 for n, point
in enumerate(data.points):
255 artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
258 def plotXY(self, axes, data, xDim, yDim):
259 i = data.dimensions.index(yDim)
260 j = data.dimensions.index(xDim)
262 for n, point
in enumerate(data.points):
263 artists.append(axes.axvline(point[j], color=self.
colors[n % len(self.
colors)], **self.
kwds))
264 artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
269 """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d
270 slices through an N-d density.
280 layer = self.
_dict.pop(name)
281 self.
_parent._dropLayer(name, layer)
285 self.
_dict[name] = layer
286 self.
_parent._plotLayer(name, layer)
289 return self.
_dict[name]
295 return len(self.
_dict)
298 return str(self.
_dict)
301 return repr(self.
_dict)
304 layer = self.
_dict[name]
305 self.
_parent._dropLayer(name, layer)
306 self.
_parent._plotLayer(name, layer)
316 for v
in self.
data.values():
317 for dim
in v.dimensions:
318 if dim
not in active:
320 self.
_lower[dim] = v.lower[dim]
321 self.
_upper[dim] = v.upper[dim]
327 self.
figure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
331 def _dropLayer(self, name, layer):
332 def removeArtist(*key):
334 self.
_objs.pop(key).remove()
335 except AttributeError:
342 for i, yDim
in enumerate(self.
_active):
343 removeArtist(
None, i, name)
344 removeArtist(i,
None, name)
345 for j, xDim
in enumerate(self.
_active):
348 removeArtist(i, j, name)
350 def _plotLayer(self, name, layer):
351 for i, yDim
in enumerate(self.
_active):
352 if yDim
not in self.
data[layer.tag].dimensions:
354 self.
_objs[
None, i, name] = layer.plotX(self.
_axes[
None, i], self.
data[layer.tag], yDim)
355 self.
_objs[i,
None, name] = layer.plotY(self.
_axes[i,
None], self.
data[layer.tag], yDim)
356 for j, xDim
in enumerate(self.
_active):
357 if xDim
not in self.
data[layer.tag].dimensions:
361 self.
_objs[i, j, name] = layer.plotXY(self.
_axes[i, j], self.
data[layer.tag], xDim, yDim)
362 self.
_axes[
None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
363 self.
_axes[i,
None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
364 self.
_axes[
None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
365 self.
_axes[i,
None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
367 def _get_active(self):
370 def _set_active(self, active):
372 if len(s) != len(active):
373 raise ValueError(
"Active set contains duplicates")
375 raise ValueError(
"Invalid values in active set")
378 active = property(_get_active, _set_active, doc=
"sequence of active dimensions to plot (sequence of str)")
387 def _build_axes(self):
398 axesX = self.
_axes[
None, j] = self.
figure.add_subplot(n+1, n+1, jStart+j*jStride)
399 axesX.autoscale(
False, axis=
'x')
400 axesX.xaxis.tick_top()
403 bbox = axesX.get_position()
405 axesX.set_position(bbox)
406 axesY = self.
_axes[i,
None] = self.
figure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
407 axesY.autoscale(
False, axis=
'y')
408 axesY.yaxis.tick_right()
411 bbox = axesY.get_position()
413 axesY.set_position(bbox)
416 axesXY = self.
_axes[i, j] = self.
figure.add_subplot(
417 n+1, n+1, iStart+i*iStride + jStart+j*jStride,
418 sharex=self.
_axes[
None, j],
419 sharey=self.
_axes[i,
None]
421 axesXY.autoscale(
False)
428 xbox = self.
_axes[
None, j].get_position()
429 ybox = self.
_axes[i,
None].get_position()
430 self.
figure.text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.
active[i],
431 ha=
'center', va=
'center', weight=
'bold')
432 self.
_axes[i, j].get_frame().set_facecolor(
'none')
439 """An example data object for DensityPlot, demonstrating the necessarity interface.
441 There are two levels of requirements for a data object. First are the attributes
442 required by the DensityPlot object itself; these must be present on every data object:
444 dimensions ------ a sequence of strings that provide names for the dimensions
446 lower ----------- a dictionary of {dimension-name: lower-bound}
448 upper ----------- a dictionary of {dimension-name: upper-bound}
450 The second level of requirements are those of the Layer objects provided here. These
451 may be absent if the associated Layer is not used or is subclassed to reimplement the
452 Layer method that calls the data object method. Currently, these include:
454 eval1d, eval2d -- methods used by the SurfaceLayer class; see their docs for more info
456 values ---------- attribute used by the HistogramLayer and ScatterLayer classes, an array
457 with shape (M,N), where N is the number of dimension and M is the number
460 weights --------- optional attribute used by the HistogramLayer and ScatterLayer classes,
461 a 1-d array with size=M that provides weights for each data point
466 self.
mu = numpy.array([-10.0, 0.0, 10.0])
467 self.
sigma = numpy.array([3.0, 2.0, 1.0])
470 self.
values = numpy.random.randn(2000, 3) * self.
sigma[numpy.newaxis, :] + self.
mu[numpy.newaxis, :]
473 """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array;
474 this method must be numpy-vectorized).
477 return numpy.exp(-0.5*((x-self.
mu[i])/self.
sigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.
sigma[i])
480 """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y
481 (2-d numpy arrays with the same shape; this method must be numpy-vectorized).
485 return (numpy.exp(-0.5*(((x-self.
mu[j])/self.
sigma[j])**2 + ((y-self.
mu[i])/self.
sigma[i])**2))
486 / (2.0*numpy.pi * self.
sigma[j]*self.
sigma[i]))
490 """Create and return a DensityPlot with example data."""
491 fig = matplotlib.pyplot.figure()