Source code for timecast.modules.core

"""timecast.modules.core"""
# TODO
# - Optimizers should apply to all children unless children have specified version
# - hierarchical parameters?
# - Tree flatten is very crude (only applies to params)
# - How to identify params (right now just ndarray)
# - Pass class directly to jax
# - How to handle buffers vs parameters
# - Users can do bad things with naming
import inspect

import jax
import jax.numpy as jnp


def tree_flatten(module):
    """Flatten module parameters for Jax"""
    leaves, aux = jax.tree_util.tree_flatten(module.get_param_tree())
    aux = {
        "treedef": aux,
        "arguments": module.arguments,
        "attrs": module.attrs,
        "class": module.__class__,
    }
    return leaves, aux


def tree_unflatten(aux, leaves):
    """Unflatten module parameters for Jax"""
    module = aux["class"](*aux["arguments"].args, **aux["arguments"].kwargs)
    module.set_param_tree(jax.tree_util.tree_unflatten(aux["treedef"], leaves))
    for attr in aux["attrs"]:
        if attr in module.__dict__["params"]:
            module.__dict__[attr] = module.__dict__["params"][attr]
    return module


[docs]class Module: """Core module class""" def __new__(cls, *args, **kwargs): """For avoiding super().__init__()""" obj = object.__new__(cls) obj.__setattr__("attrs", set()) obj.__setattr__("modules", {}) obj.__setattr__("params", {}) obj.__setattr__("arguments", inspect.signature(obj.__init__).bind(*args)) obj.arguments.apply_defaults() return obj @classmethod def __init_subclass__(cls, *args, **kwargs): """For avoiding a decorator for each subclass""" super().__init_subclass__(*args, **kwargs) jax.tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten) def __setattr__(self, name, value): """Setting attributes Notes: * Any attribute of type Module is added to a modules dict * Any attribute of type jnp.ndarray is added to a params dict """ self.__dict__[name] = value self.attrs.add(name) if isinstance(value, Module): self.__dict__["modules"][name] = value elif isinstance(value, jnp.ndarray): self.__dict__["params"][name] = value
[docs] def get_param_tree(self): """Return recursed parameter tree""" params = self.params for name, module in self.modules.items(): params[name] = module.get_param_tree() return params
[docs] def set_param_tree(self, tree): """Apply parameter tree""" for param in self.params: self.params[param] = tree[param] self.__dict__[param] = tree[param] for name, module in self.modules.items(): module.set_param_tree(tree[name])
[docs] def add_module(self, module, name=None): """Add module outside attributes""" counter = 0 while name is None or name in self.__dict__["modules"]: name = "{}_{}".format(type(module).__name__, counter) counter += 1 self.__dict__["modules"][name] = module
[docs] def add_param(self, param, name): """Add parameter outside attributes""" counter = 0 while name is None or name in self.__dict__["params"]: name = "{}_{}".format(name, counter) counter += 1 self.__dict__["params"][name] = param