Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fiber safety to __crystal_once & class_[getter|property]?(&) macros #15340

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion spec/std/socket/spec_helper.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ require "spec"
require "socket"

module SocketSpecHelper
class_getter?(supports_ipv6 : Bool) do
@@supports_ipv6 : Bool?

class_getter?(supports_ipv6 : Bool) { detect_supports_ipv6? }

private def self.detect_supports_ipv6? : Bool
TCPServer.open("::1", 0) { return true }
false
rescue Socket::Error
Expand Down
4 changes: 3 additions & 1 deletion spec/support/time.cr
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ end
# Enable the `SeTimeZonePrivilege` privilege before changing the system time
# zone. This is necessary because the privilege is by default granted but
# disabled for any new process. This only needs to be done once per run.
class_getter? time_zone_privilege_enabled : Bool do
class_getter?(time_zone_privilege_enabled : Bool) { detect_time_zone_privilege_enabled? }

private def self.detect_time_zone_privilege_enabled? : Bool
if LibC.LookupPrivilegeValueW(nil, SeTimeZonePrivilege, out time_zone_luid) == 0
raise RuntimeError.from_winerror("LookupPrivilegeValueW")
end
Expand Down
144 changes: 86 additions & 58 deletions src/crystal/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
# - `__crystal_once`: called each time a constant or class variable has to be
# initialized and is its responsibility to verify the initializer is executed
# only once and to fail on recursion.

# In multithread mode a mutex is used to avoid race conditions between threads.
#
# On Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex.
# Also defines the `Crystal.once(flag, &)` method used to protect lazy
# initialization of class getters & properties.
#
# A `Mutex` is used to avoid race conditions between threads and fibers.

{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# This implementation uses an enum over the initialization flag pointer for
Expand All @@ -22,28 +21,32 @@
# :nodoc:
enum OnceState : Int8
Processing = -1
Uninitialized = 0
Initialized = 1
Uninitialized = 0
Initialized = 1
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = uninitialized Mutex
@@once_mutex = uninitialized Mutex

# :nodoc:
def self.once_mutex=(@@once_mutex : Mutex)
end
{% end %}
# :nodoc:
def self.once_mutex=(@@once_mutex : Mutex)
end

# :nodoc:
#
# Identical to `__crystal_once` but takes a block with possibly closured
# data. Used by `class_[getter|property](declaration, &block)` for example.
def self.once(flag : OnceState*, &) : Nil
return if flag.value.initialized?
once_exec(flag) { yield }
end

# :nodoc:
#
# Using @[NoInline] so LLVM optimizes for the hot path (var already
# initialized).
@[NoInline]
def self.once(flag : OnceState*, initializer : Void*) : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex.synchronize { once_exec(flag, initializer) }
{% else %}
once_exec(flag, initializer)
{% end %}
def self.once(flag : OnceState*, initializer : Void*, closure_data : Void*) : Nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Why pass the function pointer and closure data as separate values instead of a Proc instance?

once_exec(flag) { Proc(Nil).new(initializer, closure_data).call }

# safety check, and allows to safely call `Intrinsics.unreachable` in
# `__crystal_once`
Expand All @@ -53,25 +56,27 @@
end
end

private def self.once_exec(flag : OnceState*, initializer : Void*) : Nil
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = :processing
Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = :initialized
in .processing?
raise "Recursion while initializing class variables and/or constants"
private def self.once_exec(flag, &)
@@once_mutex.synchronize do
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = OnceState::Processing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Shouldn't we unlock the mutex at this point? The initialization code might be expensive and delay execution of other initializers. I figure it ought to be possible to execute different initializers concurrently.

By setting the flag value to processing we have reserved exclusive write access for the current fiber, so we might not even need to reacquire it for setting the state to initialized (assuming the i8 assignment is atomic).

Copy link
Contributor Author

@ysbaddaden ysbaddaden Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the mutex reentrancy for thread A to notice the recursion but also to block thread B from accessing the value until thread A has properly initialized it.

Otherwise thread B could lock the mutex, notice that the flag is :processing and fail with a recursion issue (oops) instead of waiting.

To solve this, we'd need a checked mutex (not a reentrant one) per constant and class variable, which might not be a bad idea 🤔

yield
flag.value = OnceState::Initialized
in .processing?
raise "Recursion while initializing class variables and/or constants"
end
end
end
end

# :nodoc:
fun __crystal_once_init : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
Crystal.once_mutex = Mutex.new(:reentrant)
{% end %}
Thread.init
Fiber.init
Crystal.once_mutex = Mutex.new(:reentrant)
end

# :nodoc:
Expand All @@ -83,7 +88,7 @@
fun __crystal_once(flag : Crystal::OnceState*, initializer : Void*) : Nil
return if flag.value.initialized?

Crystal.once(flag, initializer)
Crystal.once(flag, initializer, Pointer(Void).null)

# tell LLVM that it can optimize away repeated `__crystal_once` calls for
# this global (e.g. repeated access to constant in a single funtion);
Expand All @@ -94,49 +99,72 @@
# This implementation uses a global array to store the initialization flag
# pointers for each value to find infinite loops and raise an error.

# :nodoc:
class Crystal::OnceState
@rec = [] of Bool*

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag
module Crystal
# :nodoc:
class OnceState
@mutex = Mutex.new(:reentrant)
@rec = [] of Bool*

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true
def once(flag : Bool*, &)
return if flag.value
once_exec(flag) { yield }
end

@rec.pop
@[NoInline]
def once(flag : Bool*, initializer : Void*, closure_data : Void*)
once_exec(flag) { Proc(Nil).new(initializer, closure_data).call }
end
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)
private def once_exec(flag, &)
@mutex.synchronize do
return if flag.value

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag

