diff --git a/charm4py/pool.py b/charm4py/pool.py index 91c9a1db..4314e5f4 100644 --- a/charm4py/pool.py +++ b/charm4py/pool.py @@ -295,6 +295,24 @@ def taskError(self, worker_id, job_id, exception): raise job.exception self.schedule() +# Makes one PE inactive on each host so the number of workers is the same on all hosts as +# opposed to the basic PoolScheduler which has one fewer worker on the host with PE 0. +# This can be useful for running tasks on a GPU cluster for example. Running five PEs +# on nodes with 4 GPUs would ensure each worker gets a GPU and no GPUs are left idle. +class ConstantWorkersPerHostPoolScheduler(PoolScheduler): + + def __init__(self): + super().__init__() + n_pes = charm.numPes() + n_hosts = charm.numHosts() + pes_per_host = n_pes // n_hosts + + assert n_pes % n_hosts == 0 # Enforce constant number of pes per host + assert pes_per_host > 1 # We're letting one pe on each host be unused + + self.idle_workers = set([i for i in range(n_pes) if not i % pes_per_host == 0 ]) + self.num_workers = len(self.idle_workers) + class Worker(Chare):