Source code for timecast.learners._take

"""flax.nn.Module for taking an index from input"""
import flax
import jax.numpy as jnp
import numpy as np


[docs]class Take(flax.nn.Module): """Identity index online learner"""
[docs] def apply(self, x: np.ndarray, index: int): """ Note: * Returns `x[index]` as the prediction for the next time step * This is a workaround for the case where we have a blackbox series of predictions (see documentation) Args: x (np.ndarray): input data index (int): index to take Returns: np.ndarray: result """ if jnp.isscalar(x): raise ValueError("Input x must be an array for Take learner") # We don't check for index < x.shape[0] because this confuses flax's # init_by_shape if not isinstance(index, int) or index < 0: raise IndexError("Take must be positive. Got index {}".format(index)) return x[index]