22from __future__
import annotations
26 "PrettyPictureConnections",
27 "PrettyPictureConfig",
29 "PrettyMosaicConnections",
35from typing
import TYPE_CHECKING, cast, Any
38from lsst.daf.butler
import Butler, DeferredDatasetHandle
39from lsst.daf.butler
import DatasetRef
40from lsst.pex.config import Field, Config, ConfigDictField, ConfigField, ListField, ChoiceField
44 PipelineTaskConnections,
46 InMemoryDatasetHandle,
50from lsst.pipe.base.connectionTypes
import Input, Output
51from lsst.geom import Box2I, Point2I, Extent2I
54from ._plugins
import plugins
55from ._colorMapper
import lsstRGB
61 from numpy.typing
import NDArray
62 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
67 PipelineTaskConnections,
68 dimensions={
"tract",
"patch",
"skymap"},
69 defaultTemplates={
"coaddTypeName":
"deep"},
73 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
74 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
76 name=
"{coaddTypeName}CoaddPsfMatched",
77 storageClass=
"ExposureF",
78 dimensions=(
"tract",
"patch",
"skymap",
"band"),
83 doc=
"A RGB image created from the input data stored as a 3d array",
84 name=
"rgb_picture_array",
85 storageClass=
"NumpyArray",
86 dimensions=(
"tract",
"patch",
"skymap"),
89 outputRGBMask = Output(
90 doc=
"A Mask corresponding to the fused masks of the input channels",
91 name=
"rgb_picture_mask",
93 dimensions=(
"tract",
"patch",
"skymap"),
97class ChannelRGBConfig(
Config):
98 """This describes the rgb values of a given input channel.
100 For instance if this channel is red the values would be self.r = 1,
101 self.g = 0, self.b = 0. If the channel was cyan the values would be
102 self.r = 0, self.g = 1, self.b = 1.
105 r = Field[float](doc=
"The amount of red contained in this channel")
106 g = Field[float](doc=
"The amount of green contained in this channel")
107 b = Field[float](doc=
"The amount of blue contained in this channel")
110 for f
in (self.r, self.g, self.b):
112 raise ValueError(f
"Field {f} can not have a value less than 0 or greater than one")
113 return super().validate()
117 """Configurations to control how luminance is mapped in the rgb code"""
119 stretch = Field[float](doc=
"The stretch of the luminance in asinh", default=400)
120 max = Field[float](doc=
"The maximum allowed luminance on a 0 to 100 scale", default=85)
121 A = Field[float](doc=
"A scaling factor to apply post asinh stretching", default=1)
122 b0 = Field[float](doc=
"A linear offset to apply post asinh stretching", default=0.00)
123 minimum = Field[float](
124 doc=
"The minimum intensity value after stretch, values lower will be set to zero", default=0
126 floor = Field[float](doc=
"A scaling factor to apply to the luminance before asinh scaling", default=0.0)
127 Q = Field[float](doc=
"softening parameter", default=0.7)
130class LocalContrastConfig(
Config):
131 """Configuration to control local contrast enhancement of the luminance
134 doLocalContrast = Field[bool](
135 doc=
"Apply local contrast enhancements to the luminance channel", default=
True
137 highlights = Field[float](doc=
"Adjustment factor for the highlights", default=-0.9)
138 shadows = Field[float](doc=
"Adjustment factor for the shadows", default=0.5)
139 clarity = Field[float](doc=
"Amount of clarity to apply to contrast modification", default=0.1)
140 sigma = Field[float](
141 doc=
"The scale size of what is considered local in the contrast enhancement", default=30
143 maxLevel = Field[int](
144 doc=
"The maximum number of scales the contrast should be enhanced over, if None then all",
150class ScaleColorConfig(
Config):
151 """Controls color scaling in the rgb generation process."""
153 saturation = Field[float](
155 "The overall saturation factor with the scaled luminance between zero and one. "
156 "A value of one is not recommended as it makes bright pixels very saturated"
160 maxChroma = Field[float](
162 "The maximum chromaticity in the CIELCh color space, large "
163 "values will cause bright pixels to fall outside the RGB gamut."
169class RemapBoundsConfig(
Config):
170 """Remaps input images to a known range of values.
172 Often input images are not mapped to any defined range of values
173 (for instance if they are in count units). This controls how the units of
174 and image are mapped to a zero to one range by determining an upper
178 quant = Field[float](
180 "The maximum values of each of the three channels will be multiplied by this factor to "
181 "determine the maximum flux of the image, values larger than this quantity will be clipped."
185 absMax = Field[float](
186 doc=
"Instead of determining the maximum value from the image, use this fixed value instead",
190 scaleBoundFactor = Field[float](
192 "Factor used to compare absMax and the emperically determined"
193 "maximim. if emperical_max is less than scaleBoundFactor*absMax"
194 "then the emperical_max is used instead of absMax, even if it"
195 "is set. Do not set this field to skip this comparison."
201class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
203 doc=
"A dictionary that maps band names to their rgb channel configurations",
205 itemtype=ChannelRGBConfig,
208 imageRemappingConfig = ConfigField[RemapBoundsConfig](
209 doc=
"Configuration controlling channel normalization process"
211 luminanceConfig = ConfigField[LumConfig](
212 doc=
"Configuration for the luminance scaling when making an RGB image"
214 localContrastConfig = ConfigField[LocalContrastConfig](
215 doc=
"Configuration controlling the local contrast correction in RGB image production"
217 colorConfig = ConfigField[ScaleColorConfig](
218 doc=
"Configuration to control the color scaling process in RGB image production"
220 cieWhitePoint = ListField[float](
221 doc=
"The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
223 arrayType = ChoiceField[str](
224 doc=
"The dataset type for the output image array",
227 "uint8":
"Use 8 bit arrays, 255 max",
228 "uint16":
"Use 16 bit arrays, 65535 max",
229 "half":
"Use 16 bit float arrays, 1 max",
230 "float":
"Use 32 bit float arrays, 1 max",
233 doPSFDeconcovlve = Field[bool](
234 doc=
"Use the PSF in a richardson lucy deconvolution on the luminance channel.",
238 def setDefaults(self):
239 self.channelConfig[
"i"] = ChannelRGBConfig(r=1, g=0, b=0)
240 self.channelConfig[
"r"] = ChannelRGBConfig(r=0, g=1, b=0)
241 self.channelConfig[
"g"] = ChannelRGBConfig(r=0, g=0, b=1)
242 return super().setDefaults()
245class PrettyPictureTask(PipelineTask):
246 _DefaultName =
"prettyPictureTask"
247 ConfigClass = PrettyPictureConfig
251 def run(self, images: Mapping[str, Exposure]) -> Struct:
254 jointMask:
None | NDArray =
None
255 maskDict: Mapping[str, int] = {}
256 for channel, imageExposure
in images.items():
257 imageArray = imageExposure.image.array
259 for plug
in plugins.channel():
261 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict()
263 channels[channel] = imageArray
266 shape = imageArray.shape
267 maskDict = imageExposure.mask.getMaskPlaneDict()
268 if jointMask
is None:
269 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
270 jointMask |= imageExposure.mask.array
273 imageRArray = np.zeros(shape, dtype=np.float32)
274 imageGArray = np.zeros(shape, dtype=np.float32)
275 imageBArray = np.zeros(shape, dtype=np.float32)
277 for band, image
in channels.items():
278 mix = self.config.channelConfig[band]
280 imageRArray += mix.r * image
282 imageGArray += mix.g * image
284 imageBArray += mix.b * image
286 exposure = next(iter(images.values()))
287 box: Box2I = exposure.getBBox()
288 boxCenter = box.getCenter()
290 psf = exposure.psf.computeImage(boxCenter).array
295 colorImage = lsstRGB(
299 scaleLumKWargs=self.config.luminanceConfig.toDict(),
300 remapBoundsKwargs=self.config.imageRemappingConfig.toDict(),
301 scaleColorKWargs=self.config.colorConfig.toDict(),
302 **(self.config.localContrastConfig.toDict()),
303 cieWhitePoint=tuple(self.config.cieWhitePoint),
304 psf=psf
if self.config.doPSFDeconcovlve
else None,
309 match self.config.arrayType:
323 assert True,
"This code path should be unreachable"
329 assert jointMask
is not None
331 for plug
in plugins.partial():
332 colorImage = plug(colorImage, jointMask, maskDict)
335 lsstMask =
Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
336 lsstMask.array = jointMask
337 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask)
341 butlerQC: QuantumContext,
342 inputRefs: InputQuantizedConnection,
343 outputRefs: OutputQuantizedConnection,
345 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
346 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
347 outputs = self.run(sortedImages)
348 butlerQC.put(outputs, outputRefs)
350 def makeInputsFromRefs(
351 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
352 ) -> dict[str, Exposure]:
353 sortedImages: dict[str, Exposure] = {}
355 key: str = cast(str, ref.dataId[
"band"])
356 image = butler.get(ref)
357 sortedImages[key] = image
360 def makeInputsFromArrays(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
363 for key, array
in kwargs.items():
364 temp[key] =
Exposure(
Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
365 temp[key].image.array[:] = array
367 return self.makeInputsFromExposures(**temp)
369 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
371 for key, value
in kwargs.items():
372 sortedImages[key] = value
376class PrettyMosaicConnections(PipelineTaskConnections, dimensions=(
"tract",
"skymap")):
378 doc=
"Individual RGB images that are to go into the mosaic",
379 name=
"rgb_picture_array",
380 storageClass=
"NumpyArray",
381 dimensions=(
"tract",
"patch",
"skymap"),
387 doc=
"The skymap which the data has been mapped onto",
388 storageClass=
"SkyMap",
389 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
390 dimensions=(
"skymap",),
393 inputRGBMask = Input(
394 doc=
"Individual RGB images that are to go into the mosaic",
395 name=
"rgb_picture_mask",
397 dimensions=(
"tract",
"patch",
"skymap"),
402 outputRGBMosaic = Output(
403 doc=
"A RGB mosaic created from the input data stored as a 3d array",
404 name=
"rgb_mosaic_array",
405 storageClass=
"NumpyArray",
406 dimensions=(
"tract",
"skymap"),
410class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
411 binFactor = Field[int](doc=
"The factor to bin by when producing the mosaic")
414class PrettyMosaicTask(PipelineTask):
415 _DefaultName =
"prettyMosaicTask"
416 ConfigClass = PrettyMosaicConfig
422 inputRGB: Iterable[DeferredDatasetHandle],
424 inputRGBMask: Iterable[DeferredDatasetHandle],
431 for handle
in inputRGB:
432 dataId = handle.dataId
433 tractInfo: TractInfo = skyMap[dataId[
"tract"]]
434 patchInfo: PatchInfo = tractInfo[dataId[
"patch"]]
435 bbox = patchInfo.getOuterBBox()
438 tractMaps.append(tractInfo)
443 origin = newBox.getBegin()
444 for iterBox
in boxes:
445 localOrigin = iterBox.getBegin() - origin
446 localOrigin = Point2I(
447 x=int(np.floor(localOrigin.x / self.config.binFactor)),
448 y=int(np.floor(localOrigin.y / self.config.binFactor)),
450 localExtent = Extent2I(
451 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
452 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
454 tmpBox =
Box2I(localOrigin, localExtent)
455 modifiedBoxes.append(tmpBox)
456 boxes = modifiedBoxes
459 newBoxOrigin = Point2I(0, 0)
460 newBoxExtent = Extent2I(
461 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
462 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
464 newBox =
Box2I(newBoxOrigin, newBoxExtent)
467 self.imageHandle = tempfile.NamedTemporaryFile()
468 self.maskHandle = tempfile.NamedTemporaryFile()
469 consolidatedImage =
None
470 consolidatedMask =
None
475 for box, handle, handleMask, tractInfo
in zip(boxes, inputRGB, inputRGBMask, tractMaps):
477 rgbMask = handleMask.get()
478 maskDict = rgbMask.getMaskPlaneDict()
480 if consolidatedImage
is None:
481 consolidatedImage = np.memmap(
482 self.imageHandle.name,
484 shape=(newBox.getHeight(), newBox.getWidth(), 3),
487 if consolidatedMask
is None:
488 consolidatedMask = np.memmap(
489 self.maskHandle.name,
491 shape=(newBox.getHeight(), newBox.getWidth()),
492 dtype=rgbMask.array.dtype,
495 if self.config.binFactor > 1:
497 shape = tuple(box.getDimensions())[::-1]
502 fx=shape[0] / self.config.binFactor,
503 fy=shape[1] / self.config.binFactor,
505 rgbMask = cv2.resize(
506 rgbMask.array.astype(np.float32),
509 fx=shape[0] / self.config.binFactor,
510 fy=shape[1] / self.config.binFactor,
512 existing = ~np.all(consolidatedImage[*box.slices] == 0, axis=2)
513 if tmpImg
is None or tmpImg.shape != rgb.shape:
514 ramp = np.linspace(0, 1, tractInfo.patch_border * 2)
515 tmpImg = np.zeros(rgb.shape[:2])
516 tmpImg[: tractInfo.patch_border * 2, :] = np.repeat(
517 np.expand_dims(ramp, 1), tmpImg.shape[1], axis=1
520 tmpImg[-1 * tractInfo.patch_border * 2:, :] = np.repeat(
521 np.expand_dims(1 - ramp, 1), tmpImg.shape[1], axis=1
523 tmpImg[:, : tractInfo.patch_border * 2] = np.repeat(
524 np.expand_dims(ramp, 0), tmpImg.shape[0], axis=0
527 tmpImg[:, -1 * tractInfo.patch_border * 2:] = np.repeat(
528 np.expand_dims(1 - ramp, 0), tmpImg.shape[0], axis=0
530 tmpImg = np.repeat(np.expand_dims(tmpImg, 2), 3, axis=2)
532 consolidatedImage[*box.slices][~existing, :] = rgb[~existing, :]
533 consolidatedImage[*box.slices][existing, :] = (
534 tmpImg[existing] * rgb[existing]
535 + (1 - tmpImg[existing]) * consolidatedImage[*box.slices][existing, :]
538 tmpMask = np.zeros_like(rgbMask.array)
539 tmpMask[existing] = np.bitwise_or(
540 rgbMask.array[existing], consolidatedMask[*box.slices][existing]
542 tmpMask[~existing] = rgbMask.array[~existing]
543 consolidatedMask[*box.slices] = tmpMask
545 for plugin
in plugins.full():
546 if consolidatedImage
is not None and consolidatedMask
is not None:
547 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
550 if consolidatedImage
is None:
551 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
553 return Struct(outputRGBMosaic=consolidatedImage)
557 butlerQC: QuantumContext,
558 inputRefs: InputQuantizedConnection,
559 outputRefs: OutputQuantizedConnection,
561 inputs = butlerQC.get(inputRefs)
562 outputs = self.run(**inputs)
563 butlerQC.put(outputs, outputRefs)
564 if hasattr(self,
"imageHandle"):
565 self.imageHandle.close()
566 if hasattr(self,
"maskHandle"):
567 self.maskHandle.close()
569 def makeInputsFromArrays(
570 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
571 ) -> Iterable[DeferredDatasetHandle]:
572 structuredInputs = []
573 for dataId, array
in inputs:
574 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
576 return structuredInputs
A class to contain the data, WCS, and other information needed to describe an image of the sky.
Represent a 2-dimensional array of bitmask pixels.
An integer coordinate rectangle.