Source code for flax.nn.base

# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""NN base modules for JAX."""

import abc
import contextlib
import functools
import hashlib
import inspect
from typing import Any
import warnings

from . import utils
from . import stochastic
from flax import jax_utils
from flax import serialization
from flax import struct

import jax
from jax import random


_module_stack = utils.CallStack()
_module_output_trackers = utils.CallStack()
_state_stack = utils.CallStack()


def _track_outputs(x):
  for module_output_tracker in _module_output_trackers:
    xs = module_output_tracker.retrieve(default=[])
    xs.append(x)
    module_output_tracker.store(xs)


class _ModuleFrame:
  """A ModuleFrame the context needed to init or apply a Module.

  In particular, `self.params` is a dictionary where parameters are
  stored (during module init) and read from (during module application).

  When `module.init()` is first called, a new ModuleFrame is created with
  an empty `params` dictionary. When `self.param` is called within that
  module, a new key is added to track that parameter, with the computed
  parameter's initial value.

  When a module calls into a submodule, a new key is added, with a value
  being an empty dictionary. Then that new dictionary is passed in as `params`
  on a new sub-ModuleFrame. That new sub-ModuleFrame keeps track of its parent
  with the `parent` attribute.

  When the whole init process is complete, the top-level ModuleFrame'
  `params` are returned, which contain a nested dictionary of parameters.

  During module application, a similer process happens but this time
  the parameters are only read from.

  Additional attributes on ModuleFrame track context needed to assist error
  handling, shared parameters and transparent modules that are wrapped without
  creating additional sub-parameters. TODO: Consider elaborating on this
  last paragraph.
  """

  def __init__(self, name,
               parent=None, params=None, rng=None,
               transparent=False):
    if params is None:
      params = {}
    self.parent = parent
    self.rng = rng
    self.params = params
    self.shared = {}
    self.shared_names = set()
    self.name = name
    self.transparent = transparent

    self._name_counter = 0

  @property
  def is_init(self):
    return self.rng is not None

  @property
  def path(self):
    """Path of the the module scope.

    paths are similar to unix file names (eg. '/module/nested/dense')

    Returns:
      The path of this Module scope.
    """
    if self.parent is None:
      if self.name is None:
        return '/'
      else:
        return '/' + self.name

    path = self.parent.path
    if not self.parent.transparent:
      if path[-1] != '/':
        path += '/'
      path += self.name
    return path

  def is_descendent_of(self, frame):
    """Check whether this frame is a descendent of the given frame."""
    if frame is self.parent:
      return True
    elif self.parent:
      return self.parent.is_descendent_of(frame)
    else:
      return False

  def create_name(self):
    name = str(self._name_counter)
    self._name_counter += 1
    return name


def module_method(fn):
  """Decorates a function as a module method.

  The `module_method` allows modules to have multiple methods that make use of
  the modules parameters.

  Example::

    class MyLinearModule(nn.Module):
      def apply(self, x, features, kernel_init):
        kernel = self.param('kernel', (x.shape[-1], features), kernel_init)
        return jnp.dot(x, kernel)

      @nn.module_method
      def apply_transpose(self, x, **kwargs):
        kernel = self.get_param('kernel')
        return jnp.dot(x, kernel.transpose((1, 0)))

  A module method can be called on A Model instance directly::

    y, initial_params = MyLinearModule.init(rng, x)
    model = nn.Model(MyLinearModule, initial_params)
    z = model.apply_transpose(y)

  Module methods can also be called on shared modules::

    class AutoEncoder(nn.module):
      def apply(self, x, features):
        linear_fn = MyLinearModule.shared(features=features)
        h = linear_fn(x)
        y = linear_fn.apply_transpose(h)
        return y


  Args:
    fn: the function to be decorated
  Returns:
    the decorated function
  """

  cache = {}

  # module method are just Module class instances.
  # But we want it to inherit from the class such that we can call other methods
  # of the module. We need a class property to find out which class the method
  # is defined on.
  def wrapper(cls):
    if cls not in cache:
      class ModuleMethod(cls):
        apply = fn
      ModuleMethod.__name__ = fn.__name__
      ModuleMethod.__qualname__ = f'{cls.__qualname__}.{fn.__name__}'
      cache[cls] = ModuleMethod
    return cache[cls]

  return utils.classproperty(wrapper)


