diff --git a/src/ws_pool.ml b/src/ws_pool.ml index 874cbd5c..d32c71f8 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -5,7 +5,16 @@ include Runner let ( let@ ) = ( @@ ) +module Id = struct + type t = unit ref + (** Unique identifier for a pool *) + + let create () : t = Sys.opaque_identity (ref ()) + let equal : t -> t -> bool = ( == ) +end + type worker_state = { + pool_id_: Id.t; (** Unique per pool *) mutable thread: Thread.t; q: task WSQ.t; (** Work stealing queue *) rng: Random.State.t; @@ -17,6 +26,7 @@ type worker_state = { type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type state = { + id_: Id.t; active: bool A.t; (** Becomes [false] when the pool is shutdown. *) workers: worker_state array; (** Fixed set of workers. *) main_q: task Queue.t; (** Main queue for tasks coming from the outside *) @@ -59,7 +69,10 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit = (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) match w with - | Some w -> + | Some w when Id.equal self.id_ w.pool_id_ -> + (* we're on this same pool, schedule in the worker's state. Otherwise + we might also be on pool A but asking to schedule on pool B, + so we have to check that identifiers match. *) let pushed = WSQ.push w.q task in if pushed then try_wake_someone_ self @@ -70,7 +83,7 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit if self.n_waiting_nonzero then Condition.signal self.cond; Mutex.unlock self.mutex ) - | None -> + | _ -> if A.get self.active then ( (* push into the main queue *) Mutex.lock self.mutex; @@ -216,6 +229,7 @@ let dummy_task_ () = assert false let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?around_task ?num_threads () : t = + let pool_id_ = Id.create () in (* wrapper *) let around_task = match around_task with @@ -233,6 +247,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let dummy = Thread.self () in Array.init num_threads (fun i -> { + pool_id_; thread = dummy; q = WSQ.create ~dummy:dummy_task_ (); rng = Random.State.make [| i |]; @@ -241,6 +256,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let pool = { + id_ = pool_id_; active = A.make true; workers; main_q = Queue.create (); diff --git a/test/dune b/test/dune index 56261dad..43955ec6 100644 --- a/test/dune +++ b/test/dune @@ -1,6 +1,7 @@ (tests (names t_fib + t_ws_pool_confusion t_bench1 t_fib_rec t_futs1 diff --git a/test/t_ws_pool_confusion.ml b/test/t_ws_pool_confusion.ml new file mode 100644 index 00000000..20488b65 --- /dev/null +++ b/test/t_ws_pool_confusion.ml @@ -0,0 +1,28 @@ +open Moonpool + +let delay () = Thread.delay 0.001 + +let run ~p_main:_ ~p_sub () = + let f1 = + Fut.spawn ~on:p_sub (fun () -> + delay (); + 1) + in + let f2 = + Fut.spawn ~on:p_sub (fun () -> + delay (); + 2) + in + Fut.wait_block_exn f1 + Fut.wait_block_exn f2 + +let () = + let p_main = Ws_pool.create ~num_threads:2 () in + let p_sub = Ws_pool.create ~num_threads:10 () in + + let futs = List.init 8 (fun _ -> Fut.spawn ~on:p_main (run ~p_main ~p_sub)) in + + let l = List.map Fut.wait_block_exn futs in + assert (l = List.init 8 (fun _ -> 3)); + + print_endline "ok"; + ()