LSST Applications g0f08755f38+c89d42e150,g1635faa6d4+b6cf076a36,g1653933729+a8ce1bb630,g1a0ca8cf93+4c08b13bf7,g28da252d5a+f33f8200ef,g29321ee8c0+0187be18b1,g2bbee38e9b+9634bc57db,g2bc492864f+9634bc57db,g2cdde0e794+c2c89b37c4,g3156d2b45e+41e33cbcdc,g347aa1857d+9634bc57db,g35bb328faa+a8ce1bb630,g3a166c0a6a+9634bc57db,g3e281a1b8c+9f2c4e2fc3,g414038480c+077ccc18e7,g41af890bb2+e740673f1a,g5fbc88fb19+17cd334064,g7642f7d749+c89d42e150,g781aacb6e4+a8ce1bb630,g80478fca09+f8b2ab54e1,g82479be7b0+e2bd23ab8b,g858d7b2824+c89d42e150,g9125e01d80+a8ce1bb630,g9726552aa6+10f999ec6a,ga5288a1d22+065360aec4,gacf8899fa4+9553554aa7,gae0086650b+a8ce1bb630,gb58c049af0+d64f4d3760,gbd46683f8f+ac57cbb13d,gc28159a63d+9634bc57db,gcf0d15dbbd+e37acf7834,gda3e153d99+c89d42e150,gda6a2b7d83+e37acf7834,gdaeeff99f8+1711a396fd,ge2409df99d+cb1e6652d6,ge79ae78c31+9634bc57db,gf0baf85859+147a0692ba,gf3967379c6+02b11634a5,w.2024.45
LSST Data Management Base Package
Loading...
Searching...
No Matches
fft.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Fourier"]
25
26import operator
27from typing import Callable, Sequence
28
29import numpy as np
30from numpy.typing import DTypeLike
31from scipy import fftpack
32
33
34def centered(arr: np.ndarray, newshape: Sequence[int]) -> np.ndarray:
35 """Return the central newshape portion of the array.
36
37 Parameters
38 ----------
39 arr:
40 The array to center.
41 newshape:
42 The new shape of the array.
43
44 Notes
45 -----
46 If the array shape is odd and the target is even,
47 the center of `arr` is shifted to the center-right
48 pixel position.
49 This is slightly different than the scipy implementation,
50 which uses the center-left pixel for the array center.
51 The reason for the difference is that we have
52 adopted the convention of `np.fft.fftshift` in order
53 to make sure that changing back and forth from
54 fft standard order (0 frequency and position is
55 in the bottom left) to 0 position in the center.
56 """
57 _newshape = np.array(newshape)
58 currshape = np.array(arr.shape)
59
60 if not np.all(_newshape <= currshape):
61 msg = f"arr must be larger than newshape in both dimensions, received {arr.shape}, and {_newshape}"
62 raise ValueError(msg)
63
64 startind = (currshape - _newshape + 1) // 2
65 endind = startind + _newshape
66 myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
67
68 return arr[tuple(myslice)]
69
70
71def fast_zero_pad(arr: np.ndarray, pad_width: Sequence[Sequence[int]]) -> np.ndarray:
72 """Fast version of numpy.pad when `mode="constant"`
73
74 Executing `numpy.pad` with zeros is ~1000 times slower
75 because it doesn't make use of the `zeros` method for padding.
76
77 Parameters
78 ---------
79 arr:
80 The array to pad
81 pad_width:
82 Number of values padded to the edges of each axis.
83 See numpy.pad docs for more.
84
85 Returns
86 -------
87 result: np.ndarray
88 The array padded with `constant_values`
89 """
90 newshape = tuple([a + ps[0] + ps[1] for a, ps in zip(arr.shape, pad_width)])
91
92 result = np.zeros(newshape, dtype=arr.dtype)
93 slices = tuple([slice(start, s - end) for s, (start, end) in zip(result.shape, pad_width)])
94 result[slices] = arr
95 return result
96
97
98def _pad(
99 arr: np.ndarray,
100 newshape: Sequence[int],
101 axes: int | Sequence[int] | None = None,
102 mode: str = "constant",
103 constant_values: float = 0,
104) -> np.ndarray:
105 """Pad an array to fit into newshape
106
107 Pad `arr` with zeros to fit into newshape,
108 which uses the `np.fft.fftshift` convention of moving
109 the center pixel of `arr` (if `arr.shape` is odd) to
110 the center-right pixel in an even shaped `newshape`.
111
112 Parameters
113 ----------
114 arr:
115 The arrray to pad.
116 newshape:
117 The new shape of the array.
118 axes:
119 The axes that are being reshaped.
120 mode:
121 The numpy mode used to pad the array.
122 In other words, how to fill the new padded elements.
123 See ``numpy.pad`` for details.
124 constant_values:
125 If `mode` == "constant" then this is the value to set all of
126 the new padded elements to.
127 """
128 _newshape = np.asarray(newshape)
129 if axes is None:
130 currshape = np.array(arr.shape)
131 diff = _newshape - currshape
132 startind = (diff + 1) // 2
133 endind = diff - startind
134 pad_width = list(zip(startind, endind))
135 else:
136 # only pad the axes that will be transformed
137 pad_width = [(0, 0) for _ in arr.shape]
138 if isinstance(axes, int):
139 axes = [axes]
140 for a, axis in enumerate(axes):
141 diff = _newshape[a] - arr.shape[axis]
142 startind = (diff + 1) // 2
143 endind = diff - startind
144 pad_width[axis] = (startind, endind)
145 if mode == "constant" and constant_values == 0:
146 result = fast_zero_pad(arr, pad_width)
147 else:
148 result = np.pad(arr, tuple(pad_width), mode=mode) # type: ignore
149 return result
150
151
153 im_or_shape1: np.ndarray | Sequence[int],
154 im_or_shape2: np.ndarray | Sequence[int],
155 padding: int = 3,
156 axes: int | Sequence[int] | None = None,
157 use_max: bool = False,
158) -> tuple:
159 """Return the fast fft shapes for each spatial axis
160
161 Calculate the fast fft shape for each dimension in
162 axes.
163
164 Parameters
165 ----------
166 im_or_shape1:
167 The left image or shape of an image.
168 im_or_shape2:
169 The right image or shape of an image.
170 padding:
171 Any additional padding to add to the final shape.
172 axes:
173 The axes that are being transformed.
174 use_max:
175 Whether or not to use the maximum of the two shapes,
176 or the sum of the two shapes.
177
178 Returns
179 -------
180 shape:
181 Tuple of the shape to use when the two images are transformed
182 into k-space.
183 """
184 if isinstance(im_or_shape1, np.ndarray):
185 shape1 = np.asarray(im_or_shape1.shape)
186 else:
187 shape1 = np.asarray(im_or_shape1)
188 if isinstance(im_or_shape2, np.ndarray):
189 shape2 = np.asarray(im_or_shape2.shape)
190 else:
191 shape2 = np.asarray(im_or_shape2)
192 # Make sure the shapes are the same size
193 if len(shape1) != len(shape2):
194 msg = (
195 "img1 and img2 must have the same number of dimensions, "
196 f"but got {len(shape1)} and {len(shape2)}"
197 )
198 raise ValueError(msg)
199 # Set the combined shape based on the total dimensions
200 if axes is None:
201 if use_max:
202 shape = np.max([shape1, shape2], axis=0)
203 else:
204 shape = shape1 + shape2
205 else:
206 if isinstance(axes, int):
207 axes = [axes]
208 shape = np.zeros(len(axes), dtype="int")
209 for n, ax in enumerate(axes):
210 shape[n] = shape1[ax] + shape2[ax]
211 if use_max:
212 shape[n] = np.max([shape1[ax], shape2[ax]])
213
214 shape += padding
215 # Use the next fastest shape in each dimension
216 shape = [fftpack.next_fast_len(s) for s in shape]
217 return tuple(shape)
218
219
221 """An array that stores its Fourier Transform
222
223 The `Fourier` class is used for images that will make
224 use of their Fourier Transform multiple times.
225 In order to prevent numerical artifacts the same image
226 convolved with different images might require different
227 padding, so the FFT for each different shape is stored
228 in a dictionary.
229
230 Parameters
231 ----------
232 image: np.ndarray
233 The real space image.
234 image_fft: dict[Sequence[int], np.ndarray]
235 A dictionary of {shape: fft_value} for which each different
236 shape has a precalculated FFT.
237 """
238
240 self,
241 image: np.ndarray,
242 image_fft: dict[Sequence[Sequence[int]], np.ndarray] | None = None,
243 ):
244 if image_fft is None:
245 self._fft: dict[Sequence[Sequence[int]], np.ndarray] = {}
246 else:
247 self._fft = image_fft
248 self._image = image
249
250 @staticmethod
252 image_fft: np.ndarray,
253 fft_shape: Sequence[int],
254 image_shape: Sequence[int],
255 axes: int | Sequence[int] | None = None,
256 dtype: DTypeLike = float,
257 ) -> Fourier:
258 """Generate a new Fourier object from an FFT dictionary
259
260 If the fft of an image has been generated but not its
261 real space image (for example when creating a convolution kernel),
262 this method can be called to create a new `Fourier` instance
263 from the k-space representation.
264
265 Parameters
266 ----------
267 image_fft:
268 The FFT of the image.
269 fft_shape:
270 "Fast" shape of the image used to generate the FFT.
271 This will be different than `image_fft.shape` if
272 any of the dimensions are odd, since `np.fft.rfft`
273 requires an even number of dimensions (for symmetry),
274 so this tells `np.fft.irfft` how to go from
275 complex k-space to real space.
276 image_shape:
277 The shape of the image *before padding*.
278 This will regenerate the image with the extra
279 padding stripped.
280 axes:
281 The dimension(s) of the array that will be transformed.
282
283 Returns
284 -------
285 result:
286 A `Fourier` object generated from the FFT.
287 """
288 if axes is None:
289 axes = range(len(image_shape))
290 if isinstance(axes, int):
291 axes = [axes]
292 all_axes = range(len(image_shape))
293 image = np.fft.irfftn(image_fft, fft_shape, axes=axes).astype(dtype)
294 # Shift the center of the image from the bottom left to the center
295 image = np.fft.fftshift(image, axes=axes)
296 # Trim the image to remove the padding added
297 # to reduce fft artifacts
298 image = centered(image, image_shape)
299 key = (tuple(fft_shape), tuple(axes), tuple(all_axes))
300
301 return Fourier(image, {key: image_fft})
302
303 @property
304 def image(self) -> np.ndarray:
305 """The real space image"""
306 return self._image
307
308 @property
309 def shape(self) -> tuple[int, ...]:
310 """The shape of the real space image"""
311 return self._image.shape
312
313 def fft(self, fft_shape: Sequence[int], axes: int | Sequence[int]) -> np.ndarray:
314 """The FFT of an image for a given `fft_shape` along desired `axes`
315
316 Parameters
317 ----------
318 fft_shape:
319 "Fast" shape of the image used to generate the FFT.
320 This will be different than `image_fft.shape` if
321 any of the dimensions are odd, since `np.fft.rfft`
322 requires an even number of dimensions (for symmetry),
323 so this tells `np.fft.irfft` how to go from
324 complex k-space to real space.
325 axes:
326 The dimension(s) of the array that will be transformed.
327 """
328 if isinstance(axes, int):
329 axes = (axes,)
330 all_axes = range(len(self.imageimage.shape))
331 fft_key = (tuple(fft_shape), tuple(axes), tuple(all_axes))
332
333 # If this is the first time calling `fft` for this shape,
334 # generate the FFT.
335 if fft_key not in self._fft:
336 if len(fft_shape) != len(axes):
337 msg = f"fft_shape self.axes must have the same number of dimensions, got {fft_shape}, {axes}"
338 raise ValueError(msg)
339 image = _pad(self.imageimage, fft_shape, axes)
340 self._fft[fft_key] = np.fft.rfftn(np.fft.ifftshift(image, axes), axes=axes)
341 return self._fft[fft_key]
342
343 def __len__(self) -> int:
344 """Length of the image"""
345 return len(self.imageimage)
346
347 def __getitem__(self, index: int | Sequence[int] | slice) -> Fourier:
348 # Make the index a tuple
349 if isinstance(index, int):
350 index = tuple([index])
351
352 # Axes that are removed from the shape of the new object
353 if isinstance(index, slice):
354 removed = np.array([])
355 else:
356 removed = np.array([n for n, idx in enumerate(index) if idx is not None])
357
358 # Create views into the fft transformed values, appropriately adjusting
359 # the shapes for the new axes
360
361 fft_kernels = {
362 (
363 tuple([s for idx, s in enumerate(key[0]) if key[0][idx] not in removed]),
364 tuple([s for idx, s in enumerate(key[1]) if key[1][idx] not in removed]),
365 tuple([s for idx, s in enumerate(key[2]) if key[2][idx] not in removed]),
366 ): kernel[index]
367 for key, kernel in self._fft.items()
368 }
369 # mpypy doesn't recognize that tuple[int, ...]
370 # is a valid Sequence[int] for some reason
371 return Fourier(self.imageimage[index], fft_kernels) # type: ignore
372
373
375 image1: Fourier,
376 image2: Fourier,
377 padding: int,
378 op: Callable,
379 shape: Sequence[int],
380 axes: int | Sequence[int],
381) -> Fourier:
382 """Combine two images in k-space using a given `operator`
383
384 Parameters
385 ----------
386 image1:
387 The LHS of the equation.
388 image2:
389 The RHS of the equation.
390 padding:
391 The amount of padding to add before transforming into k-space.
392 op:
393 The operator used to combine the two images.
394 This is either ``operator.mul`` for a convolution
395 or ``operator.truediv`` for deconvolution.
396 shape:
397 The shape of the output image.
398 axes:
399 The dimension(s) of the array that will be transformed.
400 """
401 if len(image1.shape) != len(image2.shape):
402 msg = (
403 "Both images must have the same number of axes, "
404 f"got {len(image1.shape)} and {len(image2.shape)}"
405 )
406 raise ValueError(msg)
407
408 fft_shape = get_fft_shape(image1.image, image2.image, padding, axes)
409 if (
410 op == operator.truediv
411 or op == operator.floordiv
412 or op == operator.itruediv
413 or op == operator.ifloordiv
414 ):
415 # prevent divide by zero
416 lhs = image1.fft(fft_shape, axes)
417 rhs = image2.fft(fft_shape, axes)
418
419 # Broadcast, if necessary
420 if rhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]:
421 rhs = np.tile(rhs, (lhs.shape[0],) + (1,) * len(rhs.shape[1:]))
422 if lhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]:
423 lhs = np.tile(lhs, (rhs.shape[0],) + (1,) * len(lhs.shape[1:]))
424 # only select non-zero elements for the denominator
425 cuts = rhs != 0
426 transformed_fft = np.zeros(lhs.shape, dtype=lhs.dtype)
427 transformed_fft[cuts] = op(lhs[cuts], rhs[cuts])
428 else:
429 transformed_fft = op(image1.fft(fft_shape, axes), image2.fft(fft_shape, axes))
430 return Fourier.from_fft(transformed_fft, fft_shape, shape, axes, image1.image.dtype)
431
432
434 kernel1: np.ndarray | Fourier,
435 kernel2: np.ndarray | Fourier,
436 padding: int = 3,
437 axes: int | Sequence[int] = (-2, -1),
438 return_fourier: bool = True,
439 normalize: bool = False,
440) -> Fourier | np.ndarray:
441 """Calculate the difference kernel to match kernel1 to kernel2
442
443 Parameters
444 ----------
445 kernel1:
446 The first kernel, either as array or as `Fourier` object
447 kernel2:
448 The second kernel, either as array or as `Fourier` object
449 padding:
450 Additional padding to use when generating the FFT
451 to supress artifacts.
452 axes:
453 Axes that contain the spatial information for the kernels.
454 return_fourier:
455 Whether to return `Fourier` or array
456 normalize:
457 Whether or not to normalize the input kernels.
458
459 Returns
460 -------
461 result:
462 The difference kernel to go from `kernel1` to `kernel2`.
463 """
464 if not isinstance(kernel1, Fourier):
465 kernel1 = Fourier(kernel1)
466 if not isinstance(kernel2, Fourier):
467 kernel2 = Fourier(kernel2)
468
469 if kernel1.shape[0] < kernel2.shape[0]:
470 shape = kernel2.shape
471 else:
472 shape = kernel1.shape
473
474 diff = _kspace_operation(kernel1, kernel2, padding, operator.truediv, shape, axes=axes)
475 if return_fourier:
476 return diff
477 else:
478 return np.real(diff.image)
479
480
482 image: np.ndarray | Fourier,
483 kernel: np.ndarray | Fourier,
484 padding: int = 3,
485 axes: int | Sequence[int] = (-2, -1),
486 return_fourier: bool = True,
487 normalize: bool = False,
488) -> np.ndarray | Fourier:
489 """Convolve image with a kernel
490
491 Parameters
492 ----------
493 image:
494 Image either as array or as `Fourier` object
495 kernel:
496 Convolution kernel either as array or as `Fourier` object
497 padding:
498 Additional padding to use when generating the FFT
499 to suppress artifacts.
500 axes:
501 Axes that contain the spatial information for the PSFs.
502 return_fourier:
503 Whether to return `Fourier` or array
504 normalize:
505 Whether or not to normalize the input kernels.
506
507 Returns
508 -------
509 result:
510 The convolution of the image with the kernel.
511 """
512 if not isinstance(image, Fourier):
513 image = Fourier(image)
514 if not isinstance(kernel, Fourier):
515 kernel = Fourier(kernel)
516
517 convolved = _kspace_operation(image, kernel, padding, operator.mul, image.shape, axes=axes)
518 if return_fourier:
519 return convolved
520 else:
521 return np.real(convolved.image)
std::vector< SchemaItem< Flag > > * items
__init__(self, np.ndarray image, dict[Sequence[Sequence[int]], np.ndarray]|None image_fft=None)
Definition fft.py:243
np.ndarray image(self)
Definition fft.py:304
tuple[int,...] shape(self)
Definition fft.py:309
Fourier __getitem__(self, int|Sequence[int]|slice index)
Definition fft.py:347
Fourier from_fft(np.ndarray image_fft, Sequence[int] fft_shape, Sequence[int] image_shape, int|Sequence[int]|None axes=None, DTypeLike dtype=float)
Definition fft.py:257
np.ndarray|Fourier convolve(np.ndarray|Fourier image, np.ndarray|Fourier kernel, int padding=3, int|Sequence[int] axes=(-2, -1), bool return_fourier=True, bool normalize=False)
Definition fft.py:488
Fourier _kspace_operation(Fourier image1, Fourier image2, int padding, Callable op, Sequence[int] shape, int|Sequence[int] axes)
Definition fft.py:381
np.ndarray _pad(np.ndarray arr, Sequence[int] newshape, int|Sequence[int]|None axes=None, str mode="constant", float constant_values=0)
Definition fft.py:104
tuple get_fft_shape(np.ndarray|Sequence[int] im_or_shape1, np.ndarray|Sequence[int] im_or_shape2, int padding=3, int|Sequence[int]|None axes=None, bool use_max=False)
Definition fft.py:158
np.ndarray centered(np.ndarray arr, Sequence[int] newshape)
Definition fft.py:34
np.ndarray fast_zero_pad(np.ndarray arr, Sequence[Sequence[int]] pad_width)
Definition fft.py:71
Fourier|np.ndarray match_kernel(np.ndarray|Fourier kernel1, np.ndarray|Fourier kernel2, int padding=3, int|Sequence[int] axes=(-2, -1), bool return_fourier=True, bool normalize=False)
Definition fft.py:440