def _fn_parameters(fn):
  return tuple(inspect.signature(fn).parameters.values())


MODULE_CLASSMETHODS = [
    'create', 'create_by_shape', 'init', 'init_by_shape', 'call', 'partial'
]


class _ModuleMeta(abc.ABCMeta):
  """Meta class for automatically setting the doc of Modules."""

  def __init__(cls, name, bases, attrs):
    super(_ModuleMeta, cls).__init__(name, bases, attrs)
    apply_fn = cls.apply
    apply_doc = apply_fn.__doc__
    cls.__doc__ = apply_doc
    apply_params = _fn_parameters(apply_fn)
    cls.__signature__ = inspect.signature(cls).replace(
        parameters=apply_params[1:])

    if not bases:
      return  # skip method signature overides for Module class.

    def wrap_special_method(name):
      """override the signature and docstring for one of Module's classmethods."""
      orig_fn = getattr(Module, name)

      @functools.wraps(orig_fn)
      def wrapper(class_, *args, **kwargs):
        super_fn = getattr(super(cls, class_), name)
        return super_fn(*args, **kwargs)
      wrapper.__doc__ = f'''{orig_fn.__doc__}

      Apply docstring:

      {apply_doc}
      '''
      base_params = tuple(x for x in _fn_parameters(orig_fn)
                          if x.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD)
      new_params = base_params + apply_params[1:]
      wrapper.__signature__ = inspect.signature(orig_fn).replace(
          parameters=new_params)
      setattr(cls, name, classmethod(wrapper))

    for name in MODULE_CLASSMETHODS:
      wrap_special_method(name)


def _fold_in_str(rng, data):
  """Fold a string into a jax.random.PRNGKey using its SHA-1 hash."""
  m = hashlib.sha1()
  m.update(data.encode('utf-8'))
  d = m.digest()
  hash_int = int.from_bytes(d[:4], byteorder='big')
  return random.fold_in(rng, hash_int)


