LSST Applications g04e9c324dd+8c5ae1fdc5,g134cb467dc+b203dec576,g18429d2f64+358861cd2c,g199a45376c+0ba108daf9,g1fd858c14a+dd066899e3,g262e1987ae+ebfced1d55,g29ae962dfc+72fd90588e,g2cef7863aa+aef1011c0b,g35bb328faa+8c5ae1fdc5,g3fd5ace14f+b668f15bc5,g4595892280+3897dae354,g47891489e3+abcf9c3559,g4d44eb3520+fb4ddce128,g53246c7159+8c5ae1fdc5,g67b6fd64d1+abcf9c3559,g67fd3c3899+1f72b5a9f7,g74acd417e5+cb6b47f07b,g786e29fd12+668abc6043,g87389fa792+8856018cbb,g89139ef638+abcf9c3559,g8d7436a09f+bcf525d20c,g8ea07a8fe4+9f5ccc88ac,g90f42f885a+6054cc57f1,g97be763408+06f794da49,g9dd6db0277+1f72b5a9f7,ga681d05dcb+7e36ad54cd,gabf8522325+735880ea63,gac2eed3f23+abcf9c3559,gb89ab40317+abcf9c3559,gbf99507273+8c5ae1fdc5,gd8ff7fe66e+1f72b5a9f7,gdab6d2f7ff+cb6b47f07b,gdc713202bf+1f72b5a9f7,gdfd2d52018+8225f2b331,ge365c994fd+375fc21c71,ge410e46f29+abcf9c3559,geaed405ab2+562b3308c0,gf9a733ac38+8c5ae1fdc5,w.2025.35
LSST Data Management Base Package
Loading...
Searching...
No Matches
blend.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Blend"]
25
26from typing import Callable, Sequence, cast
27
28import numpy as np
29
30from .bbox import Box
31from .component import Component, FactorizedComponent
32from .image import Image
33from .observation import Observation
34from .source import Source
35
36
37class Blend:
38 """A single blend.
39
40 This class holds all of the sources and observation that are to be fit,
41 as well as performing fitting and joint initialization of the
42 spectral components (when applicable).
43
44 Parameters
45 ----------
46 sources:
47 The sources to fit.
48 observation:
49 The observation that contains the images,
50 PSF, etc. that are being fit.
51 metadata:
52 Additional metadata to store with the blend.
53 """
54
55 def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None):
56 self.sources = list(sources)
57 self.observation = observation
58 if metadata is not None and len(metadata) == 0:
59 metadata = None
60 self.metadata = metadata
61
62 # Initialize the iteration count and loss function
63 self.it = 0
64 self.loss: list[float] = []
65
66 @property
67 def shape(self) -> tuple[int, int, int]:
68 """Shape of the model for the entire `Blend`."""
69 return self.observation.shape
70
71 @property
72 def bbox(self) -> Box:
73 """The bounding box of the entire blend."""
74 return self.observation.bbox
75
76 @property
77 def components(self) -> list[Component]:
78 """The list of all components in the blend.
79
80 Since the list of sources might change,
81 this is always built on the fly.
82 """
83 return [c for src in self.sources for c in src.components]
84
85 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
86 """Generate a model of the entire blend.
87
88 Parameters
89 ----------
90 convolve:
91 Whether to convolve the model with the observed PSF in each band.
92 use_flux:
93 Whether to use the re-distributed flux associated with the sources
94 instead of the component models.
95
96 Returns
97 -------
98 model:
99 The model created by combining all of the source models.
100 """
101 model = Image(
102 np.zeros(self.shape, dtype=self.observation.images.dtype),
103 bands=self.observation.bands,
104 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
105 )
106
107 if use_flux:
108 for src in self.sources:
109 if src.flux_weighted_image is None:
110 raise ValueError(
111 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
112 )
113 src.flux_weighted_image.insert_into(model)
114 else:
115 for component in self.components:
116 component.get_model().insert_into(model)
117 if convolve:
118 return self.observation.convolve(model)
119 return model
120
121 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
122 """Gradient of the likelihood wrt the unconvolved model
123
124 Returns
125 -------
126 result:
127 The gradient of the likelihood wrt the model
128 model_data:
129 The convol model data used to calculate the gradient.
130 This can be useful for debugging but is not used in
131 production.
132 """
133 model = self.get_model(convolve=True)
134 # Update the loss
135 self.loss.append(self.observation.log_likelihood(model))
136 # Calculate the gradient wrt the model d(logL)/d(model)
137 result = self.observation.weights * (model - self.observation.images)
138 result = self.observation.convolve(result, grad=True)
139 return result, model.data
140
141 @property
142 def log_likelihood(self) -> float:
143 """The current log-likelihood
144
145 This is calculated on the fly to ensure that it is always up to date
146 with the current model parameters.
147 """
148 return self.observation.log_likelihood(self.get_model(convolve=True))
149
150 def fit_spectra(self, clip: bool = False) -> Blend:
151 """Fit all of the spectra given their current morphologies with a
152 linear least squares algorithm.
153
154 Parameters
155 ----------
156 clip:
157 Whether or not to clip components that were not
158 assigned any flux during the fit.
159
160 Returns
161 -------
162 blend:
163 The blend with updated components is returned.
164 """
165 from .initialization import multifit_spectra
166
167 morphs = []
168 spectra = []
169 factorized_indices = []
170 model = Image.from_box(
171 self.observation.bbox,
172 bands=self.observation.bands,
173 dtype=self.observation.dtype,
174 )
175 components = self.components
176 for idx, component in enumerate(components):
177 if hasattr(component, "morph") and hasattr(component, "spectrum"):
178 component = cast(FactorizedComponent, component)
179 morphs.append(component.morph)
180 spectra.append(component.spectrum)
181 factorized_indices.append(idx)
182 else:
183 model.insert(component.get_model())
184 model = self.observation.convolve(model, mode="real")
185
186 boxes = [c.bbox for c in components]
187 fit_spectra = multifit_spectra(
188 self.observation,
189 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
190 model,
191 )
192 for idx in range(len(morphs)):
193 component = cast(FactorizedComponent, components[factorized_indices[idx]])
194 component.spectrum[:] = fit_spectra[idx]
195 component.spectrum[component.spectrum < 0] = 0
196
197 # Run the proxes for all of the components to make sure that the
198 # spectra are consistent with the constraints.
199 # In practice this usually means making sure that they are
200 # non-negative.
201 for src in self.sources:
202 for component in src.components:
203 if (
204 hasattr(component, "spectrum")
205 and hasattr(component, "prox_spectrum")
206 and component.prox_spectrum is not None # type: ignore
207 ):
208 component.prox_spectrum(component.spectrum) # type: ignore
209
210 if clip:
211 # Remove components with no positive flux
212 for src in self.sources:
213 _components = []
214 for component in src.components:
215 component_model = component.get_model()
216 component_model.data[component_model.data < 0] = 0
217 if np.sum(component_model.data) > 0:
218 _components.append(component)
219 src.components = _components
220
221 return self
222
223 def fit(
224 self,
225 max_iter: int,
226 e_rel: float = 1e-4,
227 min_iter: int = 15,
228 resize: int = 10,
229 ) -> tuple[int, float]:
230 """Fit all of the parameters
231
232 Parameters
233 ----------
234 max_iter:
235 The maximum number of iterations
236 e_rel:
237 The relative error to use for determining convergence.
238 min_iter:
239 The minimum number of iterations.
240 resize:
241 Number of iterations before attempting to resize the
242 resizable components. If `resize` is `None` then
243 no resizing is ever attempted.
244
245 Returns
246 -------
247 it:
248 Number of iterations.
249 loss:
250 Loss for the last solution
251 """
252 while self.it < max_iter:
253 # Calculate the gradient wrt the on-convolved model
254 grad_log_likelihood = self._grad_log_likelihood()
255 if resize is not None and self.it > 0 and self.it % resize == 0:
256 do_resize = True
257 else:
258 do_resize = False
259 # Update each component given the current gradient
260 for component in self.components:
261 overlap = component.bbox & self.bbox
262 component.update(self.it, grad_log_likelihood[0][overlap].data)
263 # Check to see if any components need to be resized
264 if do_resize:
265 component.resize(self.bbox)
266 # Stopping criteria
267 self.it += 1
268 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
269 break
270 return self.it, self.loss[-1]
271
272 def parameterize(self, parameterization: Callable):
273 """Convert the component parameter arrays into Parameter instances
274
275 Parameters
276 ----------
277 parameterization:
278 A function to use to convert parameters of a given type into
279 a `Parameter` in place. It should take a single argument that
280 is the `Component` or `Source` that is to be parameterized.
281 """
282 for source in self.sources:
283 source.parameterize(parameterization)
284
285 def conserve_flux(self, mask_footprint: bool = True, weight_image: Image | None = None) -> None:
286 """Use the source models as templates to re-distribute flux
287 from the data
288
289 The source models are used as approximations to the data,
290 which redistribute the flux in the data according to the
291 ratio of the models for each source.
292 There is no return value for this function,
293 instead it adds (or modifies) a ``flux_weighted_image``
294 attribute to each the sources with the flux attributed to
295 that source.
296
297 Parameters
298 ----------
299 blend:
300 The blend that is being fit
301 mask_footprint:
302 Whether or not to apply a mask for pixels with zero weight.
303 weight_image:
304 The weight image to use for the redistribution.
305 If `None` then the observation image is used.
306 """
307 observation = self.observation
308 py = observation.psfs.shape[-2] // 2
309 px = observation.psfs.shape[-1] // 2
310
311 images = observation.images.copy()
312 if mask_footprint:
313 images.data[observation.weights.data == 0] = 0
314
315 if weight_image is None:
316 weight_image = self.get_model()
317 # Always convolve in real space to avoid FFT artifacts
318 weight_image = observation.convolve(weight_image, mode="real")
319
320 # Due to ringing in the PSF, the convolved model can have
321 # negative values. We take the absolute value to avoid
322 # negative fluxes in the flux weighted images.
323 weight_image.data[:] = np.abs(weight_image.data)
324
325 for src in self.sources:
326 if src.is_null:
327 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
328 continue
329 src_model = src.get_model()
330
331 # Grow the model to include the wings of the PSF
332 src_box = src.bbox.grow((py, px))
333 overlap = observation.bbox & src_box
334 src_model = src_model.project(bbox=overlap)
335 src_model = observation.convolve(src_model, mode="real")
336 src_model.data[:] = np.abs(src_model.data)
337 numerator = src_model.data
338 denominator = weight_image[overlap].data
339 cuts = denominator != 0
340 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
341 ratio[cuts] = numerator[cuts] / denominator[cuts]
342 ratio[denominator == 0] = 0
343 # sometimes numerical errors can cause a hot pixel to have a
344 # slightly higher ratio than 1
345 ratio[ratio > 1] = 1
346 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
None conserve_flux(self, bool mask_footprint=True, Image|None weight_image=None)
Definition blend.py:285
tuple[int, int, int] shape(self)
Definition blend.py:67
parameterize(self, Callable parameterization)
Definition blend.py:272
tuple[Image, np.ndarray] _grad_log_likelihood(self)
Definition blend.py:121
__init__(self, Sequence[Source] sources, Observation observation, dict|None metadata=None)
Definition blend.py:55
Image get_model(self, bool convolve=False, bool use_flux=False)
Definition blend.py:85
Blend fit_spectra(self, bool clip=False)
Definition blend.py:150
Definition __init__.py:1
np.ndarray|Fourier convolve(np.ndarray|Fourier image, np.ndarray|Fourier kernel, int padding=3, int|Sequence[int] axes=(-2, -1), bool return_fourier=True, bool normalize=False)
Definition fft.py:488