22 """Support utilities for Measuring sources"""
39 from .
import subtractPsf, fitKernelParamsToImage
43 afwDisplay.setDefaultMaskTransparency(75)
48 objId = int((oid & 0xffff) - 1)
51 return dict(objId=objId)
56 def showSourceSet(sSet, xy0=(0, 0), display=
None, ctype=afwDisplay.GREEN, symb=
"+", size=2):
57 """Draw the (XAstrom, YAstrom) positions of a set of Sources. Image has the given XY0"""
60 display = afwDisplay.Display()
61 with display.Buffering():
63 xc, yc = s.getXAstrom() - xy0[0], s.getYAstrom() - xy0[1]
66 display.dot(
str(
splitId(s.getId(),
True)[
"objId"]), xc, yc, ctype=ctype, size=size)
68 display.dot(symb, xc, yc, ctype=ctype, size=size)
76 symb=None, ctype=None, ctypeUnused=None, ctypeBad=None, size=2, display=None):
77 """Show the SpatialCells.
79 If symb is something that afwDisplay.Display.dot() understands (e.g. "o"),
80 the top nMaxPerCell candidates will be indicated with that symbol, using
85 display = afwDisplay.Display()
86 with display.Buffering():
87 origin = [-exposure.getMaskedImage().getX0(), -exposure.getMaskedImage().getY0()]
88 for cell
in psfCellSet.getCellList():
89 displayUtils.drawBBox(cell.getBBox(), origin=origin, display=display)
95 goodies = ctypeBad
is None
96 for cand
in cell.begin(goodies):
100 xc, yc = cand.getXCenter() + origin[0], cand.getYCenter() + origin[1]
106 color = ctypeBad
if cand.isBad()
else ctype
114 display.dot(symb, xc, yc, ctype=ct, size=size)
116 source = cand.getSource()
119 rchi2 = cand.getChi2()
122 display.dot(
"%d %.1f" % (
splitId(source.getId(),
True)[
"objId"], rchi2),
123 xc - size, yc - size - 4, ctype=color, size=2)
126 display.dot(
"%.2f %.2f %.2f" % (source.getIxx(), source.getIxy(), source.getIyy()),
127 xc-size, yc + size + 4, ctype=color, size=size)
131 def showPsfCandidates(exposure, psfCellSet, psf=None, display=None, normalize=True, showBadCandidates=True,
132 fitBasisComponents=False, variance=None, chi=None):
133 """Display the PSF candidates.
135 If psf is provided include PSF model and residuals; if normalize is true normalize the PSFs
138 If chi is True, generate a plot of residuals/sqrt(variance), i.e. chi
140 If fitBasisComponents is true, also find the best linear combination of the PSF's components
144 display = afwDisplay.Display()
147 if variance
is not None:
152 mos = displayUtils.Mosaic()
154 candidateCenters = []
155 candidateCentersBad = []
158 for cell
in psfCellSet.getCellList():
159 for cand
in cell.begin(
False):
160 rchi2 = cand.getChi2()
164 if not showBadCandidates
and cand.isBad():
168 im_resid = displayUtils.Mosaic(gutter=0, background=-5, mode=
"x")
171 im = cand.getMaskedImage()
172 xc, yc = cand.getXCenter(), cand.getYCenter()
174 margin = 0
if True else 5
175 w, h = im.getDimensions()
179 bim = im.Factory(w + 2*margin, h + 2*margin)
183 bim.getVariance().
set(stdev**2)
190 im = im.Factory(im,
True)
191 im.setXY0(cand.getMaskedImage().getXY0())
196 im_resid.append(im.Factory(im,
True))
200 psfIm = mi.getImage()
201 config = measBase.SingleFrameMeasurementTask.ConfigClass()
202 config.slots.centroid =
"base_SdssCentroid"
204 schema = afwTable.SourceTable.makeMinimalSchema()
205 measureSources = measBase.SingleFrameMeasurementTask(schema, config=config)
209 miBig = mi.Factory(im.getWidth() + 2*extra, im.getHeight() + 2*extra)
210 miBig[extra:-extra, extra:-extra, afwImage.LOCAL] = mi
211 miBig.setXY0(mi.getX0() - extra, mi.getY0() - extra)
221 footprintSet.makeSources(catalog)
223 if len(catalog) == 0:
224 raise RuntimeError(
"Failed to detect any objects")
226 measureSources.run(catalog, exp)
227 if len(catalog) == 1:
231 for i, s
in enumerate(catalog):
232 d = numpy.hypot(xc - s.getX(), yc - s.getY())
233 if i == 0
or d < dmin:
235 xc, yc = source.getCentroid()
245 resid = resid.getImage()
246 var = im.getVariance()
247 var = var.Factory(var,
True)
248 numpy.sqrt(var.getArray(), var.getArray())
251 im_resid.append(resid)
254 if fitBasisComponents:
255 im = cand.getMaskedImage()
257 im = im.Factory(im,
True)
258 im.setXY0(cand.getMaskedImage().getXY0())
261 noSpatialKernel = psf.getKernel()
263 noSpatialKernel =
None
272 outImage = afwImage.ImageD(outputKernel.getDimensions())
273 outputKernel.computeImage(outImage,
False)
275 im -= outImage.convertF()
279 bim = im.Factory(w + 2*margin, h + 2*margin)
283 bim.assign(resid, bbox)
287 resid = resid.getImage()
290 im_resid.append(resid)
292 im = im_resid.makeMosaic()
294 im = cand.getMaskedImage()
299 objId =
splitId(cand.getSource().getId(),
True)[
"objId"]
301 lab =
"%d chi^2 %.1f" % (objId, rchi2)
302 ctype = afwDisplay.RED
if cand.isBad()
else afwDisplay.GREEN
304 lab =
"%d flux %8.3g" % (objId, cand.getSource().getPsfInstFlux())
305 ctype = afwDisplay.GREEN
307 mos.append(im, lab, ctype)
309 if False and numpy.isnan(rchi2):
310 display.mtv(cand.getMaskedImage().getImage(), title=
"showPsfCandidates: candidate")
311 print(
"amp", cand.getAmplitude())
313 im = cand.getMaskedImage()
314 center = (candidateIndex, xc - im.getX0(), yc - im.getY0())
317 candidateCentersBad.append(center)
319 candidateCenters.append(center)
322 title =
"chi(Psf fit)"
324 title =
"Stars & residuals"
325 mosaicImage = mos.makeMosaic(display=display, title=title)
327 with display.Buffering():
328 for centers, color
in ((candidateCenters, afwDisplay.GREEN), (candidateCentersBad, afwDisplay.RED)):
330 bbox = mos.getBBox(cen[0])
331 display.dot(
"+", cen[1] + bbox.getMinX(), cen[2] + bbox.getMinY(), ctype=color)
336 def makeSubplots(fig, nx=2, ny=2, Nx=1, Ny=1, plottingArea=(0.1, 0.1, 0.85, 0.80),
337 pxgutter=0.05, pygutter=0.05, xgutter=0.04, ygutter=0.04,
338 headroom=0.0, panelBorderWeight=0, panelColor=
'black'):
339 """Return a generator of a set of subplots, a set of Nx*Ny panels of nx*ny plots. Each panel is fully
340 filled by row (starting in the bottom left) before the next panel is started. If panelBorderWidth is
341 greater than zero a border is drawn around each panel, adjusted to enclose the axis labels.
344 subplots = makeSubplots(fig, 2, 2, Nx=1, Ny=1, panelColor='k')
345 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (0,0)')
346 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (1,0)')
347 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (0,1)')
348 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (1,1)')
351 @param fig The matplotlib figure to draw
352 @param nx The number of plots in each row of each panel
353 @param ny The number of plots in each column of each panel
354 @param Nx The number of panels in each row of the figure
355 @param Ny The number of panels in each column of the figure
356 @param plottingArea (x0, y0, x1, y1) for the part of the figure containing all the panels
357 @param pxgutter Spacing between columns of panels in units of (x1 - x0)
358 @param pygutter Spacing between rows of panels in units of (y1 - y0)
359 @param xgutter Spacing between columns of plots within a panel in units of (x1 - x0)
360 @param ygutter Spacing between rows of plots within a panel in units of (y1 - y0)
361 @param headroom Extra spacing above each plot for e.g. a title
362 @param panelBorderWeight Width of border drawn around panels
363 @param panelColor Colour of border around panels
368 import matplotlib.pyplot
as plt
369 except ImportError
as e:
370 log.warn(
"Unable to import matplotlib: %s", e)
376 except AttributeError:
377 fig.__show = fig.show
384 fig.show = types.MethodType(myShow, fig)
393 Callback to draw the panel borders when the plots are drawn to the canvas
395 if panelBorderWeight <= 0:
398 for p
in axes.keys():
401 bboxes.append(ax.bbox.union([label.get_window_extent()
for label
in
402 ax.get_xticklabels() + ax.get_yticklabels()]))
409 bbox = ax.bbox.union(bboxes)
411 xy0, xy1 = ax.transData.inverted().
transform(bbox)
414 w, h = x1 - x0, y1 - y0
423 rec = ax.add_patch(plt.Rectangle((x0, y0), w, h, fill=
False,
424 lw=panelBorderWeight, edgecolor=panelColor))
425 rec.set_clip_on(
False)
429 fig.canvas.mpl_connect(
'draw_event', on_draw)
433 x0, y0 = plottingArea[0:2]
434 W, H = plottingArea[2:4]
435 w = (W - (Nx - 1)*pxgutter - (nx*Nx - 1)*xgutter)/float(nx*Nx)
436 h = (H - (Ny - 1)*pygutter - (ny*Ny - 1)*ygutter)/float(ny*Ny)
440 for panel
in range(Nx*Ny):
444 for window
in range(nx*ny):
445 x = nx*px + window%nx
446 y = ny*py + window//nx
447 ax = fig.add_axes((x0 + xgutter + pxgutter + x*w + (px - 1)*pxgutter + (x - 1)*xgutter,
448 y0 + ygutter + pygutter + y*h + (py - 1)*pygutter + (y - 1)*ygutter,
449 w, h), frame_on=
True, facecolor=
'w')
455 matchKernelAmplitudes=False, keepPlots=True):
456 """Plot the PSF spatial model."""
460 import matplotlib.pyplot
as plt
461 import matplotlib
as mpl
462 except ImportError
as e:
463 log.warn(
"Unable to import matplotlib: %s", e)
466 noSpatialKernel = psf.getKernel()
473 for cell
in psfCellSet.getCellList():
474 for cand
in cell.begin(
False):
475 if not showBadCandidates
and cand.isBad():
479 im = cand.getMaskedImage()
487 for p, k
in zip(params, kernels):
488 amp += p * k.getSum()
490 targetFits = badFits
if cand.isBad()
else candFits
491 targetPos = badPos
if cand.isBad()
else candPos
492 targetAmps = badAmps
if cand.isBad()
else candAmps
494 targetFits.append([x / amp
for x
in params])
495 targetPos.append(candCenter)
496 targetAmps.append(amp)
498 xGood = numpy.array([pos.getX()
for pos
in candPos]) - exposure.getX0()
499 yGood = numpy.array([pos.getY()
for pos
in candPos]) - exposure.getY0()
500 zGood = numpy.array(candFits)
502 xBad = numpy.array([pos.getX()
for pos
in badPos]) - exposure.getX0()
503 yBad = numpy.array([pos.getY()
for pos
in badPos]) - exposure.getY0()
504 zBad = numpy.array(badFits)
507 xRange = numpy.linspace(0, exposure.getWidth(), num=numSample)
508 yRange = numpy.linspace(0, exposure.getHeight(), num=numSample)
510 kernel = psf.getKernel()
511 nKernelComponents = kernel.getNKernelParameters()
515 nPanelX = int(math.sqrt(nKernelComponents))
516 nPanelY = nKernelComponents//nPanelX
517 while nPanelY*nPanelX < nKernelComponents:
523 fig.canvas._tkcanvas._root().lift()
529 mpl.rcParams[
"figure.titlesize"] =
"x-small"
530 subplots =
makeSubplots(fig, 2, 2, Nx=nPanelX, Ny=nPanelY, xgutter=0.06, ygutter=0.06, pygutter=0.04)
532 for k
in range(nKernelComponents):
533 func = kernel.getSpatialFunction(k)
534 dfGood = zGood[:, k] - numpy.array([func(pos.getX(), pos.getY())
for pos
in candPos])
538 dfBad = zBad[:, k] - numpy.array([func(pos.getX(), pos.getY())
for pos
in badPos])
539 yMin =
min([yMin, dfBad.min()])
540 yMax =
max([yMax, dfBad.max()])
541 yMin -= 0.05 * (yMax - yMin)
542 yMax += 0.05 * (yMax - yMin)
547 fRange = numpy.ndarray((len(xRange), len(yRange)))
548 for j, yVal
in enumerate(yRange):
549 for i, xVal
in enumerate(xRange):
550 fRange[j][i] = func(xVal, yVal)
554 ax.set_autoscale_on(
False)
555 ax.set_xbound(lower=0, upper=exposure.getHeight())
556 ax.set_ybound(lower=yMin, upper=yMax)
557 ax.plot(yGood, dfGood,
'b+')
559 ax.plot(yBad, dfBad,
'r+')
561 ax.set_title(
'Residuals(y)')
565 if matchKernelAmplitudes
and k == 0:
572 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
573 im = ax.imshow(fRange, aspect=
'auto', origin=
"lower", norm=norm,
574 extent=[0, exposure.getWidth()-1, 0, exposure.getHeight()-1])
575 ax.set_title(
'Spatial poly')
576 plt.colorbar(im, orientation=
'horizontal', ticks=[vmin, vmax])
579 ax.set_autoscale_on(
False)
580 ax.set_xbound(lower=0, upper=exposure.getWidth())
581 ax.set_ybound(lower=yMin, upper=yMax)
582 ax.plot(xGood, dfGood,
'b+')
584 ax.plot(xBad, dfBad,
'r+')
586 ax.set_title(
'K%d Residuals(x)' % k)
590 photoCalib = exposure.getPhotoCalib()
592 if photoCalib.getCalibrationMean() <= 0:
595 ampMag = [photoCalib.instFluxToMagnitude(candAmp)
for candAmp
in candAmps]
596 ax.plot(ampMag, zGood[:, k],
'b+')
598 badAmpMag = [photoCalib.instFluxToMagnitude(badAmp)
for badAmp
in badAmps]
599 ax.plot(badAmpMag, zBad[:, k],
'r+')
601 ax.set_title(
'Flux variation')
606 if keepPlots
and not keptPlots:
609 print(
"%s: Please close plots when done." % __name__)
614 print(
"Plots closed, exiting...")
616 atexit.register(show)
620 def showPsf(psf, eigenValues=None, XY=None, normalize=True, display=None):
621 """Display a PSF's eigen images
623 If normalize is True, set the largest absolute value of each eigenimage to 1.0 (n.b. sum == 0.0 for i > 0)
629 coeffs = psf.getLocalKernel(
lsst.geom.PointD(XY[0], XY[1])).getKernelParameters()
633 mos = displayUtils.Mosaic(gutter=2, background=-0.1)
634 for i, k
in enumerate(psf.getKernel().getKernelList()):
635 im = afwImage.ImageD(k.getDimensions())
636 k.computeImage(im,
False)
638 im /= numpy.max(numpy.abs(im.getArray()))
641 mos.append(im,
"%g" % (coeffs[i]/coeffs[0]))
646 display = afwDisplay.Display()
647 mos.makeMosaic(display=display, title=
"Kernel Basis Functions")
652 def showPsfMosaic(exposure, psf=None, nx=7, ny=None, showCenter=True, showEllipticity=False,
653 showFwhm=False, stampSize=0, display=None, title=None):
654 """Show a mosaic of Psf images. exposure may be an Exposure (optionally with PSF),
655 or a tuple (width, height)
657 If stampSize is > 0, the psf images will be trimmed to stampSize*stampSize
662 showEllipticity =
True
663 scale = 2*math.log(2)
665 mos = displayUtils.Mosaic()
668 width, height = exposure.getWidth(), exposure.getHeight()
669 x0, y0 = exposure.getXY0()
671 psf = exposure.getPsf()
672 except AttributeError:
674 width, height = exposure[0], exposure[1]
677 raise RuntimeError(
"Unable to extract width/height from object of type %s" %
type(exposure))
680 ny = int(nx*float(height)/width + 0.5)
684 centroidName =
"SdssCentroid"
685 shapeName =
"base_SdssShape"
687 schema = afwTable.SourceTable.makeMinimalSchema()
688 schema.getAliasMap().
set(
"slot_Centroid", centroidName)
689 schema.getAliasMap().
set(
"slot_Centroid_flag", centroidName+
"_flag")
691 control = measBase.SdssCentroidControl()
692 centroider = measBase.SdssCentroidAlgorithm(control, centroidName, schema)
694 sdssShape = measBase.SdssShapeControl()
695 shaper = measBase.SdssShapeAlgorithm(sdssShape, shapeName, schema)
696 table = afwTable.SourceTable.make(schema)
698 table.defineCentroid(centroidName)
699 table.defineShape(shapeName)
704 if stampSize <= w
and stampSize <= h:
712 x = int(ix*(width-1)/(nx-1)) + x0
713 y = int(iy*(height-1)/(ny-1)) + y0
719 im = im.Factory(im, bbox)
720 lab =
"PSF(%d,%d)" % (x, y)
if False else ""
725 w, h = im.getWidth(), im.getHeight()
726 centerX = im.getX0() + w//2
727 centerY = im.getY0() + h//2
728 src = table.makeRecord()
731 foot.addPeak(centerX, centerY, 1)
732 src.setFootprint(foot)
735 centroider.measure(src, exp)
736 centers.append((src.getX() - im.getX0(), src.getY() - im.getY0()))
738 shaper.measure(src, exp)
739 shapes.append((src.getIxx(), src.getIxy(), src.getIyy()))
744 display = afwDisplay.Display()
745 mos.makeMosaic(display=display, title=title
if title
else "Model Psf", mode=nx)
747 if centers
and display:
748 with display.Buffering():
749 for i, (cen, shape)
in enumerate(zip(centers, shapes)):
750 bbox = mos.getBBox(i)
751 xc, yc = cen[0] + bbox.getMinX(), cen[1] + bbox.getMinY()
753 display.dot(
"+", xc, yc, ctype=afwDisplay.BLUE)
756 ixx, ixy, iyy = shape
760 display.dot(
"@:%g,%g,%g" % (ixx, ixy, iyy), xc, yc, ctype=afwDisplay.RED)
766 mimIn = exposure.getMaskedImage()
767 mimIn = mimIn.Factory(mimIn,
True)
769 psf = exposure.getPsf()
770 psfWidth, psfHeight = psf.getLocalKernel().getDimensions()
774 w, h = int(mimIn.getWidth()/scale), int(mimIn.getHeight()/scale)
776 im = mimIn.Factory(w + psfWidth, h + psfHeight)
780 x, y = s.getX(), s.getY()
782 sx, sy = int(x/scale + 0.5), int(y/scale + 0.5)
786 sim = smim.getImage()
790 flux = s.getApInstFlux()
791 elif magType ==
"model":
792 flux = s.getModelInstFlux()
793 elif magType ==
"psf":
794 flux = s.getPsfInstFlux()
796 raise RuntimeError(
"Unknown flux type %s" % magType)
799 except Exception
as e:
803 expIm = mimIn.getImage().
Factory(mimIn.getImage(),
805 int(y) - psfHeight//2),
811 cenPos.append([x - expIm.getX0() + sx, y - expIm.getY0() + sy])
816 display = afwDisplay.Display()
817 display.mtv(im, title=
"showPsfResiduals: image")
818 with display.Buffering():
820 display.dot(
"+", x, y)
826 """Write the contents of a SpatialCellSet to a many-MEF fits file"""
829 for cell
in psfCellSet.getCellList():
830 for cand
in cell.begin(
False):
836 md.set(
"CELL", cell.getLabel())
837 md.set(
"ID", cand.getId())
838 md.set(
"XCENTER", cand.getXCenter())
839 md.set(
"YCENTER", cand.getYCenter())
840 md.set(
"BAD", cand.isBad())
841 md.set(
"AMPL", cand.getAmplitude())
842 md.set(
"FLUX", cand.getSource().getPsfInstFlux())
843 md.set(
"CHI2", cand.getSource().getChi2())
845 im.writeFits(fileName, md, mode)
849 display.mtv(im, title=
"saveSpatialCellSet: image")