LSST Applications g063fba187b+cac8b7c890,g0f08755f38+6aee506743,g1653933729+a8ce1bb630,g168dd56ebc+a8ce1bb630,g1a2382251a+b4475c5878,g1dcb35cd9c+8f9bc1652e,g20f6ffc8e0+6aee506743,g217e2c1bcf+73dee94bd0,g28da252d5a+1f19c529b9,g2bbee38e9b+3f2625acfc,g2bc492864f+3f2625acfc,g3156d2b45e+6e55a43351,g32e5bea42b+1bb94961c2,g347aa1857d+3f2625acfc,g35bb328faa+a8ce1bb630,g3a166c0a6a+3f2625acfc,g3e281a1b8c+c5dd892a6c,g3e8969e208+a8ce1bb630,g414038480c+5927e1bc1e,g41af890bb2+8a9e676b2a,g7af13505b9+809c143d88,g80478fca09+6ef8b1810f,g82479be7b0+f568feb641,g858d7b2824+6aee506743,g89c8672015+f4add4ffd5,g9125e01d80+a8ce1bb630,ga5288a1d22+2903d499ea,gb58c049af0+d64f4d3760,gc28159a63d+3f2625acfc,gcab2d0539d+b12535109e,gcf0d15dbbd+46a3f46ba9,gda6a2b7d83+46a3f46ba9,gdaeeff99f8+1711a396fd,ge79ae78c31+3f2625acfc,gef2f8181fd+0a71e47438,gf0baf85859+c1f95f4921,gfa517265be+6aee506743,gfa999e8aa5+17cd334064,w.2024.51
LSST Data Management Base Package
Loading...
Searching...
No Matches
wavelet.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = [
23 "starlet_transform",
24 "starlet_reconstruction",
25 "multiband_starlet_transform",
26 "multiband_starlet_reconstruction",
27 "get_multiresolution_support",
28]
29
30from dataclasses import dataclass
31from typing import Callable, Sequence
32
33import numpy as np
34
35
36def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray:
37 """Convolve an image with a bspline at a given scale.
38
39 This uses the spline
40 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])`
41 from Starck et al. 2011.
42
43 Parameters
44 ----------
45 image:
46 The 2D image or wavelet coefficients to convolve.
47 scale:
48 The wavelet scale for the convolution. This sets the
49 spacing between adjacent pixels with the spline.
50
51 Returns
52 -------
53 result:
54 The result of convolving the `image` with the spline.
55 """
56 # Filter for the scarlet transform. Here bspline
57 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]).astype(image.dtype)
58 j = scale
59
60 slice0 = slice(None, -(2 ** (j + 1)))
61 slice1 = slice(None, -(2**j))
62 slice3 = slice(2**j, None)
63 slice4 = slice(2 ** (j + 1), None)
64 # row
65 col = image * h1d[2]
66 col[slice4] += image[slice0] * h1d[0]
67 col[slice3] += image[slice1] * h1d[1]
68 col[slice1] += image[slice3] * h1d[3]
69 col[slice0] += image[slice4] * h1d[4]
70
71 # column
72 result = col * h1d[2]
73 result[:, slice4] += col[:, slice0] * h1d[0]
74 result[:, slice3] += col[:, slice1] * h1d[1]
75 result[:, slice1] += col[:, slice3] * h1d[3]
76 result[:, slice0] += col[:, slice4] * h1d[4]
77 return result
78
79
80def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int:
81 """Get the number of scales to use in the starlet transform.
82
83 Parameters
84 ----------
85 image_shape:
86 The 2D shape of the image that is being transformed
87 scales:
88 The number of scales to transform with starlets.
89 The total dimension of the starlet will have
90 `scales+1` dimensions, since it will also hold
91 the image at all scales higher than `scales`.
92
93 Returns
94 -------
95 result:
96 Number of scales, adjusted for the size of the image.
97 """
98 # Number of levels for the Starlet decomposition
99 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1
100 if (scales is None) or scales > max_scale:
101 scales = max_scale
102 return int(scales)
103
104
106 image: np.ndarray,
107 scales: int | None = None,
108 generation: int = 2,
109 convolve2d: Callable | None = None,
110) -> np.ndarray:
111 """Perform a starlet transform, or 2nd gen starlet transform.
112
113 Parameters
114 ----------
115 image:
116 The image to transform into starlet coefficients.
117 scales:
118 The number of scale to transform with starlets.
119 The total dimension of the starlet will have
120 `scales+1` dimensions, since it will also hold
121 the image at all scales higher than `scales`.
122 generation:
123 The generation of the transform.
124 This must be `1` or `2`.
125 convolve2d:
126 The filter function to use to convolve the image
127 with starlets in 2D.
128
129 Returns
130 -------
131 starlet:
132 The starlet dictionary for the input `image`.
133 """
134 if len(image.shape) != 2:
135 raise ValueError(f"Image should be 2D, got {len(image.shape)}")
136 if generation not in (1, 2):
137 raise ValueError(f"generation should be 1 or 2, got {generation}")
138
139 scales = get_starlet_scales(image.shape, scales)
140 c = image
141 if convolve2d is None:
142 convolve2d = bspline_convolve
143
144 # wavelet set of coefficients.
145 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype)
146 for j in range(scales):
147 gen1 = convolve2d(c, j)
148
149 if generation == 2:
150 gen2 = convolve2d(gen1, j)
151 starlet[j] = c - gen2
152 else:
153 starlet[j] = c - gen1
154
155 c = gen1
156
157 starlet[-1] = c
158 return starlet
159
160
162 image: np.ndarray,
163 scales: int | None = None,
164 generation: int = 2,
165 convolve2d: Callable | None = None,
166) -> np.ndarray:
167 """Perform a starlet transform of a multiband image.
168
169 See `starlet_transform` for a description of the parameters.
170 """
171 if len(image.shape) != 3:
172 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}")
173 if generation not in (1, 2):
174 raise ValueError(f"generation should be 1 or 2, got {generation}")
175 scales = get_starlet_scales(image.shape, scales)
176
177 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype)
178 for b, image in enumerate(image):
179 wavelets[:, b] = starlet_transform(image, scales=scales, generation=generation, convolve2d=convolve2d)
180 return wavelets
181
182
184 starlets: np.ndarray,
185 generation: int = 2,
186 convolve2d: Callable | None = None,
187) -> np.ndarray:
188 """Reconstruct an image from a dictionary of starlets
189
190 Parameters
191 ----------
192 starlets:
193 The starlet dictionary used to reconstruct the image
194 with dimension (scales+1, Ny, Nx).
195 generation:
196 The generation of the starlet transform (either ``1`` or ``2``).
197 convolve2d:
198 The filter function to use to convolve the image
199 with starlets in 2D.
200
201 Returns
202 -------
203 image:
204 The 2D image reconstructed from the input `starlet`.
205 """
206 if generation == 1:
207 return np.sum(starlets, axis=0)
208 if convolve2d is None:
209 convolve2d = bspline_convolve
210 scales = len(starlets) - 1
211
212 c = starlets[-1]
213 for i in range(1, scales + 1):
214 j = scales - i
215 cj = convolve2d(c, j)
216 c = cj + starlets[j]
217 return c
218
219
221 starlets: np.ndarray,
222 generation: int = 2,
223 convolve2d: Callable | None = None,
224) -> np.ndarray:
225 """Reconstruct a multiband image.
226
227 See `starlet_reconstruction` for a description of the
228 remainder of the parameters.
229 """
230 _, bands, width, height = starlets.shape
231 result = np.zeros((bands, width, height), dtype=starlets.dtype)
232 for band in range(bands):
233 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d)
234 return result
235
236
237@dataclass
239 support: np.ndarray
240 sigma: np.ndarray
241
242
244 image: np.ndarray,
245 starlets: np.ndarray,
246 sigma: float,
247 sigma_scaling: float = 3,
248 epsilon: float = 1e-1,
249 max_iter: int = 20,
250 image_type: str = "ground",
251) -> MultiResolutionSupport:
252 """Calculate the multi-resolution support for a
253 dictionary of starlet coefficients.
254
255 This is different for ground and space based telescopes.
256 For space-based telescopes the procedure in Starck and Murtagh 1998
257 iteratively calculates the multi-resolution support.
258 For ground based images, where the PSF is much wider and there are no
259 pixels with no signal at all scales, we use a modified method that
260 estimates support at each scale independently.
261
262 Parameters
263 ----------
264 image:
265 The image to transform into starlet coefficients.
266 starlets:
267 The starlet dictionary used to reconstruct `image` with
268 dimension (scales+1, Ny, Nx).
269 sigma:
270 The standard deviation of the `image`.
271 sigma_scaling:
272 The multiple of `sigma` to use to calculate significance.
273 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is
274 standard deviation at the jth scale, are considered significant.
275 epsilon:
276 The convergence criteria of the algorithm.
277 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the
278 algorithm has completed.
279 max_iter:
280 Maximum number of iterations to fit `sigma_j` at each scale.
281 image_type:
282 The type of image that is being used.
283 This should be "ground" for ground based images with wide PSFs or
284 "space" for images from space-based telescopes with a narrow PSF.
285
286 Returns
287 -------
288 M:
289 Mask with significant coefficients in `starlets` set to `True`.
290 """
291 if image_type not in ("ground", "space"):
292 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}")
293
294 if image_type == "space":
295 # Calculate sigma_je, the standard deviation at
296 # each scale due to gaussian noise
297 noise_img = np.random.normal(size=image.shape)
298 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1)
299 sigma_je = np.zeros((len(noise_starlet),))
300 for j, star in enumerate(noise_starlet):
301 sigma_je[j] = np.std(star)
302 noise = image - starlets[-1]
303
304 last_sigma_i = sigma
305 for it in range(max_iter):
306 m = np.abs(starlets) > sigma_scaling * sigma * sigma_je[:, None, None]
307 s = np.sum(m, axis=0) == 0
308 sigma_i = np.std(noise * s)
309 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon:
310 break
311 last_sigma_i = sigma_i
312 sigma_j = sigma_je
313 else:
314 # Sigma to use for significance at each scale
315 # Initially we use the input `sigma`
316 sigma_j = np.ones((len(starlets),), dtype=image.dtype) * sigma
317 last_sigma_j = sigma_j
318 for it in range(max_iter):
319 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None]
320 # Take the standard deviation of the current
321 # insignificant coeffs at each scale
322 s = ~m
323 sigma_j = np.std(starlets * s.astype(int), axis=(1, 2))
324 # At lower scales all of the pixels may be significant,
325 # so sigma is effectively zero. To avoid infinities we
326 # only check the scales with non-zero sigma
327 cut = sigma_j > 0
328 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon):
329 break
330
331 last_sigma_j = sigma_j
332 # noinspection PyUnboundLocalVariable
333 return MultiResolutionSupport(support=m.astype(int), sigma=sigma_j)
334
335
337 image: np.ndarray,
338 sigma: float | None = None,
339 sigma_scaling: float = 3,
340 epsilon: float = 1e-1,
341 max_iter: int = 20,
342 image_type: str = "ground",
343 positive: bool = True,
344) -> np.ndarray:
345 """Apply wavelet denoising
346
347 Uses the algorithm and notation from Starck et al. 2011, section 4.1
348
349 Parameters
350 ----------
351 image:
352 The image to denoise
353 sigma:
354 The standard deviation of the image
355 sigma_scaling:
356 The threshold in units of sigma to declare a coefficient significant
357 epsilon:
358 Convergence criteria for determining the support
359 max_iter:
360 The maximum number of iterations.
361 This applies to both finding the support and the denoising loop.
362 image_type:
363 The type of image that is being used.
364 This should be "ground" for ground based images with wide PSFs or
365 "space" for images from space-based telescopes with a narrow PSF.
366 positive:
367 Whether or not the expected result should be positive
368
369 Returns
370 -------
371 result:
372 The resulting denoised image after `max_iter` iterations.
373 """
374 image_coeffs = starlet_transform(image)
375 if sigma is None:
376 sigma = np.median(np.absolute(image - np.median(image)))
377 coeffs = image_coeffs.copy()
378 support = get_multiresolution_support(
379 image=image,
380 starlets=coeffs,
381 sigma=sigma,
382 sigma_scaling=sigma_scaling,
383 epsilon=epsilon,
384 max_iter=max_iter,
385 image_type=image_type,
386 )
387 x = starlet_reconstruction(coeffs)
388
389 for n in range(max_iter):
390 coeffs = starlet_transform(x)
391 x = x + starlet_reconstruction(support.support * (image_coeffs - coeffs))
392 if positive:
393 x[x < 0] = 0
394 return x
np.ndarray starlet_transform(np.ndarray image, int|None scales=None, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:110
int get_starlet_scales(Sequence[int] image_shape, int|None scales=None)
Definition wavelet.py:80
np.ndarray multiband_starlet_reconstruction(np.ndarray starlets, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:224
np.ndarray apply_wavelet_denoising(np.ndarray image, float|None sigma=None, float sigma_scaling=3, float epsilon=1e-1, int max_iter=20, str image_type="ground", bool positive=True)
Definition wavelet.py:344
np.ndarray multiband_starlet_transform(np.ndarray image, int|None scales=None, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:166
np.ndarray bspline_convolve(np.ndarray image, int scale)
Definition wavelet.py:36
MultiResolutionSupport get_multiresolution_support(np.ndarray image, np.ndarray starlets, float sigma, float sigma_scaling=3, float epsilon=1e-1, int max_iter=20, str image_type="ground")
Definition wavelet.py:251
np.ndarray starlet_reconstruction(np.ndarray starlets, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:187