LSST Applications g063fba187b+cac8b7c890,g0f08755f38+6aee506743,g1653933729+a8ce1bb630,g168dd56ebc+a8ce1bb630,g1a2382251a+b4475c5878,g1dcb35cd9c+8f9bc1652e,g20f6ffc8e0+6aee506743,g217e2c1bcf+73dee94bd0,g28da252d5a+1f19c529b9,g2bbee38e9b+3f2625acfc,g2bc492864f+3f2625acfc,g3156d2b45e+6e55a43351,g32e5bea42b+1bb94961c2,g347aa1857d+3f2625acfc,g35bb328faa+a8ce1bb630,g3a166c0a6a+3f2625acfc,g3e281a1b8c+c5dd892a6c,g3e8969e208+a8ce1bb630,g414038480c+5927e1bc1e,g41af890bb2+8a9e676b2a,g7af13505b9+809c143d88,g80478fca09+6ef8b1810f,g82479be7b0+f568feb641,g858d7b2824+6aee506743,g89c8672015+f4add4ffd5,g9125e01d80+a8ce1bb630,ga5288a1d22+2903d499ea,gb58c049af0+d64f4d3760,gc28159a63d+3f2625acfc,gcab2d0539d+b12535109e,gcf0d15dbbd+46a3f46ba9,gda6a2b7d83+46a3f46ba9,gdaeeff99f8+1711a396fd,ge79ae78c31+3f2625acfc,gef2f8181fd+0a71e47438,gf0baf85859+c1f95f4921,gfa517265be+6aee506743,gfa999e8aa5+17cd334064,w.2024.51
LSST Data Management Base Package
Loading...
Searching...
No Matches
diff_matched_tract_catalog.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = [
23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig',
24 'MatchType', 'MeasurementType', 'SourceType',
25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile',
26]
27
28import lsst.afw.geom as afwGeom
30 ComparableCatalog, ConvertCatalogCoordinatesConfig,
31)
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as cT
36from lsst.skymap import BaseSkyMap
37
38from abc import ABCMeta, abstractmethod
39from astropy.stats import mad_std
40import astropy.table
41import astropy.units as u
42from dataclasses import dataclass
43from decimal import Decimal
44from deprecated.sphinx import deprecated
45from enum import Enum
46import numpy as np
47import pandas as pd
48from scipy.stats import iqr
49from smatch.matcher import sphdist
50from types import SimpleNamespace
51from typing import Sequence
52import warnings
53
54
55def is_sequence_set(x: Sequence):
56 return len(x) == len(set(x))
57
58
59@deprecated(reason="This method is no longer being used and will be removed after v28.",
60 version="v28.0", category=FutureWarning)
61def is_percentile(x: str):
62 return 0 <= Decimal(x) <= 100
63
64
65DiffMatchedTractCatalogBaseTemplates = {
66 "name_input_cat_ref": "truth_summary",
67 "name_input_cat_target": "objectTable_tract",
68 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
69}
70
71
73 pipeBase.PipelineTaskConnections,
74 dimensions=("tract", "skymap"),
75 defaultTemplates=DiffMatchedTractCatalogBaseTemplates,
76):
77 cat_ref = cT.Input(
78 doc="Reference object catalog to match from",
79 name="{name_input_cat_ref}",
80 storageClass="ArrowAstropy",
81 dimensions=("tract", "skymap"),
82 deferLoad=True,
83 )
84 cat_target = cT.Input(
85 doc="Target object catalog to match",
86 name="{name_input_cat_target}",
87 storageClass="ArrowAstropy",
88 dimensions=("tract", "skymap"),
89 deferLoad=True,
90 )
91 skymap = cT.Input(
92 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
93 name="{name_skymap}",
94 storageClass="SkyMap",
95 dimensions=("skymap",),
96 )
97 cat_match_ref = cT.Input(
98 doc="Reference match catalog with indices of target matches",
99 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
100 storageClass="ArrowAstropy",
101 dimensions=("tract", "skymap"),
102 deferLoad=True,
103 )
104 cat_match_target = cT.Input(
105 doc="Target match catalog with indices of references matches",
106 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
107 storageClass="ArrowAstropy",
108 dimensions=("tract", "skymap"),
109 deferLoad=True,
110 )
111 columns_match_target = cT.Input(
112 doc="Target match catalog columns",
113 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
114 storageClass="ArrowColumnList",
115 dimensions=("tract", "skymap"),
116 )
117 cat_matched = cT.Output(
118 doc="Catalog with reference and target columns for joined sources",
119 name="matched_{name_input_cat_ref}_{name_input_cat_target}",
120 storageClass="ArrowAstropy",
121 dimensions=("tract", "skymap"),
122 )
123 diff_matched = cT.Output(
124 doc="Table with aggregated counts, difference and chi statistics",
125 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
126 storageClass="ArrowAstropy",
127 dimensions=("tract", "skymap"),
128 )
129
130 def __init__(self, *, config=None):
131 if config.refcat_sharding_type != "tract":
132 if config.refcat_sharding_type == "none":
133 old = self.cat_ref
134 del self.cat_ref
135 self.cat_ref = cT.Input(
136 doc=old.doc,
137 name=old.name,
138 storageClass=old.storageClass,
139 dimensions=(),
140 deferLoad=old.deferLoad,
141 )
142 if not (config.compute_stats and len(config.columns_flux) > 0):
143 del self.diff_matched
144
145
146class MatchedCatalogFluxesConfig(pexConfig.Config):
147 column_ref_flux = pexConfig.Field(
148 dtype=str,
149 doc='Reference catalog flux column name',
150 )
151 columns_target_flux = pexConfig.ListField(
152 dtype=str,
153 listCheck=is_sequence_set,
154 doc="List of target catalog flux column names",
155 )
156 columns_target_flux_err = pexConfig.ListField(
157 dtype=str,
158 listCheck=is_sequence_set,
159 doc="List of target catalog flux error column names",
160 )
161
162 # this should be an orderedset
163 @property
164 def columns_in_ref(self) -> list[str]:
165 return [self.column_ref_flux]
166
167 # this should also be an orderedset
168 @property
169 def columns_in_target(self) -> list[str]:
170 columns = [col for col in self.columns_target_flux]
171 columns.extend(col for col in self.columns_target_flux_err if col not in columns)
172 return columns
173
174
176 pipeBase.PipelineTaskConfig,
177 pipelineConnections=DiffMatchedTractCatalogConnections,
178):
179 column_matched_prefix_ref = pexConfig.Field[str](
180 default='refcat_',
181 doc='The prefix for matched columns copied from the reference catalog',
182 )
183 column_ref_extended = pexConfig.Field[str](
184 default='is_pointsource',
185 deprecated='This field is no longer being used and will be removed after v28.',
186 doc='The boolean reference table column specifying if the target is extended',
187 )
188 column_ref_extended_inverted = pexConfig.Field[bool](
189 default=True,
190 deprecated='This field is no longer being used and will be removed after v28.',
191 doc='Whether column_ref_extended specifies if the object is compact, not extended',
192 )
193 column_target_extended = pexConfig.Field[str](
194 default='refExtendedness',
195 deprecated='This field is no longer being used and will be removed after v28.',
196 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
197 )
198 compute_stats = pexConfig.Field[bool](
199 default=False,
200 deprecated='This field is no longer being used and will be removed after v28.',
201 doc='Whether to compute matched difference statistics',
202 )
203 include_unmatched = pexConfig.Field[bool](
204 default=False,
205 doc='Whether to include unmatched rows in the matched table',
206 )
207
208 @property
209 def columns_in_ref(self) -> list[str]:
210 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2]
211 if self.compute_stats:
212 columns_all.append(self.column_ref_extended)
213 for column_lists in (
214 (
215 self.columns_ref_copy,
216 ),
217 (x.columns_in_ref for x in self.columns_flux.values()),
218 ):
219 for column_list in column_lists:
220 columns_all.extend(column_list)
221
222 return list({column: None for column in columns_all}.keys())
223
224 @property
225 def columns_in_target(self) -> list[str]:
226 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2]
227 if self.compute_stats:
228 columns_all.append(self.column_target_extended)
229 if self.coord_format.coords_ref_to_convert is not None:
230 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values()
231 if col not in columns_all)
232 for column_lists in (
233 (
238 ),
239 (x.columns_in_target for x in self.columns_flux.values()),
240 ):
241 for column_list in column_lists:
242 columns_all.extend(col for col in column_list if col not in columns_all)
243 return columns_all
244
245 columns_flux = pexConfig.ConfigDictField(
246 doc="Configs for flux columns for each band",
247 keytype=str,
248 itemtype=MatchedCatalogFluxesConfig,
249 default={},
250 )
251 columns_ref_mag_to_nJy = pexConfig.DictField[str, str](
252 doc='Reference table AB mag columns to convert to nJy flux columns with new names',
253 default={},
254 )
255 columns_ref_copy = pexConfig.ListField[str](
256 doc='Reference table columns to copy into cat_matched',
257 default=[],
258 listCheck=is_sequence_set,
259 )
260 columns_target_coord_err = pexConfig.ListField[str](
261 doc='Target table coordinate columns with standard errors (sigma)',
262 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
263 )
264 columns_target_copy = pexConfig.ListField[str](
265 doc='Target table columns to copy into cat_matched',
266 default=('patch',),
267 listCheck=is_sequence_set,
268 )
269 columns_target_mag_to_nJy = pexConfig.DictField[str, str](
270 doc='Target table AB mag columns to convert to nJy flux columns with new names',
271 default={},
272 )
273 columns_target_select_true = pexConfig.ListField[str](
274 doc='Target table columns to require to be True for selecting sources',
275 default=('detect_isPrimary',),
276 listCheck=is_sequence_set,
277 )
278 columns_target_select_false = pexConfig.ListField[str](
279 doc='Target table columns to require to be False for selecting sources',
280 default=('merge_peak_sky',),
281 listCheck=is_sequence_set,
282 )
283 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig](
284 doc="Configuration for coordinate conversion",
285 )
286 extendedness_cut = pexConfig.Field[float](
287 deprecated="This field is no longer being used and will be removed after v28.",
288 default=0.5,
289 doc='Minimum extendedness for a measured source to be considered extended',
290 )
291 mag_num_bins = pexConfig.Field[int](
292 deprecated="This field is no longer being used and will be removed after v28.",
293 doc='Number of magnitude bins',
294 default=15,
295 )
296 mag_brightest_ref = pexConfig.Field[float](
297 deprecated="This field is no longer being used and will be removed after v28.",
298 doc='Brightest magnitude cutoff for binning',
299 default=15,
300 )
301 mag_ceiling_target = pexConfig.Field[float](
302 deprecated="This field is no longer being used and will be removed after v28.",
303 doc='Ceiling (maximum/faint) magnitude for target sources',
304 default=None,
305 optional=True,
306 )
307 mag_faintest_ref = pexConfig.Field[float](
308 deprecated="This field is no longer being used and will be removed after v28.",
309 doc='Faintest magnitude cutoff for binning',
310 default=30,
311 )
312 mag_zeropoint_ref = pexConfig.Field[float](
313 deprecated="This field is no longer being used and will be removed after v28.",
314 doc='Magnitude zeropoint for reference sources',
315 default=31.4,
316 )
317 mag_zeropoint_target = pexConfig.Field[float](
318 deprecated="This field is no longer being used and will be removed after v28.",
319 doc='Magnitude zeropoint for target sources',
320 default=31.4,
321 )
322 percentiles = pexConfig.ListField[str](
323 deprecated="This field is no longer being used and will be removed after v28.",
324 doc='Percentiles to compute for diff/chi values',
325 # -2, -1, +1, +2 sigma percentiles for normal distribution
326 default=('2.275', '15.866', '84.134', '97.725'),
327 itemCheck=lambda x: 0 <= Decimal(x) <= 100,
328 listCheck=is_sequence_set,
329 )
330 refcat_sharding_type = pexConfig.ChoiceField[str](
331 doc="The type of sharding (spatial splitting) for the reference catalog",
332 allowed={"tract": "Tract-based shards", "none": "No sharding at all"},
333 default="tract",
334 )
335
336 def validate(self):
337 super().validate()
338
339 errors = []
340
341 for columns_mag, columns_in, name_columns_copy in (
342 (self.columns_ref_mag_to_nJy, self.columns_in_refcolumns_in_ref, "columns_ref_copy"),
344 ):
345 columns_copy = getattr(self, name_columns_copy)
346 for column_old, column_new in columns_mag.items():
347 if column_old not in columns_in:
348 errors.append(
349 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you"
350 f" forget to add it to self.{name_columns_copy}={columns_copy}?"
351 )
352 if column_new in columns_copy:
353 errors.append(
354 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}"
355 f" this will cause a collision. Please choose a different name."
356 )
357 if errors:
358 raise ValueError("\n".join(errors))
359
360
361@deprecated(reason="This class is no longer being used and will be removed after v28.",
362 version="v28.0", category=FutureWarning)
363@dataclass(frozen=True)
365 doc: str
366 name: str
367
368
369@deprecated(reason="This class is no longer being used and will be removed after v28.",
370 version="v28.0", category=FutureWarning)
371class MeasurementType(Enum):
372 DIFF = SimpleNamespace(
373 doc="difference (measured - reference)",
374 name="diff",
375 )
376 CHI = SimpleNamespace(
377 doc="scaled difference (measured - reference)/error",
378 name="chi",
379 )
380
381
382@deprecated(reason="This class is no longer being used and will be removed after v28.",
383 version="v28.0", category=FutureWarning)
384class Statistic(metaclass=ABCMeta):
385 """A statistic that can be applied to a set of values.
386 """
387 @abstractmethod
388 def doc(self) -> str:
389 """A description of the statistic"""
390 raise NotImplementedError('Subclasses must implement this method')
391
392 @abstractmethod
393 def name_short(self) -> str:
394 """A short name for the statistic, e.g. for a table column name"""
395 raise NotImplementedError('Subclasses must implement this method')
396
397 @abstractmethod
398 def value(self, values):
399 """The value of the statistic for a set of input values.
400
401 Parameters
402 ----------
403 values : `Collection` [`float`]
404 A set of values to compute the statistic for.
405
406 Returns
407 -------
408 statistic : `float`
409 The value of the statistic.
410 """
411 raise NotImplementedError('Subclasses must implement this method')
412
413
414@deprecated(reason="This class is no longer being used and will be removed after v28.",
415 version="v28.0", category=FutureWarning)
417 """The median of a set of values."""
418 @classmethod
419 def doc(cls) -> str:
420 return "Median"
421
422 @classmethod
423 def name_short(cls) -> str:
424 return "median"
425
426 def value(self, values):
427 return np.median(values)
428
429
430@deprecated(reason="This class is no longer being used and will be removed after v28.",
431 version="v28.0", category=FutureWarning)
433 """The re-scaled interquartile range (sigma equivalent)."""
434 @classmethod
435 def doc(cls) -> str:
436 return "Interquartile range divided by ~1.349 (sigma-equivalent)"
437
438 @classmethod
439 def name_short(cls) -> str:
440 return "sig_iqr"
441
442 def value(self, values):
443 return iqr(values, scale='normal')
444
445
446@deprecated(reason="This class is no longer being used and will be removed after v28.",
447 version="v28.0", category=FutureWarning)
449 """The re-scaled median absolute deviation (sigma equivalent)."""
450 @classmethod
451 def doc(cls) -> str:
452 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
453
454 @classmethod
455 def name_short(cls) -> str:
456 return "sig_mad"
457
458 def value(self, values):
459 return mad_std(values)
460
461
462@deprecated(reason="This class is no longer being used and will be removed after v28.",
463 version="v28.0", category=FutureWarning)
464@dataclass(frozen=True)
466 """An arbitrary percentile.
467
468 Parameters
469 ----------
470 percentile : `float`
471 A valid percentile (0 <= p <= 100).
472 """
473 percentile: float
474
475 def doc(self) -> str:
476 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
477
478 def name_short(self) -> str:
479 return f"pctl{f'{self.percentile/100:.5f}'[2:]}"
480
481 def value(self, values):
482 return np.percentile(values, self.percentilepercentile)
483
484 def __post_init__(self):
485 if not ((self.percentilepercentile >= 0) and (self.percentilepercentile <= 100)):
486 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')
487
488
489@deprecated(reason="This method is no longer being used and will be removed after v28.",
490 version="v28.0", category=FutureWarning)
491def _get_stat_name(*args):
492 return '_'.join(args)
493
494
495@deprecated(reason="This method is no longer being used and will be removed after v28.",
496 version="v28.0", category=FutureWarning)
497def _get_column_name(band, *args):
498 return f"{band}_{_get_stat_name(*args)}"
499
500
501@deprecated(reason="This method is no longer being used and will be removed after v28.",
502 version="v28.0", category=FutureWarning)
503def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
504 """Compute statistics on differences and store results in a row.
505
506 Parameters
507 ----------
508 values_ref : `numpy.ndarray`, (N,)
509 Reference values.
510 values_target : `numpy.ndarray`, (N,)
511 Measured values.
512 errors_target : `numpy.ndarray`, (N,)
513 Errors (standard deviations) on `values_target`.
514 row : `numpy.ndarray`, (1, C)
515 A numpy array with pre-assigned column names.
516 stats : `Dict` [`str`, `Statistic`]
517 A dict of `Statistic` values to measure, keyed by their column suffix.
518 suffixes : `Dict` [`MeasurementType`, `str`]
519 A dict of measurement type column suffixes, keyed by the measurement type.
520 prefix : `str`
521 A prefix for all column names (e.g. band).
522 skip_diff : `bool`
523 Whether to skip computing statistics on differences. Note that
524 differences will still be computed for chi statistics.
525
526 Returns
527 -------
528 row_with_stats : `numpy.ndarray`, (1, C)
529 The original `row` with statistic values assigned.
530 """
531 n_ref = len(values_ref)
532 if n_ref > 0:
533 n_target = len(values_target)
534 n_target_err = len(errors_target) if errors_target is not None else n_ref
535 if (n_target != n_ref) or (n_target_err != n_ref):
536 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}'
537 f', error_target={n_target_err} must match')
538
539 do_chi = errors_target is not None
540 diff = values_target - values_ref
541 chi = diff/errors_target if do_chi else diff
542 # Could make this configurable, but non-finite values/errors are not really usable
543 valid = np.isfinite(chi)
544 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]}
545 if do_chi:
546 values_type[MeasurementType.CHI] = chi[valid]
547
548 for suffix_type, suffix in suffixes.items():
549 values = values_type.get(suffix_type)
550 if values is not None and len(values) > 0:
551 for stat_name, stat in stats.items():
552 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values)
553 return row
554
555
556@deprecated(reason="This class is no longer being used and will be removed after v28.",
557 version="v28.0", category=FutureWarning)
558@dataclass(frozen=True)
560 is_extended: bool | None
561 label: str
562
563
564@deprecated(reason="This class is no longer being used and will be removed after v28.",
565 version="v28.0", category=FutureWarning)
566class SourceType(Enum):
567 ALL = SimpleNamespace(is_extended=None, label='all')
568 RESOLVED = SimpleNamespace(is_extended=True, label='resolved')
569 UNRESOLVED = SimpleNamespace(is_extended=False, label='unresolved')
570
571
572@deprecated(reason="This class is no longer being used and will be removed after v28.",
573 version="v28.0", category=FutureWarning)
574class MatchType(Enum):
575 ALL = 'all'
576 MATCH_RIGHT = 'match_right'
577 MATCH_WRONG = 'match_wrong'
578
579
580@deprecated(reason="This method is no longer being used and will be removed after v28.",
581 version="v28.0", category=FutureWarning)
582def _get_columns(bands_columns: dict, suffixes: dict, suffixes_flux: dict, suffixes_mag: dict,
583 stats: dict, target: ComparableCatalog, column_dist: str):
584 """Get column names for a table of difference statistics.
585
586 Parameters
587 ----------
588 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`]
589 Dict keyed by band of flux column configuration.
590 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`]
591 Dict of suffixes for each `MeasurementType` type, for general columns (e.g.
592 coordinates), fluxes and magnitudes, respectively.
593 stats : `Dict` [`str`, `Statistic`]
594 Dict of suffixes for each `Statistic` type.
595 target : `ComparableCatalog`
596 A target catalog with coordinate column names.
597 column_dist : `str`
598 The name of the distance column.
599
600 Returns
601 -------
602 columns : `Dict` [`str`, `type`]
603 Dictionary of column types keyed by name.
604 n_models : `int`
605 The number of models measurements will be made for.
606
607 Notes
608 -----
609 Presently, models must be identical for each band.
610 """
611 # Initial columns
612 columns = {
613 "bin": int,
614 "mag_min": float,
615 "mag_max": float,
616 }
617
618 # pre-assign all of the columns with appropriate types
619 n_models = 0
620
621 bands = list(bands_columns.keys())
622 n_bands = len(bands)
623
624 for idx, (band, config_flux) in enumerate(bands_columns.items()):
625 columns_suffix = [
626 ('flux', suffixes_flux),
627 ('mag', suffixes_mag),
628 ]
629 if idx == 0:
630 n_models = len(config_flux.columns_target_flux)
631 if (idx > 0) or (n_bands > 2):
632 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes))
633 n_models_flux = len(config_flux.columns_target_flux)
634 n_models_err = len(config_flux.columns_target_flux_err)
635
636 # TODO: Do equivalent validation earlier, in the config
637 if (n_models_flux != n_models) or (n_models_err != n_models):
638 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and'
639 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}')
640
641 for sourcetype in SourceType:
642 label = sourcetype.value.label
643 # Totals would be redundant
644 if sourcetype != SourceType.ALL:
645 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target')
646 for mtype in MatchType):
647 columns[_get_column_name(band, label, item)] = int
648
649 for item in (target.column_coord1, target.column_coord2, column_dist):
650 for suffix in suffixes.values():
651 for stat in stats.keys():
652 columns[_get_column_name(band, label, item, suffix, stat)] = float
653
654 for item in config_flux.columns_target_flux:
655 for prefix_item, suffixes_col in columns_suffix:
656 for suffix in suffixes_col.values():
657 for stat in stats.keys():
658 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float
659
660 return columns, n_models
661
662
663class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
664 """Load subsets of matched catalogs and output a merged catalog of matched sources.
665 """
666 ConfigClass = DiffMatchedTractCatalogConfig
667 _DefaultName = "DiffMatchedTractCatalog"
668
669 def runQuantum(self, butlerQC, inputRefs, outputRefs):
670 inputs = butlerQC.get(inputRefs)
671 skymap = inputs.pop("skymap")
672
673 columns_match_target = ['match_row']
674 if 'match_candidate' in inputs['columns_match_target']:
675 columns_match_target.append('match_candidate')
676
677 outputs = self.run(
678 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
679 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
680 catalog_match_ref=inputs['cat_match_ref'].get(
681 parameters={'columns': ['match_candidate', 'match_row']},
682 ),
683 catalog_match_target=inputs['cat_match_target'].get(
684 parameters={'columns': columns_match_target},
685 ),
686 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
687 )
688 butlerQC.put(outputs, outputRefs)
689
690 def run(
691 self,
692 catalog_ref: pd.DataFrame | astropy.table.Table,
693 catalog_target: pd.DataFrame | astropy.table.Table,
694 catalog_match_ref: pd.DataFrame | astropy.table.Table,
695 catalog_match_target: pd.DataFrame | astropy.table.Table,
696 wcs: afwGeom.SkyWcs = None,
697 ) -> pipeBase.Struct:
698 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
699 a combined matched catalog with columns from both inputs.
700
701 Parameters
702 ----------
703 catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
704 A reference catalog to diff objects/sources from.
705 catalog_target : `pandas.DataFrame` | `astropy.table.Table`
706 A target catalog to diff reference objects/sources to.
707 catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table`
708 A catalog with match indices of target sources and selection flags
709 for each reference source.
710 catalog_match_target : `pandas.DataFrame` | `astropy.table.Table`
711 A catalog with selection flags for each target source.
712 wcs : `lsst.afw.image.SkyWcs`
713 A coordinate system to convert catalog positions to sky coordinates,
714 if necessary.
715
716 Returns
717 -------
718 retStruct : `lsst.pipe.base.Struct`
719 A struct with output_ref and output_target attribute containing the
720 output matched catalogs.
721 """
722 # Would be nice if this could refer directly to ConfigClass
723 config: DiffMatchedTractCatalogConfig = self.config
724
725 is_ref_pd = isinstance(catalog_ref, pd.DataFrame)
726 is_target_pd = isinstance(catalog_target, pd.DataFrame)
727 is_match_ref_pd = isinstance(catalog_match_ref, pd.DataFrame)
728 is_match_target_pd = isinstance(catalog_match_target, pd.DataFrame)
729 if is_ref_pd:
730 catalog_ref = astropy.table.Table.from_pandas(catalog_ref)
731 if is_target_pd:
732 catalog_target = astropy.table.Table.from_pandas(catalog_target)
733 if is_match_ref_pd:
734 catalog_match_ref = astropy.table.Table.from_pandas(catalog_match_ref)
735 if is_match_target_pd:
736 catalog_match_target = astropy.table.Table.from_pandas(catalog_match_target)
737 # TODO: Remove pandas support in DM-46523
738 if is_ref_pd or is_target_pd or is_match_ref_pd or is_match_target_pd:
739 warnings.warn("pandas usage in MatchProbabilisticTask is deprecated; it will be removed "
740 " in favour of astropy.table after release 28.0.0", category=FutureWarning)
741
742 select_ref = catalog_match_ref['match_candidate']
743 # Add additional selection criteria for target sources beyond those for matching
744 # (not recommended, but can be done anyway)
745 select_target = (catalog_match_target['match_candidate']
746 if 'match_candidate' in catalog_match_target.columns
747 else np.ones(len(catalog_match_target), dtype=bool))
748 for column in config.columns_target_select_true:
749 select_target &= catalog_target[column]
750 for column in config.columns_target_select_false:
751 select_target &= ~catalog_target[column]
752
753 ref, target = config.coord_format.format_catalogs(
754 catalog_ref=catalog_ref, catalog_target=catalog_target,
755 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy,
756 )
757 cat_ref = ref.catalog
758 cat_target = target.catalog
759 n_target = len(cat_target)
760
761 if config.include_unmatched:
762 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
763 cat_add['match_candidate'] = cat_match['match_candidate']
764
765 match_row = catalog_match_ref['match_row']
766 matched_ref = match_row >= 0
767 matched_row = match_row[matched_ref]
768 matched_target = np.zeros(n_target, dtype=bool)
769 matched_target[matched_row] = True
770
771 # Add/compute distance columns
772 coord1_target_err, coord2_target_err = config.columns_target_coord_err
773 column_dist, column_dist_err = 'match_distance', 'match_distanceErr'
774 dist = np.full(n_target, np.nan)
775
776 target_match_c1, target_match_c2 = (coord[matched_row] for coord in (target.coord1, target.coord2))
777 target_ref_c1, target_ref_c2 = (coord[matched_ref] for coord in (ref.coord1, ref.coord2))
778
779 dist_err = np.full(n_target, np.nan)
780 dist[matched_row] = sphdist(
781 target_match_c1, target_match_c2, target_ref_c1, target_ref_c2
782 ) if config.coord_format.coords_spherical else np.hypot(
783 target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2,
784 )
785 cat_target_matched = cat_target[matched_row]
786 # This will convert a masked array to an array filled with nans
787 # wherever there are bad values (otherwise sphdist can raise)
788 c1_err, c2_err = (
789 np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err)
790 )
791 # Should probably explicitly add cosine terms if ref has errors too
792 dist_err[matched_row] = sphdist(
793 target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err
794 ) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err)
795 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
796
797 # Create a matched table, preserving the target catalog's named index (if it has one)
798 cat_left = cat_target[matched_row]
799 cat_right = cat_ref[matched_ref]
800 cat_right.rename_columns(
801 list(cat_right.columns),
802 new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
803 )
804 cat_matched = astropy.table.hstack((cat_left, cat_right))
805
806 if config.include_unmatched:
807 # Create an unmatched table with the same schema as the matched one
808 # ... but only for objects with no matches (for completeness/purity)
809 # and that were selected for matching (or inclusion via config)
810 cat_right = astropy.table.Table(
811 cat_ref[~matched_ref & select_ref]
812 )
813 cat_right.rename_columns(
814 cat_right.colnames,
815 [f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames],
816 )
817 match_row_target = catalog_match_target['match_row']
818 cat_left = cat_target[~(match_row_target >= 0) & select_target]
819 # This may be slower than pandas but will, for example, create
820 # masked columns for booleans, which pandas does not support.
821 # See https://github.com/pandas-dev/pandas/issues/46662
822 cat_unmatched = astropy.table.vstack([cat_left, cat_right])
823
824 for columns_convert_base, prefix in (
825 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
826 (config.columns_target_mag_to_nJy, ""),
827 ):
828 if columns_convert_base:
829 columns_convert = {
830 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
831 } if prefix else columns_convert_base
832 to_convert = [cat_matched]
833 if config.include_unmatched:
834 to_convert.append(cat_unmatched)
835 for cat_convert in to_convert:
836 cat_convert.rename_columns(
837 tuple(columns_convert.keys()),
838 tuple(columns_convert.values()),
839 )
840 for column_flux in columns_convert.values():
841 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])
842
843 data = None
844 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
845 n_bands = len(band_fluxes)
846
847 # TODO: Deprecated by RFC-1017 and to be removed in DM-44988
848 do_stats = self.config.compute_stats and (n_bands > 0)
849 if do_stats:
850 # Slightly smelly hack for when a column (like distance) is already relative to truth
851 column_dummy = 'dummy'
852 cat_ref[column_dummy] = np.zeros_like(ref.coord1)
853
854 # Add a boolean column for whether a match is classified correctly
855 # TODO: remove the assumption of a boolean column
856 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)
857
858 extended_target = cat_target[config.column_target_extended] >= config.extendedness_cut
859
860 # Define difference/chi columns and statistics thereof
861 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
862 # Skip diff for fluxes - covered by mags
863 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
864 # Skip chi for magnitudes, which have strange errors
865 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
866 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}
867
868 for percentile in self.config.percentiles:
869 stat = Percentile(percentile=float(Decimal(percentile)))
870 stats[stat.name_short()] = stat
871
872 # Get dict of column names
873 columns, n_models = _get_columns(
874 bands_columns=config.columns_flux,
875 suffixes=suffixes,
876 suffixes_flux=suffixes_flux,
877 suffixes_mag=suffixes_mag,
878 stats=stats,
879 target=target,
880 column_dist=column_dist,
881 )
882
883 # Setup numpy table
884 n_bins = config.mag_num_bins
885 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
886 data['bin'] = np.arange(n_bins)
887
888 # Setup bins
889 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
890 num=n_bins + 1)
891 data['mag_min'] = bins_mag[:-1]
892 data['mag_max'] = bins_mag[1:]
893 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))
894
895 # Define temporary columns for intermediate storage
896 column_mag_temp = 'mag_temp'
897 column_color_temp = 'color_temp'
898 column_color_err_temp = 'colorErr_temp'
899 flux_err_frac_prev = [None]*n_models
900 mag_prev = [None]*n_models
901
902 columns_target = {
903 target.column_coord1: (
904 ref.column_coord1, target.column_coord1, coord1_target_err, False,
905 ),
906 target.column_coord2: (
907 ref.column_coord2, target.column_coord2, coord2_target_err, False,
908 ),
909 column_dist: (column_dummy, column_dist, column_dist_err, False),
910 }
911
912 # Cheat a little and do the first band last so that the color is
913 # based on the last band
914 band_fluxes.append(band_fluxes[0])
915 flux_err_frac_first = None
916 mag_first = None
917 mag_ref_first = None
918
919 band_prev = None
920 for idx_band, (band, config_flux) in enumerate(band_fluxes):
921 if idx_band == n_bands:
922 # These were already computed earlier
923 mag_ref = mag_ref_first
924 flux_err_frac = flux_err_frac_first
925 mag_model = mag_first
926 else:
927 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref
928 flux_err_frac = [None]*n_models
929 mag_model = [None]*n_models
930
931 if idx_band > 0:
932 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref
933
934 cat_ref[column_mag_temp] = mag_ref
935
936 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi)
937 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)]
938
939 # Iterate over multiple models, compute their mags and colours (if there's a previous band)
940 for idx_model in range(n_models):
941 column_target_flux = config_flux.columns_target_flux[idx_model]
942 column_target_flux_err = config_flux.columns_target_flux_err[idx_model]
943
944 flux_target = cat_target[column_target_flux]
945 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target
946 if config.mag_ceiling_target is not None:
947 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target
948 mag_model[idx_model] = mag_target
949
950 # These are needed for computing magnitude/color "errors" (which are a sketchy concept)
951 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target
952
953 # Stop if idx == 0: The rest will be picked up at idx == n_bins
954 if idx_band > 0:
955 # Keep these mags tabulated for convenience
956 column_mag_temp_model = f'{column_mag_temp}{idx_model}'
957 cat_target[column_mag_temp_model] = mag_target
958
959 columns_target[f'flux_{column_target_flux}'] = (
960 config_flux.column_ref_flux,
961 column_target_flux,
962 column_target_flux_err,
963 True,
964 )
965 # Note: magnitude errors are generally problematic and not worth aggregating
966 columns_target[f'mag_{column_target_flux}'] = (
967 column_mag_temp, column_mag_temp_model, None, False,
968 )
969
970 # No need for colors if this is the last band and there are only two bands
971 # (because it would just be the negative of the first color)
972 skip_color = (idx_band == n_bands) and (n_bands <= 2)
973 if not skip_color:
974 column_color_temp_model = f'{column_color_temp}{idx_model}'
975 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}'
976
977 # e.g. if order is ugrizy, first color will be u - g
978 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model]
979
980 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors
981 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot(
982 flux_err_frac[idx_model], flux_err_frac_prev[idx_model])
983 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = (
984 column_color_temp,
985 column_color_temp_model,
986 column_color_err_temp_model,
987 False,
988 )
989
990 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag):
991 row = data[idx_bin]
992 # Reference sources only need to be counted once
993 if idx_model == 0:
994 select_ref_bin = select_ref_bins[idx_bin]
995 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi)
996
997 for sourcetype in SourceType:
998 sourcetype_info = sourcetype.value
999 is_extended = sourcetype_info.is_extended
1000 # Counts filtered by match selection and magnitude bin
1001 select_ref_sub = select_ref_bin.copy()
1002 select_target_sub = select_target_bin.copy()
1003 if is_extended is not None:
1004 is_extended_ref = (extended_ref == is_extended)
1005 select_ref_sub &= is_extended_ref
1006 if idx_model == 0:
1007 n_ref_sub = np.count_nonzero(select_ref_sub)
1008 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
1009 MatchType.ALL.value)] = n_ref_sub
1010 select_target_sub &= (extended_target == is_extended)
1011 n_target_sub = np.count_nonzero(select_target_sub)
1012 row[_get_column_name(band, sourcetype_info.label, 'n_target',
1013 MatchType.ALL.value)] = n_target_sub
1014
1015 # Filter matches by magnitude bin and true class
1016 match_row_bin = match_row.copy()
1017 match_row_bin[~select_ref_sub] = -1
1018 match_good = match_row_bin >= 0
1019
1020 n_match = np.count_nonzero(match_good)
1021
1022 # Same for counts of matched target sources (for e.g. purity)
1023
1024 if n_match > 0:
1025 rows_matched = match_row_bin[match_good]
1026 subset_target = cat_target[rows_matched]
1027 if (is_extended is not None) and (idx_model == 0):
1028 right_type = extended_target[rows_matched] == is_extended
1029 n_total = len(right_type)
1030 n_right = np.count_nonzero(right_type)
1031 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
1032 MatchType.MATCH_RIGHT.value)] = n_right
1033 row[_get_column_name(
1034 band,
1035 sourcetype_info.label,
1036 'n_ref',
1037 MatchType.MATCH_WRONG.value,
1038 )] = n_total - n_right
1039
1040 # compute stats for this bin, for all columns
1041 for column, (column_ref, column_target, column_err_target, skip_diff) \
1042 in columns_target.items():
1043 values_ref = cat_ref[column_ref][match_good]
1044 errors_target = (
1045 subset_target[column_err_target]
1046 if column_err_target is not None
1047 else None
1048 )
1050 values_ref,
1051 subset_target[column_target],
1052 errors_target,
1053 row,
1054 stats,
1055 suffixes,
1056 prefix=f'{band}_{sourcetype_info.label}_{column}',
1057 skip_diff=skip_diff,
1058 )
1059
1060 # Count matched target sources with *measured* mags within bin
1061 # Used for e.g. purity calculation
1062 # Should be merged with above code if there's ever a need for
1063 # measuring stats on this source selection
1064 select_target_sub &= matched_target
1065
1066 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0):
1067 n_total = np.count_nonzero(select_target_sub)
1068 right_type = np.zeros(n_target, dtype=bool)
1069 right_type[match_row[matched_ref & is_extended_ref]] = True
1070 right_type &= select_target_sub
1071 n_right = np.count_nonzero(right_type)
1072 row[_get_column_name(band, sourcetype_info.label, 'n_target',
1073 MatchType.MATCH_RIGHT.value)] = n_right
1074 row[_get_column_name(band, sourcetype_info.label, 'n_target',
1075 MatchType.MATCH_WRONG.value)] = n_total - n_right
1076
1077 # delete the flux/color columns since they change with each band
1078 for prefix in ('flux', 'mag'):
1079 del columns_target[f'{prefix}_{column_target_flux}']
1080 if not skip_color:
1081 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}']
1082
1083 # keep values needed for colors
1084 flux_err_frac_prev = flux_err_frac
1085 mag_prev = mag_model
1086 band_prev = band
1087 if idx_band == 0:
1088 flux_err_frac_first = flux_err_frac
1089 mag_first = mag_model
1090 mag_ref_first = mag_ref
1091
1092 if config.include_unmatched:
1093 # This is probably less efficient than just doing an outer join originally; worth checking
1094 cat_matched = astropy.table.vstack([cat_matched, cat_unmatched])
1095
1096 retStruct = pipeBase.Struct(cat_matched=cat_matched)
1097 if do_stats:
1098 retStruct.diff_matched = astropy.table.Table(data)
1099 return retStruct
A 2-dimensional celestial WCS that transform pixels to ICRS RA/Dec, using the LSST standard for pixel...
Definition SkyWcs.h:117
pipeBase.Struct run(self, pd.DataFrame|astropy.table.Table catalog_ref, pd.DataFrame|astropy.table.Table catalog_target, pd.DataFrame|astropy.table.Table catalog_match_ref, pd.DataFrame|astropy.table.Table catalog_match_target, afwGeom.SkyWcs wcs=None)
compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False)
_get_columns(dict bands_columns, dict suffixes, dict suffixes_flux, dict suffixes_mag, dict stats, ComparableCatalog target, str column_dist)