Source code for timecast.utils.experiment

"""Experiment decorator for running experiments in parallel"""
import inspect
from collections.abc import Iterable
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union

import multiprocess.context as ctx
import pathos

ctx._force_start_method("spawn")


def runner(args: Tuple[Callable, Dict[Any, Any]]) -> Any:
    """Runner function for process pool
    
    Args:
        args: A pair where the first element is the function to run and
        the second element is a dictionary of kwargs

    Returns:
        Any: the result of the function in the first element
    """
    return args[0](**args[1])


def product(*args: Any) -> List[Any]:
    """Recursive function to yield the cartesian product of a list of arguments

    Notes:
        * We don't use itertools.product because it will recursively examine its
        arguments and we want a tuple to stay a tuple in the resulting product
    """
    if args:
        for a in args[0]:
            for prod in product(*args[1:]) if args[1:] else ((),):
                yield (a,) + prod


[docs]class experiment: """Class decorator to run experiments""" def __init__(self, argnames: Union[List[str], str], arglists: List[Any]) -> None: """Initializes the experiment with argnames and arglists Args: argnames: List of argument names or comma-delimited string of names (as in pytest) arglists: A list of one or more tuples of values for each name in argnames """ if isinstance(argnames, list): self._argnames = argnames else: self._argnames = [argname.strip() for argname in argnames.split(",")] # Validate that all arglists have the correct number of arguments # for arglist in arglists: # NOTE: Removed this check because it realizes generators passed in as args # arglist_length = 1 if "__len__" not in dir(arglist) else len(arglist) self._arglists = arglists self._spec = [(self._argnames, self._arglists)] def __call__(self, funcOrExp): """Decorator magic Notes: * Takes either a Callable or experiment object so we can chain decorators """ # If we see an experiment, pass on the function to execute if hasattr(funcOrExp, "_func"): self._func = funcOrExp._func self._spec.extend(funcOrExp._spec) # If we see a callable, we know we're wrapping a naked function elif isinstance(funcOrExp, Callable): self._func = funcOrExp return self def _validate(self): """Validate an experiment's argnames and arglists against _func's arguments Notes: * Each decorated call to @experiment must specify arguments disjoint with every other call to @experiment (i.e., can't have a call "a,b,c" and "b") * The union of all arguments specified must be a subset of _func's arguments and a superset of _func's arguments without defaults * If _func has *args, we ignore * If _func has **kwargs, it will eat up any unused arguments """ # Check for duplicate argument specifications argnames = set() for spec in self._spec: for argname in spec[0]: if argname in argnames: raise ValueError("Found duplicate argname {} in {}".format(argname, spec)) argnames.add(argname) func_fullargspec = inspect.getfullargspec(self._func) func_args = func_fullargspec.args func_args_without_defaults = ( func_args if func_fullargspec.defaults is None else func_args[: -len(func_fullargspec.defaults)] ) # NOTE: We don't use set operations because we want more descriptive error messages # Find any missing arguments missing_args = [] for arg in func_args_without_defaults: if arg not in argnames: missing_args.append(arg) if len(missing_args) > 0: raise ValueError( "Arguments without defaults not found. Required: {}. Supplied: {}".format( missing_args, argnames ) ) # Find any extra arguments extra_args = [] if func_fullargspec.varkw is None: for arg in argnames: if arg not in func_args: extra_args.append(arg) if len(extra_args) > 0: raise ValueError("Found unused arguments: {}".format(extra_args)) def _generate_arglists(self): """Create a generator to pass function and arguments to a multiprocessing pool""" argnames = [item for sublist in [spec[0] for spec in self._spec] for item in sublist] arglists = product(*[spec[1] for spec in self._spec]) for arglist in arglists: flattened = [] # Flatten a list of lists and atoms for arg in arglist: if isinstance(arg, Iterable) and not isinstance(arg, str): flattened.extend(arg) else: flattened.append(arg) yield (self._func, {key: val for key, val in zip(argnames, flattened)})
[docs] def run(self, processes=1, chunksize=1, tqdm=None): """Execute the experiment""" self._validate() single = lambda: map(runner, self._generate_arglists()) # noqa: E731 # TODO: Figure out why parallel hangs in pytest parallel = lambda: pathos.pools.ProcessPool(nodes=processes).imap( # noqa: E731 runner, self._generate_arglists(), chunksize=chunksize ) process = single if processes == 1 else parallel return list(process() if tqdm is None else tqdm.tqdm(process()))