LSST Applications g063fba187b+cac8b7c890,g0f08755f38+6aee506743,g1653933729+a8ce1bb630,g168dd56ebc+a8ce1bb630,g1a2382251a+b4475c5878,g1dcb35cd9c+8f9bc1652e,g20f6ffc8e0+6aee506743,g217e2c1bcf+73dee94bd0,g28da252d5a+1f19c529b9,g2bbee38e9b+3f2625acfc,g2bc492864f+3f2625acfc,g3156d2b45e+6e55a43351,g32e5bea42b+1bb94961c2,g347aa1857d+3f2625acfc,g35bb328faa+a8ce1bb630,g3a166c0a6a+3f2625acfc,g3e281a1b8c+c5dd892a6c,g3e8969e208+a8ce1bb630,g414038480c+5927e1bc1e,g41af890bb2+8a9e676b2a,g7af13505b9+809c143d88,g80478fca09+6ef8b1810f,g82479be7b0+f568feb641,g858d7b2824+6aee506743,g89c8672015+f4add4ffd5,g9125e01d80+a8ce1bb630,ga5288a1d22+2903d499ea,gb58c049af0+d64f4d3760,gc28159a63d+3f2625acfc,gcab2d0539d+b12535109e,gcf0d15dbbd+46a3f46ba9,gda6a2b7d83+46a3f46ba9,gdaeeff99f8+1711a396fd,ge79ae78c31+3f2625acfc,gef2f8181fd+0a71e47438,gf0baf85859+c1f95f4921,gfa517265be+6aee506743,gfa999e8aa5+17cd334064,w.2024.51
LSST Data Management Base Package
Loading...
Searching...
No Matches
parameters.py
Go to the documentation of this file.
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = [
25 "parameter",
26 "Parameter",
27 "FistaParameter",
28 "AdaproxParameter",
29 "FixedParameter",
30 "relative_step",
31 "phi_psi",
32 "DEFAULT_ADAPROX_FACTOR",
33]
34
35from typing import Callable, Sequence, cast
36
37import numpy as np
38import numpy.typing as npt
39
40from .bbox import Box
41
42# The default factor used for adaprox parameter steps
43DEFAULT_ADAPROX_FACTOR = 1e-2
44
45
46def step_function_wrapper(step: float) -> Callable:
47 """Wrapper to make a numerical step into a step function
48
49 Parameters
50 ----------
51 step:
52 The step to take for a given array.
53
54 Returns
55 -------
56 step_function:
57 The step function that takes an array and returns the
58 numerical step.
59 """
60 return lambda x: step
61
62
64 """A parameter in a `Component`
65
66 Parameters
67 ----------
68 x:
69 The array of values that is being fit.
70 helpers:
71 A dictionary of helper arrays that are used by an optimizer to
72 persist values like the gradient of `x`, the Hessian of `x`, etc.
73 step:
74 A numerical step value or function to calculate the step for a
75 given `x``.
76 grad:
77 A function to calculate the gradient of `x`.
78 prox:
79 A function to take the proximal operator of `x`.
80 """
81
83 self,
84 x: np.ndarray,
85 helpers: dict[str, np.ndarray],
86 step: Callable | float,
87 grad: Callable | None = None,
88 prox: Callable | None = None,
89 ):
90 self.x = x
91 self.helpers = helpers
92
93 if isinstance(step, float):
94 _step = step_function_wrapper(step)
95 else:
96 _step = step
97
98 self._step = _step
99 self.grad = grad
100 self.prox = prox
101
102 @property
103 def step(self) -> float:
104 """Calculate the step
105
106 Return
107 ------
108 step:
109 The numerical step if no iteration is given.
110 """
111 return self._step(self.x)
112
113 @property
114 def shape(self) -> tuple[int, ...]:
115 """The shape of the array that is being fit."""
116 return self.x.shape
117
118 @property
119 def dtype(self) -> npt.DTypeLike:
120 """The numpy dtype of the array that is being fit."""
121 return self.x.dtype
122
123 def copy(self) -> Parameter:
124 """Copy this parameter, including all of the helper arrays."""
125 helpers = {k: v.copy() for k, v in self.helpers.items()}
126 return Parameter(self.x.copy(), helpers, 0)
127
128 def update(self, it: int, input_grad: np.ndarray, *args):
129 """Update the parameter in one iteration.
130
131 This includes the gradient update, proximal update,
132 and any meta parameters that are stored as class
133 attributes to update the parameter.
134
135 Parameters
136 ----------
137 it:
138 The current iteration
139 input_grad:
140 The gradient from the full model, passed to the parameter.
141 """
142 raise NotImplementedError("Base Parameters cannot be updated")
143
144 def resize(self, old_box: Box, new_box: Box):
145 """Grow the parameter and all of the helper parameters
146
147 Parameters
148 ----------
149 old_box:
150 The old bounding box for the parameter.
151 new_box:
152 The new bounding box for the parameter.
153 """
154 slices = new_box.overlapped_slices(old_box)
155 x = np.zeros(new_box.shape, dtype=self.dtype)
156 x[slices[0]] = self.x[slices[1]]
157 self.x = x
158
159 for name, value in self.helpers.items():
160 result = np.zeros(new_box.shape, dtype=self.dtype)
161 result[slices[0]] = value[slices[1]]
162 self.helpers[name] = result
163
164
165def parameter(x: np.ndarray | Parameter) -> Parameter:
166 """Convert a `np.ndarray` into a `Parameter`.
167
168 Parameters
169 ----------
170 x:
171 The array or parameter to convert into a `Parameter`.
172
173 Returns
174 -------
175 result:
176 `x`, converted into a `Parameter` if necessary.
177 """
178 if isinstance(x, Parameter):
179 return x
180 return Parameter(x, {}, 0)
181
182
184 """A `Parameter` that updates itself using the Beck-Teboulle 2009
185 FISTA proximal gradient method.
186
187 See https://www.ceremade.dauphine.fr/~carlier/FISTA
188 """
189
191 self,
192 x: np.ndarray,
193 step: float,
194 grad: Callable | None = None,
195 prox: Callable | None = None,
196 t0: float = 1,
197 z0: np.ndarray | None = None,
198 ):
199 if z0 is None:
200 z0 = x
201
202 super().__init__(
203 x,
204 {"z": z0},
205 step,
206 grad,
207 prox,
208 )
209 self.t = t0
210
211 def update(self, it: int, input_grad: np.ndarray, *args):
212 """Update the parameter and meta-parameters using the PGM
213
214 See `Parameter` for the full description.
215 """
216 if len(args) == 0:
217 step = self.step
218 else:
219 step = self.step / np.sum(args[0] * args[0])
220 _x = self.x
221 _z = self.helpers["z"]
222
223 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args)
224 if self.prox is not None:
225 x = self.prox(y)
226 else:
227 x = y
228 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2))
229 omega = 1 + (self.t - 1) / t
230 self.helpers["z"] = _x + omega * (x - _x)
231 _x[:] = x
232 self.t = t
233
234
235# The following code block contains different update methods for
236# various implementations of ADAM.
237# We currently use the `amsgrad_phi_psi` update by default,
238# but it can easily be interchanged by passing a different
239# variant name to the `AdaproxParameter`.
240
241
242# noinspection PyUnusedLocal
243def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
244 # moving averages
245 m[:] = (1 - b1[it]) * g + b1[it] * m
246 v[:] = (1 - b2) * (g**2) + b2 * v
247
248 # bias correction
249 t = it + 1
250 phi = m / (1 - b1[it] ** t)
251 psi = np.sqrt(v / (1 - b2**t)) + eps
252 return phi, psi
253
254
255# noinspection PyUnusedLocal
256def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
257 # moving averages
258 m[:] = (1 - b1[it]) * g + b1[it] * m
259 v[:] = (1 - b2) * (g**2) + b2 * v
260
261 # bias correction
262 t = it + 1
263 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
264 psi = np.sqrt(v / (1 - b2**t)) + eps
265 return phi, psi
266
267
268# noinspection PyUnusedLocal
269def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
270 # moving averages
271 m[:] = (1 - b1[it]) * g + b1[it] * m
272 v[:] = (1 - b2) * (g**2) + b2 * v
273
274 phi = m
275 vhat[:] = np.maximum(vhat, v)
276 # sanitize zero-gradient elements
277 if eps > 0:
278 vhat = np.maximum(vhat, eps)
279 psi = np.sqrt(vhat)
280 return phi, psi
281
282
283def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
284 # moving averages
285 m[:] = (1 - b1[it]) * g + b1[it] * m
286 v[:] = (1 - b2) * (g**2) + b2 * v
287
288 phi = m
289 vhat[:] = np.maximum(vhat, v)
290 # sanitize zero-gradient elements
291 if eps > 0:
292 vhat = np.maximum(vhat, eps)
293 psi = vhat**p
294 return phi, psi
295
296
297# noinspection PyUnusedLocal
298def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
299 # moving averages
300 m[:] = (1 - b1[it]) * g + b1[it] * m
301 v[:] = (1 - b2) * (g**2) + b2 * v
302
303 phi = m
304 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
305 vhat[:] = np.maximum(factor * vhat, v)
306 # sanitize zero-gradient elements
307 if eps > 0:
308 vhat = np.maximum(vhat, eps)
309 psi = np.sqrt(vhat)
310 return phi, psi
311
312
313# noinspection PyUnusedLocal
314def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
315 rho_inf = 2 / (1 - b2) - 1
316
317 # moving averages
318 m[:] = (1 - b1[it]) * g + b1[it] * m
319 v[:] = (1 - b2) * (g**2) + b2 * v
320
321 # bias correction
322 t = it + 1
323 phi = m / (1 - b1[it] ** t)
324 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
325
326 if rho > 4:
327 psi = np.sqrt(v / (1 - b2**t))
328 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
329 psi /= r
330 else:
331 psi = np.ones(g.shape, g.dtype)
332 # sanitize zero-gradient elements
333 if eps > 0:
334 psi = np.maximum(psi, np.sqrt(eps))
335 return phi, psi
336
337
338# Dictionary to link ADAM variation names to their functional algorithms.
339phi_psi = {
340 "adam": _adam_phi_psi,
341 "nadam": _nadam_phi_psi,
342 "amsgrad": _amsgrad_phi_psi,
343 "padam": _padam_phi_psi,
344 "adamx": _adamx_phi_psi,
345 "radam": _radam_phi_psi,
346}
347
348
350 """Mock an array with only a single item"""
351
352 def __init__(self, value):
353 self.value = value
354
355 def __getitem__(self, item):
356 return self.value
357
358
360 """Operator updated using te Proximal ADAM algorithm
361
362 Uses multiple variants of adaptive quasi-Newton gradient descent
363 * Adam (Kingma & Ba 2015)
364 * NAdam (Dozat 2016)
365 * AMSGrad (Reddi, Kale & Kumar 2018)
366 * PAdam (Chen & Gu 2018)
367 * AdamX (Phuong & Phong 2019)
368 * RAdam (Liu et al. 2019)
369 See details of the algorithms in the respective papers.
370 """
371
373 self,
374 x: np.ndarray,
375 step: Callable | float,
376 grad: Callable | None = None,
377 prox: Callable | None = None,
378 b1: float = 0.9,
379 b2: float = 0.999,
380 eps: float = 1e-8,
381 p: float = 0.25,
382 m0: np.ndarray | None = None,
383 v0: np.ndarray | None = None,
384 vhat0: np.ndarray | None = None,
385 scheme: str = "amsgrad",
386 prox_e_rel: float = 1e-6,
387 ):
388 shape = x.shape
389 dtype = x.dtype
390 if m0 is None:
391 m0 = np.zeros(shape, dtype=dtype)
392
393 if v0 is None:
394 v0 = np.zeros(shape, dtype=dtype)
395
396 if vhat0 is None:
397 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
398
399 super().__init__(
400 x,
401 {
402 "m": m0,
403 "v": v0,
404 "vhat": vhat0,
405 },
406 step,
407 grad,
408 prox,
409 )
410
411 if isinstance(b1, float):
412 _b1 = SingleItemArray(b1)
413 else:
414 _b1 = b1
415
416 self.b1 = _b1
417 self.b2 = b2
418 self.eps = eps
419 self.p = p
420
421 self.phi_psi = phi_psi[scheme]
422 self.e_rel = prox_e_rel
423
424 def update(self, it: int, input_grad: np.ndarray, *args):
425 """Update the parameter and meta-parameters using the PGM
426
427 See `~Parameter` for more.
428 """
429 _x = self.xx
430 # Calculate the gradient
431 grad = cast(Callable, self.grad)(input_grad, _x, *args)
432 # Get the update for the parameter
433 phi, psi = self.phi_psi(
434 it,
435 grad,
436 self.helpers["m"],
437 self.helpers["v"],
438 self.helpers["vhat"],
439 self.b1,
440 self.b2,
441 self.eps,
442 self.p,
443 )
444 # Calculate the step size
445 step = self.step
446 if it > 0:
447 _x += -step * phi / psi
448 else:
449 # This is a scheme that Peter Melchior and I came up with to
450 # dampen the known affect of ADAM, where the first iteration
451 # is often much larger than desired.
452 _x += -step * phi / psi / 10
453
454 self.xx = cast(Callable, self.prox)(_x)
455
456
458 """A parameter that is not updated"""
459
460 def __init__(self, x: np.ndarray):
461 super().__init__(x, {}, 0)
462
463 def update(self, it: int, input_grad: np.ndarray, *args):
464 pass
465
466
468 x: np.ndarray,
469 factor: float = 0.1,
470 minimum: float = 0,
471 axis: int | Sequence[int] | None = None,
472):
473 """Step size set at `factor` times the mean of `X` in direction `axis`"""
474 return np.maximum(minimum, factor * x.mean(axis=axis))
std::vector< SchemaItem< Flag > > * items
int const step
update(self, int it, np.ndarray input_grad, *args)
__init__(self, np.ndarray x, Callable|float step, Callable|None grad=None, Callable|None prox=None, float b1=0.9, float b2=0.999, float eps=1e-8, float p=0.25, np.ndarray|None m0=None, np.ndarray|None v0=None, np.ndarray|None vhat0=None, str scheme="amsgrad", float prox_e_rel=1e-6)
__init__(self, np.ndarray x, float step, Callable|None grad=None, Callable|None prox=None, float t0=1, np.ndarray|None z0=None)
update(self, int it, np.ndarray input_grad, *args)
update(self, int it, np.ndarray input_grad, *args)
resize(self, Box old_box, Box new_box)
__init__(self, np.ndarray x, dict[str, np.ndarray] helpers, Callable|float step, Callable|None grad=None, Callable|None prox=None)
Definition parameters.py:89
update(self, int it, np.ndarray input_grad, *args)
_padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Callable step_function_wrapper(float step)
Definition parameters.py:46
_amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Parameter parameter(np.ndarray|Parameter x)
relative_step(np.ndarray x, float factor=0.1, float minimum=0, int|Sequence[int]|None axis=None)
_adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)