class Module(metaclass=_ModuleMeta):
  """Functional modules."""

  def __new__(cls, *args, name=None, **kwargs):
    if not _module_stack:
      raise ValueError('A Module should only be instantiated directly inside'
                       ' another module.')
    parent = cls._get_construction_frame()
    apply_kwargs = cls._extend_kwargs(kwargs)
    if name is None:
      name = cls._default_name()
    elif cls._is_shared():
      raise ValueError('Cannot override the name of a shared module')
    if name is None:  # also no default name
      name = cls.__name__ + '_' + parent.create_name()
    cls._check_name(name, parent)
    if parent.is_init and name not in parent.params:
      rng = _fold_in_str(parent.rng, name)
      params = {}
      parent.params[name] = params
    else:  # apply
      if name not in parent.params:
        raise ValueError(f'No module named {name} was created during'
                         ' initialization.')
      params = parent.params[name]
      rng = None
    frame = _ModuleFrame(name, parent=parent, rng=rng, params=params,
                         transparent=cls._is_transparent())
    with cls._with_instance(frame) as instance:
      y = instance.apply(*args, **apply_kwargs)
      _track_outputs(y)
    return y

  @abc.abstractmethod
  def apply(self, *args, **kwargs):
    pass

  @classmethod
  def shared(class_, *, name=None, **kwargs):
    """Partially applies a module and shared parameters for each call.

    Args:
      name: name of this module.
      **kwargs: keyword arguments that should be partially applied.
    Returns:
      A subclass of Module that shares parameters when called multiple times.
    """
    if not _module_stack:
      raise ValueError(
          'The shared module should be used during Module application')

    parent = _module_stack[-1]
    if name is None:
      name = parent.create_name()
    if name in parent.shared_names:
      raise ValueError(f'Shared module named "{name}" already exists.')
    parent.shared_names.add(name)

    partial_module = class_.partial(**kwargs)

    class SharedModule(partial_module):
      """Wraps a module to enable shared parameters."""

      @classmethod
      def _default_name(cls):
        return name

      @classmethod
      def _is_shared(cls):
        return True

      @classmethod
      def _get_construction_frame(cls):
        return parent

    SharedModule.__name__ = class_.__name__
    SharedModule.__qualname__ = class_.__qualname__

    return SharedModule

  @classmethod
  def _get_construction_frame(cls):
    """Return the ModuleFrame where this module was constructed.

    Modules can be shared across different parts of a parameter tree.
    We need to ensure that the parameter object is the same in every instance
    of the same shared module. We resolve this by deciding on a canonical
    ModuleFrame (corresponding to a particular part of the top-level parameter
    tree) where parameters are stored. Concretely, it is the
    "construction frame" -- that is, the frame in which the module is first
    defined. For non-shared modules, that's where it's called. For shared
    modules, it's where `submodule.shared(...)` is called (which may or may
    not be the frame in which it is used.)

    Returns:
      The ModuleFrame instance where this module was constructed.
    """
    return _module_stack[-1]

  @classmethod
  def partial(class_, *, name=None, **kwargs):
    """Partially applies a module with the given arguments.

    Unlike `functools.partial` this will return a subclass of Module.

    Args:
      name: the name used the module
      **kwargs: the argument to be applied.
    Returns:
      A subclass of Module which partially applies the given keyword arguments.
    """

    class PartialModule(class_):
      """Wraps a module with partial application."""

      @classmethod
      def _default_name(cls):
        if name is not None:
          return name
        else:
          return super()._default_name()

      @classmethod
      def _extend_kwargs(cls, invoke_kwargs):
        extended_kwargs = kwargs.copy()
        extended_kwargs.update(invoke_kwargs)
        return super()._extend_kwargs(extended_kwargs)
    # __doc__ is handled by the Module meta class
    PartialModule.__name__ = class_.__name__
    PartialModule.__qualname__ = class_.__qualname__

    return PartialModule

  @classmethod
  def create(cls, _rng, *args, name=None, **kwargs):
    """Create a module instance by evaluating the model.

    DEPRECATION WARNING:
    `create()` is deprecated use `init()` to initialize parameters and
    then explicitly create a `nn.Model` given the module and initialized
    parameters.

    Use create_by_shape instead to initialize without doing computation.
    Initializer functions can depend both on the shape and the value of inputs.

    Args:
      _rng: the random number generator used to initialize parameters.
      *args: arguments passed to the module's apply function
      name: name of this module
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and an instance of Model
    """
    warnings.warn("`create()` will be removed soon."
                  " Use `init()` to initialize parameters and then explicitly"
                  " create a `nn.Model` given the module and initialized"
                  " parameters.",
                  DeprecationWarning)
    y, params = cls.init(_rng, *args, name=name, **kwargs)
    model = Model(cls, params)
    return y, model

  @classmethod
  def create_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs):
    """Create a module instance using only shape and dtype information.

    DEPRECATION WARNING:
    `create_by_shape()` is deprecated use `init_by_shape()` to initialize
    parameters and then explicitly create a `nn.Model` given the module and
    initialized parameters.


    This method will initialize the model without computation.
    Initializer functions can depend on the shape but not the value of inputs.

    Args:
      _rng: the random number generator used to initialize parameters.
      input_specs: an iterable of (shape, dtype) pairs specifying the inputs
      *args: other arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and an instance of Model
    """
    warnings.warn("`create_by_shape()` will be removed soon."
                  " Use `init_by_shape()` to initialize parameters and then"
                  " explicitly create a `nn.Model` given the module and "
                  " initialized parameters.",
                  DeprecationWarning)

    y, params = cls.init_by_shape(_rng, input_specs, *args, name=name, **kwargs)
    model = Model(cls, params)
    return y, model

  @classmethod
  def init(cls, _rng, *args, name=None, **kwargs):
    """Initialize the module parameters.

    Args:
      _rng: the random number generator used to initialize parameters.
      *args: arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and the initialized parameters
    """
    kwargs = cls._extend_kwargs(kwargs)
    if _module_stack:
      parent = _module_stack[-1]
    else:
      parent = None
    if name is None:
      name = cls._default_name()

    frame = _ModuleFrame(name, rng=_rng, parent=parent,
                         transparent=cls._is_transparent())
    with cls._with_instance(frame) as instance:
      y = instance.apply(*args, **kwargs)
      _track_outputs(y)
    return y, cls._post_process_params(frame.params)

  @classmethod
  def init_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs):
    """Initialize the module parameters.

    This method will initialize the module parameters without computation.
    Initializer functions can depend on the shape but not the value of inputs.

    Args:
      _rng: the random number generator used to initialize parameters.
      input_specs: an iterable of (shape, dtype) pairs specifying the inputs
      *args: arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      A pair consisting of the model output and the initialized parameters
    Example:
      ```
      input_shape = (batch_size, image_size, image_size, 3)
      model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
                                      input_specs=[(input_shape, jnp.float32)])
      ```
    """
    stochastic_rng = None
    try:
      stochastic_rng = stochastic.make_rng()
    except ValueError:
      # Either there is no stochastic scope or the current
      # scope is invalid due to another jax transformation.
      # In both cases we should not try to lift the stochastic
      # scope into the lazy evaluation
      pass

    def lazy_init(*inputs):
      def init_fn():
        return cls.init(_rng, *(inputs + args), name=name, **kwargs)
      if stochastic_rng is not None:
        # Create a new stochastic scope inside the lazy evalution
        # this way we can use a stochastic scope in combination
        # with init_by_shape.
        with stochastic.stochastic(stochastic_rng):
          return init_fn()
      else:
        return init_fn()
    return jax_utils.partial_eval_by_shape(lazy_init, input_specs)

  @classmethod
  def call(cls, params, *args, name=None, **kwargs):
    """Evaluate the module with the given parameters.

    Args:
      params: the parameters of the module. Typically, inital parameter values
        are constructed using `Module.init` or `Module.init_by_shape`.
      *args: arguments passed to the module's apply function
      name: name of this module.
      **kwargs: keyword arguments passed to the module's apply function
    Returns:
      The output of the module's apply function.
    """
    params = cls._pre_process_params(params)
    kwargs = cls._extend_kwargs(kwargs)
    if _module_stack:
      parent = _module_stack[-1]
    else:
      parent = None
    if name is None:
      name = cls._default_name()
    frame = _ModuleFrame(name, params=params, parent=parent,
                         transparent=cls._is_transparent())
    with cls._with_instance(frame) as instance:
      y = instance.apply(*args, **kwargs)
      _track_outputs(y)
    return y

  def param(self, name, shape, initializer):
    """Defines a parameter within the module's apply function.

    Args:
      name: The name of the parameter.
      shape: The shape of the parameter. If None the param be any type.
      initializer: An initializer function
                   taking an RNG and the shape as arguments.
    Returns:
      The value of the parameter.
    """
    frame = self._frame
    if frame.is_init:
      if name in frame.params:
        raise ValueError(
            "Name '%s' was already used for another parameter." % name)
      key = _fold_in_str(frame.rng, name)
      frame.params[name] = initializer(key, shape)
    if name not in frame.params:
      raise ValueError("Parameter with name '%s' does not exist." % name)
    param = frame.params[name]
    if shape is not None and param.shape != shape:
      raise ValueError(
          'Existing shape {} differs from requested shape {}'.format(
              param.shape, shape))
    return param

  def get_param(self, name):
    """Retrieves a parameter within the module's apply function.

    Args:
      name: The name of the parameter.
    Returns:
      The value of the parameter.
    """
    frame = self._frame
    if name not in frame.params:
      raise ValueError("Parameter with name '%s' does not exist." % name)
    return frame.params[name]

  def state(self, name, shape=None, initializer=None, collection=None):
    """Declare a state variable within the module's apply function.

    A state variable has an attribute value which can be updated by simply
    assigning a value to it. For example::

      class Example(nn.Module):
        def apply(self, inputs, decay=0.9):
          ema = self.state('ema', inputs.shape, initializers.zeros)
          ema.value = decay * ema.value + (1 - decay) * inputs
          return inputs

    By default Modules are stateless. See `flax.nn.stateful` to enable stateful
    computations.

    Args:
      name: the name of the state variable.
      shape: optional shape passed to the initializer (default: None)
      initializer: optional initializer function
        taking an RNG and the shape as arguments.
      collection: optional `flax.nn.Collection` used to store the state.
        By default the state collection passed to the `nn.stateful` context is
        used.
    Returns:
      An instance of ModuleState.
    """
    _top_frame('state')
    if collection is None:
      collection = get_state()
    state = ModuleState(collection, name)
    # find the frames that are in init mode
    init_frames = [f for f in _module_stack if f.is_init]
    if initializer is not None and init_frames:
      # use the closest frame that is initializing to get an rng
      init_frame = init_frames[-1]
      init_frame.rng, key = random.split(init_frame.rng)
      init_value = initializer(key, shape)
      state.value = init_value
    return state

  def is_stateful(self):
    return is_stateful()

  def is_initializing(self):
    _top_frame('is_initializing')
    return self._frame.is_init

  @classmethod
  @contextlib.contextmanager
  def _with_instance(cls, frame):
    """Private constructor for Module.

    A module instance is constructed using a scope and is tied to a _ModuleFrame
    This way the methods on the Module instance can rely on the _ModuleFrame
    being available.

    Args:
      frame: an instance of _ModuleFrame
    Yields:
      An instance of Module
    """
    instance = object.__new__(cls)
    instance._frame = frame  # pylint: disable=protected-access
    with _module_stack.frame(frame):
      yield instance

  @classmethod
  def _check_name(cls, name, parent):
    """Check whether the name of the module is valid within the parent scope."""
    if name is not None:
      if not isinstance(name, str):
        raise ValueError('Name must be a string.')
      if '/' in name or ':' in name:
        raise ValueError('Name should not contain slashes or colons.')
    shared = cls._is_shared()
    if name in parent.shared:
      # a module with this name already exists. Check validity of sharing
      if shared != parent.shared[name]:
        raise ValueError(f'The name "{name}" is used for both a shared'
                         ' and unshared module.')
      if not parent.shared[name]:
        raise ValueError(f'A module with named "{name}" already exists.')
    parent.shared[name] = shared

  @classmethod
  def _extend_kwargs(cls, kwargs):
    return kwargs

  @classmethod
  def _pre_process_params(cls, params):
    return params

  @classmethod
  def _post_process_params(cls, params):
    return params

  @classmethod
  def _is_transparent(cls):
    return False

  @classmethod
  def _is_shared(cls):
    return False

  @classmethod
  def _default_name(cls):
    return None


