Source code for timecast.learners._predict_last

"""flax.nn.Module for predicting last value

Todo:
    * Implement last value n steps ago (requires state)
"""
import flax
import jax
import numpy as np

from timecast.learners.base import NewMixin


[docs]class PredictLast(NewMixin, flax.nn.Module): """Identity online learner"""
[docs] def apply(self, x: np.ndarray): """ Note: * Returns `x` as the prediction for the next time step Args: x (np.ndarray): input data Returns: np.ndarray: result """ # TODO (flax): Remove this once flax updates _ = self.param("dummy", (), jax.nn.initializers.zeros) return x