Loading [MathJax]/extensions/tex2jax.js
LSST Applications g0fba68d861+05816baf74,g1ec0fe41b4+f536777771,g1fd858c14a+a9301854fb,g35bb328faa+fcb1d3bbc8,g4af146b050+a5c07d5b1d,g4d2262a081+6e5fcc2a4e,g53246c7159+fcb1d3bbc8,g56a49b3a55+9c12191793,g5a012ec0e7+3632fc3ff3,g60b5630c4e+ded28b650d,g67b6fd64d1+ed4b5058f4,g78460c75b0+2f9a1b4bcd,g786e29fd12+cf7ec2a62a,g8352419a5c+fcb1d3bbc8,g87b7deb4dc+7b42cf88bf,g8852436030+e5453db6e6,g89139ef638+ed4b5058f4,g8e3bb8577d+d38d73bdbd,g9125e01d80+fcb1d3bbc8,g94187f82dc+ded28b650d,g989de1cb63+ed4b5058f4,g9d31334357+ded28b650d,g9f33ca652e+50a8019d8c,gabe3b4be73+1e0a283bba,gabf8522325+fa80ff7197,gb1101e3267+d9fb1f8026,gb58c049af0+f03b321e39,gb665e3612d+2a0c9e9e84,gb89ab40317+ed4b5058f4,gcf25f946ba+e5453db6e6,gd6cbbdb0b4+bb83cc51f8,gdd1046aedd+ded28b650d,gde0f65d7ad+941d412827,ge278dab8ac+d65b3c2b70,ge410e46f29+ed4b5058f4,gf23fb2af72+b7cae620c0,gf5e32f922b+fcb1d3bbc8,gf67bdafdda+ed4b5058f4,w.2025.16
LSST Data Management Base Package
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
fit_coadd_multiband.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = [
23 "CoaddMultibandFitConfig", "CoaddMultibandFitConnections", "CoaddMultibandFitSubConfig",
24 "CoaddMultibandFitSubTask", "CoaddMultibandFitTask",
25]
26
27from .fit_multiband import CatalogExposure, CatalogExposureConfig
28
29import lsst.afw.table as afwTable
30from lsst.meas.base import SkyMapIdGeneratorConfig
31from lsst.meas.extensions.scarlet.io import updateCatalogFootprints
32import lsst.pex.config as pexConfig
33import lsst.pipe.base as pipeBase
34import lsst.pipe.base.connectionTypes as cT
35
36import astropy.table
37from abc import ABC, abstractmethod
38from pydantic import Field
39from pydantic.dataclasses import dataclass
40from typing import Iterable
41
42CoaddMultibandFitBaseTemplates = {
43 "name_coadd": "deep",
44 "name_method": "multiprofit",
45 "name_table": "objects",
46}
47
48
49@dataclass(frozen=True, kw_only=True, config=CatalogExposureConfig)
51 table_psf_fits: astropy.table.Table = Field(title="A table of PSF fit parameters for each source")
52
53 def get_catalog(self):
54 return self.catalog
55
56
58 pipeBase.PipelineTaskConnections,
59 dimensions=("tract", "patch", "skymap"),
60 defaultTemplates=CoaddMultibandFitBaseTemplates,
61):
62 cat_ref = cT.Input(
63 doc="Reference multiband source catalog",
64 name="{name_coadd}Coadd_ref",
65 storageClass="SourceCatalog",
66 dimensions=("tract", "patch", "skymap"),
67 )
68 cats_meas = cT.Input(
69 doc="Deblended single-band source catalogs",
70 name="{name_coadd}Coadd_meas",
71 storageClass="SourceCatalog",
72 dimensions=("tract", "patch", "band", "skymap"),
73 multiple=True,
74 )
75 coadds = cT.Input(
76 doc="Exposures on which to run fits",
77 name="{name_coadd}Coadd_calexp",
78 storageClass="ExposureF",
79 dimensions=("tract", "patch", "band", "skymap"),
80 multiple=True,
81 )
82 models_psf = cT.Input(
83 doc="Input PSF model parameter catalog",
84 # Consider allowing independent psf fit method
85 name="{name_coadd}Coadd_psfs_{name_method}",
86 storageClass="ArrowAstropy",
87 dimensions=("tract", "patch", "band", "skymap"),
88 multiple=True,
89 )
90 models_scarlet = pipeBase.connectionTypes.Input(
91 doc="Multiband scarlet models produced by the deblender",
92 name="{name_coadd}Coadd_scarletModelData",
93 storageClass="ScarletModelData",
94 dimensions=("tract", "patch", "skymap"),
95 )
96
97 def adjustQuantum(self, inputs, outputs, label, data_id):
98 """Validates the `lsst.daf.butler.DatasetRef` bands against the
99 subtask's list of bands to fit and drops unnecessary bands.
100
101 Parameters
102 ----------
103 inputs : `dict`
104 Dictionary whose keys are an input (regular or prerequisite)
105 connection name and whose values are a tuple of the connection
106 instance and a collection of associated `DatasetRef` objects.
107 The exact type of the nested collections is unspecified; it can be
108 assumed to be multi-pass iterable and support `len` and ``in``, but
109 it should not be mutated in place. In contrast, the outer
110 dictionaries are guaranteed to be temporary copies that are true
111 `dict` instances, and hence may be modified and even returned; this
112 is especially useful for delegating to `super` (see notes below).
113 outputs : `Mapping`
114 Mapping of output datasets, with the same structure as ``inputs``.
115 label : `str`
116 Label for this task in the pipeline (should be used in all
117 diagnostic messages).
118 data_id : `lsst.daf.butler.DataCoordinate`
119 Data ID for this quantum in the pipeline (should be used in all
120 diagnostic messages).
121
122 Returns
123 -------
124 adjusted_inputs : `Mapping`
125 Mapping of the same form as ``inputs`` with updated containers of
126 input `DatasetRef` objects. All inputs involving the 'band'
127 dimension are adjusted to put them in consistent order and remove
128 unneeded bands.
129 adjusted_outputs : `Mapping`
130 Mapping of updated output datasets; always empty for this task.
131
132 Raises
133 ------
134 lsst.pipe.base.NoWorkFound
135 Raised if there are not enough of the right bands to run the task
136 on this quantum.
137 """
138 # Check which bands are going to be fit
139 bands_fit, bands_read_only = self.config.get_band_sets()
140 bands_needed = bands_fit + [band for band in bands_read_only if band not in bands_fit]
141 bands_needed_set = set(bands_needed)
142
143 adjusted_inputs = {}
144 bands_found, connection_first = None, None
145 for connection_name, (connection, dataset_refs) in inputs.items():
146 # Datasets without bands in their dimensions should be fine
147 if 'band' in connection.dimensions:
148 datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs}
149 bands_set = set(datasets_by_band.keys())
150 if self.config.allow_missing_bands:
151 # Use the first dataset found as the reference since all
152 # dataset types with band should have the same bands
153 # This will only break if one of the calexp/meas datasets
154 # is missing from a given band, which would surely be an
155 # upstream problem anyway
156 if bands_found is None:
157 bands_found, connection_first = bands_set, connection_name
158 if len(bands_found) == 0:
159 raise pipeBase.NoWorkFound(
160 f'DatasetRefs={dataset_refs} for {connection_name=} is empty'
161 )
162 elif not set(bands_read_only).issubset(bands_set):
163 raise pipeBase.NoWorkFound(
164 f'DatasetRefs={dataset_refs} has {bands_set=} which is missing at least one'
165 f' of {bands_read_only=}'
166 )
167 # Put the bands to fit first, then any other bands
168 # needed for initialization/priors only last
169 bands_needed = [band for band in bands_fit if band in bands_found] + [
170 band for band in bands_read_only if band not in bands_found
171 ]
172 elif bands_found != bands_set:
173 raise RuntimeError(
174 f'DatasetRefs={dataset_refs} with {connection_name=} has {bands_set=} !='
175 f' {bands_found=} from {connection_first=}'
176 )
177 # All configured bands are treated as necessary
178 elif not bands_needed_set.issubset(bands_set):
179 raise pipeBase.NoWorkFound(
180 f'DatasetRefs={dataset_refs} have data with bands in the'
181 f' set={set(datasets_by_band.keys())},'
182 f' which is not a superset of the required bands={bands_needed} defined by'
183 f' {self.config.__class__}.fit_coadd_multiband='
184 f'{self.config.fit_coadd_multiband._value.__class__}\'s attributes'
185 f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.'
186 f' Add the required bands={set(bands_needed).difference(datasets_by_band.keys())}.'
187 )
188 # Adjust all datasets with band dimensions to include just
189 # the needed bands, in consistent order.
190 adjusted_inputs[connection_name] = (
191 connection,
192 [datasets_by_band[band] for band in bands_needed]
193 )
194
195 # Delegate to super for more checks.
196 inputs.update(adjusted_inputs)
197 super().adjustQuantum(inputs, outputs, label, data_id)
198 return adjusted_inputs, {}
199
200 def __init__(self, *, config=None):
201 if config.drop_psf_connection:
202 del self.models_psf
203
204
206 cat_output = cT.Output(
207 doc="Output source model fit parameter catalog",
208 name="{name_coadd}Coadd_{name_table}_{name_method}",
209 storageClass="ArrowTable",
210 dimensions=("tract", "patch", "skymap"),
211 )
212
213
214class CoaddMultibandFitSubConfig(pexConfig.Config):
215 """Configuration for implementing fitter subtasks.
216 """
217
218 bands_fit = pexConfig.ListField[str](
219 default=[],
220 doc="list of bandpass filters to fit",
221 listCheck=lambda x: (len(x) > 0) and (len(set(x)) == len(x)),
222 )
223
224 @abstractmethod
225 def bands_read_only(self) -> set:
226 """Return the set of bands that the Task needs to read (e.g. for
227 defining priors) but not necessarily fit.
228
229 Returns
230 -------
231 The set of such bands.
232 """
233
234
235class CoaddMultibandFitSubTask(pipeBase.Task, ABC):
236 """Subtask interface for multiband fitting of deblended sources.
237
238 Parameters
239 ----------
240 **kwargs
241 Additional arguments to be passed to the `lsst.pipe.base.Task`
242 constructor.
243 """
244 ConfigClass = CoaddMultibandFitSubConfig
245
246 def __init__(self, **kwargs):
247 super().__init__(**kwargs)
248
249 @abstractmethod
250 def run(
251 self, catexps: Iterable[CatalogExposureInputs], cat_ref: afwTable.SourceCatalog
252 ) -> pipeBase.Struct:
253 """Fit models to deblended sources from multi-band inputs.
254
255 Parameters
256 ----------
257 catexps : `typing.List [CatalogExposureInputs]`
258 A list of catalog-exposure pairs with metadata in a given band.
259 cat_ref : `lsst.afw.table.SourceCatalog`
260 A reference source catalog to fit.
261
262 Returns
263 -------
264 retStruct : `lsst.pipe.base.Struct`
265 A struct with a cat_output attribute containing the output
266 measurement catalog.
267
268 Notes
269 -----
270 Subclasses may have further requirements on the input parameters,
271 including:
272 - Passing only one catexp per band;
273 - Catalogs containing HeavyFootprints with deblended images;
274 - Fitting only a subset of the sources.
275 If any requirements are not met, the subtask should fail as soon as
276 possible.
277 """
278
279
281 pipeBase.PipelineTaskConfig,
282 pipelineConnections=CoaddMultibandFitInputConnections,
283):
284 """Base class for multiband fitting."""
285
286 allow_missing_bands = pexConfig.Field[bool](
287 doc="Whether to still fit even if some bands are missing",
288 default=True,
289 )
290 drop_psf_connection = pexConfig.Field[bool](
291 doc="Whether to drop the PSF model connection, e.g. because PSF parameters are in the input catalog",
292 default=False,
293 )
294 fit_coadd_multiband = pexConfig.ConfigurableField(
295 target=CoaddMultibandFitSubTask,
296 doc="Task to fit sources using multiple bands",
297 )
298 idGenerator = SkyMapIdGeneratorConfig.make_field()
299
300 def get_band_sets(self):
301 """Get the set of bands required by the fit_coadd_multiband subtask.
302
303 Returns
304 -------
305 bands_fit : `set`
306 The set of bands that the subtask will fit.
307 bands_read_only : `set`
308 The set of bands that the subtask will only read data
309 (measurement catalog and exposure) for.
310 """
311 try:
312 bands_fit = self.fit_coadd_multiband.bands_fit
313 except AttributeError:
314 raise RuntimeError(f'{__class__}.fit_coadd_multiband must have bands_fit attribute') from None
315 bands_read_only = self.fit_coadd_multiband.bands_read_only()
316 return tuple(list({band: None for band in bands}.keys()) for bands in (bands_fit, bands_read_only))
317
318
320 CoaddMultibandFitBaseConfig,
321 pipelineConnections=CoaddMultibandFitConnections,
322):
323 """Configuration for a CoaddMultibandFitTask."""
324
325
327 """Base class for tasks that fit or rebuild multiband models.
328
329 This class only implements data reconstruction.
330 """
331
332 def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInputs]:
333 id_tp = self.config.idGenerator.apply(butlerQC.quantum.dataId).catalog_id
334 # This is a roundabout way of ensuring all inputs get sorted and matched
335 keys = ["cats_meas", "coadds"]
336 has_psf_models = "models_psf" in inputs
337 if has_psf_models:
338 keys.append("models_psf")
339 input_refs_objs = ((getattr(inputRefs, key), inputs[key]) for key in keys)
340 inputs_sorted = tuple(
341 {dRef.dataId: obj for dRef, obj in zip(refs, objs)}
342 for refs, objs in input_refs_objs
343 )
344 cats = inputs_sorted[0]
345 exps = inputs_sorted[1]
346 models_psf = inputs_sorted[2] if has_psf_models else None
347 dataIds = set(cats).union(set(exps))
348 models_scarlet = inputs["models_scarlet"]
349 catexp_dict = {}
350 dataId = None
351 for dataId in dataIds:
352 catalog = cats[dataId]
353 exposure = exps[dataId]
354 updateCatalogFootprints(
355 modelData=models_scarlet,
356 catalog=catalog,
357 band=dataId['band'],
358 imageForRedistribution=exposure,
359 removeScarletData=False,
360 updateFluxColumns=False,
361 )
362 catexp_dict[dataId['band']] = CatalogExposureInputs(
363 catalog=catalog,
364 exposure=exposure,
365 table_psf_fits=models_psf[dataId] if has_psf_models else astropy.table.Table(),
366 dataId=dataId,
367 id_tract_patch=id_tp,
368 )
369 # This shouldn't happen unless this is called with no inputs, but check anyway
370 if dataId is None:
371 raise RuntimeError(f"Did not build any catexps for {inputRefs=}")
372 catexps = []
373 for band in self.config.get_band_sets()[0]:
374 if band in catexp_dict:
375 catexp = catexp_dict[band]
376 else:
377 # Make a dummy catexp with a dataId if there's no data
378 # This should be handled by any subtasks
379 dataId_band = dataId.to_simple(minimal=True)
380 dataId_band.dataId["band"] = band
381 catexp = CatalogExposureInputs(
382 catalog=afwTable.SourceCatalog(),
383 exposure=None,
384 table_psf_fits=astropy.table.Table(),
385 dataId=dataId.from_simple(dataId_band, universe=dataId.universe),
386 id_tract_patch=id_tp,
387 )
388 catexps.append(catexp)
389 return catexps
390
391
392class CoaddMultibandFitTask(CoaddMultibandFitBase, pipeBase.PipelineTask):
393 """Fit deblended exposures in multiple bands simultaneously.
394
395 It is generally assumed but not enforced (except optionally by the
396 configurable `fit_coadd_multiband` subtask) that there is only one exposure
397 per band, presumably a coadd.
398 """
399
400 ConfigClass = CoaddMultibandFitConfig
401 _DefaultName = "coaddMultibandFit"
402
403 def __init__(self, initInputs, **kwargs):
404 super().__init__(initInputs=initInputs, **kwargs)
405 self.makeSubtask("fit_coadd_multiband")
406
407 def make_kwargs(self, butlerQC, inputRefs, inputs):
408 """Make any kwargs needed to be passed to run.
409
410 This method should be overloaded by subclasses that are configured to
411 use a specific subtask that needs additional arguments derived from
412 the inputs but do not otherwise need to overload runQuantum."""
413 return {}
414
415 def runQuantum(self, butlerQC, inputRefs, outputRefs):
416 inputs = butlerQC.get(inputRefs)
417 catexps = self.build_catexps(butlerQC, inputRefs, inputs)
418 if not self.config.allow_missing_bands and any([catexp is None for catexp in catexps]):
419 raise RuntimeError(
420 f"Got a None catexp with {self.config.allow_missing_band=}; NoWorkFound should have been"
421 f" raised earlier"
422 )
423 kwargs = self.make_kwargs(butlerQC, inputRefs, inputs)
424 outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref'], **kwargs)
425 butlerQC.put(outputs, outputRefs)
426
427 def run(
428 self,
429 catexps: list[CatalogExposure],
430 cat_ref: afwTable.SourceCatalog,
431 **kwargs
432 ) -> pipeBase.Struct:
433 """Fit sources from a reference catalog using data from multiple
434 exposures in the same region (patch).
435
436 Parameters
437 ----------
438 catexps : `typing.List [CatalogExposure]`
439 A list of catalog-exposure pairs in a given band.
440 cat_ref : `lsst.afw.table.SourceCatalog`
441 A reference source catalog to fit.
442
443 Returns
444 -------
445 retStruct : `lsst.pipe.base.Struct`
446 A struct with a cat_output attribute containing the output
447 measurement catalog.
448
449 Notes
450 -----
451 Subtasks may have further requirements; see `CoaddMultibandFitSubTask.run`.
452 """
453 cat_output = self.fit_coadd_multiband.run(catalog_multi=cat_ref, catexps=catexps, **kwargs).output
454 retStruct = pipeBase.Struct(cat_output=cat_output)
455 return retStruct
list[CatalogExposureInputs] build_catexps(self, butlerQC, inputRefs, inputs)
pipeBase.Struct run(self, Iterable[CatalogExposureInputs] catexps, afwTable.SourceCatalog cat_ref)
pipeBase.Struct run(self, list[CatalogExposure] catexps, afwTable.SourceCatalog cat_ref, **kwargs)