LSST Applications g0f08755f38+c89d42e150,g1635faa6d4+b6cf076a36,g1653933729+a8ce1bb630,g1a0ca8cf93+4c08b13bf7,g28da252d5a+f33f8200ef,g29321ee8c0+0187be18b1,g2bbee38e9b+9634bc57db,g2bc492864f+9634bc57db,g2cdde0e794+c2c89b37c4,g3156d2b45e+41e33cbcdc,g347aa1857d+9634bc57db,g35bb328faa+a8ce1bb630,g3a166c0a6a+9634bc57db,g3e281a1b8c+9f2c4e2fc3,g414038480c+077ccc18e7,g41af890bb2+e740673f1a,g5fbc88fb19+17cd334064,g7642f7d749+c89d42e150,g781aacb6e4+a8ce1bb630,g80478fca09+f8b2ab54e1,g82479be7b0+e2bd23ab8b,g858d7b2824+c89d42e150,g9125e01d80+a8ce1bb630,g9726552aa6+10f999ec6a,ga5288a1d22+065360aec4,gacf8899fa4+9553554aa7,gae0086650b+a8ce1bb630,gb58c049af0+d64f4d3760,gbd46683f8f+ac57cbb13d,gc28159a63d+9634bc57db,gcf0d15dbbd+e37acf7834,gda3e153d99+c89d42e150,gda6a2b7d83+e37acf7834,gdaeeff99f8+1711a396fd,ge2409df99d+cb1e6652d6,ge79ae78c31+9634bc57db,gf0baf85859+147a0692ba,gf3967379c6+02b11634a5,w.2024.45
LSST Data Management Base Package
Loading...
Searching...
No Matches
blend.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Blend"]
25
26from typing import Callable, Sequence, cast
27
28import numpy as np
29
30from .bbox import Box
31from .component import Component, FactorizedComponent
32from .image import Image
33from .observation import Observation
34from .source import Source
35
36
37class Blend:
38 """A single blend.
39
40 This class holds all of the sources and observation that are to be fit,
41 as well as performing fitting and joint initialization of the
42 spectral components (when applicable).
43
44 Parameters
45 ----------
46 sources:
47 The sources to fit.
48 observation:
49 The observation that contains the images,
50 PSF, etc. that are being fit.
51 """
52
53 def __init__(self, sources: Sequence[Source], observation: Observation):
54 self.sources = list(sources)
55 self.observation = observation
56
57 # Initialize the iteration count and loss function
58 self.it = 0
59 self.loss: list[float] = []
60
61 @property
62 def shape(self) -> tuple[int, int, int]:
63 """Shape of the model for the entire `Blend`."""
64 return self.observation.shape
65
66 @property
67 def bbox(self) -> Box:
68 """The bounding box of the entire blend."""
69 return self.observation.bbox
70
71 @property
72 def components(self) -> list[Component]:
73 """The list of all components in the blend.
74
75 Since the list of sources might change,
76 this is always built on the fly.
77 """
78 return [c for src in self.sources for c in src.components]
79
80 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
81 """Generate a model of the entire blend.
82
83 Parameters
84 ----------
85 convolve:
86 Whether to convolve the model with the observed PSF in each band.
87 use_flux:
88 Whether to use the re-distributed flux associated with the sources
89 instead of the component models.
90
91 Returns
92 -------
93 model:
94 The model created by combining all of the source models.
95 """
96 model = Image(
97 np.zeros(self.shapeshape, dtype=self.observation.images.dtype),
98 bands=self.observation.bands,
99 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
100 )
101
102 if use_flux:
103 for src in self.sources:
104 if src.flux_weighted_image is None:
105 raise ValueError(
106 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
107 )
108 src.flux_weighted_image.insert_into(model)
109 else:
110 for component in self.components:
111 component.get_model().insert_into(model)
112 if convolve:
113 return self.observation.convolve(model)
114 return model
115
116 def _grad_log_likelihood(self) -> Image:
117 """Gradient of the likelihood wrt the unconvolved model"""
118 model = self.get_model(convolve=True)
119 # Update the loss
120 self.loss.append(self.observation.log_likelihood(model))
121 # Calculate the gradient wrt the model d(logL)/d(model)
122 result = self.observation.weights * (model - self.observation.images)
123 result = self.observation.convolve(result, grad=True)
124 return result
125
126 @property
127 def log_likelihood(self) -> float:
128 """The current log-likelihood
129
130 This is calculated on the fly to ensure that it is always up to date
131 with the current model parameters.
132 """
133 return self.observation.log_likelihood(self.get_model(convolve=True))
134
135 def fit_spectra(self, clip: bool = False) -> Blend:
136 """Fit all of the spectra given their current morphologies with a
137 linear least squares algorithm.
138
139 Parameters
140 ----------
141 clip:
142 Whether or not to clip components that were not
143 assigned any flux during the fit.
144
145 Returns
146 -------
147 blend:
148 The blend with updated components is returned.
149 """
150 from .initialization import multifit_spectra
151
152 morphs = []
153 spectra = []
154 factorized_indices = []
155 model = Image.from_box(
156 self.observation.bbox,
157 bands=self.observation.bands,
158 dtype=self.observation.dtype,
159 )
160 components = self.components
161 for idx, component in enumerate(components):
162 if hasattr(component, "morph") and hasattr(component, "spectrum"):
163 component = cast(FactorizedComponent, component)
164 morphs.append(component.morph)
165 spectra.append(component.spectrum)
166 factorized_indices.append(idx)
167 else:
168 model.insert(component.get_model())
169 model = self.observation.convolve(model, mode="real")
170
171 boxes = [c.bbox for c in components]
172 fit_spectra = multifit_spectra(
173 self.observation,
174 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
175 model,
176 )
177 for idx in range(len(morphs)):
178 component = cast(FactorizedComponent, components[factorized_indices[idx]])
179 component.spectrum[:] = fit_spectra[idx]
180 component.spectrum[component.spectrum < 0] = 0
181
182 # Run the proxes for all of the components to make sure that the
183 # spectra are consistent with the constraints.
184 # In practice this usually means making sure that they are
185 # non-negative.
186 for src in self.sources:
187 for component in src.components:
188 if (
189 hasattr(component, "spectrum")
190 and hasattr(component, "prox_spectrum")
191 and component.prox_spectrum is not None # type: ignore
192 ):
193 component.prox_spectrum(component.spectrum) # type: ignore
194
195 if clip:
196 # Remove components with no positive flux
197 for src in self.sources:
198 _components = []
199 for component in src.components:
200 component_model = component.get_model()
201 component_model.data[component_model.data < 0] = 0
202 if np.sum(component_model.data) > 0:
203 _components.append(component)
204 src.components = _components
205
206 return self
207
208 def fit(
209 self,
210 max_iter: int,
211 e_rel: float = 1e-4,
212 min_iter: int = 15,
213 resize: int = 10,
214 ) -> tuple[int, float]:
215 """Fit all of the parameters
216
217 Parameters
218 ----------
219 max_iter:
220 The maximum number of iterations
221 e_rel:
222 The relative error to use for determining convergence.
223 min_iter:
224 The minimum number of iterations.
225 resize:
226 Number of iterations before attempting to resize the
227 resizable components. If `resize` is `None` then
228 no resizing is ever attempted.
229
230 Returns
231 -------
232 it:
233 Number of iterations.
234 loss:
235 Loss for the last solution
236 """
237 while self.it < max_iter:
238 # Calculate the gradient wrt the on-convolved model
239 grad_log_likelihood = self._grad_log_likelihood()
240 if resize is not None and self.it > 0 and self.it % resize == 0:
241 do_resize = True
242 else:
243 do_resize = False
244 # Update each component given the current gradient
245 for component in self.components:
246 overlap = component.bbox & self.bboxbbox
247 component.update(self.it, grad_log_likelihood[overlap].data)
248 # Check to see if any components need to be resized
249 if do_resize:
250 component.resize(self.bboxbbox)
251 # Stopping criteria
252 self.it += 1
253 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
254 break
255 return self.it, self.loss[-1]
256
257 def parameterize(self, parameterization: Callable):
258 """Convert the component parameter arrays into Parameter instances
259
260 Parameters
261 ----------
262 parameterization:
263 A function to use to convert parameters of a given type into
264 a `Parameter` in place. It should take a single argument that
265 is the `Component` or `Source` that is to be parameterized.
266 """
267 for source in self.sources:
268 source.parameterize(parameterization)
269
270 def conserve_flux(self, mask_footprint: bool = True) -> None:
271 """Use the source models as templates to re-distribute flux
272 from the data
273
274 The source models are used as approximations to the data,
275 which redistribute the flux in the data according to the
276 ratio of the models for each source.
277 There is no return value for this function,
278 instead it adds (or modifies) a ``flux_weighted_image``
279 attribute to each the sources with the flux attributed to
280 that source.
281
282 Parameters
283 ----------
284 blend:
285 The blend that is being fit
286 mask_footprint:
287 Whether or not to apply a mask for pixels with zero weight.
288 """
289 observation = self.observation
290 py = observation.psfs.shape[-2] // 2
291 px = observation.psfs.shape[-1] // 2
292
293 images = observation.images.copy()
294 if mask_footprint:
295 images.data[observation.weights.data == 0] = 0
296 model = self.get_model()
297 # Always convolve in real space to avoid FFT artifacts
298 model = observation.convolve(model, mode="real")
299 model.data[model.data < 0] = 0
300
301 for src in self.sources:
302 if src.is_null:
303 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
304 continue
305 src_model = src.get_model()
306
307 # Grow the model to include the wings of the PSF
308 src_box = src.bbox.grow((py, px))
309 overlap = observation.bbox & src_box
310 src_model = src_model.project(bbox=overlap)
311 src_model = observation.convolve(src_model, mode="real")
312 src_model.data[src_model.data < 0] = 0
313 numerator = src_model.data
314 denominator = model[overlap].data
315 cuts = denominator != 0
316 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
317 ratio[cuts] = numerator[cuts] / denominator[cuts]
318 ratio[denominator == 0] = 0
319 # sometimes numerical errors can cause a hot pixel to have a
320 # slightly higher ratio than 1
321 ratio[ratio > 1] = 1
322 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
table::Key< table::Array< int > > components
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
tuple[int, int, int] shape(self)
Definition blend.py:62
parameterize(self, Callable parameterization)
Definition blend.py:257
Image _grad_log_likelihood(self)
Definition blend.py:116
list[Component] components(self)
Definition blend.py:72
None conserve_flux(self, bool mask_footprint=True)
Definition blend.py:270
__init__(self, Sequence[Source] sources, Observation observation)
Definition blend.py:53
Image get_model(self, bool convolve=False, bool use_flux=False)
Definition blend.py:80
Blend fit_spectra(self, bool clip=False)
Definition blend.py:135