LSST Applications g013ef56533+63812263fb,g083dd6704c+a047e97985,g199a45376c+0ba108daf9,g1fd858c14a+fde7a7a78c,g210f2d0738+db0c280453,g262e1987ae+abed931625,g29ae962dfc+058d1915d8,g2cef7863aa+aef1011c0b,g35bb328faa+8c5ae1fdc5,g3fd5ace14f+64337f1634,g47891489e3+f459a6810c,g53246c7159+8c5ae1fdc5,g54cd7ddccb+890c8e1e5d,g5a60e81ecd+d9e514a434,g64539dfbff+db0c280453,g67b6fd64d1+f459a6810c,g6ebf1fc0d4+8c5ae1fdc5,g7382096ae9+36d16ea71a,g74acd417e5+c70e70fbf6,g786e29fd12+668abc6043,g87389fa792+8856018cbb,g89139ef638+f459a6810c,g8d7436a09f+1b779678e3,g8ea07a8fe4+81eaaadc04,g90f42f885a+34c0557caf,g97be763408+9583a964dd,g98a1a72a9c+028271c396,g98df359435+530b675b85,gb8cb2b794d+4e54f68785,gbf99507273+8c5ae1fdc5,gc2a301910b+db0c280453,gca7fc764a6+f459a6810c,gd7ef33dd92+f459a6810c,gdab6d2f7ff+c70e70fbf6,ge410e46f29+f459a6810c,ge41e95a9f2+db0c280453,geaed405ab2+e3b4b2a692,gf9a733ac38+8c5ae1fdc5,w.2025.43
LSST Data Management Base Package
Loading...
Searching...
No Matches
blend.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Blend"]
25
26from abc import ABC, abstractmethod
27from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
28
29import numpy as np
30
31from .bbox import Box
32from .component import Component, FactorizedComponent
33from .image import Image
34from .observation import Observation
35from .source import Source
36
37if TYPE_CHECKING:
38 from .io import ScarletBlendData, ScarletSourceBaseData
39
40
41class BlendBase(ABC):
42 """A base class for blends that can be extended to add additional
43 functionality.
44
45 This class holds all of the sources and observation that are to be fit,
46 as well as performing fitting and joint initialization of the
47 spectral components (when applicable).
48
49 Parameters
50 ----------
51 sources:
52 The sources to fit.
53 observation:
54 The observation that contains the images,
55 PSF, etc. that are being fit.
56 metadata:
57 Additional metadata to store with the blend.
58 """
59
60 sources: list[Source]
61 observation: Observation
62 metadata: dict | None
63
64 @property
65 def shape(self) -> tuple[int, int, int]:
66 """Shape of the model for the entire `Blend`."""
67 return self.observation.shape
68
69 @property
70 def bbox(self) -> Box:
71 """The bounding box of the entire blend."""
72 return self.observation.bbox
73
74 @property
75 def components(self) -> list[Component]:
76 """The list of all components in the blend.
77
78 Since the list of sources might change,
79 this is always built on the fly.
80 """
81 return [c for src in self.sources for c in src.components]
82
83 @abstractmethod
84 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
85 """Generate a model of the entire blend.
86
87 Parameters
88 ----------
89 convolve:
90 Whether to convolve the model with the observed PSF in each band.
91 use_flux:
92 Whether to use the re-distributed flux associated with the sources
93 instead of the component models.
94
95 Returns
96 -------
97 model:
98 The model created by combining all of the source models.
99 """
100
101 @abstractmethod
102 def to_data(self) -> ScarletBlendData:
103 """Convert the blend into a serializable dictionary format.
104
105 Returns
106 -------
107 data:
108 A dictionary containing all of the information needed to
109 reconstruct the blend.
110 """
111
112
113class Blend(BlendBase):
114 """A single blend.
115
116 This class holds all of the sources and observation that are to be fit,
117 as well as performing fitting and joint initialization of the
118 spectral components (when applicable).
119
120 Parameters
121 ----------
122 sources:
123 The sources to fit.
124 observation:
125 The observation that contains the images,
126 PSF, etc. that are being fit.
127 metadata:
128 Additional metadata to store with the blend.
129 """
130
131 def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None):
132 self.sources = list(sources)
133 self.observation = observation
134 if metadata is not None and len(metadata) == 0:
135 metadata = None
136 self.metadata = metadata
137
138 # Initialize the iteration count and loss function
139 self.it = 0
140 self.loss: list[float] = []
141
142 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
143 """Generate a model of the entire blend.
144
145 Parameters
146 ----------
147 convolve:
148 Whether to convolve the model with the observed PSF in each band.
149 use_flux:
150 Whether to use the re-distributed flux associated with the sources
151 instead of the component models.
152
153 Returns
154 -------
155 model:
156 The model created by combining all of the source models.
157 """
158 model = Image(
159 np.zeros(self.shape, dtype=self.observation.images.dtype),
160 bands=self.observation.bands,
161 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
162 )
163
164 if use_flux:
165 for src in self.sources:
166 if src.flux_weighted_image is None:
167 raise ValueError(
168 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
169 )
170 src.flux_weighted_image.insert_into(model)
171 else:
172 for component in self.components:
173 component.get_model().insert_into(model)
174 if convolve:
175 return self.observation.convolve(model)
176 return model
177
178 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
179 """Gradient of the likelihood wrt the unconvolved model
180
181 Returns
182 -------
183 result:
184 The gradient of the likelihood wrt the model
185 model_data:
186 The convol model data used to calculate the gradient.
187 This can be useful for debugging but is not used in
188 production.
189 """
190 model = self.get_model(convolve=True)
191 # Update the loss
192 self.loss.append(self.observation.log_likelihood(model))
193 # Calculate the gradient wrt the model d(logL)/d(model)
194 result = self.observation.weights * (model - self.observation.images)
195 result = self.observation.convolve(result, grad=True)
196 return result, model.data
197
198 @property
199 def log_likelihood(self) -> float:
200 """The current log-likelihood
201
202 This is calculated on the fly to ensure that it is always up to date
203 with the current model parameters.
204 """
205 return self.observation.log_likelihood(self.get_model(convolve=True))
206
207 def fit_spectra(self, clip: bool = False) -> Blend:
208 """Fit all of the spectra given their current morphologies with a
209 linear least squares algorithm.
210
211 Parameters
212 ----------
213 clip:
214 Whether or not to clip components that were not
215 assigned any flux during the fit.
216
217 Returns
218 -------
219 blend:
220 The blend with updated components is returned.
221 """
222 from .initialization import multifit_spectra
223
224 morphs = []
225 spectra = []
226 factorized_indices = []
227 model = Image.from_box(
228 self.observation.bbox,
229 bands=self.observation.bands,
230 dtype=self.observation.dtype,
231 )
232 components = self.components
233 for idx, component in enumerate(components):
234 if hasattr(component, "morph") and hasattr(component, "spectrum"):
235 component = cast(FactorizedComponent, component)
236 morphs.append(component.morph)
237 spectra.append(component.spectrum)
238 factorized_indices.append(idx)
239 else:
240 model.insert(component.get_model())
241 model = self.observation.convolve(model, mode="real")
242
243 boxes = [c.bbox for c in components]
244 fit_spectra = multifit_spectra(
245 self.observation,
246 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
247 model,
248 )
249 for idx in range(len(morphs)):
250 component = cast(FactorizedComponent, components[factorized_indices[idx]])
251 component.spectrum[:] = fit_spectra[idx]
252 component.spectrum[component.spectrum < 0] = 0
253
254 # Run the proxes for all of the components to make sure that the
255 # spectra are consistent with the constraints.
256 # In practice this usually means making sure that they are
257 # non-negative.
258 for src in self.sources:
259 for component in src.components:
260 if (
261 hasattr(component, "spectrum")
262 and hasattr(component, "prox_spectrum")
263 and component.prox_spectrum is not None # type: ignore
264 ):
265 component.prox_spectrum(component.spectrum) # type: ignore
266
267 if clip:
268 # Remove components with no positive flux
269 for src in self.sources:
270 _components = []
271 for component in src.components:
272 component_model = component.get_model()
273 component_model.data[component_model.data < 0] = 0
274 if np.sum(component_model.data) > 0:
275 _components.append(component)
276 src.components = _components
277
278 return self
279
280 def fit(
281 self,
282 max_iter: int,
283 e_rel: float = 1e-4,
284 min_iter: int = 15,
285 resize: int = 10,
286 ) -> tuple[int, float]:
287 """Fit all of the parameters
288
289 Parameters
290 ----------
291 max_iter:
292 The maximum number of iterations
293 e_rel:
294 The relative error to use for determining convergence.
295 min_iter:
296 The minimum number of iterations.
297 resize:
298 Number of iterations before attempting to resize the
299 resizable components. If `resize` is `None` then
300 no resizing is ever attempted.
301
302 Returns
303 -------
304 it:
305 Number of iterations.
306 loss:
307 Loss for the last solution
308 """
309 while self.it < max_iter:
310 # Calculate the gradient wrt the on-convolved model
311 grad_log_likelihood = self._grad_log_likelihood()
312 if resize is not None and self.it > 0 and self.it % resize == 0:
313 do_resize = True
314 else:
315 do_resize = False
316 # Update each component given the current gradient
317 for component in self.components:
318 overlap = component.bbox & self.bbox
319 component.update(self.it, grad_log_likelihood[0][overlap].data)
320 # Check to see if any components need to be resized
321 if do_resize:
322 component.resize(self.bbox)
323 # Stopping criteria
324 self.it += 1
325 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
326 break
327 return self.it, self.loss[-1]
328
329 def parameterize(self, parameterization: Callable):
330 """Convert the component parameter arrays into Parameter instances
331
332 Parameters
333 ----------
334 parameterization:
335 A function to use to convert parameters of a given type into
336 a `Parameter` in place. It should take a single argument that
337 is the `Component` or `Source` that is to be parameterized.
338 """
339 for source in self.sources:
340 source.parameterize(parameterization)
341
342 def conserve_flux(self, mask_footprint: bool = True, weight_image: Image | None = None) -> None:
343 """Use the source models as templates to re-distribute flux
344 from the data
345
346 The source models are used as approximations to the data,
347 which redistribute the flux in the data according to the
348 ratio of the models for each source.
349 There is no return value for this function,
350 instead it adds (or modifies) a ``flux_weighted_image``
351 attribute to each the sources with the flux attributed to
352 that source.
353
354 Parameters
355 ----------
356 blend:
357 The blend that is being fit
358 mask_footprint:
359 Whether or not to apply a mask for pixels with zero weight.
360 weight_image:
361 The weight image to use for the redistribution.
362 If `None` then the observation image is used.
363 """
364 observation = self.observation
365 py = observation.psfs.shape[-2] // 2
366 px = observation.psfs.shape[-1] // 2
367
368 images = observation.images.copy()
369 if mask_footprint:
370 images.data[observation.weights.data == 0] = 0
371
372 if weight_image is None:
373 weight_image = self.get_model()
374 # Always convolve in real space to avoid FFT artifacts
375 weight_image = observation.convolve(weight_image, mode="real")
376
377 # Due to ringing in the PSF, the convolved model can have
378 # negative values. We take the absolute value to avoid
379 # negative fluxes in the flux weighted images.
380 weight_image.data[:] = np.abs(weight_image.data)
381
382 for src in self.sources:
383 if src.is_null:
384 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
385 continue
386 src_model = src.get_model()
387
388 # Grow the model to include the wings of the PSF
389 src_box = src.bbox.grow((py, px))
390 overlap = observation.bbox & src_box
391 src_model = src_model.project(bbox=overlap)
392 src_model = observation.convolve(src_model, mode="real")
393 src_model.data[:] = np.abs(src_model.data)
394 numerator = src_model.data
395 denominator = weight_image[overlap].data
396 cuts = denominator != 0
397 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
398 ratio[cuts] = numerator[cuts] / denominator[cuts]
399 ratio[denominator == 0] = 0
400 # sometimes numerical errors can cause a hot pixel to have a
401 # slightly higher ratio than 1
402 ratio[ratio > 1] = 1
403 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
404
405 def to_data(self) -> ScarletBlendData:
406 """Convert the Blend into a persistable data object
407
408 Parameters
409 ----------
410 blend :
411 The blend that is being persisted.
412
413 Returns
414 -------
415 blend_data :
416 The data model for a single blend.
417 """
418 from .io import ScarletBlendData
419
420 sources: dict[Any, ScarletSourceBaseData] = {}
421 for sidx, source in enumerate(self.sources):
422 metadata = source.metadata or {}
423 if "id" in metadata:
424 sources[metadata["id"]] = source.to_data()
425 else:
426 sources[sidx] = source.to_data()
427
428 blend_data = ScarletBlendData(
429 origin=self.bbox.origin, # type: ignore
430 shape=self.bbox.shape, # type: ignore
431 sources=sources,
432 metadata=self.metadata,
433 )
434
435 return blend_data
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
tuple[int, int, int] shape(self)
Definition blend.py:65
ScarletBlendData to_data(self)
Definition blend.py:102
list[Component] components(self)
Definition blend.py:75
Image get_model(self, bool convolve=False, bool use_flux=False)
Definition blend.py:84
None conserve_flux(self, bool mask_footprint=True, Image|None weight_image=None)
Definition blend.py:342
parameterize(self, Callable parameterization)
Definition blend.py:329
tuple[Image, np.ndarray] _grad_log_likelihood(self)
Definition blend.py:178
ScarletBlendData to_data(self)
Definition blend.py:405
__init__(self, Sequence[Source] sources, Observation observation, dict|None metadata=None)
Definition blend.py:131
Image get_model(self, bool convolve=False, bool use_flux=False)
Definition blend.py:142
Blend fit_spectra(self, bool clip=False)
Definition blend.py:207
Definition __init__.py:1
np.ndarray|Fourier convolve(np.ndarray|Fourier image, np.ndarray|Fourier kernel, int padding=3, int|Sequence[int] axes=(-2, -1), bool return_fourier=True, bool normalize=False)
Definition fft.py:488