Skip to content

Commit 05b080f

Browse files
authored
Add MPIWorkerManager (#22)
1 parent bcc76f5 commit 05b080f

File tree

8 files changed

+364
-93
lines changed

8 files changed

+364
-93
lines changed

src/MPIClusterManagers.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
module MPIClusterManagers
22

3-
export MPIManager, launch, manage, kill, procs, connect, mpiprocs, @mpi_do, TransportMode, MPI_ON_WORKERS, TCP_TRANSPORT_ALL, MPI_TRANSPORT_ALL
3+
export MPIManager, launch, manage, kill, procs, connect, mpiprocs, @mpi_do, TransportMode, MPI_ON_WORKERS, TCP_TRANSPORT_ALL, MPI_TRANSPORT_ALL, MPIWorkerManager
44

55
using Distributed, Serialization
66
import MPI
77

8+
import Base: kill
9+
import Sockets: connect, listenany, accept, IPv4, getsockname, getaddrinfo, wait_connected, IPAddr
10+
11+
include("workermanager.jl")
812
include("mpimanager.jl")
13+
include("worker.jl")
14+
include("mpido.jl")
915

1016
end # module

src/mpido.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
################################################################################
2+
# MPI-specific communication methods
3+
4+
# Execute a command on all MPI ranks
5+
# This uses MPI as communication method even if @everywhere uses TCP
6+
function mpi_do(mgr::Union{MPIManager,MPIWorkerManager}, expr)
7+
!mgr.initialized && wait(mgr.cond_initialized)
8+
jpids = keys(mgr.j2mpi)
9+
refs = Array{Any}(undef, length(jpids))
10+
for (i,p) in enumerate(Iterators.filter(x -> x != myid(), jpids))
11+
refs[i] = remotecall(expr, p)
12+
end
13+
# Execution on local process should be last, since it can block the main
14+
# event loop
15+
if myid() in jpids
16+
refs[end] = remotecall(expr, myid())
17+
end
18+
19+
# Retrieve remote exceptions if any
20+
@sync begin
21+
for r in refs
22+
@async begin
23+
resp = remotecall_fetch(r.where, r) do rr
24+
wrkr_result = rr[]
25+
# Only return result if it is an exception, i.e. don't
26+
# return a valid result of a worker computation. This is
27+
# a mpi_do and not mpi_callfetch.
28+
isa(wrkr_result, Exception) ? wrkr_result : nothing
29+
end
30+
isa(resp, Exception) && throw(resp)
31+
end
32+
end
33+
end
34+
nothing
35+
end
36+
37+
macro mpi_do(mgr, expr)
38+
quote
39+
# Evaluate expression in Main module
40+
thunk = () -> (Core.eval(Main, $(Expr(:quote, expr))); nothing)
41+
mpi_do($(esc(mgr)), thunk)
42+
end
43+
end

src/mpimanager.jl

Lines changed: 20 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import Base: kill
2-
import Sockets: connect, listenany, accept, IPv4, getsockname, getaddrinfo, wait_connected
3-
4-
51
################################################################################
62
# MPI Cluster Manager
73
# Note: The cluster manager object lives only in the manager process,
@@ -53,6 +49,10 @@ mutable struct MPIManager <: ClusterManager
5349
launch_timeout::Real = 60.0,
5450
mode::TransportMode = MPI_ON_WORKERS,
5551
master_tcp_interface::String="" )
52+
if mode == MPI_ON_WORKERS
53+
@warn "MPIManager with MPI_ON_WORKERS is deprecated and will be removed in the next release. Use MPIWorkerManager instead."
54+
end
55+
5656
mgr = new()
5757
mgr.np = np
5858
mgr.mpi2j = Dict{Int,Int}()
@@ -123,10 +123,11 @@ end
123123

