Source code for timecast.learners.base

"""Helper functions and classes for timecast.learners"""
import abc
import inspect
from typing import Any
from typing import Iterable
from typing import List
from typing import Tuple

import flax
import jax.numpy as jnp
import numpy as np

from timecast.utils import random


[docs]class FitMixin(abc.ABC): """Mixin class that provides a fit function for offline training / learner initialization"""
[docs] @classmethod @abc.abstractmethod def fit( cls, data: Iterable[Tuple[np.ndarray, np.ndarray, Any]], input_dim: int, output_dim: int = 1, key: jnp.ndarray = None, **kwargs ) -> flax.nn.Model: """Fit and initialize learner on training data Notes: * We could infer input_dim from data, but for now, require users to explicitly provide * output_dim defaults to 1 and is ignored for now Todo: * Really intended for passing in timeseries at a time, not individual time series observations; is this the right general API? * Shape is (1, input_dim); what about mini-batches? Args: data: an iterable of tuples containing input/truth pairs of time series plus any auxiliary value input_dim: number of feature dimensions in input output_dim: number of feature dimensions in output key: random key for jax random kwargs: Extra keyword arguments Returns: flax.nn.Model: initialized model """ raise NotImplementedError()
[docs]class NewMixin: """Mixin class that provides a flax.nn.Model constructor function""" # TODO (flax)
[docs] @classmethod def new( cls, shapes: List[Tuple[int, ...]], *, key: jnp.ndarray = None, name: str = None, **kwargs ) -> flax.nn.Model: """ Args: shapes (List[Tuple[int, ...]]): shapes for initialization key (jnp.ndarray): key for jax random name (str): identifier for top-level module **kwargs: arguments for flax.nn.Module """ argspec = inspect.getfullargspec(cls.apply) num_kwargs = 0 if argspec.defaults is None else len(argspec.defaults) _kwargs = {} # Capture kwarg defaults for i in range(1, num_kwargs + 1): _kwargs[argspec.args[-i]] = argspec.defaults[-i] # Override kwarg defaults, if any for kw, arg in kwargs.items(): if kw in argspec.args or argspec.varkw is not None: _kwargs[kw] = arg else: raise ValueError("Found extra kwarg {}: {}".format(kw, arg)) model_def = cls.partial(**kwargs) if name is None: name = cls.__name__.lower() if key is None: key = random.generate_key() with flax.nn.stateful() as state: _, params = model_def.init_by_shape(key, shapes) model = flax.nn.Model(model_def, params) return model, state