Source code for timecast.learners._arx_history

"""flax.nn.Module for an auto-regressive online learner.

Todo:
    * Implement strided histories
    * Figure out normalizing
"""
import warnings
from typing import Tuple
from typing import Union

import flax
import jax.numpy as jnp
import numpy as np

from timecast.learners._linear import Linear
from timecast.learners.base import NewMixin


default_output_shape: int = 1


class _ARXHistory(NewMixin, flax.nn.Module):
    """AR online learner helper"""

    def apply(
        self,
        data: np.ndarray,
        output_shape: Union[Tuple[int, ...], int] = default_output_shape,
        constrain: bool = True,
        batched: bool = False,
        name: str = "ARXHistory",
    ):
        """
        Args:
            data (np.ndarray): (batch, history_len, input_dim)
            output_shape (Union[Tuple[int, ...], int]): int or tuple
            describing output shape
            constrain: force one parameter per for each slot in history. TODO:
            explain this better
            batched (bool): first axis is batch axis
            name (str): name to pass to Linear

        Returns:
            np.ndarray: result
        """
        # TODO: Check shape of features and reshape if necessary
        input_shape = data.shape[(2 if batched else 1) :]

        # TODO: We would ideally like to be able to raise with jit
        """
        history_shape = data.shape[(1 if batched else 0) :]
        # TODO (flax): We really shouldn't need state to just have some local
        # variables set in is_initializing...
        self.history_shape = self.state("history_shape", shape=(history_shape))
        if self.is_initializing():
            self.history_shape.value = history_shape

        else:
            if history_shape != self.history_shape.value:
                raise ValueError(
                    "Got input_shape {}, expected input_shape {}".format(
                        history_shape, self.history_shape.value
                    )
                )
        """

        if jnp.isscalar(output_shape):
            output_shape = (output_shape,)

        if constrain:
            # If we have the non-default output_shape and it doesn't match input_shape
            if output_shape != default_output_shape and input_shape != output_shape:
                warnings.warn(
                    "When constrained, input data shape must equal output data"
                    "shape. Got input_shape {} and output_shape {}. Coercing"
                    "output_shape to input_shape".format(input_shape, output_shape)
                )
            output_shape = input_shape
            input_axes = (1 if batched else 0,)
        else:
            input_axes = tuple(range(1 if batched else 0, data.ndim))

        return Linear(
            inputs=data,
            output_shape=output_shape,
            input_axes=input_axes,
            batch_axes=((0,) if batched else ()),
            bias=True,
            dtype=jnp.float32,
            kernel_init=flax.nn.initializers.zeros,
            bias_init=flax.nn.initializers.zeros,
            precision=None,
            name="Linear",
        )


[docs]class ARXHistory(NewMixin, flax.nn.Module): """AR online learner with history as input"""
[docs] def apply( self, targets: np.ndarray = None, features: np.ndarray = None, output_shape: Union[Tuple[int, ...], int] = 1, constrain: bool = True, batched: bool = False, ): """ Notation * x = features * y = targets * H = history_len Estimates the following: \hat{y} = \sum_{i = 1}^{H + 1} B_i x_{t - i - 1} + a \sum_{i = 1} ^ H A_i y_{t - i} + b Notes: * Assumes `features` and `targets` have three dimensions: (batch, history, data). Any extra dimensions are part of the data shape * Doesn't care if features and targets have different history or input dimensions Args: targets (np.ndarray): target data features (np.ndarray): feature data output_shape (Union[Tuple[int, ...], int]): int or tuple describing output shape constrain: force one parameter per for each slot in history. TODO: explain this better batched (bool): first axis is batch axis Returns: np.ndarray: result """ Ay, Bx = 0, 0 self.has_targets = self.state("has_targets", shape=()) self.has_features = self.state("has_features", shape=()) has_targets = targets is not None and targets.ndim > 0 has_features = features is not None and features.ndim > 0 if self.is_initializing(): self.has_targets.value = has_targets self.has_features.value = has_features # TODO: We would ideally like to be able to raise with jit """ if not has_targets and not has_features: raise ValueError("Need one or both of targets and features") """ # TODO: We would ideally like to be able to raise with jit """ else: if not has_targets and self.has_targets.value: raise ValueError("Expected targets, but got None") if has_targets and not self.has_targets.value: raise ValueError("Did not expected targets, but got targets") if not has_features and self.has_features.value: raise ValueError("Expected features, but got None") if has_features and not self.has_features.value: raise ValueError("Did not expected features, but got features") if has_features and has_targets and features.shape[0] != targets.shape[0]: raise ValueError("Targets and features need to have same sized batch") """ if has_targets: Ay = _ARXHistory( data=targets, output_shape=output_shape, constrain=constrain, batched=batched, name="Targets", ) if has_features: Bx = _ARXHistory( data=features, output_shape=output_shape, constrain=constrain, batched=batched, name="Features", ) return Ay + Bx