Source code for timecast.learners._linear

"""Linear transformation"""
from typing import Iterable

import flax
import jax
import jax.numpy as jnp

from timecast.learners.base import NewMixin

default_kernel_init = flax.nn.initializers.lecun_normal()


def _normalize_axes(axes, ndim):
    """
    A tuple by convention. len(axes_tuple) then also gives the rank efficiently.

    Warning:
        * This doesn't raise when some axis is greater than ndim - 1
    """
    return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


[docs]class Linear(NewMixin, flax.nn.Module): """A linear transformation with flexible axes."""
[docs] def apply( self, inputs, output_shape, input_axes=-1, batch_axes=(), bias=True, dtype=jnp.float32, kernel_init=default_kernel_init, bias_init=flax.nn.initializers.zeros, precision=None, ): """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. output_shape: tuple of output shape. input_axes: tuple with axes to apply the transformation on. batch_axes: tuple with batch axes. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: The transformed input. """ inputs = jnp.asarray(inputs, dtype) if not isinstance(output_shape, Iterable): output_shape = (output_shape,) if not isinstance(input_axes, Iterable): input_axes = (input_axes,) if not isinstance(batch_axes, Iterable): batch_axes = (batch_axes,) output_shape, input_axes, batch_axes = ( tuple(output_shape), tuple(input_axes), tuple(batch_axes), ) if len(input_axes) == 0: raise IndexError("Must have at least one input dimension") if len(output_shape) == 0: raise IndexError("Must have at least one output dimension") if batch_axes: max_dim = jnp.max(batch_axes) if set(batch_axes) != set(range(max_dim + 1)): raise ValueError( "batch_axes %s must be consecutive leading " "dimensions starting from 0." % str(batch_axes) ) if set(batch_axes) & set(input_axes): raise IndexError("Batch axes and input axes must not have reused axes") if jnp.max(batch_axes + input_axes) >= inputs.ndim: raise IndexError("Not enough dimensions in input for batch/input axes") ndim = inputs.ndim input_axes = _normalize_axes(input_axes, ndim) batch_axes = _normalize_axes(batch_axes, ndim) n_input_axes, n_output_axes = len(input_axes), len(output_shape) def kernel_init_wrap(rng, shape, dtype=jnp.float32): """Initializing and inducing correct shapes""" flat_shape = ( jnp.prod(shape[:n_input_axes]), jnp.prod(shape[-n_output_axes:]), ) kernel = kernel_init(rng, flat_shape, dtype) return jnp.reshape(kernel, shape) kernel_shape = tuple([inputs.shape[ax] for ax in input_axes]) + output_shape kernel = self.param("kernel", kernel_shape, kernel_init_wrap) kernel = jnp.asarray(kernel, dtype) contract_ind = tuple(range(n_input_axes)) out = jax.lax.dot_general( inputs, kernel, ((input_axes, contract_ind), ((), ())), precision=precision ) if bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): """Initializing and inducing correct shapes""" flat_shape = (jnp.prod(shape[-n_output_axes:]),) bias = bias_init(rng, flat_shape, dtype) return jnp.reshape(bias, shape) bias = self.param("bias", output_shape, bias_init_wrap) # Reshape bias for broadcast. num_expand_dims = len(set(range(inputs.ndim)) - set(input_axes)) for _ in range(num_expand_dims): bias = bias[None, :] bias = jnp.asarray(bias, dtype) out = out + bias return out