yield
flag.value = true

@rec.pop
end
end
{% end %}
end

@@once_state = uninitialized OnceState

# :nodoc:
def self.once_state=(@@once_state : OnceState)
end

# :nodoc:
def self.once(flag : Bool*, &) : Nil
return if flag.value
@@once_state.once(flag) { yield }
end
end

# :nodoc:
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
Thread.init
Fiber.init
(Crystal.once_state = Crystal::OnceState.new).as(Void*)
end

# :nodoc:
@[AlwaysInline]
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
return if flag.value
state.as(Crystal::OnceState).once(flag, initializer)
state.as(Crystal::OnceState).once(flag, initializer, Pointer(Void).null)
Intrinsics.unreachable unless flag.value
end
{% end %}

{% if flag?(:interpreted) %}
# make sure to initialize the mutex so we can use Crystal.once in the
# class_[getter|property]? macros; the compiler does the call by itself, but
# the interpreter doesn't (it doesn't use __crystal_once to protect the
# initialization of constants and class vars).
__crystal_once_init
{% end %}
13 changes: 12 additions & 1 deletion src/crystal/system/thread.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
module Crystal::System::Thread
# alias Handle

# def self.init : Nil

# def self.new_handle(thread_obj : ::Thread) : Handle

# def self.current_handle : Handle
Expand Down Expand Up @@ -48,7 +50,16 @@ class Thread
include Crystal::System::Thread

# all thread objects, so the GC can see them (it doesn't scan thread locals)
protected class_getter(threads) { Thread::LinkedList(Thread).new }
@@threads = uninitialized Thread::LinkedList(Thread)

protected def self.threads : Thread::LinkedList(Thread)
@@threads
end

def self.init : Nil
@@threads = Thread::LinkedList(Thread).new
Crystal::System::Thread.init
end

@system_handle : Crystal::System::Thread::Handle
@exception : Exception?
Expand Down
29 changes: 20 additions & 9 deletions src/crystal/system/unix/pthread.cr
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ module Crystal::System::Thread
raise RuntimeError.from_os_error("pthread_create", Errno.new(ret)) unless ret == 0
end

def self.init : Nil
{% if flag?(:musl) %}
@@main_handle = current_handle
{% elsif flag?(:openbsd) || flag?(:android) %}
ret = LibC.pthread_key_create(out current_key, nil)
raise RuntimeError.from_os_error("pthread_key_create", Errno.new(ret)) unless ret == 0
@@current_key = current_key
{% end %}
end

def self.thread_proc(data : Void*) : Void*
th = data.as(::Thread)

Expand Down Expand Up @@ -53,13 +63,7 @@ module Crystal::System::Thread
# Android appears to support TLS to some degree, but executables fail with
# an underaligned TLS segment, see https://github.com/crystal-lang/crystal/issues/13951
{% if flag?(:openbsd) || flag?(:android) %}
@@current_key : LibC::PthreadKeyT

@@current_key = begin
ret = LibC.pthread_key_create(out current_key, nil)
raise RuntimeError.from_os_error("pthread_key_create", Errno.new(ret)) unless ret == 0
current_key
end
@@current_key = uninitialized LibC::PthreadKeyT

def self.current_thread : ::Thread
if ptr = LibC.pthread_getspecific(@@current_key)
Expand All @@ -84,11 +88,18 @@ module Crystal::System::Thread
end
{% else %}
@[ThreadLocal]
class_property current_thread : ::Thread { ::Thread.new }
@@current_thread : ::Thread?

def self.current_thread : ::Thread
@@current_thread ||= ::Thread.new
end

def self.current_thread? : ::Thread?
@@current_thread
end

def self.current_thread=(@@current_thread : ::Thread)
end
{% end %}

def self.sleep(time : ::Time::Span) : Nil
Expand Down Expand Up @@ -169,7 +180,7 @@ module Crystal::System::Thread
end

{% if flag?(:musl) %}
@@main_handle : Handle = current_handle
@@main_handle = uninitialized Handle

def self.current_is_main?
current_handle == @@main_handle
Expand Down
14 changes: 13 additions & 1 deletion src/crystal/system/wasi/thread.cr
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module Crystal::System::Thread
alias Handle = Nil

def self.init : Nil
end

def self.new_handle(thread_obj : ::Thread) : Handle
raise NotImplementedError.new("Crystal::System::Thread.new_handle")
end
Expand All @@ -13,7 +16,16 @@ module Crystal::System::Thread
raise NotImplementedError.new("Crystal::System::Thread.yield_current")
end

class_property current_thread : ::Thread { ::Thread.new }
def self.current_thread : ::Thread
@@current_thread ||= ::Thread.new
end

def self.current_thread? : ::Thread?
@@current_thread
end

def self.current_thread=(@@current_thread : ::Thread)
end

def self.sleep(time : ::Time::Span) : Nil
req = uninitialized LibC::Timespec
Expand Down
Loading
Loading