22from __future__
import annotations
32 "DEFAULT_ADAPROX_FACTOR",
35from typing
import Callable, Sequence, cast
38import numpy.typing
as npt
43DEFAULT_ADAPROX_FACTOR = 1e-2
47 """Wrapper to make a numerical step into a step function
52 The step to take for a given array.
57 The step function that takes an array and returns the
64 """A parameter in a `Component`
69 The array of values that is being fit.
71 A dictionary of helper arrays that are used by an optimizer to
72 persist values like the gradient of `x`, the Hessian of `x`, etc.
74 A numerical step value or function to calculate the step for a
77 A function to calculate the gradient of `x`.
79 A function to take the proximal operator of `x`.
85 helpers: dict[str, np.ndarray],
86 step: Callable | float,
87 grad: Callable |
None =
None,
88 prox: Callable |
None =
None,
93 if isinstance(step, float):
104 """Calculate the step
109 The numerical step if no iteration is given.
115 """The shape of the array that is being fit."""
120 """The numpy dtype of the array that is being fit."""
124 """Copy this parameter, including all of the helper arrays."""
125 helpers = {k: v.copy()
for k, v
in self.
helpers.
items()}
128 def update(self, it: int, input_grad: np.ndarray, *args):
129 """Update the parameter in one iteration.
131 This includes the gradient update, proximal update,
132 and any meta parameters that are stored as class
133 attributes to update the parameter.
138 The current iteration
140 The gradient from the full model, passed to the parameter.
142 raise NotImplementedError(
"Base Parameters cannot be updated")
144 def resize(self, old_box: Box, new_box: Box):
145 """Grow the parameter and all of the helper parameters
150 The old bounding box for the parameter.
152 The new bounding box for the parameter.
154 slices = new_box.overlapped_slices(old_box)
155 x = np.zeros(new_box.shape, dtype=self.
dtype)
156 x[slices[0]] = self.
x[slices[1]]
160 result = np.zeros(new_box.shape, dtype=self.
dtype)
161 result[slices[0]] = value[slices[1]]
166 """Convert a `np.ndarray` into a `Parameter`.
171 The array or parameter to convert into a `Parameter`.
176 `x`, converted into a `Parameter` if necessary.
178 if isinstance(x, Parameter):
184 """A `Parameter` that updates itself using the Beck-Teboulle 2009
185 FISTA proximal gradient method.
187 See https://www.ceremade.dauphine.fr/~carlier/FISTA
194 grad: Callable |
None =
None,
195 prox: Callable |
None =
None,
197 z0: np.ndarray |
None =
None,
211 def update(self, it: int, input_grad: np.ndarray, *args):
212 """Update the parameter and meta-parameters using the PGM
214 See `Parameter` for the full description.
216 step = self.
step / np.sum(args[0] * args[0])
220 y = _z - step * cast(Callable, self.
grad)(input_grad, _x, *args)
221 if self.
prox is not None:
225 t = 0.5 * (1 + np.sqrt(1 + 4 * self.
t**2))
226 omega = 1 + (self.
t - 1) / t
227 self.
helpers[
"z"] = _x + omega * (x - _x)
242 m[:] = (1 - b1[it]) * g + b1[it] * m
243 v[:] = (1 - b2) * (g**2) + b2 * v
247 phi = m / (1 - b1[it] ** t)
248 psi = np.sqrt(v / (1 - b2**t)) + eps
255 m[:] = (1 - b1[it]) * g + b1[it] * m
256 v[:] = (1 - b2) * (g**2) + b2 * v
260 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
261 psi = np.sqrt(v / (1 - b2**t)) + eps
268 m[:] = (1 - b1[it]) * g + b1[it] * m
269 v[:] = (1 - b2) * (g**2) + b2 * v
272 vhat[:] = np.maximum(vhat, v)
275 vhat = np.maximum(vhat, eps)
282 m[:] = (1 - b1[it]) * g + b1[it] * m
283 v[:] = (1 - b2) * (g**2) + b2 * v
286 vhat[:] = np.maximum(vhat, v)
289 vhat = np.maximum(vhat, eps)
297 m[:] = (1 - b1[it]) * g + b1[it] * m
298 v[:] = (1 - b2) * (g**2) + b2 * v
301 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
302 vhat[:] = np.maximum(factor * vhat, v)
305 vhat = np.maximum(vhat, eps)
312 rho_inf = 2 / (1 - b2) - 1
315 m[:] = (1 - b1[it]) * g + b1[it] * m
316 v[:] = (1 - b2) * (g**2) + b2 * v
320 phi = m / (1 - b1[it] ** t)
321 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
324 psi = np.sqrt(v / (1 - b2**t))
325 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
328 psi = np.ones(g.shape, g.dtype)
331 psi = np.maximum(psi, np.sqrt(eps))
337 "adam": _adam_phi_psi,
338 "nadam": _nadam_phi_psi,
339 "amsgrad": _amsgrad_phi_psi,
340 "padam": _padam_phi_psi,
341 "adamx": _adamx_phi_psi,
342 "radam": _radam_phi_psi,
347 """Mock an array with only a single item"""
357 """Operator updated using te Proximal ADAM algorithm
359 Uses multiple variants of adaptive quasi-Newton gradient descent
360 * Adam (Kingma & Ba 2015)
362 * AMSGrad (Reddi, Kale & Kumar 2018)
363 * PAdam (Chen & Gu 2018)
364 * AdamX (Phuong & Phong 2019)
365 * RAdam (Liu et al. 2019)
366 See details of the algorithms in the respective papers.
372 step: Callable | float,
373 grad: Callable |
None =
None,
374 prox: Callable |
None =
None,
379 m0: np.ndarray |
None =
None,
380 v0: np.ndarray |
None =
None,
381 vhat0: np.ndarray |
None =
None,
382 scheme: str =
"amsgrad",
383 prox_e_rel: float = 1e-6,
388 m0 = np.zeros(shape, dtype=dtype)
391 v0 = np.zeros(shape, dtype=dtype)
394 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
408 if isinstance(b1, float):
421 def update(self, it: int, input_grad: np.ndarray, *args):
422 """Update the parameter and meta-parameters using the PGM
424 See `~Parameter` for more.
428 grad = cast(Callable, self.
grad)(input_grad, _x, *args)
444 _x += -step * phi / psi
449 _x += -step * phi / psi / 10
455 """A parameter that is not updated"""
460 def update(self, it: int, input_grad: np.ndarray, *args):
468 axis: int | Sequence[int] |
None =
None,
470 """Step size set at `factor` times the mean of `X` in direction `axis`"""
471 return np.maximum(minimum, factor * x.mean(axis=axis))
std::vector< SchemaItem< Flag > > * items
update(self, int it, np.ndarray input_grad, *args)
__init__(self, np.ndarray x, Callable|float step, Callable|None grad=None, Callable|None prox=None, float b1=0.9, float b2=0.999, float eps=1e-8, float p=0.25, np.ndarray|None m0=None, np.ndarray|None v0=None, np.ndarray|None vhat0=None, str scheme="amsgrad", float prox_e_rel=1e-6)
__init__(self, np.ndarray x, float step, Callable|None grad=None, Callable|None prox=None, float t0=1, np.ndarray|None z0=None)
update(self, int it, np.ndarray input_grad, *args)
__init__(self, np.ndarray x)
update(self, int it, np.ndarray input_grad, *args)
resize(self, Box old_box, Box new_box)
__init__(self, np.ndarray x, dict[str, np.ndarray] helpers, Callable|float step, Callable|None grad=None, Callable|None prox=None)
npt.DTypeLike dtype(self)
tuple[int,...] shape(self)
update(self, int it, np.ndarray input_grad, *args)
_padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Callable step_function_wrapper(float step)
_amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Parameter parameter(np.ndarray|Parameter x)
relative_step(np.ndarray x, float factor=0.1, float minimum=0, int|Sequence[int]|None axis=None)
_adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)