LSST Applications g0603fd7c41+022847dfd1,g0aad566f14+f45185db35,g180d380827+40e913b07a,g2079a07aa2+86d27d4dc4,g2305ad1205+696e5f3872,g2bbee38e9b+047b288a59,g337abbeb29+047b288a59,g33d1c0ed96+047b288a59,g3a166c0a6a+047b288a59,g3d1719c13e+f45185db35,g3de15ee5c7+5201731f0d,g487adcacf7+19f9b77d7d,g50ff169b8f+96c6868917,g52b1c1532d+585e252eca,g591dd9f2cf+248b16177b,g63cd9335cc+585e252eca,g858d7b2824+f45185db35,g88963caddf+0cb8e002cc,g991b906543+f45185db35,g99cad8db69+1747e75aa3,g9b9dfce982+78139cbddb,g9ddcbc5298+9a081db1e4,ga1e77700b3+a912195c07,gae0086650b+585e252eca,gb0e22166c9+60f28cb32d,gb3a676b8dc+b4feba26a1,gb4b16eec92+f82f04eb54,gba4ed39666+c2a2e4ac27,gbb8dafda3b+215b19b0ab,gc120e1dc64+b0284b5341,gc28159a63d+047b288a59,gc3e9b769f7+dcad4ace9a,gcf0d15dbbd+78139cbddb,gdaeeff99f8+f9a426f77a,ge79ae78c31+047b288a59,w.2024.19
LSST Data Management Base Package
Loading...
Searching...
No Matches
observation.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Observation", "convolve"]
25
26from typing import Any, cast
27
28import numpy as np
29import numpy.typing as npt
30
31from .bbox import Box
32from .fft import Fourier, _pad, centered
33from .fft import convolve as fft_convolve
34from .fft import match_kernel
35from .image import Image
36
37
38def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None = None) -> np.ndarray:
39 """Create filter coordinate grid needed for the apply filter function
40
41 Parameters
42 ----------
43 filter_values:
44 The 2D array of the filter to apply.
45 center:
46 The center (y,x) of the filter. If `center` is `None` then
47 `filter_values` must have an odd number of rows and columns
48 and the center will be set to the center of `filter_values`.
49
50 Returns
51 -------
52 coords:
53 The coordinates of the pixels in `filter_values`,
54 where the coordinates of the `center` pixel are `(0,0)`.
55 """
56 if filter_values.ndim != 2:
57 raise ValueError("`filter_values` must be 2D")
58 if center is None:
59 if filter_values.shape[0] % 2 == 0 or filter_values.shape[1] % 2 == 0:
60 msg = """Ambiguous center of the `filter_values` array,
61 you must use a `filter_values` array
62 with an odd number of rows and columns or
63 calculate `coords` on your own."""
64 raise ValueError(msg)
65 center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore
66 center = cast(tuple[int, int], center)
67 x = np.arange(filter_values.shape[1])
68 y = np.arange(filter_values.shape[0])
69 x, y = np.meshgrid(x, y)
70 x -= center[1]
71 y -= center[0]
72 coords = np.dstack([y, x])
73 return coords
74
75
76def get_filter_bounds(coords: np.ndarray) -> tuple[int, int, int, int]:
77 """Get the slices in x and y to apply a filter
78
79 Parameters
80 ----------
81 coords:
82 The coordinates of the filter,
83 defined by `get_filter_coords`.
84
85 Returns
86 -------
87 y_start, y_end, x_start, x_end:
88 The start and end of each slice that is passed to `apply_filter`.
89 """
90 z = np.zeros((len(coords),), dtype=int)
91 # Set the y slices
92 y_start = np.max([z, coords[:, 0]], axis=0)
93 y_end = -np.min([z, coords[:, 0]], axis=0)
94 # Set the x slices
95 x_start = np.max([z, coords[:, 1]], axis=0)
96 x_end = -np.min([z, coords[:, 1]], axis=0)
97 return y_start, y_end, x_start, x_end
98
99
100def convolve(image: np.ndarray, psf: np.ndarray, bounds: tuple[int, int, int, int]):
101 """Convolve an image with a PSF in real space
102
103 Parameters
104 ----------
105 image:
106 The multi-band image to convolve.
107 psf:
108 The psf to convolve the image with.
109 bounds:
110 The filter bounds required by the ``apply_filter`` C++ method,
111 usually obtained by calling `get_filter_bounds`.
112 """
113 from lsst.scarlet.lite.operators_pybind11 import apply_filter # type: ignore
114
115 result = np.empty(image.shape, dtype=image.dtype)
116 for band in range(len(image)):
117 img = image[band]
118
120 img,
121 psf[band].reshape(-1),
122 bounds[0],
123 bounds[1],
124 bounds[2],
125 bounds[3],
126 result[band],
127 )
128 return result
129
130
131def _set_image_like(images: np.ndarray | Image, bands: tuple | None = None, bbox: Box | None = None) -> Image:
132 """Ensure that an image-like array is cast appropriately as an image
133
134 Parameters
135 ----------
136 images:
137 The multiband image-like array to cast as an Image.
138 If it already has `bands` and `bbox` properties then it is returned
139 with no modifications.
140 bands:
141 The bands for the multiband-image.
142 If `images` is a numpy array, this parameter is mandatory.
143 If `images` is an `Image` and `bands` is not `None`,
144 then `bands` is ignored.
145 bbox:
146 Bounding box containing the image.
147 If `images` is a numpy array, this parameter is mandatory.
148 If `images` is an `Image` and `bbox` is not `None`,
149 then `bbox` is ignored.
150
151 Returns
152 -------
153 images: Image
154 The input images converted into an image.
155 """
156 if isinstance(images, Image):
157 # This is already an image
158 if bbox is not None and images.bbox != bbox:
159 raise ValueError(f"Bounding boxes {images.bbox} and {bbox} do not agree")
160 return images
161
162 if bbox is None:
163 bbox = Box(images.shape[-2:])
164 return Image(images, bands=bands, yx0=cast(tuple[int, int], bbox.origin))
165
166
168 """A single observation
169
170 This class contains all of the observed images and derived
171 properties, like PSFs, variance map, and weight maps,
172 required for most optimizers.
173 This includes methods to match a scarlet model PSF to the oberved PSF
174 in each band.
175
176 Notes
177 -----
178 This is effectively a combination of the `Observation` and
179 `Renderer` class from scarlet main, greatly simplified due
180 to the assumptions that the observations are all resampled
181 onto the same pixel grid and that the `images` contain all
182 of the information for all of the model bands.
183
184 Parameters
185 ----------
186 images:
187 (bands, y, x) array of observed images.
188 variance:
189 (bands, y, x) array of variance for each image pixel.
190 weights:
191 (bands, y, x) array of weights to use when calculate the
192 likelihood of each pixel.
193 psfs:
194 (bands, y, x) array of the PSF image in each band.
195 model_psf:
196 (bands, y, x) array of the model PSF image in each band.
197 If `model_psf` is `None` then convolution is performed,
198 which should only be done when the observation is a
199 PSF matched coadd, and the scarlet model has the same PSF.
200 noise_rms:
201 Per-band average noise RMS. If `noise_rms` is `None` then the mean
202 of the sqrt of the variance is used.
203 bbox:
204 The bounding box containing the model. If `bbox` is `None` then
205 a `Box` is created that is the shape of `images` with an origin
206 at `(0, 0)`.
207 padding:
208 Padding to use when performing an FFT convolution.
209 convolution_mode:
210 The method of convolution. This should be either "fft" or "real".
211 """
212
214 self,
215 images: np.ndarray | Image,
216 variance: np.ndarray | Image,
217 weights: np.ndarray | Image,
218 psfs: np.ndarray,
219 model_psf: np.ndarray | None = None,
220 noise_rms: np.ndarray | None = None,
221 bbox: Box | None = None,
222 bands: tuple | None = None,
223 padding: int = 3,
224 convolution_mode: str = "fft",
225 ):
226 # Convert the images to a multi-band `Image` and use the resulting
227 # bbox and bands.
228 images = _set_image_like(images, bands, bbox)
229 bands = images.bands
230 bbox = images.bbox
231 self.images = images
232 self.variance = _set_image_like(variance, bands, bbox)
233 self.weights = _set_image_like(weights, bands, bbox)
234 # make sure that the images and psfs have the same dtype
235 if psfs.dtype != images.dtype:
236 psfs = psfs.astype(images.dtype)
237 self.psfs = psfs
238
239 if convolution_mode not in [
240 "fft",
241 "real",
242 ]:
243 raise ValueError("convolution_mode must be either 'fft' or 'real'")
244 self.mode = convolution_mode
245 if noise_rms is None:
246 noise_rms = np.array(np.mean(np.sqrt(variance.data), axis=(1, 2)))
247 self.noise_rms = noise_rms
248
249 # Create a difference kernel to convolve the model to the PSF
250 # in each band
251 self.model_psf = model_psf
252 self.padding = padding
253 if model_psf is not None:
254 if model_psf.dtype != images.dtype:
255 self.model_psf = model_psf.astype(images.dtype)
256 self.diff_kernel: Fourier | None = cast(Fourier, match_kernel(psfs, model_psf, padding=padding))
257 # The gradient of a convolution is another convolution,
258 # but with the flipped and transposed kernel.
259 diff_img = self.diff_kernel.image
260 self.grad_kernel: Fourier | None = Fourier(diff_img[:, ::-1, ::-1])
261 else:
262 self.diff_kernel = None
263 self.grad_kernel = None
264
265 self._convolution_bounds: tuple[int, int, int, int] | None = None
266
267 @property
268 def bands(self) -> tuple:
269 """The bands in the observations."""
270 return self.images.bands
271
272 @property
273 def bbox(self) -> Box:
274 """The bounding box for the full observation."""
275 return self.images.bbox
276
277 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image:
278 """Convolve the model into the observed seeing in each band.
279
280 Parameters
281 ----------
282 image:
283 The 3D image to convolve.
284 mode:
285 The convolution mode to use.
286 This should be "real" or "fft" or `None`,
287 where `None` will use the default `convolution_mode`
288 specified during init.
289 grad:
290 Whether this is a backward gradient convolution
291 (`grad==True`) or a pure convolution with the PSF.
292
293 Returns
294 -------
295 result:
296 The convolved image.
297 """
298 if grad:
299 kernel = self.grad_kernel
300 else:
301 kernel = self.diff_kernel
302
303 if kernel is None:
304 return image
305
306 if mode is None:
307 mode = self.mode
308 if mode == "fft":
309 result = fft_convolve(
310 Fourier(image.data),
311 kernel,
312 axes=(1, 2),
313 return_fourier=False,
314 )
315 elif mode == "real":
316 dy = image.shape[1] - kernel.image.shape[1]
317 dx = image.shape[2] - kernel.image.shape[2]
318 if dy < 0 or dx < 0:
319 # The image needs to be padded because it is smaller than
320 # the psf kernel
321 _image = image.data
322 newshape = list(_image.shape)
323 if dy < 0:
324 newshape[1] += kernel.image.shape[1] - image.shape[1]
325 if dx < 0:
326 newshape[2] += kernel.image.shape[2] - image.shape[2]
327 _image = _pad(_image, newshape)
328 result = convolve(_image, kernel.image, self.convolution_bounds)
329 result = centered(result, image.data.shape) # type: ignore
330 else:
331 result = convolve(image.data, kernel.image, self.convolution_bounds)
332 else:
333 raise ValueError(f"mode must be either 'fft' or 'real', got {mode}")
334 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0)
335
336 def log_likelihood(self, model: Image) -> float:
337 """Calculate the log likelihood of the given model
338
339 Parameters
340 ----------
341 model:
342 Model to compare with the observed images.
343
344 Returns
345 -------
346 result:
347 The log-likelihood of the given model.
348 """
349 result = 0.5 * -np.sum((self.weights * (self.images - model) ** 2).data)
350 return result
351
352 @property
353 def shape(self) -> tuple[int, int, int]:
354 """The shape of the images, variance, etc."""
355 return cast(tuple[int, int, int], self.images.shape)
356
357 @property
358 def n_bands(self) -> int:
359 """The number of bands in the observation"""
360 return self.images.shape[0]
361
362 @property
363 def dtype(self) -> npt.DTypeLike:
364 """The dtype of the observation is the dtype of the images"""
365 return self.images.dtype
366
367 @property
368 def convolution_bounds(self) -> tuple[int, int, int, int]:
369 """Build the slices needed for convolution in real space"""
370 if self._convolution_bounds is None:
371 coords = get_filter_coords(cast(Fourier, self.diff_kernel).image[0])
372 self._convolution_bounds = get_filter_bounds(coords.reshape(-1, 2))
373 return self._convolution_bounds
374
375 @staticmethod
376 def empty(
377 bands: tuple[Any], psfs: np.ndarray, model_psf: np.ndarray, bbox: Box, dtype: npt.DTypeLike
378 ) -> Observation:
379 dummy_image = np.zeros((len(bands),) + bbox.shape, dtype=dtype)
380
381 return Observation(
382 images=dummy_image,
383 variance=dummy_image,
384 weights=dummy_image,
385 psfs=psfs,
386 model_psf=model_psf,
387 noise_rms=np.zeros((len(bands),), dtype=dtype),
388 bbox=bbox,
389 bands=bands,
390 convolution_mode="real",
391 )
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
__init__(self, np.ndarray|Image images, np.ndarray|Image variance, np.ndarray|Image weights, np.ndarray psfs, np.ndarray|None model_psf=None, np.ndarray|None noise_rms=None, Box|None bbox=None, tuple|None bands=None, int padding=3, str convolution_mode="fft")
Observation empty(tuple[Any] bands, np.ndarray psfs, np.ndarray model_psf, Box bbox, npt.DTypeLike dtype)
float log_likelihood(self, Image model)
tuple[int, int, int, int] convolution_bounds(self)
Image convolve(self, Image image, str|None mode=None, bool grad=False)
convolve(np.ndarray image, np.ndarray psf, tuple[int, int, int, int] bounds)
Image _set_image_like(np.ndarray|Image images, tuple|None bands=None, Box|None bbox=None)
np.ndarray get_filter_coords(np.ndarray filter_values, tuple[int, int]|None center=None)
tuple[int, int, int, int] get_filter_bounds(np.ndarray coords)
void apply_filter(Eigen::Ref< const M > image, Eigen::Ref< const V > values, Eigen::Ref< const IndexVector > y_start, Eigen::Ref< const IndexVector > y_end, Eigen::Ref< const IndexVector > x_start, Eigen::Ref< const IndexVector > x_end, Eigen::Ref< M, 0, Eigen::Stride< Eigen::Dynamic, Eigen::Dynamic > > result)