def module(fun):
  """Convert a function into the apply method of a new Module.

  This is convenient shortcut for writing higher level modules that don't need
  access to `self` for creating parameters directly.

  Example usage::

    @nn.module
    def DenseLayer(x, features):
      x = flax.nn.Dense(x, features)
      x = flax.nn.relu(x)
      return x

  This is exactly equivalent to defining the following `nn.Module` subclass::

    class DenseLayer(nn.Module):
      def apply(self, x, features):
        x = flax.nn.Dense(x, features)
        x = flax.nn.relu(x)
        return x

  Args:
    fun: the function to convert.
  Returns:
    New Module subclass.
  """
  @functools.wraps(fun)
  def apply(self, *args, **kwargs):
    del self  # unused
    return fun(*args, **kwargs)
  return type(fun.__name__, (Module,), dict(apply=apply))


# TODO(flax-dev) consider removing this...
class TransparentModule(Module):
  """Transparent module.

  A transparent module can only have one parameter named '0'.
  """

  @classmethod
  def _pre_process_params(cls, params):
    return {'0': params}

  @classmethod
  def _post_process_params(cls, params):
    entries = list(params.items())
    if len(entries) != 1:
      raise ValueError('Transparent modules should have exactly one child.')
    key, value = entries[0]
    if key != '0':
      raise ValueError('Transparent module should contain an unnamed child.')
    return value

  @classmethod
  def _is_transparent(cls):
    return True


