LSST Applications g063fba187b+cac8b7c890,g0f08755f38+6aee506743,g1653933729+a8ce1bb630,g168dd56ebc+a8ce1bb630,g1a2382251a+b4475c5878,g1dcb35cd9c+8f9bc1652e,g20f6ffc8e0+6aee506743,g217e2c1bcf+73dee94bd0,g28da252d5a+1f19c529b9,g2bbee38e9b+3f2625acfc,g2bc492864f+3f2625acfc,g3156d2b45e+6e55a43351,g32e5bea42b+1bb94961c2,g347aa1857d+3f2625acfc,g35bb328faa+a8ce1bb630,g3a166c0a6a+3f2625acfc,g3e281a1b8c+c5dd892a6c,g3e8969e208+a8ce1bb630,g414038480c+5927e1bc1e,g41af890bb2+8a9e676b2a,g7af13505b9+809c143d88,g80478fca09+6ef8b1810f,g82479be7b0+f568feb641,g858d7b2824+6aee506743,g89c8672015+f4add4ffd5,g9125e01d80+a8ce1bb630,ga5288a1d22+2903d499ea,gb58c049af0+d64f4d3760,gc28159a63d+3f2625acfc,gcab2d0539d+b12535109e,gcf0d15dbbd+46a3f46ba9,gda6a2b7d83+46a3f46ba9,gdaeeff99f8+1711a396fd,ge79ae78c31+3f2625acfc,gef2f8181fd+0a71e47438,gf0baf85859+c1f95f4921,gfa517265be+6aee506743,gfa999e8aa5+17cd334064,w.2024.51
LSST Data Management Base Package
Loading...
Searching...
No Matches
observation.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Observation", "convolve"]
25
26from typing import Any, cast
27
28import numpy as np
29import numpy.typing as npt
30
31from .bbox import Box
32from .fft import Fourier, _pad, centered
33from .fft import convolve as fft_convolve
34from .fft import match_kernel
35from .image import Image
36
37
38def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None = None) -> np.ndarray:
39 """Create filter coordinate grid needed for the apply filter function
40
41 Parameters
42 ----------
43 filter_values:
44 The 2D array of the filter to apply.
45 center:
46 The center (y,x) of the filter. If `center` is `None` then
47 `filter_values` must have an odd number of rows and columns
48 and the center will be set to the center of `filter_values`.
49
50 Returns
51 -------
52 coords:
53 The coordinates of the pixels in `filter_values`,
54 where the coordinates of the `center` pixel are `(0,0)`.
55 """
56 if filter_values.ndim != 2:
57 raise ValueError("`filter_values` must be 2D")
58 if center is None:
59 if filter_values.shape[0] % 2 == 0 or filter_values.shape[1] % 2 == 0:
60 msg = """Ambiguous center of the `filter_values` array,
61 you must use a `filter_values` array
62 with an odd number of rows and columns or
63 calculate `coords` on your own."""
64 raise ValueError(msg)
65 center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore
66 x = np.arange(filter_values.shape[1])
67 y = np.arange(filter_values.shape[0])
68 x, y = np.meshgrid(x, y)
69 x -= center[1]
70 y -= center[0]
71 coords = np.dstack([y, x])
72 return coords
73
74
75def get_filter_bounds(coords: np.ndarray) -> tuple[int, int, int, int]:
76 """Get the slices in x and y to apply a filter
77
78 Parameters
79 ----------
80 coords:
81 The coordinates of the filter,
82 defined by `get_filter_coords`.
83
84 Returns
85 -------
86 y_start, y_end, x_start, x_end:
87 The start and end of each slice that is passed to `apply_filter`.
88 """
89 z = np.zeros((len(coords),), dtype=int)
90 # Set the y slices
91 y_start = np.max([z, coords[:, 0]], axis=0)
92 y_end = -np.min([z, coords[:, 0]], axis=0)
93 # Set the x slices
94 x_start = np.max([z, coords[:, 1]], axis=0)
95 x_end = -np.min([z, coords[:, 1]], axis=0)
96 return y_start, y_end, x_start, x_end
97
98
99def convolve(image: np.ndarray, psf: np.ndarray, bounds: tuple[int, int, int, int]):
100 """Convolve an image with a PSF in real space
101
102 Parameters
103 ----------
104 image:
105 The multi-band image to convolve.
106 psf:
107 The psf to convolve the image with.
108 bounds:
109 The filter bounds required by the ``apply_filter`` C++ method,
110 usually obtained by calling `get_filter_bounds`.
111 """
112 from lsst.scarlet.lite.operators_pybind11 import apply_filter # type: ignore
113
114 result = np.empty(image.shape, dtype=image.dtype)
115 for band in range(len(image)):
116 img = image[band]
117
119 img,
120 psf[band].reshape(-1),
121 bounds[0],
122 bounds[1],
123 bounds[2],
124 bounds[3],
125 result[band],
126 )
127 return result
128
129
130def _set_image_like(images: np.ndarray | Image, bands: tuple | None = None, bbox: Box | None = None) -> Image:
131 """Ensure that an image-like array is cast appropriately as an image
132
133 Parameters
134 ----------
135 images:
136 The multiband image-like array to cast as an Image.
137 If it already has `bands` and `bbox` properties then it is returned
138 with no modifications.
139 bands:
140 The bands for the multiband-image.
141 If `images` is a numpy array, this parameter is mandatory.
142 If `images` is an `Image` and `bands` is not `None`,
143 then `bands` is ignored.
144 bbox:
145 Bounding box containing the image.
146 If `images` is a numpy array, this parameter is mandatory.
147 If `images` is an `Image` and `bbox` is not `None`,
148 then `bbox` is ignored.
149
150 Returns
151 -------
152 images: Image
153 The input images converted into an image.
154 """
155 if isinstance(images, Image):
156 # This is already an image
157 if bbox is not None and images.bbox != bbox:
158 raise ValueError(f"Bounding boxes {images.bbox} and {bbox} do not agree")
159 return images
160
161 if bbox is None:
162 bbox = Box(images.shape[-2:])
163 return Image(images, bands=bands, yx0=cast(tuple[int, int], bbox.origin))
164
165
167 """A single observation
168
169 This class contains all of the observed images and derived
170 properties, like PSFs, variance map, and weight maps,
171 required for most optimizers.
172 This includes methods to match a scarlet model PSF to the oberved PSF
173 in each band.
174
175 Notes
176 -----
177 This is effectively a combination of the `Observation` and
178 `Renderer` class from scarlet main, greatly simplified due
179 to the assumptions that the observations are all resampled
180 onto the same pixel grid and that the `images` contain all
181 of the information for all of the model bands.
182
183 Parameters
184 ----------
185 images:
186 (bands, y, x) array of observed images.
187 variance:
188 (bands, y, x) array of variance for each image pixel.
189 weights:
190 (bands, y, x) array of weights to use when calculate the
191 likelihood of each pixel.
192 psfs:
193 (bands, y, x) array of the PSF image in each band.
194 model_psf:
195 (bands, y, x) array of the model PSF image in each band.
196 If `model_psf` is `None` then convolution is performed,
197 which should only be done when the observation is a
198 PSF matched coadd, and the scarlet model has the same PSF.
199 noise_rms:
200 Per-band average noise RMS. If `noise_rms` is `None` then the mean
201 of the sqrt of the variance is used.
202 bbox:
203 The bounding box containing the model. If `bbox` is `None` then
204 a `Box` is created that is the shape of `images` with an origin
205 at `(0, 0)`.
206 padding:
207 Padding to use when performing an FFT convolution.
208 convolution_mode:
209 The method of convolution. This should be either "fft" or "real".
210 """
211
213 self,
214 images: np.ndarray | Image,
215 variance: np.ndarray | Image,
216 weights: np.ndarray | Image,
217 psfs: np.ndarray,
218 model_psf: np.ndarray | None = None,
219 noise_rms: np.ndarray | None = None,
220 bbox: Box | None = None,
221 bands: tuple | None = None,
222 padding: int = 3,
223 convolution_mode: str = "fft",
224 ):
225 # Convert the images to a multi-band `Image` and use the resulting
226 # bbox and bands.
227 images = _set_image_like(images, bands, bbox)
228 bands = images.bands
229 bbox = images.bbox
230 self.images = images
231 self.variance = _set_image_like(variance, bands, bbox)
232 self.weights = _set_image_like(weights, bands, bbox)
233 # make sure that the images and psfs have the same dtype
234 if psfs.dtype != images.dtype:
235 psfs = psfs.astype(images.dtype)
236 self.psfs = psfs
237
238 if convolution_mode not in [
239 "fft",
240 "real",
241 ]:
242 raise ValueError("convolution_mode must be either 'fft' or 'real'")
243 self.mode = convolution_mode
244 if noise_rms is None:
245 noise_rms = np.array(np.mean(np.sqrt(variance.data), axis=(1, 2)))
246 self.noise_rms = noise_rms
247
248 # Create a difference kernel to convolve the model to the PSF
249 # in each band
250 self.model_psf = model_psf
251 self.padding = padding
252 if model_psf is not None:
253 if model_psf.dtype != images.dtype:
254 self.model_psf = model_psf.astype(images.dtype)
255 self.diff_kernel: Fourier | None = cast(Fourier, match_kernel(psfs, model_psf, padding=padding))
256 # The gradient of a convolution is another convolution,
257 # but with the flipped and transposed kernel.
258 diff_img = self.diff_kernel.image
259 self.grad_kernel: Fourier | None = Fourier(diff_img[:, ::-1, ::-1])
260 else:
261 self.diff_kernel = None
262 self.grad_kernel = None
263
264 self._convolution_bounds: tuple[int, int, int, int] | None = None
265
266 @property
267 def bands(self) -> tuple:
268 """The bands in the observations."""
269 return self.images.bands
270
271 @property
272 def bbox(self) -> Box:
273 """The bounding box for the full observation."""
274 return self.images.bbox
275
276 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image:
277 """Convolve the model into the observed seeing in each band.
278
279 Parameters
280 ----------
281 image:
282 The 3D image to convolve.
283 mode:
284 The convolution mode to use.
285 This should be "real" or "fft" or `None`,
286 where `None` will use the default `convolution_mode`
287 specified during init.
288 grad:
289 Whether this is a backward gradient convolution
290 (`grad==True`) or a pure convolution with the PSF.
291
292 Returns
293 -------
294 result:
295 The convolved image.
296 """
297 if grad:
298 kernel = self.grad_kernel
299 else:
300 kernel = self.diff_kernel
301
302 if kernel is None:
303 return image
304
305 if mode is None:
306 mode = self.mode
307 if mode == "fft":
308 result = fft_convolve(
309 Fourier(image.data),
310 kernel,
311 axes=(1, 2),
312 return_fourier=False,
313 )
314 elif mode == "real":
315 dy = image.shape[1] - kernel.image.shape[1]
316 dx = image.shape[2] - kernel.image.shape[2]
317 if dy < 0 or dx < 0:
318 # The image needs to be padded because it is smaller than
319 # the psf kernel
320 _image = image.data
321 newshape = list(_image.shape)
322 if dy < 0:
323 newshape[1] += kernel.image.shape[1] - image.shape[1]
324 if dx < 0:
325 newshape[2] += kernel.image.shape[2] - image.shape[2]
326 _image = _pad(_image, newshape)
327 result = convolve(_image, kernel.image, self.convolution_bounds)
328 result = centered(result, image.data.shape) # type: ignore
329 else:
330 result = convolve(image.data, kernel.image, self.convolution_bounds)
331 else:
332 raise ValueError(f"mode must be either 'fft' or 'real', got {mode}")
333 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0)
334
335 def log_likelihood(self, model: Image) -> float:
336 """Calculate the log likelihood of the given model
337
338 Parameters
339 ----------
340 model:
341 Model to compare with the observed images.
342
343 Returns
344 -------
345 result:
346 The log-likelihood of the given model.
347 """
348 result = 0.5 * -np.sum((self.weights * (self.images - model) ** 2).data)
349 return result
350
351 @property
352 def shape(self) -> tuple[int, int, int]:
353 """The shape of the images, variance, etc."""
354 return cast(tuple[int, int, int], self.images.shape)
355
356 @property
357 def n_bands(self) -> int:
358 """The number of bands in the observation"""
359 return self.images.shape[0]
360
361 @property
362 def dtype(self) -> npt.DTypeLike:
363 """The dtype of the observation is the dtype of the images"""
364 return self.images.dtype
365
366 @property
367 def convolution_bounds(self) -> tuple[int, int, int, int]:
368 """Build the slices needed for convolution in real space"""
369 if self._convolution_bounds is None:
370 coords = get_filter_coords(cast(Fourier, self.diff_kernel).image[0])
371 self._convolution_bounds = get_filter_bounds(coords.reshape(-1, 2))
372 return self._convolution_bounds
373
374 @staticmethod
375 def empty(
376 bands: tuple[Any], psfs: np.ndarray, model_psf: np.ndarray, bbox: Box, dtype: npt.DTypeLike
377 ) -> Observation:
378 dummy_image = np.zeros((len(bands),) + bbox.shape, dtype=dtype)
379
380 return Observation(
381 images=dummy_image,
382 variance=dummy_image,
383 weights=dummy_image,
384 psfs=psfs,
385 model_psf=model_psf,
386 noise_rms=np.zeros((len(bands),), dtype=dtype),
387 bbox=bbox,
388 bands=bands,
389 convolution_mode="real",
390 )
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
__init__(self, np.ndarray|Image images, np.ndarray|Image variance, np.ndarray|Image weights, np.ndarray psfs, np.ndarray|None model_psf=None, np.ndarray|None noise_rms=None, Box|None bbox=None, tuple|None bands=None, int padding=3, str convolution_mode="fft")
Observation empty(tuple[Any] bands, np.ndarray psfs, np.ndarray model_psf, Box bbox, npt.DTypeLike dtype)
float log_likelihood(self, Image model)
tuple[int, int, int, int] convolution_bounds(self)
Image convolve(self, Image image, str|None mode=None, bool grad=False)
convolve(np.ndarray image, np.ndarray psf, tuple[int, int, int, int] bounds)
Image _set_image_like(np.ndarray|Image images, tuple|None bands=None, Box|None bbox=None)
np.ndarray get_filter_coords(np.ndarray filter_values, tuple[int, int]|None center=None)
tuple[int, int, int, int] get_filter_bounds(np.ndarray coords)
void apply_filter(Eigen::Ref< const M > image, Eigen::Ref< const V > values, Eigen::Ref< const IndexVector > y_start, Eigen::Ref< const IndexVector > y_end, Eigen::Ref< const IndexVector > x_start, Eigen::Ref< const IndexVector > x_end, Eigen::Ref< M, 0, Eigen::Stride< Eigen::Dynamic, Eigen::Dynamic > > result)