LSST Applications g063fba187b+cac8b7c890,g0f08755f38+6aee506743,g1653933729+a8ce1bb630,g168dd56ebc+a8ce1bb630,g1a2382251a+b4475c5878,g1dcb35cd9c+8f9bc1652e,g20f6ffc8e0+6aee506743,g217e2c1bcf+73dee94bd0,g28da252d5a+1f19c529b9,g2bbee38e9b+3f2625acfc,g2bc492864f+3f2625acfc,g3156d2b45e+6e55a43351,g32e5bea42b+1bb94961c2,g347aa1857d+3f2625acfc,g35bb328faa+a8ce1bb630,g3a166c0a6a+3f2625acfc,g3e281a1b8c+c5dd892a6c,g3e8969e208+a8ce1bb630,g414038480c+5927e1bc1e,g41af890bb2+8a9e676b2a,g7af13505b9+809c143d88,g80478fca09+6ef8b1810f,g82479be7b0+f568feb641,g858d7b2824+6aee506743,g89c8672015+f4add4ffd5,g9125e01d80+a8ce1bb630,ga5288a1d22+2903d499ea,gb58c049af0+d64f4d3760,gc28159a63d+3f2625acfc,gcab2d0539d+b12535109e,gcf0d15dbbd+46a3f46ba9,gda6a2b7d83+46a3f46ba9,gdaeeff99f8+1711a396fd,ge79ae78c31+3f2625acfc,gef2f8181fd+0a71e47438,gf0baf85859+c1f95f4921,gfa517265be+6aee506743,gfa999e8aa5+17cd334064,w.2024.51
LSST Data Management Base Package
Loading...
Searching...
No Matches
blend.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["Blend"]
25
26from typing import Callable, Sequence, cast
27
28import numpy as np
29
30from .bbox import Box
31from .component import Component, FactorizedComponent
32from .image import Image
33from .observation import Observation
34from .source import Source
35
36
37class Blend:
38 """A single blend.
39
40 This class holds all of the sources and observation that are to be fit,
41 as well as performing fitting and joint initialization of the
42 spectral components (when applicable).
43
44 Parameters
45 ----------
46 sources:
47 The sources to fit.
48 observation:
49 The observation that contains the images,
50 PSF, etc. that are being fit.
51 """
52
53 def __init__(self, sources: Sequence[Source], observation: Observation):
54 self.sources = list(sources)
55 self.observation = observation
56
57 # Initialize the iteration count and loss function
58 self.it = 0
59 self.loss: list[float] = []
60
61 @property
62 def shape(self) -> tuple[int, int, int]:
63 """Shape of the model for the entire `Blend`."""
64 return self.observation.shape
65
66 @property
67 def bbox(self) -> Box:
68 """The bounding box of the entire blend."""
69 return self.observation.bbox
70
71 @property
72 def components(self) -> list[Component]:
73 """The list of all components in the blend.
74
75 Since the list of sources might change,
76 this is always built on the fly.
77 """
78 return [c for src in self.sources for c in src.components]
79
80 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
81 """Generate a model of the entire blend.
82
83 Parameters
84 ----------
85 convolve:
86 Whether to convolve the model with the observed PSF in each band.
87 use_flux:
88 Whether to use the re-distributed flux associated with the sources
89 instead of the component models.
90
91 Returns
92 -------
93 model:
94 The model created by combining all of the source models.
95 """
96 model = Image(
97 np.zeros(self.shapeshape, dtype=self.observation.images.dtype),
98 bands=self.observation.bands,
99 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
100 )
101
102 if use_flux:
103 for src in self.sources:
104 if src.flux_weighted_image is None:
105 raise ValueError(
106 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
107 )
108 src.flux_weighted_image.insert_into(model)
109 else:
110 for component in self.components:
111 component.get_model().insert_into(model)
112 if convolve:
113 return self.observation.convolve(model)
114 return model
115
116 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
117 """Gradient of the likelihood wrt the unconvolved model
118
119 Returns
120 -------
121 result:
122 The gradient of the likelihood wrt the model
123 model_data:
124 The convol model data used to calculate the gradient.
125 This can be useful for debugging but is not used in
126 production.
127 """
128 model = self.get_model(convolve=True)
129 # Update the loss
130 self.loss.append(self.observation.log_likelihood(model))
131 # Calculate the gradient wrt the model d(logL)/d(model)
132 result = self.observation.weights * (model - self.observation.images)
133 result = self.observation.convolve(result, grad=True)
134 return result, model.data
135
136 @property
137 def log_likelihood(self) -> float:
138 """The current log-likelihood
139
140 This is calculated on the fly to ensure that it is always up to date
141 with the current model parameters.
142 """
143 return self.observation.log_likelihood(self.get_model(convolve=True))
144
145 def fit_spectra(self, clip: bool = False) -> Blend:
146 """Fit all of the spectra given their current morphologies with a
147 linear least squares algorithm.
148
149 Parameters
150 ----------
151 clip:
152 Whether or not to clip components that were not
153 assigned any flux during the fit.
154
155 Returns
156 -------
157 blend:
158 The blend with updated components is returned.
159 """
160 from .initialization import multifit_spectra
161
162 morphs = []
163 spectra = []
164 factorized_indices = []
165 model = Image.from_box(
166 self.observation.bbox,
167 bands=self.observation.bands,
168 dtype=self.observation.dtype,
169 )
170 components = self.components
171 for idx, component in enumerate(components):
172 if hasattr(component, "morph") and hasattr(component, "spectrum"):
173 component = cast(FactorizedComponent, component)
174 morphs.append(component.morph)
175 spectra.append(component.spectrum)
176 factorized_indices.append(idx)
177 else:
178 model.insert(component.get_model())
179 model = self.observation.convolve(model, mode="real")
180
181 boxes = [c.bbox for c in components]
182 fit_spectra = multifit_spectra(
183 self.observation,
184 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
185 model,
186 )
187 for idx in range(len(morphs)):
188 component = cast(FactorizedComponent, components[factorized_indices[idx]])
189 component.spectrum[:] = fit_spectra[idx]
190 component.spectrum[component.spectrum < 0] = 0
191
192 # Run the proxes for all of the components to make sure that the
193 # spectra are consistent with the constraints.
194 # In practice this usually means making sure that they are
195 # non-negative.
196 for src in self.sources:
197 for component in src.components:
198 if (
199 hasattr(component, "spectrum")
200 and hasattr(component, "prox_spectrum")
201 and component.prox_spectrum is not None # type: ignore
202 ):
203 component.prox_spectrum(component.spectrum) # type: ignore
204
205 if clip:
206 # Remove components with no positive flux
207 for src in self.sources:
208 _components = []
209 for component in src.components:
210 component_model = component.get_model()
211 component_model.data[component_model.data < 0] = 0
212 if np.sum(component_model.data) > 0:
213 _components.append(component)
214 src.components = _components
215
216 return self
217
218 def fit(
219 self,
220 max_iter: int,
221 e_rel: float = 1e-4,
222 min_iter: int = 15,
223 resize: int = 10,
224 ) -> tuple[int, float]:
225 """Fit all of the parameters
226
227 Parameters
228 ----------
229 max_iter:
230 The maximum number of iterations
231 e_rel:
232 The relative error to use for determining convergence.
233 min_iter:
234 The minimum number of iterations.
235 resize:
236 Number of iterations before attempting to resize the
237 resizable components. If `resize` is `None` then
238 no resizing is ever attempted.
239
240 Returns
241 -------
242 it:
243 Number of iterations.
244 loss:
245 Loss for the last solution
246 """
247 while self.it < max_iter:
248 # Calculate the gradient wrt the on-convolved model
249 grad_log_likelihood = self._grad_log_likelihood()
250 if resize is not None and self.it > 0 and self.it % resize == 0:
251 do_resize = True
252 else:
253 do_resize = False
254 # Update each component given the current gradient
255 for component in self.components:
256 overlap = component.bbox & self.bboxbbox
257 component.update(self.it, grad_log_likelihood[0][overlap].data)
258 # Check to see if any components need to be resized
259 if do_resize:
260 component.resize(self.bboxbbox)
261 # Stopping criteria
262 self.it += 1
263 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
264 break
265 return self.it, self.loss[-1]
266
267 def parameterize(self, parameterization: Callable):
268 """Convert the component parameter arrays into Parameter instances
269
270 Parameters
271 ----------
272 parameterization:
273 A function to use to convert parameters of a given type into
274 a `Parameter` in place. It should take a single argument that
275 is the `Component` or `Source` that is to be parameterized.
276 """
277 for source in self.sources:
278 source.parameterize(parameterization)
279
280 def conserve_flux(self, mask_footprint: bool = True) -> None:
281 """Use the source models as templates to re-distribute flux
282 from the data
283
284 The source models are used as approximations to the data,
285 which redistribute the flux in the data according to the
286 ratio of the models for each source.
287 There is no return value for this function,
288 instead it adds (or modifies) a ``flux_weighted_image``
289 attribute to each the sources with the flux attributed to
290 that source.
291
292 Parameters
293 ----------
294 blend:
295 The blend that is being fit
296 mask_footprint:
297 Whether or not to apply a mask for pixels with zero weight.
298 """
299 observation = self.observation
300 py = observation.psfs.shape[-2] // 2
301 px = observation.psfs.shape[-1] // 2
302
303 images = observation.images.copy()
304 if mask_footprint:
305 images.data[observation.weights.data == 0] = 0
306 model = self.get_model()
307 # Always convolve in real space to avoid FFT artifacts
308 model = observation.convolve(model, mode="real")
309 model.data[model.data < 0] = 0
310
311 for src in self.sources:
312 if src.is_null:
313 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
314 continue
315 src_model = src.get_model()
316
317 # Grow the model to include the wings of the PSF
318 src_box = src.bbox.grow((py, px))
319 overlap = observation.bbox & src_box
320 src_model = src_model.project(bbox=overlap)
321 src_model = observation.convolve(src_model, mode="real")
322 src_model.data[src_model.data < 0] = 0
323 numerator = src_model.data
324 denominator = model[overlap].data
325 cuts = denominator != 0
326 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
327 ratio[cuts] = numerator[cuts] / denominator[cuts]
328 ratio[denominator == 0] = 0
329 # sometimes numerical errors can cause a hot pixel to have a
330 # slightly higher ratio than 1
331 ratio[ratio > 1] = 1
332 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
table::Key< table::Array< int > > components
A class to represent a 2-dimensional array of pixels.
Definition Image.h:51
tuple[int, int, int] shape(self)
Definition blend.py:62
parameterize(self, Callable parameterization)
Definition blend.py:267
tuple[Image, np.ndarray] _grad_log_likelihood(self)
Definition blend.py:116
list[Component] components(self)
Definition blend.py:72
None conserve_flux(self, bool mask_footprint=True)
Definition blend.py:280
__init__(self, Sequence[Source] sources, Observation observation)
Definition blend.py:53
Image get_model(self, bool convolve=False, bool use_flux=False)
Definition blend.py:80
Blend fit_spectra(self, bool clip=False)
Definition blend.py:145