class TruncatedModule(TransparentModule):
  """Wraps a Module and returns the requested intermediate outputs instead.

  See `Model.truncate_at` for a simple api to get the intermediate outputs of
  an existing Model.
  """

  def apply(self, *args, wrapped_module=None, truncate_path=None, **kwargs):
    """Apply the wrapped module and return some of its intermediate outputs.

    Args:
      *args: the positional arguments for the wrapped module.
      wrapped_module: The module class to be wrapped.
      truncate_path: the full name of the module (eg. '/module/sub_module').
        A list or dict of module paths can be provided to obtain the
        intermediate outputs of multiple modules.
      **kwargs: the keyword arguments for the wrapped module.
    Returns:
      The intermediate outputs specified by truncate_path.
    """
    if wrapped_module is None or truncate_path is None:
      raise ValueError(
          '`wrapped_module` and `truncate_path` are required keyword arguments')
    with capture_module_outputs() as module_outputs:
      wrapped_module(*args, **kwargs, name='0')

    def lookup_output(path):
      return module_outputs[path]
    return jax.tree_map(lookup_output, truncate_path)


@contextlib.contextmanager
def capture_module_outputs():
  """A context manager that captures all model outputs.

  Yields:
    A `flax.nn.Collection` containing all module outputs.
  """
  with Collection().mutate() as module_outputs:
    with _module_output_trackers.frame(module_outputs):
      yield module_outputs


