Source code for timecast.learners._arx

"""flax.nn.Module for an auto-regressive online learner.

Todo:
    * Implement batching (efficiently! Historify gets crushed)
    * Implement projections
"""
from typing import Tuple
from typing import Union

import flax
import jax
import jax.numpy as jnp
import numpy as np

from timecast.learners._arx_history import ARXHistory
from timecast.learners.base import NewMixin
from timecast.utils.ar import historify


[docs]class ARX(NewMixin, flax.nn.Module): """AR online learner"""
[docs] def apply( self, targets: np.ndarray = None, features: np.ndarray = None, history_len: int = 1, output_shape: Union[Tuple[int, ...], int] = 1, constrain: bool = True, batched: bool = False, ): """ Notation * x = features * y = targets * H = history_len Estimates the following: \hat{y} = \sum_{i = 1}^{H + 1} B_i x_{t - i - 1} + a \sum_{i = 1} ^ H A_i y_{t - i} + b Notes: * If batched, assume that first axis is time axis * If not batched, assume that features and / or targets are one time step and have no time or batch axis * Delegates much of the error checking to ARXHistory Args: targets (np.ndarray): target data features (np.ndarray): feature data output_shape (Union[Tuple[int, ...], int]): int or tuple describing output shape history_len (int): length of history constrain: force one parameter per for each slot in history. TODO: explain this better batched (bool): first axis is batch axis Returns: np.ndarray: result """ # TODO: check that if batched, data has enough dimensions # TODO: check that targets / features have at least 1 dimension (i.e., # not scalar) if history_len < 1: raise ValueError("Features require a history length of at least 1") has_targets = targets is not None and targets.ndim > 0 has_features = features is not None and features.ndim > 0 self.T = self.state("T", shape=()) target_history = self.state("target_history", shape=()) target_shape = self.state("target_shape", shape=()) feature_history = self.state("feature_history", shape=()) feature_shape = self.state("feature_shape", shape=()) if self.is_initializing(): self.T.value = 0 if has_targets: target_shape.value = targets.shape[(1 if batched else 0) :] target_history.value = jnp.zeros((history_len,) + target_shape.value) if has_features: feature_shape.value = features.shape[(1 if batched else 0) :] feature_history.value = jnp.zeros((history_len,) + feature_shape.value) target_histories, feature_histories = None, None if has_targets: target_histories = target_history.value if batched: target_histories = historify( jnp.vstack((target_histories, targets))[:-1, :], history_len=history_len ) if has_features: feature_histories = jnp.vstack((feature_history.value, features)) if batched: feature_histories = historify(feature_histories[1:, :], history_len=history_len) else: pass y_hat = ARXHistory( targets=target_histories, features=feature_histories, output_shape=output_shape, constrain=constrain, batched=batched, ) # TODO: Don't duplicate the vstacks (modulo index difference for target_history) if not self.is_initializing(): # Update target history with data _after_ we have made calculations if has_targets: target_history.value = jnp.vstack((target_history.value, targets))[ targets.shape[0] : ] if has_features: feature_history.value = jnp.vstack((feature_history.value, features))[ features.shape[0] : ] self.T.value += 1 # If we have targets, then we need to wait one additional time step to # have a full target window return jax.lax.cond( self.T.value + (1 if has_targets else 0) >= history_len, y_hat, lambda x: x, y_hat, lambda x: jax.lax.stop_gradient(y_hat), )