124124
Distributed.default_addprocs_params(::MPIManager) =
125125
merge(Distributed.default_addprocs_params(),
126-
Dict{Symbol,Any}(
127-
:mpiexec => nothing,
128-
:mpiflags => ``,
129-
))
126+
Dict{Symbol,Any}(
127+
:mpiexec => nothing,
128+
:mpiflags => ``,
129+
:threadlevel => :serialized,
130+
))
130131
################################################################################
131132
# Cluster Manager functionality required by Base, mostly targeting the
132133
# MPI_ON_WORKERS case
@@ -142,11 +143,16 @@ function Distributed.launch(mgr::MPIManager, params::Dict,
142143
println("Try again with a different instance of MPIManager.")
143144
throw(ErrorException("Reuse of MPIManager is not allowed."))
144145
end
145-
cookie = string(":cookie_",Distributed.cluster_cookie())
146-
setup_cmds = `import MPIClusterManagers\;MPIClusterManagers.setup_worker'('$(mgr.ip),$(mgr.port),$cookie')'`
146+
cookie = Distributed.cluster_cookie()
147+
setup_cmds = "using Distributed; import MPIClusterManagers; MPIClusterManagers.setup_worker($(repr(string(mgr.ip))),$(mgr.port),$(repr(cookie)); threadlevel=$(repr(params[:threadlevel])))"
147148
MPI.mpiexec() do mpiexec
148149
mpiexec = something(params[:mpiexec], mpiexec)
149-
mpi_cmd = `$mpiexec $(params[:mpiflags]) -n $(mgr.np) $(params[:exename]) $(params[:exeflags]) -e $(Base.shell_escape(setup_cmds))`
150+
mpiflags = params[:mpiflags]
151+
mpiflags = `$mpiflags -n $(mgr.np)`
152+
exename = params[:exename]
153+
exeflags = params[:exeflags]
154+
dir = params[:dir]
155+
mpi_cmd = Cmd(`$mpiexec $mpiflags $exename $exeflags -e $setup_cmds`, dir=dir)
150156
open(detach(mpi_cmd))
151157
end
152158
mgr.launched = true
@@ -173,6 +179,7 @@ function Distributed.launch(mgr::MPIManager, params::Dict,
173179
# Add config to the correct slot so that MPI ranks and
174180
# Julia pids are in the same order
175181
rank = Serialization.deserialize(io)
182+
_ = Serialization.deserialize(io) # not used
176183
idx = mgr.mode == MPI_ON_WORKERS ? rank+1 : rank
177184
configs[idx] = config
178185
end
@@ -196,31 +203,6 @@ function Distributed.launch(mgr::MPIManager, params::Dict,
196203
end
197204
end
198205

199-
# Entry point for MPI worker processes for MPI_ON_WORKERS and TCP_TRANSPORT_ALL
200-
setup_worker(host, port; kwargs...) = setup_worker(host, port, nothing; kwargs...)
201-
function setup_worker(host, port, cookie; stdout_to_master=true, stderr_to_master=true)
202-
!MPI.Initialized() && MPI.Init()
203-
# Connect to the manager
204-
io = connect(IPv4(host), port)
205-
wait_connected(io)
206-
stdout_to_master && redirect_stdout(io)
207-
stderr_to_master && redirect_stderr(io)
208-
209-
# Send our MPI rank to the manager
210-
rank = MPI.Comm_rank(MPI.COMM_WORLD)
211-
Serialization.serialize(io, rank)
212-
213-
# Hand over control to Base
214-
if cookie == nothing
215-
Distributed.start_worker(io)
216-
else
217-
if isa(cookie, Symbol)
218-
cookie = string(cookie)[8:end] # strip the leading "cookie_"
219-
end
220-
Distributed.start_worker(io, cookie)
221-
end
222-
end
223-
224206
# Manage a worker (e.g. register / deregister it)
225207
function Distributed.manage(mgr::MPIManager, id::Integer, config::WorkerConfig, op::Symbol)
226208
if op == :register
@@ -332,10 +314,11 @@ end
332314

333315
# Enter the MPI cluster manager's main loop (does not return on the workers)
334316
function start_main_loop(mode::TransportMode=TCP_TRANSPORT_ALL;
317+
threadlevel=:serialized,
335318
comm::MPI.Comm=MPI.COMM_WORLD,
336319
stdout_to_master=true,
337320
stderr_to_master=true)
338-
!MPI.Initialized() && MPI.Init()
321+
MPI.Initialized() || MPI.Init(;threadlevel=threadlevel)
339322
@assert MPI.Initialized() && !MPI.Finalized()
340323
if mode == TCP_TRANSPORT_ALL
341324
# Base is handling the workers and their event loop
@@ -475,49 +458,6 @@ function stop_main_loop(mgr::MPIManager)
475458
end
476459
end
477460

478-
################################################################################
479-
# MPI-specific communication methods
480-
481-
# Execute a command on all MPI ranks
482-
# This uses MPI as communication method even if @everywhere uses TCP
483-
function mpi_do(mgr::MPIManager, expr)
484-
!mgr.initialized && wait(mgr.cond_initialized)
485-
jpids = keys(mgr.j2mpi)
486-
refs = Array{Any}(undef, length(jpids))
487-
for (i,p) in enumerate(Iterators.filter(x -> x != myid(), jpids))
488-
refs[i] = remotecall(expr, p)
489-
end
490-
# Execution on local process should be last, since it can block the main
491-
# event loop
492-
if myid() in jpids
493-
refs[end] = remotecall(expr, myid())
494-
end
495-
496-
# Retrieve remote exceptions if any
497-
@sync begin
498-
for r in refs
499-
@async begin
500-
resp = remotecall_fetch(r.where, r) do rr
501-
wrkr_result = rr[]
502-
# Only return result if it is an exception, i.e. don't
503-
# return a valid result of a worker computation. This is
504-
# a mpi_do and not mpi_callfetch.
505-
isa(wrkr_result, Exception) ? wrkr_result : nothing
506-
end
507-
isa(resp, Exception) && throw(resp)
508-
end
509-
end
510-
end
511-
nothing
512-
end
513-
514-
macro mpi_do(mgr, expr)
515-
quote
516-
# Evaluate expression in Main module
517-
thunk = () -> (Core.eval(Main, $(Expr(:quote, expr))); nothing)
518-
mpi_do($(esc(mgr)), thunk)
519-
end
520-
end
521461

522462
# All managed Julia processes
523463
Distributed.procs(mgr::MPIManager) = sort(collect(keys(mgr.j2mpi)))

src/worker.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""
2+
setup_worker(host, port[, cookie];
3+
threadlevel=:serialized, stdout_to_master=true, stderr_to_master=true)
4+
5+
This is the entrypoint for MPI workers using TCP transport.
6+
7+
1. it connects to the socket on master
8+
2. sends the process rank and size
9+
3. hands over control via [`Distributed.start_worker`](https://docs.julialang.org/en/v1/stdlib/Distributed/#Distributed.start_worker)
10+
"""
11+
function setup_worker(host::Union{Integer, String}, port::Integer, cookie::Union{String, Symbol, Nothing}=nothing;
12+
threadlevel=:serialized, stdout_to_master=true, stderr_to_master=true)
13+
# Connect to the manager
14+
ip = host isa Integer ? IPv4(host) : parse(IPAddr, host)
15+
16+
io = connect(ip, port)
17+
wait_connected(io)
18+
stdout_to_master && redirect_stdout(io)
19+
stderr_to_master && redirect_stderr(io)
20+
21+
MPI.Initialized() || MPI.Init(;threadlevel=threadlevel)
22+
rank = MPI.Comm_rank(MPI.COMM_WORLD)
23+
nprocs = MPI.Comm_size(MPI.COMM_WORLD)
24+
Serialization.serialize(io, rank)
25+
Serialization.serialize(io, nprocs)
26+
27+
# Hand over control to Base
28+
if isnothing(cookie)
29+
Distributed.start_worker(io)
30+
else
31+
if isa(cookie, Symbol)
32+
cookie = string(cookie)[8:end] # strip the leading "cookie_"
33+
end
34+
Distributed.start_worker(io, cookie)
35+
end
36+
end

0 commit comments

Comments
 (0)