22from __future__
import annotations
26 "PrettyPictureConnections",
27 "PrettyPictureConfig",
29 "PrettyMosaicConnections",
35from typing
import TYPE_CHECKING, cast, Any
38from lsst.daf.butler
import Butler, DeferredDatasetHandle
39from lsst.daf.butler
import DatasetRef
40from lsst.pex.config import Field, Config, ConfigDictField, ConfigField, ListField, ChoiceField
44 PipelineTaskConnections,
46 InMemoryDatasetHandle,
50from lsst.pipe.base.connectionTypes
import Input, Output
51from lsst.geom import Box2I, Point2I, Extent2I
54from ._plugins
import plugins
55from ._colorMapper
import lsstRGB
61 from numpy.typing
import NDArray
62 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
67 PipelineTaskConnections,
68 dimensions={
"tract",
"patch",
"skymap"},
69 defaultTemplates={
"coaddTypeName":
"deep"},
73 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
74 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
76 name=
"{coaddTypeName}CoaddPsfMatched",
77 storageClass=
"ExposureF",
78 dimensions=(
"tract",
"patch",
"skymap",
"band"),
83 doc=
"A RGB image created from the input data stored as a 3d array",
84 name=
"rgb_picture_array",
85 storageClass=
"NumpyArray",
86 dimensions=(
"tract",
"patch",
"skymap"),
89 outputRGBMask = Output(
90 doc=
"A Mask corresponding to the fused masks of the input channels",
91 name=
"rgb_picture_mask",
93 dimensions=(
"tract",
"patch",
"skymap"),
97class ChannelRGBConfig(
Config):
98 """This describes the rgb values of a given input channel.
100 For instance if this channel is red the values would be self.r = 1,
101 self.g = 0, self.b = 0. If the channel was cyan the values would be
102 self.r = 0, self.g = 1, self.b = 1.
105 r = Field[float](doc=
"The amount of red contained in this channel")
106 g = Field[float](doc=
"The amount of green contained in this channel")
107 b = Field[float](doc=
"The amount of blue contained in this channel")
110 for f
in (self.r, self.g, self.b):
112 raise ValueError(f
"Field {f} can not have a value less than 0 or greater than one")
113 return super().validate()
117 """Configurations to control how luminance is mapped in the rgb code"""
119 stretch = Field[float](doc=
"The stretch of the luminance in asinh", default=400)
120 max = Field[float](doc=
"The maximum allowed luminance on a 0 to 100 scale", default=85)
121 A = Field[float](doc=
"A scaling factor to apply post asinh stretching", default=1)
122 b0 = Field[float](doc=
"A linear offset to apply post asinh stretching", default=0.00)
123 minimum = Field[float](
124 doc=
"The minimum intensity value after stretch, values lower will be set to zero", default=0
126 floor = Field[float](doc=
"A scaling factor to apply to the luminance before asinh scaling", default=0.0)
127 Q = Field[float](doc=
"softening parameter", default=0.7)
130class LocalContrastConfig(
Config):
131 """Configuration to control local contrast enhancement of the luminance
134 doLocalContrast = Field[bool](
135 doc=
"Apply local contrast enhancements to the luminance channel", default=
True
137 highlights = Field[float](doc=
"Adjustment factor for the highlights", default=-0.9)
138 shadows = Field[float](doc=
"Adjustment factor for the shadows", default=0.5)
139 clarity = Field[float](doc=
"Amount of clarity to apply to contrast modification", default=0.1)
140 sigma = Field[float](
141 doc=
"The scale size of what is considered local in the contrast enhancement", default=30
143 maxLevel = Field[int](
144 doc=
"The maximum number of scales the contrast should be enhanced over, if None then all",
150class ScaleColorConfig(
Config):
151 """Controls color scaling in the rgb generation process."""
153 saturation = Field[float](
155 "The overall saturation factor with the scaled luminance between zero and one. "
156 "A value of one is not recommended as it makes bright pixels very saturated"
160 maxChroma = Field[float](
162 "The maximum chromaticity in the CIELCh color space, large "
163 "values will cause bright pixels to fall outside the RGB gamut."
169class RemapBoundsConfig(
Config):
170 """Remaps input images to a known range of values.
172 Often input images are not mapped to any defined range of values
173 (for instance if they are in count units). This controls how the units of
174 and image are mapped to a zero to one range by determining an upper
178 quant = Field[float](
180 "The maximum values of each of the three channels will be multiplied by this factor to "
181 "determine the maximum flux of the image, values larger than this quantity will be clipped."
185 absMax = Field[float](
186 doc=
"Instead of determining the maximum value from the image, use this fixed value instead",
190 scaleBoundFactor = Field[float](
192 "Factor used to compare absMax and the emperically determined"
193 "maximim. if emperical_max is less than scaleBoundFactor*absMax"
194 "then the emperical_max is used instead of absMax, even if it"
195 "is set. Do not set this field to skip this comparison."
201class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
203 doc=
"A dictionary that maps band names to their rgb channel configurations",
205 itemtype=ChannelRGBConfig,
208 imageRemappingConfig = ConfigField[RemapBoundsConfig](
209 doc=
"Configuration controlling channel normalization process"
211 luminanceConfig = ConfigField[LumConfig](
212 doc=
"Configuration for the luminance scaling when making an RGB image"
214 localContrastConfig = ConfigField[LocalContrastConfig](
215 doc=
"Configuration controlling the local contrast correction in RGB image production"
217 colorConfig = ConfigField[ScaleColorConfig](
218 doc=
"Configuration to control the color scaling process in RGB image production"
220 cieWhitePoint = ListField[float](
221 doc=
"The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
223 arrayType = ChoiceField[str](
224 doc=
"The dataset type for the output image array",
227 "uint8":
"Use 8 bit arrays, 255 max",
228 "uint16":
"Use 16 bit arrays, 65535 max",
229 "half":
"Use 16 bit float arrays, 1 max",
230 "float":
"Use 32 bit float arrays, 1 max",
233 doPSFDeconcovlve = Field[bool](
234 doc=
"Use the PSF in a richardson lucy deconvolution on the luminance channel.", default=
True
236 exposureBrackets = ListField[float](
238 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
239 "then be fused into a final image"
242 default=[1.25, 1, 0.75],
245 def setDefaults(self):
246 self.channelConfig[
"i"] = ChannelRGBConfig(r=1, g=0, b=0)
247 self.channelConfig[
"r"] = ChannelRGBConfig(r=0, g=1, b=0)
248 self.channelConfig[
"g"] = ChannelRGBConfig(r=0, g=0, b=1)
249 return super().setDefaults()
252class PrettyPictureTask(PipelineTask):
253 _DefaultName =
"prettyPictureTask"
254 ConfigClass = PrettyPictureConfig
258 def run(self, images: Mapping[str, Exposure]) -> Struct:
261 jointMask:
None | NDArray =
None
262 maskDict: Mapping[str, int] = {}
263 for channel, imageExposure
in images.items():
264 imageArray = imageExposure.image.array
266 for plug
in plugins.channel():
268 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict()
270 channels[channel] = imageArray
273 shape = imageArray.shape
274 maskDict = imageExposure.mask.getMaskPlaneDict()
275 if jointMask
is None:
276 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
277 jointMask |= imageExposure.mask.array
280 imageRArray = np.zeros(shape, dtype=np.float32)
281 imageGArray = np.zeros(shape, dtype=np.float32)
282 imageBArray = np.zeros(shape, dtype=np.float32)
284 for band, image
in channels.items():
285 mix = self.config.channelConfig[band]
287 imageRArray += mix.r * image
289 imageGArray += mix.g * image
291 imageBArray += mix.b * image
293 exposure = next(iter(images.values()))
294 box: Box2I = exposure.getBBox()
295 boxCenter = box.getCenter()
297 psf = exposure.psf.computeImage(boxCenter).array
302 colorImage = lsstRGB(
306 scaleLumKWargs=self.config.luminanceConfig.toDict(),
307 remapBoundsKwargs=self.config.imageRemappingConfig.toDict(),
308 scaleColorKWargs=self.config.colorConfig.toDict(),
309 **(self.config.localContrastConfig.toDict()),
310 cieWhitePoint=tuple(self.config.cieWhitePoint),
311 psf=psf
if self.config.doPSFDeconcovlve
else None,
312 brackets=list(self.config.exposureBrackets)
if self.config.exposureBrackets
else None,
317 match self.config.arrayType:
331 assert True,
"This code path should be unreachable"
337 assert jointMask
is not None
339 for plug
in plugins.partial():
340 colorImage = plug(colorImage, jointMask, maskDict)
343 lsstMask =
Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
344 lsstMask.array = jointMask
345 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask)
349 butlerQC: QuantumContext,
350 inputRefs: InputQuantizedConnection,
351 outputRefs: OutputQuantizedConnection,
353 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
354 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
355 outputs = self.run(sortedImages)
356 butlerQC.put(outputs, outputRefs)
358 def makeInputsFromRefs(
359 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
360 ) -> dict[str, Exposure]:
361 sortedImages: dict[str, Exposure] = {}
363 key: str = cast(str, ref.dataId[
"band"])
364 image = butler.get(ref)
365 sortedImages[key] = image
368 def makeInputsFromArrays(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
371 for key, array
in kwargs.items():
372 temp[key] =
Exposure(
Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
373 temp[key].image.array[:] = array
375 return self.makeInputsFromExposures(**temp)
377 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
379 for key, value
in kwargs.items():
380 sortedImages[key] = value
384class PrettyMosaicConnections(PipelineTaskConnections, dimensions=(
"tract",
"skymap")):
386 doc=
"Individual RGB images that are to go into the mosaic",
387 name=
"rgb_picture_array",
388 storageClass=
"NumpyArray",
389 dimensions=(
"tract",
"patch",
"skymap"),
395 doc=
"The skymap which the data has been mapped onto",
396 storageClass=
"SkyMap",
397 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
398 dimensions=(
"skymap",),
401 inputRGBMask = Input(
402 doc=
"Individual RGB images that are to go into the mosaic",
403 name=
"rgb_picture_mask",
405 dimensions=(
"tract",
"patch",
"skymap"),
410 outputRGBMosaic = Output(
411 doc=
"A RGB mosaic created from the input data stored as a 3d array",
412 name=
"rgb_mosaic_array",
413 storageClass=
"NumpyArray",
414 dimensions=(
"tract",
"skymap"),
418class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
419 binFactor = Field[int](doc=
"The factor to bin by when producing the mosaic")
422class PrettyMosaicTask(PipelineTask):
423 _DefaultName =
"prettyMosaicTask"
424 ConfigClass = PrettyMosaicConfig
430 inputRGB: Iterable[DeferredDatasetHandle],
432 inputRGBMask: Iterable[DeferredDatasetHandle],
439 for handle
in inputRGB:
440 dataId = handle.dataId
441 tractInfo: TractInfo = skyMap[dataId[
"tract"]]
442 patchInfo: PatchInfo = tractInfo[dataId[
"patch"]]
443 bbox = patchInfo.getOuterBBox()
446 tractMaps.append(tractInfo)
451 origin = newBox.getBegin()
452 for iterBox
in boxes:
453 localOrigin = iterBox.getBegin() - origin
454 localOrigin = Point2I(
455 x=int(np.floor(localOrigin.x / self.config.binFactor)),
456 y=int(np.floor(localOrigin.y / self.config.binFactor)),
458 localExtent = Extent2I(
459 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
460 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
462 tmpBox =
Box2I(localOrigin, localExtent)
463 modifiedBoxes.append(tmpBox)
464 boxes = modifiedBoxes
467 newBoxOrigin = Point2I(0, 0)
468 newBoxExtent = Extent2I(
469 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
470 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
472 newBox =
Box2I(newBoxOrigin, newBoxExtent)
475 self.imageHandle = tempfile.NamedTemporaryFile()
476 self.maskHandle = tempfile.NamedTemporaryFile()
477 consolidatedImage =
None
478 consolidatedMask =
None
483 for box, handle, handleMask, tractInfo
in zip(boxes, inputRGB, inputRGBMask, tractMaps):
485 rgbMask = handleMask.get()
486 maskDict = rgbMask.getMaskPlaneDict()
488 if consolidatedImage
is None:
489 consolidatedImage = np.memmap(
490 self.imageHandle.name,
492 shape=(newBox.getHeight(), newBox.getWidth(), 3),
495 if consolidatedMask
is None:
496 consolidatedMask = np.memmap(
497 self.maskHandle.name,
499 shape=(newBox.getHeight(), newBox.getWidth()),
500 dtype=rgbMask.array.dtype,
503 if self.config.binFactor > 1:
505 shape = tuple(box.getDimensions())[::-1]
510 fx=shape[0] / self.config.binFactor,
511 fy=shape[1] / self.config.binFactor,
513 rgbMask = cv2.resize(
514 rgbMask.array.astype(np.float32),
517 fx=shape[0] / self.config.binFactor,
518 fy=shape[1] / self.config.binFactor,
520 existing = ~np.all(consolidatedImage[*box.slices] == 0, axis=2)
521 if tmpImg
is None or tmpImg.shape != rgb.shape:
522 ramp = np.linspace(0, 1, tractInfo.patch_border * 2)
523 tmpImg = np.zeros(rgb.shape[:2])
524 tmpImg[: tractInfo.patch_border * 2, :] = np.repeat(
525 np.expand_dims(ramp, 1), tmpImg.shape[1], axis=1
528 tmpImg[-1 * tractInfo.patch_border * 2:, :] = np.repeat(
529 np.expand_dims(1 - ramp, 1), tmpImg.shape[1], axis=1
531 tmpImg[:, : tractInfo.patch_border * 2] = np.repeat(
532 np.expand_dims(ramp, 0), tmpImg.shape[0], axis=0
535 tmpImg[:, -1 * tractInfo.patch_border * 2:] = np.repeat(
536 np.expand_dims(1 - ramp, 0), tmpImg.shape[0], axis=0
538 tmpImg = np.repeat(np.expand_dims(tmpImg, 2), 3, axis=2)
540 consolidatedImage[*box.slices][~existing, :] = rgb[~existing, :]
541 consolidatedImage[*box.slices][existing, :] = (
542 tmpImg[existing] * rgb[existing]
543 + (1 - tmpImg[existing]) * consolidatedImage[*box.slices][existing, :]
546 tmpMask = np.zeros_like(rgbMask.array)
547 tmpMask[existing] = np.bitwise_or(
548 rgbMask.array[existing], consolidatedMask[*box.slices][existing]
550 tmpMask[~existing] = rgbMask.array[~existing]
551 consolidatedMask[*box.slices] = tmpMask
553 for plugin
in plugins.full():
554 if consolidatedImage
is not None and consolidatedMask
is not None:
555 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
558 if consolidatedImage
is None:
559 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
561 return Struct(outputRGBMosaic=consolidatedImage)
565 butlerQC: QuantumContext,
566 inputRefs: InputQuantizedConnection,
567 outputRefs: OutputQuantizedConnection,
569 inputs = butlerQC.get(inputRefs)
570 outputs = self.run(**inputs)
571 butlerQC.put(outputs, outputRefs)
572 if hasattr(self,
"imageHandle"):
573 self.imageHandle.close()
574 if hasattr(self,
"maskHandle"):
575 self.maskHandle.close()
577 def makeInputsFromArrays(
578 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
579 ) -> Iterable[DeferredDatasetHandle]:
580 structuredInputs = []
581 for dataId, array
in inputs:
582 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
584 return structuredInputs
A class to contain the data, WCS, and other information needed to describe an image of the sky.
Represent a 2-dimensional array of bitmask pixels.
An integer coordinate rectangle.