LSST Applications g0f08755f38+c89d42e150,g1635faa6d4+b6cf076a36,g1653933729+a8ce1bb630,g1a0ca8cf93+4c08b13bf7,g28da252d5a+f33f8200ef,g29321ee8c0+0187be18b1,g2bbee38e9b+9634bc57db,g2bc492864f+9634bc57db,g2cdde0e794+c2c89b37c4,g3156d2b45e+41e33cbcdc,g347aa1857d+9634bc57db,g35bb328faa+a8ce1bb630,g3a166c0a6a+9634bc57db,g3e281a1b8c+9f2c4e2fc3,g414038480c+077ccc18e7,g41af890bb2+e740673f1a,g5fbc88fb19+17cd334064,g7642f7d749+c89d42e150,g781aacb6e4+a8ce1bb630,g80478fca09+f8b2ab54e1,g82479be7b0+e2bd23ab8b,g858d7b2824+c89d42e150,g9125e01d80+a8ce1bb630,g9726552aa6+10f999ec6a,ga5288a1d22+065360aec4,gacf8899fa4+9553554aa7,gae0086650b+a8ce1bb630,gb58c049af0+d64f4d3760,gbd46683f8f+ac57cbb13d,gc28159a63d+9634bc57db,gcf0d15dbbd+e37acf7834,gda3e153d99+c89d42e150,gda6a2b7d83+e37acf7834,gdaeeff99f8+1711a396fd,ge2409df99d+cb1e6652d6,ge79ae78c31+9634bc57db,gf0baf85859+147a0692ba,gf3967379c6+02b11634a5,w.2024.45
LSST Data Management Base Package
Loading...
Searching...
No Matches
wavelet.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = [
23 "starlet_transform",
24 "starlet_reconstruction",
25 "multiband_starlet_transform",
26 "multiband_starlet_reconstruction",
27 "get_multiresolution_support",
28]
29
30from typing import Callable, Sequence
31
32import numpy as np
33
34
35def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray:
36 """Convolve an image with a bspline at a given scale.
37
38 This uses the spline
39 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])`
40 from Starck et al. 2011.
41
42 Parameters
43 ----------
44 image:
45 The 2D image or wavelet coefficients to convolve.
46 scale:
47 The wavelet scale for the convolution. This sets the
48 spacing between adjacent pixels with the spline.
49
50 Returns
51 -------
52 result:
53 The result of convolving the `image` with the spline.
54 """
55 # Filter for the scarlet transform. Here bspline
56 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])
57 j = scale
58
59 slice0 = slice(None, -(2 ** (j + 1)))
60 slice1 = slice(None, -(2**j))
61 slice3 = slice(2**j, None)
62 slice4 = slice(2 ** (j + 1), None)
63 # row
64 col = image * h1d[2]
65 col[slice4] += image[slice0] * h1d[0]
66 col[slice3] += image[slice1] * h1d[1]
67 col[slice1] += image[slice3] * h1d[3]
68 col[slice0] += image[slice4] * h1d[4]
69
70 # column
71 result = col * h1d[2]
72 result[:, slice4] += col[:, slice0] * h1d[0]
73 result[:, slice3] += col[:, slice1] * h1d[1]
74 result[:, slice1] += col[:, slice3] * h1d[3]
75 result[:, slice0] += col[:, slice4] * h1d[4]
76 return result
77
78
79def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int:
80 """Get the number of scales to use in the starlet transform.
81
82 Parameters
83 ----------
84 image_shape:
85 The 2D shape of the image that is being transformed
86 scales:
87 The number of scales to transform with starlets.
88 The total dimension of the starlet will have
89 `scales+1` dimensions, since it will also hold
90 the image at all scales higher than `scales`.
91
92 Returns
93 -------
94 result:
95 Number of scales, adjusted for the size of the image.
96 """
97 # Number of levels for the Starlet decomposition
98 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1
99 if (scales is None) or scales > max_scale:
100 scales = max_scale
101 return int(scales)
102
103
105 image: np.ndarray,
106 scales: int | None = None,
107 generation: int = 2,
108 convolve2d: Callable | None = None,
109) -> np.ndarray:
110 """Perform a starlet transform, or 2nd gen starlet transform.
111
112 Parameters
113 ----------
114 image:
115 The image to transform into starlet coefficients.
116 scales:
117 The number of scale to transform with starlets.
118 The total dimension of the starlet will have
119 `scales+1` dimensions, since it will also hold
120 the image at all scales higher than `scales`.
121 generation:
122 The generation of the transform.
123 This must be `1` or `2`.
124 convolve2d:
125 The filter function to use to convolve the image
126 with starlets in 2D.
127
128 Returns
129 -------
130 starlet:
131 The starlet dictionary for the input `image`.
132 """
133 if len(image.shape) != 2:
134 raise ValueError(f"Image should be 2D, got {len(image.shape)}")
135 if generation not in (1, 2):
136 raise ValueError(f"generation should be 1 or 2, got {generation}")
137
138 scales = get_starlet_scales(image.shape, scales)
139 c = image
140 if convolve2d is None:
141 convolve2d = bspline_convolve
142
143 # wavelet set of coefficients.
144 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype)
145 for j in range(scales):
146 gen1 = convolve2d(c, j)
147
148 if generation == 2:
149 gen2 = convolve2d(gen1, j)
150 starlet[j] = c - gen2
151 else:
152 starlet[j] = c - gen1
153
154 c = gen1
155
156 starlet[-1] = c
157 return starlet
158
159
161 image: np.ndarray,
162 scales: int | None = None,
163 generation: int = 2,
164 convolve2d: Callable | None = None,
165) -> np.ndarray:
166 """Perform a starlet transform of a multiband image.
167
168 See `starlet_transform` for a description of the parameters.
169 """
170 if len(image.shape) != 3:
171 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}")
172 if generation not in (1, 2):
173 raise ValueError(f"generation should be 1 or 2, got {generation}")
174 scales = get_starlet_scales(image.shape, scales)
175
176 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype)
177 for b, image in enumerate(image):
178 wavelets[:, b] = starlet_transform(image, scales=scales, generation=generation, convolve2d=convolve2d)
179 return wavelets
180
181
183 starlets: np.ndarray,
184 generation: int = 2,
185 convolve2d: Callable | None = None,
186) -> np.ndarray:
187 """Reconstruct an image from a dictionary of starlets
188
189 Parameters
190 ----------
191 starlets:
192 The starlet dictionary used to reconstruct the image
193 with dimension (scales+1, Ny, Nx).
194 generation:
195 The generation of the starlet transform (either ``1`` or ``2``).
196 convolve2d:
197 The filter function to use to convolve the image
198 with starlets in 2D.
199
200 Returns
201 -------
202 image:
203 The 2D image reconstructed from the input `starlet`.
204 """
205 if generation == 1:
206 return np.sum(starlets, axis=0)
207 if convolve2d is None:
208 convolve2d = bspline_convolve
209 scales = len(starlets) - 1
210
211 c = starlets[-1]
212 for i in range(1, scales + 1):
213 j = scales - i
214 cj = convolve2d(c, j)
215 c = cj + starlets[j]
216 return c
217
218
220 starlets: np.ndarray,
221 generation: int = 2,
222 convolve2d: Callable | None = None,
223) -> np.ndarray:
224 """Reconstruct a multiband image.
225
226 See `starlet_reconstruction` for a description of the
227 remainder of the parameters.
228 """
229 scales, bands, width, height = starlets.shape
230 result = np.zeros((bands, width, height), dtype=starlets.dtype)
231 for band in range(bands):
232 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d)
233 return result
234
235
237 image: np.ndarray,
238 starlets: np.ndarray,
239 sigma: float,
240 sigma_scaling: float = 3,
241 epsilon: float = 1e-1,
242 max_iter: int = 20,
243 image_type: str = "ground",
244) -> np.ndarray:
245 """Calculate the multi-resolution support for a
246 dictionary of starlet coefficients.
247
248 This is different for ground and space based telescopes.
249 For space-based telescopes the procedure in Starck and Murtagh 1998
250 iteratively calculates the multi-resolution support.
251 For ground based images, where the PSF is much wider and there are no
252 pixels with no signal at all scales, we use a modified method that
253 estimates support at each scale independently.
254
255 Parameters
256 ----------
257 image:
258 The image to transform into starlet coefficients.
259 starlets:
260 The starlet dictionary used to reconstruct `image` with
261 dimension (scales+1, Ny, Nx).
262 sigma:
263 The standard deviation of the `image`.
264 sigma_scaling:
265 The multiple of `sigma` to use to calculate significance.
266 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is
267 standard deviation at the jth scale, are considered significant.
268 epsilon:
269 The convergence criteria of the algorithm.
270 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the
271 algorithm has completed.
272 max_iter:
273 Maximum number of iterations to fit `sigma_j` at each scale.
274 image_type:
275 The type of image that is being used.
276 This should be "ground" for ground based images with wide PSFs or
277 "space" for images from space-based telescopes with a narrow PSF.
278
279 Returns
280 -------
281 M:
282 Mask with significant coefficients in `starlets` set to `True`.
283 """
284 if image_type not in ("ground", "space"):
285 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}")
286
287 if image_type == "space":
288 # Calculate sigma_je, the standard deviation at
289 # each scale due to gaussian noise
290 noise_img = np.random.normal(size=image.shape)
291 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1)
292 sigma_je = np.zeros((len(noise_starlet),))
293 for j, star in enumerate(noise_starlet):
294 sigma_je[j] = np.std(star)
295 noise = image - starlets[-1]
296
297 last_sigma_i = sigma
298 for it in range(max_iter):
299 m = np.abs(starlets) > sigma_scaling * sigma * sigma_je[:, None, None]
300 s = np.sum(m, axis=0) == 0
301 sigma_i = np.std(noise * s)
302 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon:
303 break
304 last_sigma_i = sigma_i
305 else:
306 # Sigma to use for significance at each scale
307 # Initially we use the input `sigma`
308 sigma_j = np.ones((len(starlets),), dtype=image.dtype) * sigma
309 last_sigma_j = sigma_j
310 for it in range(max_iter):
311 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None]
312 # Take the standard deviation of the current
313 # insignificant coeffs at each scale
314 s = ~m
315 sigma_j = np.std(starlets * s.astype(int), axis=(1, 2))
316 # At lower scales all of the pixels may be significant,
317 # so sigma is effectively zero. To avoid infinities we
318 # only check the scales with non-zero sigma
319 cut = sigma_j > 0
320 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon):
321 break
322
323 last_sigma_j = sigma_j
324 # noinspection PyUnboundLocalVariable
325 return m.astype(int)
326
327
329 image: np.ndarray,
330 sigma: float | None = None,
331 sigma_scaling: float = 3,
332 epsilon: float = 1e-1,
333 max_iter: int = 20,
334 image_type: str = "ground",
335 positive: bool = True,
336) -> np.ndarray:
337 """Apply wavelet denoising
338
339 Uses the algorithm and notation from Starck et al. 2011, section 4.1
340
341 Parameters
342 ----------
343 image:
344 The image to denoise
345 sigma:
346 The standard deviation of the image
347 sigma_scaling:
348 The threshold in units of sigma to declare a coefficient significant
349 epsilon:
350 Convergence criteria for determining the support
351 max_iter:
352 The maximum number of iterations.
353 This applies to both finding the support and the denoising loop.
354 image_type:
355 The type of image that is being used.
356 This should be "ground" for ground based images with wide PSFs or
357 "space" for images from space-based telescopes with a narrow PSF.
358 positive:
359 Whether or not the expected result should be positive
360
361 Returns
362 -------
363 result:
364 The resulting denoised image after `max_iter` iterations.
365 """
366 image_coeffs = starlet_transform(image)
367 if sigma is None:
368 sigma = np.median(np.absolute(image - np.median(image)))
369 coeffs = image_coeffs.copy()
370 support = get_multiresolution_support(
371 image=image,
372 starlets=coeffs,
373 sigma=sigma,
374 sigma_scaling=sigma_scaling,
375 epsilon=epsilon,
376 max_iter=max_iter,
377 image_type=image_type,
378 )
379 x = starlet_reconstruction(coeffs)
380
381 for n in range(max_iter):
382 coeffs = starlet_transform(x)
383 x = x + starlet_reconstruction(support * (image_coeffs - coeffs))
384 if positive:
385 x[x < 0] = 0
386 return x
np.ndarray starlet_transform(np.ndarray image, int|None scales=None, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:109
int get_starlet_scales(Sequence[int] image_shape, int|None scales=None)
Definition wavelet.py:79
np.ndarray multiband_starlet_reconstruction(np.ndarray starlets, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:223
np.ndarray apply_wavelet_denoising(np.ndarray image, float|None sigma=None, float sigma_scaling=3, float epsilon=1e-1, int max_iter=20, str image_type="ground", bool positive=True)
Definition wavelet.py:336
np.ndarray multiband_starlet_transform(np.ndarray image, int|None scales=None, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:165
np.ndarray bspline_convolve(np.ndarray image, int scale)
Definition wavelet.py:35
np.ndarray starlet_reconstruction(np.ndarray starlets, int generation=2, Callable|None convolve2d=None)
Definition wavelet.py:186
np.ndarray get_multiresolution_support(np.ndarray image, np.ndarray starlets, float sigma, float sigma_scaling=3, float epsilon=1e-1, int max_iter=20, str image_type="ground")
Definition wavelet.py:244