"""timecast.optim._projected_sgd"""
import jax
import jax.numpy as jnp
from flax import struct
from flax.optim.base import OptimizerDef
@struct.dataclass
class _ProjectedSGDHyperParams:
"""ProjectedSGD hyperparameters"""
learning_rate: jnp.ndarray
projection_threshold: float
[docs]class ProjectedSGD(OptimizerDef):
"""Gradient descent optimizer with projections."""
def __init__(self, learning_rate: float = None, projection_threshold: float = None):
"""Constructor for the ProjectedSGD optimizer.
Args:
learning_rate (float): the step size used to update the parameters.
projection_threshold (float): threshold for parameters (Frobenius
norm for matrices, 2-norm for vectors)
"""
projection_threshold = projection_threshold or float("inf")
hyper_params = _ProjectedSGDHyperParams(learning_rate, projection_threshold)
super().__init__(hyper_params)
[docs] def init_param_state(self, param):
"""Initialize parameter state"""
return ()
[docs] def apply_param_gradient(self, step, hyper_params, param, state, grad):
"""Apply per-parametmer gradients"""
del step
assert hyper_params.learning_rate is not None, "no learning rate provided."
new_param = param - hyper_params.learning_rate * grad
norm = jnp.linalg.norm(new_param)
new_param = jax.lax.cond(
norm > hyper_params.projection_threshold,
new_param,
lambda x: hyper_params.projection_threshold / norm * x,
new_param,
lambda x: x,
)
return new_param, state