class ModuleState():
  """Tracks a state variable.

  ModuleState instances should not be created directly. See `Module.state` on
  how to create state variables inside modules.
  """

  def __init__(self, collection, name):
    self._collection = collection
    self._name = name

  def _get_state_dict(self):
    state_dict = self._collection.retrieve(default={})
    assert isinstance(state_dict, dict)
    return state_dict

  @property
  def name(self):
    return self._name

  @property
  def value(self):
    state_dict = self._get_state_dict()
    if self._name not in state_dict:
      raise ValueError(f'No state variable named `{self._name}` exists.')
    return state_dict[self._name]

  @value.setter
  def value(self, v):
    state_dict = self._get_state_dict()
    state_dict[self._name] = v
    self._collection.store(state_dict)


@contextlib.contextmanager
def stateful(state=None, mutable=True):
  """A context manager for stateful computations.

  Module's that use the `Module.state` by default store state inside the
  `Collection` specified by the (innermost) `nn.stateful` context manager.

  Typically stateful is used in 3 different modes:

  1. During init no existing state is available and the stateful context creates
     a new state collection.
  2. During training the state is passed to `nn.stateful` and the new state
     is returned which will contain the updated state.
  3. During evaluation the state is passed with `mutable=False` such that the
     model can retrieve the state but is not allowed to mutate it.

  Example::

    class MyModel(nn.Module):
      def apply(self, x):
        x = nn.Dense(x, 12)
        x = nn.BatchNorm(x)
        return x

    with nn.stateful() as state:
      _, initial_params = MyModel.init(rng, x)
      model = nn.Model(MyModel, initial_params)

    with nn.stateful(state) as new_state:
      model(x2)

    with nn.stateful(new_state, mutable=False):
      evaluate_model(model)

  Args:
    state: a `flax.nn.Collection` containing the current state.
      By default a new collection will be created.
    mutable: If true the state will be mutable otherwise it will be frozen.
  Yields:
    A `flax.nn.Collection` containing the new state.
  """
  if state is None:
    state = Collection()
  if mutable:
    with state.mutate() as new_state:
      with _state_stack.frame(new_state):
        yield new_state
  else:
    with _state_stack.frame(state):
      yield state


