LSST Applications g0f08755f38+c89d42e150,g1635faa6d4+b6cf076a36,g1653933729+a8ce1bb630,g1a0ca8cf93+4c08b13bf7,g28da252d5a+f33f8200ef,g29321ee8c0+0187be18b1,g2bbee38e9b+9634bc57db,g2bc492864f+9634bc57db,g2cdde0e794+c2c89b37c4,g3156d2b45e+41e33cbcdc,g347aa1857d+9634bc57db,g35bb328faa+a8ce1bb630,g3a166c0a6a+9634bc57db,g3e281a1b8c+9f2c4e2fc3,g414038480c+077ccc18e7,g41af890bb2+e740673f1a,g5fbc88fb19+17cd334064,g7642f7d749+c89d42e150,g781aacb6e4+a8ce1bb630,g80478fca09+f8b2ab54e1,g82479be7b0+e2bd23ab8b,g858d7b2824+c89d42e150,g9125e01d80+a8ce1bb630,g9726552aa6+10f999ec6a,ga5288a1d22+065360aec4,gacf8899fa4+9553554aa7,gae0086650b+a8ce1bb630,gb58c049af0+d64f4d3760,gbd46683f8f+ac57cbb13d,gc28159a63d+9634bc57db,gcf0d15dbbd+e37acf7834,gda3e153d99+c89d42e150,gda6a2b7d83+e37acf7834,gdaeeff99f8+1711a396fd,ge2409df99d+cb1e6652d6,ge79ae78c31+9634bc57db,gf0baf85859+147a0692ba,gf3967379c6+02b11634a5,w.2024.45
LSST Data Management Base Package
Loading...
Searching...
No Matches
_task.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = (
25 "PrettyPictureTask",
26 "PrettyPictureConnections",
27 "PrettyPictureConfig",
28 "PrettyMosaicTask",
29 "PrettyMosaicConnections",
30 "PrettyMosaicConfig",
31)
32
33from collections.abc import Iterable, Mapping
34import numpy as np
35from typing import TYPE_CHECKING, cast, Any
36from lsst.skymap import BaseSkyMap
37
38from lsst.daf.butler import Butler, DeferredDatasetHandle
39from lsst.daf.butler import DatasetRef
40from lsst.pex.config import Field, Config, ConfigDictField, ConfigField, ListField, ChoiceField
41from lsst.pipe.base import (
42 PipelineTask,
43 PipelineTaskConfig,
44 PipelineTaskConnections,
45 Struct,
46 InMemoryDatasetHandle,
47)
48import cv2
49
50from lsst.pipe.base.connectionTypes import Input, Output
51from lsst.geom import Box2I, Point2I, Extent2I
52from lsst.afw.image import Exposure, Mask
53
54from ._plugins import plugins
55from ._colorMapper import lsstRGB
56
57import tempfile
58
59
60if TYPE_CHECKING:
61 from numpy.typing import NDArray
62 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
63 from lsst.skymap import TractInfo, PatchInfo
64
65
67 PipelineTaskConnections,
68 dimensions={"tract", "patch", "skymap"},
69 defaultTemplates={"coaddTypeName": "deep"},
70):
71 inputCoadds = Input(
72 doc=(
73 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
74 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
75 ),
76 name="{coaddTypeName}CoaddPsfMatched",
77 storageClass="ExposureF",
78 dimensions=("tract", "patch", "skymap", "band"),
79 multiple=True,
80 )
81
82 outputRGB = Output(
83 doc="A RGB image created from the input data stored as a 3d array",
84 name="rgb_picture_array",
85 storageClass="NumpyArray",
86 dimensions=("tract", "patch", "skymap"),
87 )
88
89 outputRGBMask = Output(
90 doc="A Mask corresponding to the fused masks of the input channels",
91 name="rgb_picture_mask",
92 storageClass="Mask",
93 dimensions=("tract", "patch", "skymap"),
94 )
95
96
97class ChannelRGBConfig(Config):
98 """This describes the rgb values of a given input channel.
99
100 For instance if this channel is red the values would be self.r = 1,
101 self.g = 0, self.b = 0. If the channel was cyan the values would be
102 self.r = 0, self.g = 1, self.b = 1.
103 """
104
105 r = Field[float](doc="The amount of red contained in this channel")
106 g = Field[float](doc="The amount of green contained in this channel")
107 b = Field[float](doc="The amount of blue contained in this channel")
108
109 def validate(self):
110 for f in (self.r, self.g, self.b):
111 if f < 0 or f > 1:
112 raise ValueError(f"Field {f} can not have a value less than 0 or greater than one")
113 return super().validate()
114
115
116class LumConfig(Config):
117 """Configurations to control how luminance is mapped in the rgb code"""
118
119 stretch = Field[float](doc="The stretch of the luminance in asinh", default=400)
120 max = Field[float](doc="The maximum allowed luminance on a 0 to 100 scale", default=85)
121 A = Field[float](doc="A scaling factor to apply post asinh stretching", default=1)
122 b0 = Field[float](doc="A linear offset to apply post asinh stretching", default=0.00)
123 minimum = Field[float](
124 doc="The minimum intensity value after stretch, values lower will be set to zero", default=0
125 )
126 floor = Field[float](doc="A scaling factor to apply to the luminance before asinh scaling", default=0.0)
127 Q = Field[float](doc="softening parameter", default=0.7)
128
129
130class LocalContrastConfig(Config):
131 """Configuration to control local contrast enhancement of the luminance
132 channel."""
133
134 doLocalContrast = Field[bool](
135 doc="Apply local contrast enhancements to the luminance channel", default=True
136 )
137 highlights = Field[float](doc="Adjustment factor for the highlights", default=-0.9)
138 shadows = Field[float](doc="Adjustment factor for the shadows", default=0.5)
139 clarity = Field[float](doc="Amount of clarity to apply to contrast modification", default=0.1)
140 sigma = Field[float](
141 doc="The scale size of what is considered local in the contrast enhancement", default=30
142 )
143 maxLevel = Field[int](
144 doc="The maximum number of scales the contrast should be enhanced over, if None then all",
145 default=4,
146 optional=True,
147 )
148
149
150class ScaleColorConfig(Config):
151 """Controls color scaling in the rgb generation process."""
152
153 saturation = Field[float](
154 doc=(
155 "The overall saturation factor with the scaled luminance between zero and one. "
156 "A value of one is not recommended as it makes bright pixels very saturated"
157 ),
158 default=0.5,
159 )
160 maxChroma = Field[float](
161 doc=(
162 "The maximum chromaticity in the CIELCh color space, large "
163 "values will cause bright pixels to fall outside the RGB gamut."
164 ),
165 default=50.0,
166 )
167
168
169class RemapBoundsConfig(Config):
170 """Remaps input images to a known range of values.
171
172 Often input images are not mapped to any defined range of values
173 (for instance if they are in count units). This controls how the units of
174 and image are mapped to a zero to one range by determining an upper
175 bound.
176 """
177
178 quant = Field[float](
179 doc=(
180 "The maximum values of each of the three channels will be multiplied by this factor to "
181 "determine the maximum flux of the image, values larger than this quantity will be clipped."
182 ),
183 default=0.8,
184 )
185 absMax = Field[float](
186 doc="Instead of determining the maximum value from the image, use this fixed value instead",
187 default=220,
188 optional=True,
189 )
190 scaleBoundFactor = Field[float](
191 doc=(
192 "Factor used to compare absMax and the emperically determined"
193 "maximim. if emperical_max is less than scaleBoundFactor*absMax"
194 "then the emperical_max is used instead of absMax, even if it"
195 "is set. Do not set this field to skip this comparison."
196 ),
197 optional=True,
198 )
199
200
201class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
202 channelConfig = ConfigDictField(
203 doc="A dictionary that maps band names to their rgb channel configurations",
204 keytype=str,
205 itemtype=ChannelRGBConfig,
206 default={},
207 )
208 imageRemappingConfig = ConfigField[RemapBoundsConfig](
209 doc="Configuration controlling channel normalization process"
210 )
211 luminanceConfig = ConfigField[LumConfig](
212 doc="Configuration for the luminance scaling when making an RGB image"
213 )
214 localContrastConfig = ConfigField[LocalContrastConfig](
215 doc="Configuration controlling the local contrast correction in RGB image production"
216 )
217 colorConfig = ConfigField[ScaleColorConfig](
218 doc="Configuration to control the color scaling process in RGB image production"
219 )
220 cieWhitePoint = ListField[float](
221 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
222 )
223 arrayType = ChoiceField[str](
224 doc="The dataset type for the output image array",
225 default="uint8",
226 allowed={
227 "uint8": "Use 8 bit arrays, 255 max",
228 "uint16": "Use 16 bit arrays, 65535 max",
229 "half": "Use 16 bit float arrays, 1 max",
230 "float": "Use 32 bit float arrays, 1 max",
231 },
232 )
233 doPSFDeconcovlve = Field[bool](
234 doc="Use the PSF in a richardson lucy deconvolution on the luminance channel.",
235 default=True
236 )
237
238 def setDefaults(self):
239 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
240 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
241 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
242 return super().setDefaults()
243
244
245class PrettyPictureTask(PipelineTask):
246 _DefaultName = "prettyPictureTask"
247 ConfigClass = PrettyPictureConfig
248
249 config: ConfigClass
250
251 def run(self, images: Mapping[str, Exposure]) -> Struct:
252 channels = {}
253 shape = (0, 0)
254 jointMask: None | NDArray = None
255 maskDict: Mapping[str, int] = {}
256 for channel, imageExposure in images.items():
257 imageArray = imageExposure.image.array
258 # run all the plugins designed for array based interaction
259 for plug in plugins.channel():
260 imageArray = plug(
261 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict()
262 ).astype(np.float32)
263 channels[channel] = imageArray
264 # This will get done each loop, but they are trivial lookups so it
265 # does not matter
266 shape = imageArray.shape
267 maskDict = imageExposure.mask.getMaskPlaneDict()
268 if jointMask is None:
269 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
270 jointMask |= imageExposure.mask.array
271
272 # mix the images to rgb
273 imageRArray = np.zeros(shape, dtype=np.float32)
274 imageGArray = np.zeros(shape, dtype=np.float32)
275 imageBArray = np.zeros(shape, dtype=np.float32)
276
277 for band, image in channels.items():
278 mix = self.config.channelConfig[band]
279 if mix.r:
280 imageRArray += mix.r * image
281 if mix.g:
282 imageGArray += mix.g * image
283 if mix.b:
284 imageBArray += mix.b * image
285
286 exposure = next(iter(images.values()))
287 box: Box2I = exposure.getBBox()
288 boxCenter = box.getCenter()
289 try:
290 psf = exposure.psf.computeImage(boxCenter).array
291 except Exception:
292 psf = None
293 # Ignore type because Exposures do in fact have a bbox, but it is c++
294 # and not typed.
295 colorImage = lsstRGB(
296 imageRArray,
297 imageGArray,
298 imageBArray,
299 scaleLumKWargs=self.config.luminanceConfig.toDict(),
300 remapBoundsKwargs=self.config.imageRemappingConfig.toDict(),
301 scaleColorKWargs=self.config.colorConfig.toDict(),
302 **(self.config.localContrastConfig.toDict()),
303 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
304 psf=psf if self.config.doPSFDeconcovlve else None,
305 )
306
307 # Find the dataset type and thus the maximum values as well
308 maxVal: int | float
309 match self.config.arrayType:
310 case "uint8":
311 dtype = np.uint8
312 maxVal = 255
313 case "uint16":
314 dtype = np.uint16
315 maxVal = 65535
316 case "half":
317 dtype = np.half
318 maxVal = 1.0
319 case "float":
320 dtype = np.float32
321 maxVal = 1.0
322 case _:
323 assert True, "This code path should be unreachable"
324
325 # lsstRGB returns an image in 0-1 scale it to the maximum value
326 colorImage *= maxVal # type: ignore
327
328 # assert for typing reasons
329 assert jointMask is not None
330 # Run any image level correction plugins
331 for plug in plugins.partial():
332 colorImage = plug(colorImage, jointMask, maskDict)
333
334 # pack the joint mask back into a mask object
335 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
336 lsstMask.array = jointMask # type: ignore
337 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask) # type: ignore
338
339 def runQuantum(
340 self,
341 butlerQC: QuantumContext,
342 inputRefs: InputQuantizedConnection,
343 outputRefs: OutputQuantizedConnection,
344 ) -> None:
345 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
346 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
347 outputs = self.run(sortedImages)
348 butlerQC.put(outputs, outputRefs)
349
350 def makeInputsFromRefs(
351 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
352 ) -> dict[str, Exposure]:
353 sortedImages: dict[str, Exposure] = {}
354 for ref in refs:
355 key: str = cast(str, ref.dataId["band"])
356 image = butler.get(ref)
357 sortedImages[key] = image
358 return sortedImages
359
360 def makeInputsFromArrays(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
361 # ignore type because there are not proper stubs for afw
362 temp = {}
363 for key, array in kwargs.items():
364 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
365 temp[key].image.array[:] = array
366
367 return self.makeInputsFromExposures(**temp)
368
369 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
370 sortedImages = {}
371 for key, value in kwargs.items():
372 sortedImages[key] = value
373 return sortedImages
374
375
376class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
377 inputRGB = Input(
378 doc="Individual RGB images that are to go into the mosaic",
379 name="rgb_picture_array",
380 storageClass="NumpyArray",
381 dimensions=("tract", "patch", "skymap"),
382 multiple=True,
383 deferLoad=True,
384 )
385
386 skyMap = Input(
387 doc="The skymap which the data has been mapped onto",
388 storageClass="SkyMap",
389 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
390 dimensions=("skymap",),
391 )
392
393 inputRGBMask = Input(
394 doc="Individual RGB images that are to go into the mosaic",
395 name="rgb_picture_mask",
396 storageClass="Mask",
397 dimensions=("tract", "patch", "skymap"),
398 multiple=True,
399 deferLoad=True,
400 )
401
402 outputRGBMosaic = Output(
403 doc="A RGB mosaic created from the input data stored as a 3d array",
404 name="rgb_mosaic_array",
405 storageClass="NumpyArray",
406 dimensions=("tract", "skymap"),
407 )
408
409
410class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
411 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
412
413
414class PrettyMosaicTask(PipelineTask):
415 _DefaultName = "prettyMosaicTask"
416 ConfigClass = PrettyMosaicConfig
417
418 config: ConfigClass
419
420 def run(
421 self,
422 inputRGB: Iterable[DeferredDatasetHandle],
423 skyMap: BaseSkyMap,
424 inputRGBMask: Iterable[DeferredDatasetHandle],
425 ) -> Struct:
426 # create the bounding region
427 newBox = Box2I()
428 # store the bounds as they are retrieved from the skymap
429 boxes = []
430 tractMaps = []
431 for handle in inputRGB:
432 dataId = handle.dataId
433 tractInfo: TractInfo = skyMap[dataId["tract"]]
434 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
435 bbox = patchInfo.getOuterBBox()
436 boxes.append(bbox)
437 newBox.include(bbox)
438 tractMaps.append(tractInfo)
439
440 # fixup the boxes to be smaller if needed, and put the origin at zero,
441 # this must be done after constructing the complete outer box
442 modifiedBoxes = []
443 origin = newBox.getBegin()
444 for iterBox in boxes:
445 localOrigin = iterBox.getBegin() - origin
446 localOrigin = Point2I(
447 x=int(np.floor(localOrigin.x / self.config.binFactor)),
448 y=int(np.floor(localOrigin.y / self.config.binFactor)),
449 )
450 localExtent = Extent2I(
451 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
452 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
453 )
454 tmpBox = Box2I(localOrigin, localExtent)
455 modifiedBoxes.append(tmpBox)
456 boxes = modifiedBoxes
457
458 # scale the container box
459 newBoxOrigin = Point2I(0, 0)
460 newBoxExtent = Extent2I(
461 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
462 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
463 )
464 newBox = Box2I(newBoxOrigin, newBoxExtent)
465
466 # Allocate storage for the mosaic
467 self.imageHandle = tempfile.NamedTemporaryFile()
468 self.maskHandle = tempfile.NamedTemporaryFile()
469 consolidatedImage = None
470 consolidatedMask = None
471
472 # Actually assemble the mosaic
473 maskDict = {}
474 tmpImg = None
475 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
476 rgb = handle.get()
477 rgbMask = handleMask.get()
478 maskDict = rgbMask.getMaskPlaneDict()
479 # allocate the memory for the mosaic
480 if consolidatedImage is None:
481 consolidatedImage = np.memmap(
482 self.imageHandle.name,
483 mode="w+",
484 shape=(newBox.getHeight(), newBox.getWidth(), 3),
485 dtype=rgb.dtype,
486 )
487 if consolidatedMask is None:
488 consolidatedMask = np.memmap(
489 self.maskHandle.name,
490 mode="w+",
491 shape=(newBox.getHeight(), newBox.getWidth()),
492 dtype=rgbMask.array.dtype,
493 )
494
495 if self.config.binFactor > 1:
496 # opencv wants things in x, y dimensions
497 shape = tuple(box.getDimensions())[::-1]
498 rgb = cv2.resize(
499 rgb,
500 dst=None,
501 dsize=shape,
502 fx=shape[0] / self.config.binFactor,
503 fy=shape[1] / self.config.binFactor,
504 )
505 rgbMask = cv2.resize(
506 rgbMask.array.astype(np.float32),
507 dst=None,
508 dsize=shape,
509 fx=shape[0] / self.config.binFactor,
510 fy=shape[1] / self.config.binFactor,
511 )
512 existing = ~np.all(consolidatedImage[*box.slices] == 0, axis=2)
513 if tmpImg is None or tmpImg.shape != rgb.shape:
514 ramp = np.linspace(0, 1, tractInfo.patch_border * 2)
515 tmpImg = np.zeros(rgb.shape[:2])
516 tmpImg[: tractInfo.patch_border * 2, :] = np.repeat(
517 np.expand_dims(ramp, 1), tmpImg.shape[1], axis=1
518 )
519
520 tmpImg[-1 * tractInfo.patch_border * 2:, :] = np.repeat(
521 np.expand_dims(1 - ramp, 1), tmpImg.shape[1], axis=1
522 )
523 tmpImg[:, : tractInfo.patch_border * 2] = np.repeat(
524 np.expand_dims(ramp, 0), tmpImg.shape[0], axis=0
525 )
526
527 tmpImg[:, -1 * tractInfo.patch_border * 2:] = np.repeat(
528 np.expand_dims(1 - ramp, 0), tmpImg.shape[0], axis=0
529 )
530 tmpImg = np.repeat(np.expand_dims(tmpImg, 2), 3, axis=2)
531
532 consolidatedImage[*box.slices][~existing, :] = rgb[~existing, :]
533 consolidatedImage[*box.slices][existing, :] = (
534 tmpImg[existing] * rgb[existing]
535 + (1 - tmpImg[existing]) * consolidatedImage[*box.slices][existing, :]
536 )
537
538 tmpMask = np.zeros_like(rgbMask.array)
539 tmpMask[existing] = np.bitwise_or(
540 rgbMask.array[existing], consolidatedMask[*box.slices][existing]
541 )
542 tmpMask[~existing] = rgbMask.array[~existing]
543 consolidatedMask[*box.slices] = tmpMask
544
545 for plugin in plugins.full():
546 if consolidatedImage is not None and consolidatedMask is not None:
547 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
548 # If consolidated image still None, that means there was no work to do.
549 # Return an empty image instead of letting this task fail.
550 if consolidatedImage is None:
551 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
552
553 return Struct(outputRGBMosaic=consolidatedImage)
554
555 def runQuantum(
556 self,
557 butlerQC: QuantumContext,
558 inputRefs: InputQuantizedConnection,
559 outputRefs: OutputQuantizedConnection,
560 ) -> None:
561 inputs = butlerQC.get(inputRefs)
562 outputs = self.run(**inputs)
563 butlerQC.put(outputs, outputRefs)
564 if hasattr(self, "imageHandle"):
565 self.imageHandle.close()
566 if hasattr(self, "maskHandle"):
567 self.maskHandle.close()
568
569 def makeInputsFromArrays(
570 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
571 ) -> Iterable[DeferredDatasetHandle]:
572 structuredInputs = []
573 for dataId, array in inputs:
574 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
575
576 return structuredInputs
A class to contain the data, WCS, and other information needed to describe an image of the sky.
Definition Exposure.h:72
Represent a 2-dimensional array of bitmask pixels.
Definition Mask.h:81
An integer coordinate rectangle.
Definition Box.h:55