LSSTApplications  18.1.0
LSSTDataManagementBasePackage
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 
4 from builtins import zip
5 
6 from lsst.pex.config import Config, Field, ConfigurableField
7 from lsst.pipe.base import ArgumentParser, TaskRunner
8 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
9  MergeDetectionsTask,
10  DeblendCoaddSourcesTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 
19 import lsst.afw.table as afwTable
20 
21 
23  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
24  doDetection = Field(dtype=bool, default=False,
25  doc="Re-run detection? (requires *Coadd dataset to have been written)")
26  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
27  doc="Detect sources on coadd")
28  mergeCoaddDetections = ConfigurableField(
29  target=MergeDetectionsTask, doc="Merge detections")
30  deblendCoaddSources = ConfigurableField(target=DeblendCoaddSourcesTask, doc="Deblend merged detections")
31  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
32  doc="Measure merged and (optionally) deblended detections")
33  mergeCoaddMeasurements = ConfigurableField(
34  target=MergeMeasurementsTask, doc="Merge measurements")
35  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
36  doc="Forced measurement on coadded images")
37  reprocessing = Field(
38  dtype=bool, default=False,
39  doc=("Are we reprocessing?\n\n"
40  "This exists as a workaround for large deblender footprints causing large memory use "
41  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
42  "and return to reprocess on a machine with larger memory or more time "
43  "if we consider those footprints important to recover."),
44  )
45 
46  hasFakes = Field(
47  dtype=bool,
48  default=False,
49  doc="Should be set to True if fakes were inserted into the data being processed."
50  )
51 
52  def setDefaults(self):
53  Config.setDefaults(self)
54  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
55 
56  def validate(self):
57 
58  for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
59  "mergeCoaddMeasurements", "forcedPhotCoadd"):
60  coaddName = getattr(self, subtask).coaddName
61  if coaddName != self.coaddName:
62  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
63  (subtask, coaddName, self.coaddName))
64 
65 
67  """TaskRunner for running MultiBandTask
68 
69  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
70  except that we have a list of data references instead of a single
71  data reference being passed to the Task.run, and we pass the results
72  of the '--reuse-outputs-from' command option to the Task constructor.
73  """
74 
75  def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
76  TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
77  self.reuse = parsedCmd.reuse
78 
79  def makeTask(self, parsedCmd=None, args=None):
80  """A variant of the base version that passes a butler argument to the task's constructor
81  parsedCmd or args must be specified.
82  """
83  if parsedCmd is not None:
84  butler = parsedCmd.butler
85  elif args is not None:
86  dataRefList, kwargs = args
87  butler = dataRefList[0].butlerSubset.butler
88  else:
89  raise RuntimeError("parsedCmd or args must be specified")
90  return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reuse)
91 
92 
93 def unpickle(factory, args, kwargs):
94  """Unpickle something by calling a factory"""
95  return factory(*args, **kwargs)
96 
97 
99  """Multi-node driver for multiband processing"""
100  ConfigClass = MultiBandDriverConfig
101  _DefaultName = "multiBandDriver"
102  RunnerClass = MultiBandDriverTaskRunner
103 
104  def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
105  """!
106  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
107  in case it is needed.
108  @param[in] schema: the schema of the source detection catalog used as input.
109  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
110  catalog. May be None if the butler argument is provided or all steps requiring a reference
111  catalog are disabled.
112  """
113  BatchPoolTask.__init__(self, **kwargs)
114  if schema is None:
115  assert butler is not None, "Butler not provided"
116  schema = butler.get(self.config.coaddName +
117  "Coadd_det_schema", immediate=True).schema
118  self.butler = butler
119  self.reuse = tuple(reuse)
120  self.makeSubtask("detectCoaddSources")
121  self.makeSubtask("mergeCoaddDetections", schema=schema)
122  if self.config.measureCoaddSources.inputCatalog.startswith("deblended"):
123  # Ensure that the output from deblendCoaddSources matches the input to measureCoaddSources
124  self.measurementInput = self.config.measureCoaddSources.inputCatalog
125  self.deblenderOutput = []
126  if self.config.deblendCoaddSources.simultaneous:
127  self.deblenderOutput.append("deblendedModel")
128  else:
129  self.deblenderOutput.append("deblendedFlux")
130  if self.measurementInput not in self.deblenderOutput:
131  err = "Measurement input '{0}' is not in the list of deblender output catalogs '{1}'"
132  raise ValueError(err.format(self.measurementInput, self.deblenderOutput))
133 
134  self.makeSubtask("deblendCoaddSources",
135  schema=afwTable.Schema(self.mergeCoaddDetections.schema),
136  peakSchema=afwTable.Schema(self.mergeCoaddDetections.merged.getPeakSchema()),
137  butler=butler)
138  measureInputSchema = afwTable.Schema(self.deblendCoaddSources.schema)
139  else:
140  measureInputSchema = afwTable.Schema(self.mergeCoaddDetections.schema)
141  self.makeSubtask("measureCoaddSources", schema=measureInputSchema,
142  peakSchema=afwTable.Schema(
143  self.mergeCoaddDetections.merged.getPeakSchema()),
144  refObjLoader=refObjLoader, butler=butler)
145  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
146  self.measureCoaddSources.schema))
147  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
148  self.mergeCoaddMeasurements.schema))
149  if self.config.hasFakes:
150  self.coaddType = "fakes_" + self.config.coaddName
151  else:
152  self.coaddType = self.config.coaddName
153 
154  def __reduce__(self):
155  """Pickler"""
156  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
157  parentTask=self._parentTask, log=self.log,
158  butler=self.butler, reuse=self.reuse))
159 
160  @classmethod
161  def _makeArgumentParser(cls, *args, **kwargs):
162  kwargs.pop("doBatch", False)
163  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
164  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
165  ContainerClass=TractDataIdContainer)
166  parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
167  "mergeCoaddMeasurements", "forcedPhotCoadd", "deblendCoaddSources"])
168  return parser
169 
170  @classmethod
171  def batchWallTime(cls, time, parsedCmd, numCpus):
172  """!Return walltime request for batch job
173 
174  @param time: Requested time per iteration
175  @param parsedCmd: Results of argument parsing
176  @param numCores: Number of cores
177  """
178  numTargets = 0
179  for refList in parsedCmd.id.refList:
180  numTargets += len(refList)
181  return time*numTargets/float(numCpus)
182 
183  @abortOnError
184  def runDataRef(self, patchRefList):
185  """!Run multiband processing on coadds
186 
187  Only the master node runs this method.
188 
189  No real MPI communication (scatter/gather) takes place: all I/O goes
190  through the disk. We want the intermediate stages on disk, and the
191  component Tasks are implemented around this, so we just follow suit.
192 
193  @param patchRefList: Data references to run measurement
194  """
195  for patchRef in patchRefList:
196  if patchRef:
197  butler = patchRef.getButler()
198  break
199  else:
200  raise RuntimeError("No valid patches")
201  pool = Pool("all")
202  pool.cacheClear()
203  pool.storeSet(butler=butler)
204  # MultiBand measurements require that the detection stage be completed
205  # before measurements can be made.
206  #
207  # The configuration for coaddDriver.py allows detection to be turned
208  # of in the event that fake objects are to be added during the
209  # detection process. This allows the long co-addition process to be
210  # run once, and multiple different MultiBand reruns (with different
211  # fake objects) to exist from the same base co-addition.
212  #
213  # However, we only re-run detection if doDetection is explicitly True
214  # here (this should always be the opposite of coaddDriver.doDetection);
215  # otherwise we have no way to tell reliably whether any detections
216  # present in an input repo are safe to use.
217  if self.config.doDetection:
218  detectionList = []
219  for patchRef in patchRefList:
220  if ("detectCoaddSources" in self.reuse and
221  patchRef.datasetExists(self.coaddType + "Coadd_calexp", write=True)):
222  self.log.info("Skipping detectCoaddSources for %s; output already exists." %
223  patchRef.dataId)
224  continue
225  if not patchRef.datasetExists(self.coaddType + "Coadd"):
226  self.log.debug("Not processing %s; required input %sCoadd missing." %
227  (patchRef.dataId, self.config.coaddName))
228  continue
229  detectionList.append(patchRef)
230 
231  pool.map(self.runDetection, detectionList)
232 
233  patchRefList = [patchRef for patchRef in patchRefList if
234  patchRef.datasetExists(self.coaddType + "Coadd_calexp") and
235  patchRef.datasetExists(self.config.coaddName + "Coadd_det",
236  write=self.config.doDetection)]
237  dataIdList = [patchRef.dataId for patchRef in patchRefList]
238 
239  # Group by patch
240  patches = {}
241  tract = None
242  for patchRef in patchRefList:
243  dataId = patchRef.dataId
244  if tract is None:
245  tract = dataId["tract"]
246  else:
247  assert tract == dataId["tract"]
248 
249  patch = dataId["patch"]
250  if patch not in patches:
251  patches[patch] = []
252  patches[patch].append(dataId)
253 
254  pool.map(self.runMergeDetections, patches.values())
255 
256  # Deblend merged detections, and test for reprocessing
257  #
258  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
259  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
260  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
261  # they may have astronomically interesting data, we want the ability to go back and reprocess them
262  # with a more permissive configuration when we have more memory or processing time.
263  #
264  # self.runDeblendMerged will return whether there are any footprints in that image that required
265  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
266  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
267  # a particular patch.
268  #
269  # This determination of which patches need to be reprocessed exists only in memory (the measurements
270  # have been written, clobbering the old ones), so if there was an exception we would lose this
271  # information, leaving things in an inconsistent state (measurements, merged measurements and
272  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
273  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the measurements,
274  # merge and forced photometry.
275  #
276  # This is, hopefully, a temporary workaround until we can improve the
277  # deblender.
278  try:
279  reprocessed = pool.map(self.runDeblendMerged, patches.values())
280  finally:
281  if self.config.reprocessing:
282  patchReprocessing = {}
283  for dataId, reprocess in zip(dataIdList, reprocessed):
284  patchId = dataId["patch"]
285  patchReprocessing[patchId] = patchReprocessing.get(
286  patchId, False) or reprocess
287  # Persist the determination, to make error recover easier
288  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
289  for patchId in patchReprocessing:
290  if not patchReprocessing[patchId]:
291  continue
292  dataId = dict(tract=tract, patch=patchId)
293  if patchReprocessing[patchId]:
294  filename = butler.get(
295  reprocessDataset + "_filename", dataId)[0]
296  open(filename, 'a').close() # Touch file
297  elif butler.datasetExists(reprocessDataset, dataId):
298  # We must have failed at some point while reprocessing
299  # and we're starting over
300  patchReprocessing[patchId] = True
301 
302  # Only process patches that have been identifiedz as needing it
303  pool.map(self.runMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
304  patchReprocessing[dataId1["patch"]]])
305  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
306  not self.config.reprocessing or patchReprocessing[patchId]])
307  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
308  patchReprocessing[dataId1["patch"]]])
309 
310  # Remove persisted reprocessing determination
311  if self.config.reprocessing:
312  for patchId in patchReprocessing:
313  if not patchReprocessing[patchId]:
314  continue
315  dataId = dict(tract=tract, patch=patchId)
316  filename = butler.get(
317  reprocessDataset + "_filename", dataId)[0]
318  os.unlink(filename)
319 
320  def runDetection(self, cache, patchRef):
321  """! Run detection on a patch
322 
323  Only slave nodes execute this method.
324 
325  @param cache: Pool cache, containing butler
326  @param patchRef: Patch on which to do detection
327  """
328  with self.logOperation("do detections on {}".format(patchRef.dataId)):
329  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
330  coadd = patchRef.get(self.coaddType + "Coadd", immediate=True)
331  expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
332  self.detectCoaddSources.emptyMetadata()
333  detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
334  self.detectCoaddSources.write(detResults, patchRef)
335  self.detectCoaddSources.writeMetadata(patchRef)
336 
337  def runMergeDetections(self, cache, dataIdList):
338  """!Run detection merging on a patch
339 
340  Only slave nodes execute this method.
341 
342  @param cache: Pool cache, containing butler
343  @param dataIdList: List of data identifiers for the patch in different filters
344  """
345  with self.logOperation("merge detections from %s" % (dataIdList,)):
346  dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
347  dataId in dataIdList]
348  if ("mergeCoaddDetections" in self.reuse and
349  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
350  self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
351  dataRefList[0].dataId)
352  return
353  self.mergeCoaddDetections.runDataRef(dataRefList)
354 
355  def runDeblendMerged(self, cache, dataIdList):
356  """Run the deblender on a list of dataId's
357 
358  Only slave nodes execute this method.
359 
360  Parameters
361  ----------
362  cache: Pool cache
363  Pool cache with butler.
364  dataIdList: list
365  Data identifier for patch in each band.
366 
367  Returns
368  -------
369  result: bool
370  whether the patch requires reprocessing.
371  """
372  with self.logOperation("deblending %s" % (dataIdList,)):
373  dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
374  dataId in dataIdList]
375  reprocessing = False # Does this patch require reprocessing?
376  if ("deblendCoaddSources" in self.reuse and
377  all([dataRef.datasetExists(self.config.coaddName + "Coadd_" + self.measurementInput,
378  write=True) for dataRef in dataRefList])):
379  if not self.config.reprocessing:
380  self.log.info("Skipping deblendCoaddSources for %s; output already exists" % dataIdList)
381  return False
382 
383  # Footprints are the same every band, therefore we can check just one
384  catalog = dataRefList[0].get(self.config.coaddName + "Coadd_" + self.measurementInput)
385  bigFlag = catalog["deblend_parentTooBig"]
386  # Footprints marked too large by the previous deblender run
387  numOldBig = bigFlag.sum()
388  if numOldBig == 0:
389  self.log.info("No large footprints in %s" % (dataRefList[0].dataId))
390  return False
391 
392  # This if-statement can be removed after DM-15662
393  if self.config.deblendCoaddSources.simultaneous:
394  deblender = self.deblendCoaddSources.multiBandDeblend
395  else:
396  deblender = self.deblendCoaddSources.singleBandDeblend
397 
398  # isLargeFootprint() can potentially return False for a source that is marked
399  # too big in the catalog, because of "new"/different deblender configs.
400  # numNewBig is the number of footprints that *will* be too big if reprocessed
401  numNewBig = sum((deblender.isLargeFootprint(src.getFootprint()) for
402  src in catalog[bigFlag]))
403  if numNewBig == numOldBig:
404  self.log.info("All %d formerly large footprints continue to be large in %s" %
405  (numOldBig, dataRefList[0].dataId,))
406  return False
407  self.log.info("Found %d large footprints to be reprocessed in %s" %
408  (numOldBig - numNewBig, [dataRef.dataId for dataRef in dataRefList]))
409  reprocessing = True
410 
411  self.deblendCoaddSources.runDataRef(dataRefList)
412  return reprocessing
413 
414  def runMeasurements(self, cache, dataId):
415  """Run measurement on a patch for a single filter
416 
417  Only slave nodes execute this method.
418 
419  Parameters
420  ----------
421  cache: Pool cache
422  Pool cache, with butler
423  dataId: dataRef
424  Data identifier for patch
425  """
426  with self.logOperation("measurements on %s" % (dataId,)):
427  dataRef = getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp")
428  if ("measureCoaddSources" in self.reuse and
429  not self.config.reprocessing and
430  dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
431  self.log.info("Skipping measuretCoaddSources for %s; output already exists" % dataId)
432  return
433  self.measureCoaddSources.runDataRef(dataRef)
434 
435  def runMergeMeasurements(self, cache, dataIdList):
436  """!Run measurement merging on a patch
437 
438  Only slave nodes execute this method.
439 
440  @param cache: Pool cache, containing butler
441  @param dataIdList: List of data identifiers for the patch in different filters
442  """
443  with self.logOperation("merge measurements from %s" % (dataIdList,)):
444  dataRefList = [getDataRef(cache.butler, dataId, self.coaddType + "Coadd_calexp") for
445  dataId in dataIdList]
446  if ("mergeCoaddMeasurements" in self.reuse and
447  not self.config.reprocessing and
448  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
449  self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
450  dataRefList[0].dataId)
451  return
452  self.mergeCoaddMeasurements.runDataRef(dataRefList)
453 
454  def runForcedPhot(self, cache, dataId):
455  """!Run forced photometry on a patch for a single filter
456 
457  Only slave nodes execute this method.
458 
459  @param cache: Pool cache, with butler
460  @param dataId: Data identifier for patch
461  """
462  with self.logOperation("forced photometry on %s" % (dataId,)):
463  dataRef = getDataRef(cache.butler, dataId,
464  self.coaddType + "Coadd_calexp")
465  if ("forcedPhotCoadd" in self.reuse and
466  not self.config.reprocessing and
467  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
468  self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
469  return
470  self.forcedPhotCoadd.runDataRef(dataRef)
471 
472  def writeMetadata(self, dataRef):
473  """We don't collect any metadata, so skip"""
474  pass
Defines the fields and offsets for a table.
Definition: Schema.h:50
def makeSubtask(self, name, keyArgs)
Definition: task.py:275
def emptyMetadata(self)
Definition: task.py:153
def unpickle(factory, args, kwargs)
def runDataRef(self, patchRefList)
Run multiband processing on coadds.
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), kwargs)
std::shared_ptr< FrameSet > append(FrameSet const &first, FrameSet const &second)
Construct a FrameSet that performs two transformations in series.
Definition: functional.cc:33
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
bool all(CoordinateExpr< N > const &expr) noexcept
Return true if all elements are true.
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def format(config, name=None, writeSourceLine=True, prefix="", verbose=False)
Definition: history.py:168
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def logOperation(self, operation, catch=False, trace=True)
Provide a context manager for logging an operation.
Definition: parallel.py:497
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.