def is_stateful():
  """Returns true if a stateful scope is currently active (see `flax.nn.stateful`)."""
  return bool(_state_stack)


def get_state():
  if not _state_stack:
    raise ValueError('Use the flax.nn.stateful context manager to enable'
                     ' stateful computations.')
  return _state_stack[-1]


def _top_frame(call_name):
  if not _module_stack:
    raise ValueError('%s should only be used inside a '
                     'module\'s apply function.' % call_name)
  return _module_stack[-1]


@struct.dataclass
class Model:
  """A Model contains the model paramaters, state and definition."""

  module: Module = struct.field(pytree_node=False)
  params: Any

  def __call__(self, *args, **kwargs):
    return self.module.call(self.params, *args, **kwargs)

  def truncate_at(self, module_path):
    """Truncate the model by returning the outputs of the given sub-module.

    Args:
      module_path: the full name of the module (eg. '/module/sub_module').
        A list or dict of module paths can be provided to obtain the
        intermediate outputs of multiple modules.
    Returns:
      A new model with the truncated outputs. If module_path is a pytree of
      paths the outputs will be have the same structure where each path is
      replaced by the corresponding intermediate output.
    """
    truncated_module_cls = TruncatedModule.partial(
        wrapped_module=self.module, truncate_path=module_path)
    return self.replace(module=truncated_module_cls)

  def __getattr__(self, name):
    value = getattr(self.module, name)
    if inspect.isclass(value) and issubclass(value, Module):
      def wrapper(*args, **kwargs):
        return value.call(self.params, *args, **kwargs)
      return wrapper
    raise AttributeError(f'No attribute named "{name}".')

  def __hash__(self):
    # Jax will call hash when model is passed to a function transform.
    # the compiled function should not be shared among model instances because
    # it closes over the specific parameters of this model instance.
    return id(self)


