Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread safety improvements #4

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ mutable struct Worker
add_msgs::Array{Any,1}
@atomic gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -134,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,12 +144,16 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
if !isclusterlazy()
if PGRP.topology === :all_to_all
# Since higher pids connect with lower pids, the remote worker
Expand All @@ -169,6 +173,8 @@ function check_worker_state(w::Worker)
errormonitor(t)
wait_for_conn(w)
end
else
unlock(w.c_state)
end
end

Expand All @@ -187,13 +193,25 @@ function exec_conn_func(w::Worker)
end

function wait_for_conn(w)
lock(w.c_state)
if w.state === W_CREATED
unlock(w.c_state)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
else
unlock(w.c_state)
end
nothing
end
Expand Down Expand Up @@ -491,7 +509,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
@async begin
sleep(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this why addprocs() always takes at least 1s for me? 🤔

notify(launch_ntfy)
end
wait(launch_ntfy)
end

Expand Down Expand Up @@ -645,7 +666,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
lock(jw.c_state) do
if jw.state === W_CREATED
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -668,7 +694,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if wl.state === W_CREATED
# wait for wl to join
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand All @@ -685,7 +716,11 @@ function create_worker(manager, wconfig)
@async manage(w.manager, w.id, w.config, :register)
# wait for rr_ntfy_join with timeout
timedout = false
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
@async begin
sleep($timeout)
timedout = true
put!(rr_ntfy_join, 1)
end
wait(rr_ntfy_join)
if timedout
error("worker did not connect within $timeout seconds")
Expand Down
2 changes: 1 addition & 1 deletion src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
@async try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down
1 change: 1 addition & 0 deletions test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1925,4 +1925,5 @@ end
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
63 changes: 63 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")

function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(f ∘ fetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads

p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end

# Wait on the spawned tasks on the owner
@sync begin
Threads.@spawn fetch_from_owner(wait, recv)
Threads.@spawn fetch_from_owner(wait, send)
end

# Check the tasks
@test isdone(send)
@test isdone(recv)

@test !isfailed(send)
@test !isfailed(recv)

rmprocs(procs_added)
end
end
end