Source code for council.runners.parallel_for

from concurrent import futures
from typing import Iterable

from council.contexts import ChainContext, IterationContext
from council.utils import Option
from more_itertools import batched

from .errors import RunnerGeneratorError
from .loop_runner_base import LoopRunnerBase
from .runner_executor import RunnerExecutor
from .skill_runner_base import SkillRunnerBase
from .types import RunnerGenerator


[docs] class ParallelFor(LoopRunnerBase): """ Invoke a given skill for each value returned by a given generator function. Can run multiple iteration in parallel. For each invocation, the current iteration current is provided through the skill context :meth:`.SkillContext.iteration`. :meth:`.IterationContext.value` provides the value as returned by the generator function :meth:`.IterationContext.index` provides the index of the iteration Notes: Skill iteration are scheduled in the order given by the generator function. However, because multiple iterations can execute in parallel, no assumptions should be made on the order of results. """
[docs] def __init__(self, generator: RunnerGenerator, skill: SkillRunnerBase, parallelism: int = 5) -> None: """ Initialize a new instance Parameters: generator(RunnerGenerator): a generator function that yields results """ super().__init__("parallelForRunner") self._generator = generator self._skill = self.new_monitor("skill", skill) self._parallelism = parallelism
def _run(self, context: ChainContext, executor: RunnerExecutor) -> None: inner_contexts = [] all_fs = [] try: for batch in batched(self._generate(context), self._parallelism): inner = [context.fork_for(self._skill) for _ in batch] inner_contexts.extend(inner) fs = [executor.submit(self._run_skill, inner, iteration) for (inner, iteration) in zip(inner, batch)] all_fs.extend(fs) dones, not_dones = futures.wait(fs, context.budget.remaining_duration, futures.FIRST_EXCEPTION) self.rethrow_if_exception(dones) finally: [f.cancel() for f in all_fs] context.merge(inner_contexts) def _run_skill(self, context: ChainContext, iteration: IterationContext) -> None: index = iteration.index context.logger.debug(f'message="start iteration" index="{index}"') try: self._skill.inner.run_in_current_thread(context, Option.some(iteration)) finally: context.logger.debug(f'message="end iteration" index="{index}"') def _generate(self, context: ChainContext) -> Iterable[IterationContext]: try: for index, item in enumerate(self._generator(context)): yield IterationContext(index, item) except Exception as e: raise RunnerGeneratorError from e