LSST Applications g0f08755f38+9522ef2f0f,g1653933729+a905cd61c3,g168dd56ebc+a905cd61c3,g1a2382251a+910d683904,g20f6ffc8e0+9522ef2f0f,g217e2c1bcf+f4af07de8a,g28da252d5a+26a25b978d,g2bbee38e9b+cc7bbd92cc,g2bc492864f+cc7bbd92cc,g32e5bea42b+de24d92311,g347aa1857d+cc7bbd92cc,g35bb328faa+a905cd61c3,g3a166c0a6a+cc7bbd92cc,g3bd4b5ce2c+02735527dc,g3e281a1b8c+2bff41ced5,g414038480c+4de324692b,g41af890bb2+4fc8c6ef01,g43bc871e57+d0d7cc457a,g78460c75b0+4ae99bb757,g80478fca09+615987a4d7,g82479be7b0+970d1d03ea,g8365541083+a905cd61c3,g858d7b2824+9522ef2f0f,g9125e01d80+a905cd61c3,ga5288a1d22+9ad990292e,gb58c049af0+84d1b6ec45,gc28159a63d+cc7bbd92cc,gc5452a3dca+b82ec7cc4c,gcab2d0539d+475d436cbd,gcf0d15dbbd+d816b8a730,gda6a2b7d83+d816b8a730,gdaeeff99f8+686ef0dd99,ge79ae78c31+cc7bbd92cc,gef2f8181fd+c1889b0e42,gf0baf85859+f9edac6842,gf1e97e5484+a55c27affc,gfa517265be+9522ef2f0f,gfa999e8aa5+d85414070d,w.2025.01
LSST Data Management Base Package
Loading...
Searching...
No Matches
_task.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = (
25 "PrettyPictureTask",
26 "PrettyPictureConnections",
27 "PrettyPictureConfig",
28 "PrettyMosaicTask",
29 "PrettyMosaicConnections",
30 "PrettyMosaicConfig",
31)
32
33from collections.abc import Iterable, Mapping
34import numpy as np
35from typing import TYPE_CHECKING, cast, Any
36from lsst.skymap import BaseSkyMap
37
38from lsst.daf.butler import Butler, DeferredDatasetHandle
39from lsst.daf.butler import DatasetRef
40from lsst.pex.config import Field, Config, ConfigDictField, ConfigField, ListField, ChoiceField
41from lsst.pipe.base import (
42 PipelineTask,
43 PipelineTaskConfig,
44 PipelineTaskConnections,
45 Struct,
46 InMemoryDatasetHandle,
47)
48import cv2
49
50from lsst.pipe.base.connectionTypes import Input, Output
51from lsst.geom import Box2I, Point2I, Extent2I
52from lsst.afw.image import Exposure, Mask
53
54from ._plugins import plugins
55from ._colorMapper import lsstRGB
56
57import tempfile
58
59
60if TYPE_CHECKING:
61 from numpy.typing import NDArray
62 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
63 from lsst.skymap import TractInfo, PatchInfo
64
65
67 PipelineTaskConnections,
68 dimensions={"tract", "patch", "skymap"},
69 defaultTemplates={"coaddTypeName": "deep"},
70):
71 inputCoadds = Input(
72 doc=(
73 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
74 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
75 ),
76 name="{coaddTypeName}CoaddPsfMatched",
77 storageClass="ExposureF",
78 dimensions=("tract", "patch", "skymap", "band"),
79 multiple=True,
80 )
81
82 outputRGB = Output(
83 doc="A RGB image created from the input data stored as a 3d array",
84 name="rgb_picture_array",
85 storageClass="NumpyArray",
86 dimensions=("tract", "patch", "skymap"),
87 )
88
89 outputRGBMask = Output(
90 doc="A Mask corresponding to the fused masks of the input channels",
91 name="rgb_picture_mask",
92 storageClass="Mask",
93 dimensions=("tract", "patch", "skymap"),
94 )
95
96
97class ChannelRGBConfig(Config):
98 """This describes the rgb values of a given input channel.
99
100 For instance if this channel is red the values would be self.r = 1,
101 self.g = 0, self.b = 0. If the channel was cyan the values would be
102 self.r = 0, self.g = 1, self.b = 1.
103 """
104
105 r = Field[float](doc="The amount of red contained in this channel")
106 g = Field[float](doc="The amount of green contained in this channel")
107 b = Field[float](doc="The amount of blue contained in this channel")
108
109 def validate(self):
110 for f in (self.r, self.g, self.b):
111 if f < 0 or f > 1:
112 raise ValueError(f"Field {f} can not have a value less than 0 or greater than one")
113 return super().validate()
114
115
116class LumConfig(Config):
117 """Configurations to control how luminance is mapped in the rgb code"""
118
119 stretch = Field[float](doc="The stretch of the luminance in asinh", default=400)
120 max = Field[float](doc="The maximum allowed luminance on a 0 to 100 scale", default=85)
121 A = Field[float](doc="A scaling factor to apply post asinh stretching", default=1)
122 b0 = Field[float](doc="A linear offset to apply post asinh stretching", default=0.00)
123 minimum = Field[float](
124 doc="The minimum intensity value after stretch, values lower will be set to zero", default=0
125 )
126 floor = Field[float](doc="A scaling factor to apply to the luminance before asinh scaling", default=0.0)
127 Q = Field[float](doc="softening parameter", default=0.7)
128
129
130class LocalContrastConfig(Config):
131 """Configuration to control local contrast enhancement of the luminance
132 channel."""
133
134 doLocalContrast = Field[bool](
135 doc="Apply local contrast enhancements to the luminance channel", default=True
136 )
137 highlights = Field[float](doc="Adjustment factor for the highlights", default=-0.9)
138 shadows = Field[float](doc="Adjustment factor for the shadows", default=0.5)
139 clarity = Field[float](doc="Amount of clarity to apply to contrast modification", default=0.1)
140 sigma = Field[float](
141 doc="The scale size of what is considered local in the contrast enhancement", default=30
142 )
143 maxLevel = Field[int](
144 doc="The maximum number of scales the contrast should be enhanced over, if None then all",
145 default=4,
146 optional=True,
147 )
148
149
150class ScaleColorConfig(Config):
151 """Controls color scaling in the rgb generation process."""
152
153 saturation = Field[float](
154 doc=(
155 "The overall saturation factor with the scaled luminance between zero and one. "
156 "A value of one is not recommended as it makes bright pixels very saturated"
157 ),
158 default=0.5,
159 )
160 maxChroma = Field[float](
161 doc=(
162 "The maximum chromaticity in the CIELCh color space, large "
163 "values will cause bright pixels to fall outside the RGB gamut."
164 ),
165 default=50.0,
166 )
167
168
169class RemapBoundsConfig(Config):
170 """Remaps input images to a known range of values.
171
172 Often input images are not mapped to any defined range of values
173 (for instance if they are in count units). This controls how the units of
174 and image are mapped to a zero to one range by determining an upper
175 bound.
176 """
177
178 quant = Field[float](
179 doc=(
180 "The maximum values of each of the three channels will be multiplied by this factor to "
181 "determine the maximum flux of the image, values larger than this quantity will be clipped."
182 ),
183 default=0.8,
184 )
185 absMax = Field[float](
186 doc="Instead of determining the maximum value from the image, use this fixed value instead",
187 default=220,
188 optional=True,
189 )
190 scaleBoundFactor = Field[float](
191 doc=(
192 "Factor used to compare absMax and the emperically determined"
193 "maximim. if emperical_max is less than scaleBoundFactor*absMax"
194 "then the emperical_max is used instead of absMax, even if it"
195 "is set. Do not set this field to skip this comparison."
196 ),
197 optional=True,
198 )
199
200
201class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
202 channelConfig = ConfigDictField(
203 doc="A dictionary that maps band names to their rgb channel configurations",
204 keytype=str,
205 itemtype=ChannelRGBConfig,
206 default={},
207 )
208 imageRemappingConfig = ConfigField[RemapBoundsConfig](
209 doc="Configuration controlling channel normalization process"
210 )
211 luminanceConfig = ConfigField[LumConfig](
212 doc="Configuration for the luminance scaling when making an RGB image"
213 )
214 localContrastConfig = ConfigField[LocalContrastConfig](
215 doc="Configuration controlling the local contrast correction in RGB image production"
216 )
217 colorConfig = ConfigField[ScaleColorConfig](
218 doc="Configuration to control the color scaling process in RGB image production"
219 )
220 cieWhitePoint = ListField[float](
221 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
222 )
223 arrayType = ChoiceField[str](
224 doc="The dataset type for the output image array",
225 default="uint8",
226 allowed={
227 "uint8": "Use 8 bit arrays, 255 max",
228 "uint16": "Use 16 bit arrays, 65535 max",
229 "half": "Use 16 bit float arrays, 1 max",
230 "float": "Use 32 bit float arrays, 1 max",
231 },
232 )
233 doPSFDeconcovlve = Field[bool](
234 doc="Use the PSF in a richardson lucy deconvolution on the luminance channel.", default=True
235 )
236 exposureBrackets = ListField[float](
237 doc=(
238 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
239 "then be fused into a final image"
240 ),
241 optional=True,
242 default=[1.25, 1, 0.75],
243 )
244
245 def setDefaults(self):
246 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
247 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
248 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
249 return super().setDefaults()
250
251
252class PrettyPictureTask(PipelineTask):
253 _DefaultName = "prettyPictureTask"
254 ConfigClass = PrettyPictureConfig
255
256 config: ConfigClass
257
258 def run(self, images: Mapping[str, Exposure]) -> Struct:
259 channels = {}
260 shape = (0, 0)
261 jointMask: None | NDArray = None
262 maskDict: Mapping[str, int] = {}
263 for channel, imageExposure in images.items():
264 imageArray = imageExposure.image.array
265 # run all the plugins designed for array based interaction
266 for plug in plugins.channel():
267 imageArray = plug(
268 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict()
269 ).astype(np.float32)
270 channels[channel] = imageArray
271 # This will get done each loop, but they are trivial lookups so it
272 # does not matter
273 shape = imageArray.shape
274 maskDict = imageExposure.mask.getMaskPlaneDict()
275 if jointMask is None:
276 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
277 jointMask |= imageExposure.mask.array
278
279 # mix the images to rgb
280 imageRArray = np.zeros(shape, dtype=np.float32)
281 imageGArray = np.zeros(shape, dtype=np.float32)
282 imageBArray = np.zeros(shape, dtype=np.float32)
283
284 for band, image in channels.items():
285 mix = self.config.channelConfig[band]
286 if mix.r:
287 imageRArray += mix.r * image
288 if mix.g:
289 imageGArray += mix.g * image
290 if mix.b:
291 imageBArray += mix.b * image
292
293 exposure = next(iter(images.values()))
294 box: Box2I = exposure.getBBox()
295 boxCenter = box.getCenter()
296 try:
297 psf = exposure.psf.computeImage(boxCenter).array
298 except Exception:
299 psf = None
300 # Ignore type because Exposures do in fact have a bbox, but it is c++
301 # and not typed.
302 colorImage = lsstRGB(
303 imageRArray,
304 imageGArray,
305 imageBArray,
306 scaleLumKWargs=self.config.luminanceConfig.toDict(),
307 remapBoundsKwargs=self.config.imageRemappingConfig.toDict(),
308 scaleColorKWargs=self.config.colorConfig.toDict(),
309 **(self.config.localContrastConfig.toDict()),
310 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
311 psf=psf if self.config.doPSFDeconcovlve else None,
312 brackets=list(self.config.exposureBrackets) if self.config.exposureBrackets else None,
313 )
314
315 # Find the dataset type and thus the maximum values as well
316 maxVal: int | float
317 match self.config.arrayType:
318 case "uint8":
319 dtype = np.uint8
320 maxVal = 255
321 case "uint16":
322 dtype = np.uint16
323 maxVal = 65535
324 case "half":
325 dtype = np.half
326 maxVal = 1.0
327 case "float":
328 dtype = np.float32
329 maxVal = 1.0
330 case _:
331 assert True, "This code path should be unreachable"
332
333 # lsstRGB returns an image in 0-1 scale it to the maximum value
334 colorImage *= maxVal # type: ignore
335
336 # assert for typing reasons
337 assert jointMask is not None
338 # Run any image level correction plugins
339 for plug in plugins.partial():
340 colorImage = plug(colorImage, jointMask, maskDict)
341
342 # pack the joint mask back into a mask object
343 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
344 lsstMask.array = jointMask # type: ignore
345 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask) # type: ignore
346
347 def runQuantum(
348 self,
349 butlerQC: QuantumContext,
350 inputRefs: InputQuantizedConnection,
351 outputRefs: OutputQuantizedConnection,
352 ) -> None:
353 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
354 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
355 outputs = self.run(sortedImages)
356 butlerQC.put(outputs, outputRefs)
357
358 def makeInputsFromRefs(
359 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
360 ) -> dict[str, Exposure]:
361 sortedImages: dict[str, Exposure] = {}
362 for ref in refs:
363 key: str = cast(str, ref.dataId["band"])
364 image = butler.get(ref)
365 sortedImages[key] = image
366 return sortedImages
367
368 def makeInputsFromArrays(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
369 # ignore type because there are not proper stubs for afw
370 temp = {}
371 for key, array in kwargs.items():
372 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
373 temp[key].image.array[:] = array
374
375 return self.makeInputsFromExposures(**temp)
376
377 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
378 sortedImages = {}
379 for key, value in kwargs.items():
380 sortedImages[key] = value
381 return sortedImages
382
383
384class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
385 inputRGB = Input(
386 doc="Individual RGB images that are to go into the mosaic",
387 name="rgb_picture_array",
388 storageClass="NumpyArray",
389 dimensions=("tract", "patch", "skymap"),
390 multiple=True,
391 deferLoad=True,
392 )
393
394 skyMap = Input(
395 doc="The skymap which the data has been mapped onto",
396 storageClass="SkyMap",
397 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
398 dimensions=("skymap",),
399 )
400
401 inputRGBMask = Input(
402 doc="Individual RGB images that are to go into the mosaic",
403 name="rgb_picture_mask",
404 storageClass="Mask",
405 dimensions=("tract", "patch", "skymap"),
406 multiple=True,
407 deferLoad=True,
408 )
409
410 outputRGBMosaic = Output(
411 doc="A RGB mosaic created from the input data stored as a 3d array",
412 name="rgb_mosaic_array",
413 storageClass="NumpyArray",
414 dimensions=("tract", "skymap"),
415 )
416
417
418class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
419 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
420
421
422class PrettyMosaicTask(PipelineTask):
423 _DefaultName = "prettyMosaicTask"
424 ConfigClass = PrettyMosaicConfig
425
426 config: ConfigClass
427
428 def run(
429 self,
430 inputRGB: Iterable[DeferredDatasetHandle],
431 skyMap: BaseSkyMap,
432 inputRGBMask: Iterable[DeferredDatasetHandle],
433 ) -> Struct:
434 # create the bounding region
435 newBox = Box2I()
436 # store the bounds as they are retrieved from the skymap
437 boxes = []
438 tractMaps = []
439 for handle in inputRGB:
440 dataId = handle.dataId
441 tractInfo: TractInfo = skyMap[dataId["tract"]]
442 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
443 bbox = patchInfo.getOuterBBox()
444 boxes.append(bbox)
445 newBox.include(bbox)
446 tractMaps.append(tractInfo)
447
448 # fixup the boxes to be smaller if needed, and put the origin at zero,
449 # this must be done after constructing the complete outer box
450 modifiedBoxes = []
451 origin = newBox.getBegin()
452 for iterBox in boxes:
453 localOrigin = iterBox.getBegin() - origin
454 localOrigin = Point2I(
455 x=int(np.floor(localOrigin.x / self.config.binFactor)),
456 y=int(np.floor(localOrigin.y / self.config.binFactor)),
457 )
458 localExtent = Extent2I(
459 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
460 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
461 )
462 tmpBox = Box2I(localOrigin, localExtent)
463 modifiedBoxes.append(tmpBox)
464 boxes = modifiedBoxes
465
466 # scale the container box
467 newBoxOrigin = Point2I(0, 0)
468 newBoxExtent = Extent2I(
469 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
470 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
471 )
472 newBox = Box2I(newBoxOrigin, newBoxExtent)
473
474 # Allocate storage for the mosaic
475 self.imageHandle = tempfile.NamedTemporaryFile()
476 self.maskHandle = tempfile.NamedTemporaryFile()
477 consolidatedImage = None
478 consolidatedMask = None
479
480 # Actually assemble the mosaic
481 maskDict = {}
482 tmpImg = None
483 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
484 rgb = handle.get()
485 rgbMask = handleMask.get()
486 maskDict = rgbMask.getMaskPlaneDict()
487 # allocate the memory for the mosaic
488 if consolidatedImage is None:
489 consolidatedImage = np.memmap(
490 self.imageHandle.name,
491 mode="w+",
492 shape=(newBox.getHeight(), newBox.getWidth(), 3),
493 dtype=rgb.dtype,
494 )
495 if consolidatedMask is None:
496 consolidatedMask = np.memmap(
497 self.maskHandle.name,
498 mode="w+",
499 shape=(newBox.getHeight(), newBox.getWidth()),
500 dtype=rgbMask.array.dtype,
501 )
502
503 if self.config.binFactor > 1:
504 # opencv wants things in x, y dimensions
505 shape = tuple(box.getDimensions())[::-1]
506 rgb = cv2.resize(
507 rgb,
508 dst=None,
509 dsize=shape,
510 fx=shape[0] / self.config.binFactor,
511 fy=shape[1] / self.config.binFactor,
512 )
513 rgbMask = cv2.resize(
514 rgbMask.array.astype(np.float32),
515 dst=None,
516 dsize=shape,
517 fx=shape[0] / self.config.binFactor,
518 fy=shape[1] / self.config.binFactor,
519 )
520 existing = ~np.all(consolidatedImage[*box.slices] == 0, axis=2)
521 if tmpImg is None or tmpImg.shape != rgb.shape:
522 ramp = np.linspace(0, 1, tractInfo.patch_border * 2)
523 tmpImg = np.zeros(rgb.shape[:2])
524 tmpImg[: tractInfo.patch_border * 2, :] = np.repeat(
525 np.expand_dims(ramp, 1), tmpImg.shape[1], axis=1
526 )
527
528 tmpImg[-1 * tractInfo.patch_border * 2:, :] = np.repeat(
529 np.expand_dims(1 - ramp, 1), tmpImg.shape[1], axis=1
530 )
531 tmpImg[:, : tractInfo.patch_border * 2] = np.repeat(
532 np.expand_dims(ramp, 0), tmpImg.shape[0], axis=0
533 )
534
535 tmpImg[:, -1 * tractInfo.patch_border * 2:] = np.repeat(
536 np.expand_dims(1 - ramp, 0), tmpImg.shape[0], axis=0
537 )
538 tmpImg = np.repeat(np.expand_dims(tmpImg, 2), 3, axis=2)
539
540 consolidatedImage[*box.slices][~existing, :] = rgb[~existing, :]
541 consolidatedImage[*box.slices][existing, :] = (
542 tmpImg[existing] * rgb[existing]
543 + (1 - tmpImg[existing]) * consolidatedImage[*box.slices][existing, :]
544 )
545
546 tmpMask = np.zeros_like(rgbMask.array)
547 tmpMask[existing] = np.bitwise_or(
548 rgbMask.array[existing], consolidatedMask[*box.slices][existing]
549 )
550 tmpMask[~existing] = rgbMask.array[~existing]
551 consolidatedMask[*box.slices] = tmpMask
552
553 for plugin in plugins.full():
554 if consolidatedImage is not None and consolidatedMask is not None:
555 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
556 # If consolidated image still None, that means there was no work to do.
557 # Return an empty image instead of letting this task fail.
558 if consolidatedImage is None:
559 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
560
561 return Struct(outputRGBMosaic=consolidatedImage)
562
563 def runQuantum(
564 self,
565 butlerQC: QuantumContext,
566 inputRefs: InputQuantizedConnection,
567 outputRefs: OutputQuantizedConnection,
568 ) -> None:
569 inputs = butlerQC.get(inputRefs)
570 outputs = self.run(**inputs)
571 butlerQC.put(outputs, outputRefs)
572 if hasattr(self, "imageHandle"):
573 self.imageHandle.close()
574 if hasattr(self, "maskHandle"):
575 self.maskHandle.close()
576
577 def makeInputsFromArrays(
578 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
579 ) -> Iterable[DeferredDatasetHandle]:
580 structuredInputs = []
581 for dataId, array in inputs:
582 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
583
584 return structuredInputs
A class to contain the data, WCS, and other information needed to describe an image of the sky.
Definition Exposure.h:72
Represent a 2-dimensional array of bitmask pixels.
Definition Mask.h:82
An integer coordinate rectangle.
Definition Box.h:55