Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Distributed.Worker threadsafe [Take 2] #38134

Merged
merged 4 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mutable struct Worker
add_msgs::Array{Any,1}
gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
c_state::Event # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

Expand Down Expand Up @@ -133,7 +133,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, [], [], false, W_CREATED, Event(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,7 +144,7 @@ end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
notify(w.c_state)
end

function check_worker_state(w::Worker)
Expand Down Expand Up @@ -189,7 +189,7 @@ function wait_for_conn(w)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
@async (sleep(timeout); notify(w.c_state))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
Expand Down
1 change: 1 addition & 0 deletions stdlib/Distributed/test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1711,4 +1711,5 @@ include("splitrange.jl")
# Run topology tests last after removing all workers, since a given
# cluster at any time only supports a single topology.
rmprocs(workers())
include("threads.jl")
include("topology.jl")
61 changes: 61 additions & 0 deletions stdlib/Distributed/test/threads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using Test
using Distributed, Base.Threads
using Base.Iterators: product

exeflags = ("--startup-file=no",
"--check-bounds=yes",
"--depwarn=error",
"--threads=2")

function call_on(f, wid, tid)
remotecall(wid) do
t = Task(f)
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
schedule(t)
@assert threadid(t) == tid
t
end
end

# Run function on process holding the data to only serialize the result of f.
# This becomes useful for things that cannot be serialized (e.g. running tasks)
# or that would be unnecessarily big if serialized.
fetch_from_owner(f, rr) = remotecall_fetch(ffetch, rr.where, rr)

isdone(rr) = fetch_from_owner(istaskdone, rr)
isfailed(rr) = fetch_from_owner(istaskfailed, rr)

@testset "RemoteChannel allows put!/take! from thread other than 1" begin
ws = ts = product(1:2, 1:2)
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
procs_added = addprocs(2; exeflags, lazy=true)
@everywhere procs_added using Base.Threads
p1 = procs_added[w1]
p2 = procs_added[w2]
chan_id = first(procs_added)
chan = RemoteChannel(chan_id)
send = call_on(p1, t1) do
put!(chan, nothing)
end
recv = call_on(p2, t2) do
take!(chan)
end

# Wait on the spawned tasks on the owner
@sync begin
@async fetch_from_owner(wait, recv)
@async fetch_from_owner(wait, send)
end

# Check the tasks
@test isdone(send)
@test isdone(recv)

@test !isfailed(send)
@test !isfailed(recv)
rmprocs(procs_added)
end
end
end