_DefaultName = "writeObjectTable"
ConfigClass = WriteObjectTableConfig
RunnerClass = MergeSourcesRunner
# Names of table datasets to be merged
inputDatasets = ('forced_src', 'meas', 'ref')
# Tag of output dataset written by `MergeSourcesTask.write`
outputDataset = 'obj'
def __init__(self, butler=None, schema=None, **kwargs):
# It is a shame that this class can't use the default init for
# CmdLineTask, but to do so would require its own special task
# runner, which is many more lines of specialization, so this is
# how it is for now.
super().__init__(**kwargs)
def runDataRef(self, patchRefList):
catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
dataId = patchRefList[0].dataId
mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
self.write(patchRefList[0], ParquetTable(dataFrame=mergedCatalog))
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
measDict = {ref.dataId['band']: {'meas': cat} for ref, cat in
zip(inputRefs.inputCatalogMeas, inputs['inputCatalogMeas'])}
forcedSourceDict = {ref.dataId['band']: {'forced_src': cat} for ref, cat in
zip(inputRefs.inputCatalogForcedSrc, inputs['inputCatalogForcedSrc'])}
catalogs = {}
for band in measDict.keys():
catalogs[band] = {'meas': measDict[band]['meas'],
'forced_src': forcedSourceDict[band]['forced_src'],
'ref': inputs['inputCatalogRef']}
dataId = butlerQC.quantum.dataId
df = self.run(catalogs=catalogs, tract=dataId['tract'], patch=dataId['patch'])
outputs = pipeBase.Struct(outputCatalog=df)
butlerQC.put(outputs, outputRefs)
@classmethod
def _makeArgumentParser(cls):
return makeMergeArgumentParser(cls._DefaultName, cls.inputDatasets[0])
def readCatalog(self, patchRef):
band = patchRef.get(self.config.coaddName + "Coadd_filter", immediate=True).bandLabel
catalogDict = {}
for dataset in self.inputDatasets:
catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
self.log.info("Read %d sources from %s for band %s: %s",
len(catalog), dataset, band, patchRef.dataId)
catalogDict[dataset] = catalog
return band, catalogDict
def run(self, catalogs, tract, patch):
dfs = []
for filt, tableDict in catalogs.items():
for dataset, table in tableDict.items():
# Convert afwTable to pandas DataFrame
df = table.asAstropy().to_pandas().set_index('id', drop=True)
# Sort columns by name, to ensure matching schema among patches
df = df.reindex(sorted(df.columns), axis=1)
df['tractId'] = tract
df['patchId'] = patch
# Make columns a 3-level MultiIndex
df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
names=('dataset', 'band', 'column'))
dfs.append(df)
catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
return catalog
def write(self, patchRef, catalog):
patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
# since the filter isn't actually part of the data ID for the dataset
# we're saving, it's confusing to see it in the log message, even if
# the butler simply ignores it.
mergeDataId = patchRef.dataId.copy()
del mergeDataId["filter"]
self.log.info("Wrote merged catalog: %s", mergeDataId)
def writeMetadata(self, dataRefList):
pass
class WriteSourceTableConnections(pipeBase.PipelineTaskConnections,
defaultTemplates={"catalogType": ""},
dimensions=("instrument", "visit", "detector")):
catalog = connectionTypes.Input(
doc="Input full-depth catalog of sources produced by CalibrateTask",
name="{catalogType}src",
storageClass="SourceCatalog",
dimensions=("instrument", "visit", "detector")
)
outputCatalog = connectionTypes.Output(
doc="Catalog of sources, `src` in Parquet format. The 'id' column is "
"replaced with an index; all other columns are unchanged.",
name="{catalogType}source",
storageClass="DataFrame",
dimensions=("instrument", "visit", "detector")
)
class WriteSourceTableConfig(pipeBase.PipelineTaskConfig,
pipelineConnections=WriteSourceTableConnections):
pass
class WriteSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
_DefaultName = "writeSourceTable"
ConfigClass = WriteSourceTableConfig
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
inputs['ccdVisitId'] = butlerQC.quantum.dataId.pack("visit_detector")
result = self.run(**inputs).table
outputs = pipeBase.Struct(outputCatalog=result.toDataFrame())
butlerQC.put(outputs, outputRefs)
def run(self, catalog, ccdVisitId=None, **kwargs):
self.log.info("Generating parquet table from src catalog ccdVisitId=%s", ccdVisitId)
df = catalog.asAstropy().to_pandas().set_index('id', drop=True)
df['ccdVisitId'] = ccdVisitId
return pipeBase.Struct(table=ParquetTable(dataFrame=df))
class WriteRecalibratedSourceTableConnections(WriteSourceTableConnections,
defaultTemplates={"catalogType": "",
"skyWcsName": "jointcal",
"photoCalibName": "fgcm"},
dimensions=("instrument", "visit", "detector", "skymap")):
skyMap = connectionTypes.Input(
doc="skyMap needed to choose which tract-level calibrations to use when multiple available",
name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
storageClass="SkyMap",
dimensions=("skymap",),
)
exposure = connectionTypes.Input(
doc="Input exposure to perform photometry on.",
name="calexp",
storageClass="ExposureF",
dimensions=["instrument", "visit", "detector"],
)
externalSkyWcsTractCatalog = connectionTypes.Input(
doc=("Per-tract, per-visit wcs calibrations. These catalogs use the detector "
"id for the catalog id, sorted on id for fast lookup."),
name="{skyWcsName}SkyWcsCatalog",
storageClass="ExposureCatalog",
dimensions=["instrument", "visit", "tract"],
multiple=True
)
externalSkyWcsGlobalCatalog = connectionTypes.Input(
doc=("Per-visit wcs calibrations computed globally (with no tract information). "
"These catalogs use the detector id for the catalog id, sorted on id for "
"fast lookup."),
name="{skyWcsName}SkyWcsCatalog",
storageClass="ExposureCatalog",
dimensions=["instrument", "visit"],
)
externalPhotoCalibTractCatalog = connectionTypes.Input(
doc=("Per-tract, per-visit photometric calibrations. These catalogs use the "
"detector id for the catalog id, sorted on id for fast lookup."),
name="{photoCalibName}PhotoCalibCatalog",
storageClass="ExposureCatalog",
dimensions=["instrument", "visit", "tract"],
multiple=True
)
externalPhotoCalibGlobalCatalog = connectionTypes.Input(
doc=("Per-visit photometric calibrations computed globally (with no tract "
"information). These catalogs use the detector id for the catalog id, "
"sorted on id for fast lookup."),
name="{photoCalibName}PhotoCalibCatalog",
storageClass="ExposureCatalog",
dimensions=["instrument", "visit"],
)
def __init__(self, *, config=None):
super().__init__(config=config)
# Same connection boilerplate as all other applications of
# Global/Tract calibrations
if config.doApplyExternalSkyWcs and config.doReevaluateSkyWcs:
if config.useGlobalExternalSkyWcs:
self.inputs.remove("externalSkyWcsTractCatalog")
else:
self.inputs.remove("externalSkyWcsGlobalCatalog")
else:
self.inputs.remove("externalSkyWcsTractCatalog")
self.inputs.remove("externalSkyWcsGlobalCatalog")
if config.doApplyExternalPhotoCalib and config.doReevaluatePhotoCalib:
if config.useGlobalExternalPhotoCalib:
self.inputs.remove("externalPhotoCalibTractCatalog")
else:
self.inputs.remove("externalPhotoCalibGlobalCatalog")
else:
self.inputs.remove("externalPhotoCalibTractCatalog")
self.inputs.remove("externalPhotoCalibGlobalCatalog")
class WriteRecalibratedSourceTableConfig(WriteSourceTableConfig,
pipelineConnections=WriteRecalibratedSourceTableConnections):
doReevaluatePhotoCalib = pexConfig.Field(
dtype=bool,
default=True,
doc=("Add or replace local photoCalib columns from either the calexp.photoCalib or jointcal/FGCM")
)
doReevaluateSkyWcs = pexConfig.Field(
dtype=bool,
default=True,
doc=("Add or replace local WCS columns from either the calexp.wcs or or jointcal")
)
doApplyExternalPhotoCalib = pexConfig.Field(
dtype=bool,
default=True,
doc=("Whether to apply external photometric calibration via an "
"`lsst.afw.image.PhotoCalib` object. Uses the "
"``externalPhotoCalibName`` field to determine which calibration "
"to load."),
)
doApplyExternalSkyWcs = pexConfig.Field(
dtype=bool,
default=True,
doc=("Whether to apply external astrometric calibration via an "
"`lsst.afw.geom.SkyWcs` object. Uses ``externalSkyWcsName`` "
"field to determine which calibration to load."),
)
useGlobalExternalPhotoCalib = pexConfig.Field(
dtype=bool,
default=True,
doc=("When using doApplyExternalPhotoCalib, use 'global' calibrations "
"that are not run per-tract. When False, use per-tract photometric "
"calibration files.")
)
useGlobalExternalSkyWcs = pexConfig.Field(
dtype=bool,
default=False,
doc=("When using doApplyExternalSkyWcs, use 'global' calibrations "
"that are not run per-tract. When False, use per-tract wcs "
"files.")
)
def validate(self):
super().validate()
if self.doApplyExternalSkyWcs and not self.doReevaluateSkyWcs:
log.warning("doApplyExternalSkyWcs=True but doReevaluateSkyWcs=False"
"External SkyWcs will not be read or evaluated.")
if self.doApplyExternalPhotoCalib and not self.doReevaluatePhotoCalib:
log.warning("doApplyExternalPhotoCalib=True but doReevaluatePhotoCalib=False."
"External PhotoCalib will not be read or evaluated.")
class WriteRecalibratedSourceTableTask(WriteSourceTableTask):
_DefaultName = "writeRecalibratedSourceTable"
ConfigClass = WriteRecalibratedSourceTableConfig
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
inputs['ccdVisitId'] = butlerQC.quantum.dataId.pack("visit_detector")
inputs['exposureIdInfo'] = ExposureIdInfo.fromDataId(butlerQC.quantum.dataId, "visit_detector")
if self.config.doReevaluatePhotoCalib or self.config.doReevaluateSkyWcs:
if self.config.doApplyExternalPhotoCalib or self.config.doApplyExternalSkyWcs:
inputs['exposure'] = self.attachCalibs(inputRefs, **inputs)
inputs['catalog'] = self.addCalibColumns(**inputs)
result = self.run(**inputs).table
outputs = pipeBase.Struct(outputCatalog=result.toDataFrame())
butlerQC.put(outputs, outputRefs)
def attachCalibs(self, inputRefs, skyMap, exposure, externalSkyWcsGlobalCatalog=None,
externalSkyWcsTractCatalog=None, externalPhotoCalibGlobalCatalog=None,
externalPhotoCalibTractCatalog=None, **kwargs):
if not self.config.doApplyExternalSkyWcs:
# Do not modify the exposure's SkyWcs
externalSkyWcsCatalog = None
elif self.config.useGlobalExternalSkyWcs:
# Use the global external SkyWcs
externalSkyWcsCatalog = externalSkyWcsGlobalCatalog
self.log.info('Applying global SkyWcs')
else:
# use tract-level external SkyWcs from the closest overlapping tract
inputRef = getattr(inputRefs, 'externalSkyWcsTractCatalog')
tracts = [ref.dataId['tract'] for ref in inputRef]
if len(tracts) == 1:
ind = 0
self.log.info('Applying tract-level SkyWcs from tract %s', tracts[ind])
else:
ind = self.getClosestTract(tracts, skyMap,
exposure.getBBox(), exposure.getWcs())
self.log.info('Multiple overlapping externalSkyWcsTractCatalogs found (%s). '
'Applying closest to detector center: tract=%s', str(tracts), tracts[ind])
externalSkyWcsCatalog = externalSkyWcsTractCatalog[ind]
if not self.config.doApplyExternalPhotoCalib:
# Do not modify the exposure's PhotoCalib
externalPhotoCalibCatalog = None
elif self.config.useGlobalExternalPhotoCalib:
# Use the global external PhotoCalib
externalPhotoCalibCatalog = externalPhotoCalibGlobalCatalog
self.log.info('Applying global PhotoCalib')
else:
# use tract-level external PhotoCalib from the closest overlapping tract
inputRef = getattr(inputRefs, 'externalPhotoCalibTractCatalog')
tracts = [ref.dataId['tract'] for ref in inputRef]
if len(tracts) == 1:
ind = 0
self.log.info('Applying tract-level PhotoCalib from tract %s', tracts[ind])
else:
ind = self.getClosestTract(tracts, skyMap,
exposure.getBBox(), exposure.getWcs())
self.log.info('Multiple overlapping externalPhotoCalibTractCatalogs found (%s). '
'Applying closest to detector center: tract=%s', str(tracts), tracts[ind])
externalPhotoCalibCatalog = externalPhotoCalibTractCatalog[ind]
return self.prepareCalibratedExposure(exposure, externalSkyWcsCatalog, externalPhotoCalibCatalog)
def getClosestTract(self, tracts, skyMap, bbox, wcs):
if len(tracts) == 1:
return 0
center = wcs.pixelToSky(bbox.getCenter())
sep = []
for tractId in tracts:
tract = skyMap[tractId]
tractCenter = tract.getWcs().pixelToSky(tract.getBBox().getCenter())
sep.append(center.separation(tractCenter))
return np.argmin(sep)
def prepareCalibratedExposure(self, exposure, externalSkyWcsCatalog=None, externalPhotoCalibCatalog=None):
Definition at line 588 of file postprocess.py.