Source code for timecast.learners._precomputed

"""flax.nn.Module for wrapping an array

Todo:
    * Validate data and add negative tests
    * Implement batching
"""
import flax


[docs]class Precomputed(flax.nn.Module): """Wraps an array Notes: * Assumes the first dimension is a time dimension * Assumes the data is accessed in order (i.e., no shuffling) Warning: * Ignores init_by_shape """
[docs] def apply(self, x, arr): """Apply function""" self.index = self.state("index", shape=(), initializer=flax.nn.initializers.zeros) val = arr[self.index.value.astype(int)] if not self.is_initializing(): self.index.value += 1 return val