Source code for timecast.learners._predict_constant

"""flax.nn.Module for predicting a constant value

    * Validate data and add negative tests
from numbers import Real
from typing import Union

import flax
import jax
import jax.numpy as jnp
import numpy as np

from timecast.learners.base import NewMixin

[docs]class PredictConstant(NewMixin, flax.nn.Module): """Constant online learner"""
[docs] def apply(self, x: np.ndarray, c: Union[np.ndarray, Real] = 0): """ Note: * Returns `c` in the shape of `x` as the prediction for the next time step Args: x (np.ndarray): input data c (Real): prediction Returns: np.ndarray: result Raises: ValueError: if `c` is not a scalar or does not match the shape of `x` """ # TODO (flax): Remove this once flax updates _ = self.param("dummy", (), jax.nn.initializers.zeros) if jnp.isscalar(c): return jnp.ones_like(x) * c if c.shape != x.shape: raise ValueError("Constant must be scalar or match input shape") return c