class Collection:
  """A collection of tensors useful for tracking state.

  A Collection can be used to associate data with the application of a Module.
  For example a collection can be used to collect activations across modules.
  Another common use case for collections is to track internal state.
  For example, the running averages in BatchNorm can be stored in a collection.

  Attributes:
    state: the initial state by default an empty collection is created.
  """

  def __init__(self, state=None):
    if state is None:
      state = {}
    self.state = state
    # The anchor is used to determine the prefix of the collection.
    # This way we can create/nest collections inside modules.
    self._anchor = _module_stack[-1] if _module_stack else None

    self._mutable = False
    self._master_level = None
    self._root = None

  def as_dict(self):
    """Returns a dictionary with module paths as keys and the stored values.

    Returns:
      The stored values as a dictionary.
    """
    return self.state.copy()

  def __getitem__(self, key):
    return self.state[key]

  @contextlib.contextmanager
  def mutate(self):
    # pylint: disable=protected-access
    new_col = jax.tree_map(lambda x: x, self)  # clone the collection
    new_col._mutable = True
    new_col._master_level = utils._trace_level(utils._current_trace())
    try:
      yield new_col
    finally:
      new_col._mutable = False

  def retrieve(self, default=None):
    """Retrieves a value from the Collection.

    This functions should only be called with the apply function of a module.
    Args:
      default: The default returned when nothing is stored (default: None)
    Returns:
      The value previously stored in the collection.
    """
    _top_frame('retrieve')
    path = self._current_path()
    return self.state.get(path, default)

  def store(self, value):
    """Stores a value in the Collection.

    This functions should only be called with the apply function of a module.
    Args:
      value: The value to be stored in the collection
    Returns:
      The previous value stored in the collection or None.
    """
    frame = _top_frame('store')
    if not self._mutable:
      raise ValueError('Collection is not mutable. Use the `mutate` method to'
                       ' create a mutable copy.')
    # Use the Jax TraceMaster to determine if a Collection is modified from
    # inside a nested jax transformation.
    # In this case, we throw an error because transforming a stateful function
    # is ill-defined (eg. what does vmap of BatchNorm do?).
    # TODO(jheek): Add doc guide on combining jax transforms and state.
    # TODO(jheek): Should some transformations be excempt from this error?
    value_level = utils._level_of_value(value)
    if value_level > self._master_level:
      raise ValueError('Stateful operations are not allowed when the Collection'
                       ' is created outside of the current Jax transformation')

    # The root of a Collection is the first module scope that gets created
    # inside the mutate scope of the Collection. By allowing only one unique
    # root scope, we guarantee that state is not accidentally shared
    # between different models. When a user specifies an explicit name we can
    # distinguish models and a collection can have multiple roots.
    if frame == self._anchor:
      # Example:
      # with nn.Collection.mutate() as coll:
      #   coll.store(1)
      raise ValueError('State should be stored from within a module.'
                       ' Consider using the value directly instead of'
                       ' storing it in a Collection.')
    if not frame.is_descendent_of(self._anchor):
      # edge case where the Collection cannot capture the scope of a shared Module
      # See test_collection_store_fails_if_out_of_scope in nn_test.py
      raise ValueError('Trying to capture state outside the scope of this Collection.'
                       ' Most likely due to passing around a shared Module.')
    root = self._find_root(frame)
    if self._root is None:
      self._root = root
    elif self._root != root:
      if self._root.name is None or root.name is None:
        # In the following examples, should the two calls to `StatefulModule` share state or not?
        # Because it's ambiguous, we throw an error and require the user to explicitly separate state
        # by giving each instance a separate name, or to explicitly pass the same name
        # in order to share state.
        # with nn.statefull(state) as new_state:
        #   StatefulModule.call(params)
        #   StatefulModule.call(params2)
        raise ValueError('When multiple top-level module calls use a Collection'
                         ' each top-level module should have a name.')
    path = self._current_path()
    old_value = self.state.get(path, None)
    self.state[path] = value
    return old_value

  def _find_root(self, frame):
    """Find the root frame with respect to the anchor.

    The root frame is defined as the child of anchor
    that is an ancestor of frame.
    The root is used to verify that a Collection does not
    have multiple unnamed roots.

    Args:
      - frame: the frame of which we want to know the root
    Returns:
      The root of the given frame.
    """
    assert frame.is_descendent_of(self._anchor)
    root = frame
    while root.parent != self._anchor:
      root = root.parent
    return root

  def _current_path(self):
    """"The relative path from the currently active module scope to the root of the collection.

    For example: If a collection is created in the path '/module/nested' and
    something is stored by a module with the path '/module/nested/block/conv'
    the key in the collection dict will be '/block/conv'.

    Returns:
      the relative path of the active module scope.
    """
    frame = _module_stack[-1]
    assert frame.is_descendent_of(self._anchor)
    path = _module_stack[-1].path
    if self._anchor is not None and self._anchor.path != '/':
      prefix = self._anchor.path
      assert prefix == path[:len(prefix)]
      return path[len(prefix):]
    else:
      return path

def iterate_collection(collection):
  # jax iterates through pytrees for each argument/return value of a functional
  # transformations. When the collection is mutable we throw an error this way
  # we avoid silent errors due to impurity of a traced function.
  if collection._mutable:  # pylint: disable=protected-access
    raise ValueError('A mutable collection should not be transformed by Jax.')
  meta = (type(collection), collection._anchor)  # pylint: disable=protected-access
  return (collection.state,), meta


def collection_from_iterable(meta, state):
  ty, anchor = meta
  coll = ty(state[0])
  coll._anchor = anchor  # pylint: disable=protected-access
  return coll

# make sure a collection is traced.
jax.tree_util.register_pytree_node(Collection,
                                   iterate_collection,
                                   collection_from_iterable)


def _collection_state_dict(collection):
  return collection.as_dict()


def _collection_from_state_dict(_, state):
  return Collection(state)


serialization.register_serialization_state(
    Collection, _collection_state_dict, _collection_from_state_dict)