Source code for timecast.optim._multiplicative_weights

"""timecast.optim._multiplicative_weights"""
import jax.numpy as jnp
from flax import struct
from flax.optim.base import OptimizerDef


@struct.dataclass
class _MultiplicativeWeightsHyperParams:
    """MultiplicativeWeights hyperparameters"""

    eta: jnp.ndarray


[docs]class MultiplicativeWeights(OptimizerDef): """Multiplicative weights""" def __init__(self, eta: float = None): """Constructor for the MultiplicativeWeights optimizer. Args: eta (float): rate used to update the parameters. """ hyper_params = _MultiplicativeWeightsHyperParams(eta) super().__init__(hyper_params)
[docs] def init_param_state(self, param): """Initialize parameter state""" return ()
[docs] def apply_param_gradient(self, step, hyper_params, param, state, grad): """Apply per-parametmer gradients""" del step exp = param * jnp.exp(-1 * hyper_params.eta * grad) new_param = exp / exp.sum() return new_param, state