LSST Applications g0f08755f38+82efc23009,g12f32b3c4e+e7bdf1200e,g1653933729+a8ce1bb630,g1a0ca8cf93+50eff2b06f,g28da252d5a+52db39f6a5,g2bbee38e9b+37c5a29d61,g2bc492864f+37c5a29d61,g2cdde0e794+c05ff076ad,g3156d2b45e+41e33cbcdc,g347aa1857d+37c5a29d61,g35bb328faa+a8ce1bb630,g3a166c0a6a+37c5a29d61,g3e281a1b8c+fb992f5633,g414038480c+7f03dfc1b0,g41af890bb2+11b950c980,g5fbc88fb19+17cd334064,g6b1c1869cb+12dd639c9a,g781aacb6e4+a8ce1bb630,g80478fca09+72e9651da0,g82479be7b0+04c31367b4,g858d7b2824+82efc23009,g9125e01d80+a8ce1bb630,g9726552aa6+8047e3811d,ga5288a1d22+e532dc0a0b,gae0086650b+a8ce1bb630,gb58c049af0+d64f4d3760,gc28159a63d+37c5a29d61,gcf0d15dbbd+2acd6d4d48,gd7358e8bfb+778a810b6e,gda3e153d99+82efc23009,gda6a2b7d83+2acd6d4d48,gdaeeff99f8+1711a396fd,ge2409df99d+6b12de1076,ge79ae78c31+37c5a29d61,gf0baf85859+d0a5978c5a,gf3967379c6+4954f8c433,gfb92a5be7c+82efc23009,gfec2e1e490+2aaed99252,w.2024.46
LSST Data Management Base Package
Loading...
Searching...
No Matches
parameters.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = [
25 "parameter",
26 "Parameter",
27 "FistaParameter",
28 "AdaproxParameter",
29 "FixedParameter",
30 "relative_step",
31 "phi_psi",
32 "DEFAULT_ADAPROX_FACTOR",
33]
34
35from typing import Callable, Sequence, cast
36
37import numpy as np
38import numpy.typing as npt
39
40from .bbox import Box
41
42# The default factor used for adaprox parameter steps
43DEFAULT_ADAPROX_FACTOR = 1e-2
44
45
46def step_function_wrapper(step: float) -> Callable:
47 """Wrapper to make a numerical step into a step function
48
49 Parameters
50 ----------
51 step:
52 The step to take for a given array.
53
54 Returns
55 -------
56 step_function:
57 The step function that takes an array and returns the
58 numerical step.
59 """
60 return lambda x: step
61
62
64 """A parameter in a `Component`
65
66 Parameters
67 ----------
68 x:
69 The array of values that is being fit.
70 helpers:
71 A dictionary of helper arrays that are used by an optimizer to
72 persist values like the gradient of `x`, the Hessian of `x`, etc.
73 step:
74 A numerical step value or function to calculate the step for a
75 given `x``.
76 grad:
77 A function to calculate the gradient of `x`.
78 prox:
79 A function to take the proximal operator of `x`.
80 """
81
83 self,
84 x: np.ndarray,
85 helpers: dict[str, np.ndarray],
86 step: Callable | float,
87 grad: Callable | None = None,
88 prox: Callable | None = None,
89 ):
90 self.x = x
91 self.helpers = helpers
92
93 if isinstance(step, float):
94 _step = step_function_wrapper(step)
95 else:
96 _step = step
97
98 self._step = _step
99 self.grad = grad
100 self.prox = prox
101
102 @property
103 def step(self) -> float:
104 """Calculate the step
105
106 Return
107 ------
108 step:
109 The numerical step if no iteration is given.
110 """
111 return self._step(self.x)
112
113 @property
114 def shape(self) -> tuple[int, ...]:
115 """The shape of the array that is being fit."""
116 return self.x.shape
117
118 @property
119 def dtype(self) -> npt.DTypeLike:
120 """The numpy dtype of the array that is being fit."""
121 return self.x.dtype
122
123 def copy(self) -> Parameter:
124 """Copy this parameter, including all of the helper arrays."""
125 helpers = {k: v.copy() for k, v in self.helpers.items()}
126 return Parameter(self.x.copy(), helpers, 0)
127
128 def update(self, it: int, input_grad: np.ndarray, *args):
129 """Update the parameter in one iteration.
130
131 This includes the gradient update, proximal update,
132 and any meta parameters that are stored as class
133 attributes to update the parameter.
134
135 Parameters
136 ----------
137 it:
138 The current iteration
139 input_grad:
140 The gradient from the full model, passed to the parameter.
141 """
142 raise NotImplementedError("Base Parameters cannot be updated")
143
144 def resize(self, old_box: Box, new_box: Box):
145 """Grow the parameter and all of the helper parameters
146
147 Parameters
148 ----------
149 old_box:
150 The old bounding box for the parameter.
151 new_box:
152 The new bounding box for the parameter.
153 """
154 slices = new_box.overlapped_slices(old_box)
155 x = np.zeros(new_box.shape, dtype=self.dtype)
156 x[slices[0]] = self.x[slices[1]]
157 self.x = x
158
159 for name, value in self.helpers.items():
160 result = np.zeros(new_box.shape, dtype=self.dtype)
161 result[slices[0]] = value[slices[1]]
162 self.helpers[name] = result
163
164
165def parameter(x: np.ndarray | Parameter) -> Parameter:
166 """Convert a `np.ndarray` into a `Parameter`.
167
168 Parameters
169 ----------
170 x:
171 The array or parameter to convert into a `Parameter`.
172
173 Returns
174 -------
175 result:
176 `x`, converted into a `Parameter` if necessary.
177 """
178 if isinstance(x, Parameter):
179 return x
180 return Parameter(x, {}, 0)
181
182
184 """A `Parameter` that updates itself using the Beck-Teboulle 2009
185 FISTA proximal gradient method.
186
187 See https://www.ceremade.dauphine.fr/~carlier/FISTA
188 """
189
191 self,
192 x: np.ndarray,
193 step: float,
194 grad: Callable | None = None,
195 prox: Callable | None = None,
196 t0: float = 1,
197 z0: np.ndarray | None = None,
198 ):
199 if z0 is None:
200 z0 = x
201
202 super().__init__(
203 x,
204 {"z": z0},
205 step,
206 grad,
207 prox,
208 )
209 self.t = t0
210
211 def update(self, it: int, input_grad: np.ndarray, *args):
212 """Update the parameter and meta-parameters using the PGM
213
214 See `Parameter` for the full description.
215 """
216 step = self.step / np.sum(args[0] * args[0])
217 _x = self.x
218 _z = self.helpers["z"]
219
220 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args)
221 if self.prox is not None:
222 x = self.prox(y)
223 else:
224 x = y
225 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2))
226 omega = 1 + (self.t - 1) / t
227 self.helpers["z"] = _x + omega * (x - _x)
228 _x[:] = x
229 self.t = t
230
231
232# The following code block contains different update methods for
233# various implementations of ADAM.
234# We currently use the `amsgrad_phi_psi` update by default,
235# but it can easily be interchanged by passing a different
236# variant name to the `AdaproxParameter`.
237
238
239# noinspection PyUnusedLocal
240def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
241 # moving averages
242 m[:] = (1 - b1[it]) * g + b1[it] * m
243 v[:] = (1 - b2) * (g**2) + b2 * v
244
245 # bias correction
246 t = it + 1
247 phi = m / (1 - b1[it] ** t)
248 psi = np.sqrt(v / (1 - b2**t)) + eps
249 return phi, psi
250
251
252# noinspection PyUnusedLocal
253def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
254 # moving averages
255 m[:] = (1 - b1[it]) * g + b1[it] * m
256 v[:] = (1 - b2) * (g**2) + b2 * v
257
258 # bias correction
259 t = it + 1
260 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
261 psi = np.sqrt(v / (1 - b2**t)) + eps
262 return phi, psi
263
264
265# noinspection PyUnusedLocal
266def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
267 # moving averages
268 m[:] = (1 - b1[it]) * g + b1[it] * m
269 v[:] = (1 - b2) * (g**2) + b2 * v
270
271 phi = m
272 vhat[:] = np.maximum(vhat, v)
273 # sanitize zero-gradient elements
274 if eps > 0:
275 vhat = np.maximum(vhat, eps)
276 psi = np.sqrt(vhat)
277 return phi, psi
278
279
280def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
281 # moving averages
282 m[:] = (1 - b1[it]) * g + b1[it] * m
283 v[:] = (1 - b2) * (g**2) + b2 * v
284
285 phi = m
286 vhat[:] = np.maximum(vhat, v)
287 # sanitize zero-gradient elements
288 if eps > 0:
289 vhat = np.maximum(vhat, eps)
290 psi = vhat**p
291 return phi, psi
292
293
294# noinspection PyUnusedLocal
295def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
296 # moving averages
297 m[:] = (1 - b1[it]) * g + b1[it] * m
298 v[:] = (1 - b2) * (g**2) + b2 * v
299
300 phi = m
301 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
302 vhat[:] = np.maximum(factor * vhat, v)
303 # sanitize zero-gradient elements
304 if eps > 0:
305 vhat = np.maximum(vhat, eps)
306 psi = np.sqrt(vhat)
307 return phi, psi
308
309
310# noinspection PyUnusedLocal
311def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
312 rho_inf = 2 / (1 - b2) - 1
313
314 # moving averages
315 m[:] = (1 - b1[it]) * g + b1[it] * m
316 v[:] = (1 - b2) * (g**2) + b2 * v
317
318 # bias correction
319 t = it + 1
320 phi = m / (1 - b1[it] ** t)
321 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
322
323 if rho > 4:
324 psi = np.sqrt(v / (1 - b2**t))
325 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
326 psi /= r
327 else:
328 psi = np.ones(g.shape, g.dtype)
329 # sanitize zero-gradient elements
330 if eps > 0:
331 psi = np.maximum(psi, np.sqrt(eps))
332 return phi, psi
333
334
335# Dictionary to link ADAM variation names to their functional algorithms.
336phi_psi = {
337 "adam": _adam_phi_psi,
338 "nadam": _nadam_phi_psi,
339 "amsgrad": _amsgrad_phi_psi,
340 "padam": _padam_phi_psi,
341 "adamx": _adamx_phi_psi,
342 "radam": _radam_phi_psi,
343}
344
345
347 """Mock an array with only a single item"""
348
349 def __init__(self, value):
350 self.value = value
351
352 def __getitem__(self, item):
353 return self.value
354
355
357 """Operator updated using te Proximal ADAM algorithm
358
359 Uses multiple variants of adaptive quasi-Newton gradient descent
360 * Adam (Kingma & Ba 2015)
361 * NAdam (Dozat 2016)
362 * AMSGrad (Reddi, Kale & Kumar 2018)
363 * PAdam (Chen & Gu 2018)
364 * AdamX (Phuong & Phong 2019)
365 * RAdam (Liu et al. 2019)
366 See details of the algorithms in the respective papers.
367 """
368
370 self,
371 x: np.ndarray,
372 step: Callable | float,
373 grad: Callable | None = None,
374 prox: Callable | None = None,
375 b1: float = 0.9,
376 b2: float = 0.999,
377 eps: float = 1e-8,
378 p: float = 0.25,
379 m0: np.ndarray | None = None,
380 v0: np.ndarray | None = None,
381 vhat0: np.ndarray | None = None,
382 scheme: str = "amsgrad",
383 prox_e_rel: float = 1e-6,
384 ):
385 shape = x.shape
386 dtype = x.dtype
387 if m0 is None:
388 m0 = np.zeros(shape, dtype=dtype)
389
390 if v0 is None:
391 v0 = np.zeros(shape, dtype=dtype)
392
393 if vhat0 is None:
394 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
395
396 super().__init__(
397 x,
398 {
399 "m": m0,
400 "v": v0,
401 "vhat": vhat0,
402 },
403 step,
404 grad,
405 prox,
406 )
407
408 if isinstance(b1, float):
409 _b1 = SingleItemArray(b1)
410 else:
411 _b1 = b1
412
413 self.b1 = _b1
414 self.b2 = b2
415 self.eps = eps
416 self.p = p
417
418 self.phi_psi = phi_psi[scheme]
419 self.e_rel = prox_e_rel
420
421 def update(self, it: int, input_grad: np.ndarray, *args):
422 """Update the parameter and meta-parameters using the PGM
423
424 See `~Parameter` for more.
425 """
426 _x = self.xx
427 # Calculate the gradient
428 grad = cast(Callable, self.grad)(input_grad, _x, *args)
429 # Get the update for the parameter
430 phi, psi = self.phi_psi(
431 it,
432 grad,
433 self.helpers["m"],
434 self.helpers["v"],
435 self.helpers["vhat"],
436 self.b1,
437 self.b2,
438 self.eps,
439 self.p,
440 )
441 # Calculate the step size
442 step = self.step
443 if it > 0:
444 _x += -step * phi / psi
445 else:
446 # This is a scheme that Peter Melchior and I came up with to
447 # dampen the known affect of ADAM, where the first iteration
448 # is often much larger than desired.
449 _x += -step * phi / psi / 10
450
451 self.xx = cast(Callable, self.prox)(_x)
452
453
455 """A parameter that is not updated"""
456
457 def __init__(self, x: np.ndarray):
458 super().__init__(x, {}, 0)
459
460 def update(self, it: int, input_grad: np.ndarray, *args):
461 pass
462
463
465 x: np.ndarray,
466 factor: float = 0.1,
467 minimum: float = 0,
468 axis: int | Sequence[int] | None = None,
469):
470 """Step size set at `factor` times the mean of `X` in direction `axis`"""
471 return np.maximum(minimum, factor * x.mean(axis=axis))
std::vector< SchemaItem< Flag > > * items
int const step
update(self, int it, np.ndarray input_grad, *args)
__init__(self, np.ndarray x, Callable|float step, Callable|None grad=None, Callable|None prox=None, float b1=0.9, float b2=0.999, float eps=1e-8, float p=0.25, np.ndarray|None m0=None, np.ndarray|None v0=None, np.ndarray|None vhat0=None, str scheme="amsgrad", float prox_e_rel=1e-6)
__init__(self, np.ndarray x, float step, Callable|None grad=None, Callable|None prox=None, float t0=1, np.ndarray|None z0=None)
update(self, int it, np.ndarray input_grad, *args)
update(self, int it, np.ndarray input_grad, *args)
resize(self, Box old_box, Box new_box)
__init__(self, np.ndarray x, dict[str, np.ndarray] helpers, Callable|float step, Callable|None grad=None, Callable|None prox=None)
Definition parameters.py:89
update(self, int it, np.ndarray input_grad, *args)
_padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Callable step_function_wrapper(float step)
Definition parameters.py:46
_amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Parameter parameter(np.ndarray|Parameter x)
relative_step(np.ndarray x, float factor=0.1, float minimum=0, int|Sequence[int]|None axis=None)
_adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)