Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed data race in all_type_info in free-threading mode #5419

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
check.push_back((PyTypeObject *) parent.ptr());
}

auto const &type_dict = get_internals().registered_types_py;
for (size_t i = 0; i < check.size(); i++) {
auto *type = check[i];
Expand Down Expand Up @@ -176,13 +175,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
* The value is cached for the lifetime of the Python type.
*/
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
auto ins = all_type_info_get_cache(type);
if (ins.second) {
// New cache entry: populate it
all_type_info_populate(type, ins.first->second);
}

return ins.first->second;
return all_type_info_get_cache(type).first->second;
}

/**
Expand Down
15 changes: 11 additions & 4 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) {
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
all_type_info_get_cache(PyTypeObject *type) {
auto res = with_internals([type](internals &internals) {
return internals
.registered_types_py
auto ins = internals
.registered_types_py
#ifdef __cpp_lib_unordered_map_try_emplace
.try_emplace(type);
.try_emplace(type);
#else
.emplace(type, std::vector<detail::type_info *>());
.emplace(type, std::vector<detail::type_info *>());
#endif
if (ins.second) {
// For free-threading mode, this call must be under
// the with_internals() mutex lock, to avoid that other threads
// continue running with the empty ins.first->second.
all_type_info_populate(type, ins.first->second);
}
return ins;
});
if (res.second) {
// New cache entry created; set up a weak reference to automatically remove it if the type
Expand Down
5 changes: 5 additions & 0 deletions tests/pybind11_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) {
for (const auto &initializer : initializers()) {
initializer(m);
}

py::class_<TestContext>(m, "TestContext")
.def(py::init<>(&TestContext::createNewContextForInit))
.def("__enter__", &TestContext::contextEnter)
.def("__exit__", &TestContext::contextExit);
}
21 changes: 21 additions & 0 deletions tests/pybind11_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,24 @@ void ignoreOldStyleInitWarnings(F &&body) {
)",
py::dict(py::arg("body") = py::cpp_function(body)));
}

// See PR #5419 for background.
class TestContext {
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
public:
TestContext() = delete;
TestContext(const TestContext &) = delete;
TestContext(TestContext &&) = delete;
static TestContext *createNewContextForInit() { return new TestContext("new-context"); }

pybind11::object contextEnter() {
py::object contextObj = py::cast(*this);
return contextObj;
}
void contextExit(const pybind11::object & /*excType*/,
const pybind11::object & /*excVal*/,
const pybind11::object & /*excTb*/) {}

private:
explicit TestContext(const std::string &context) : context(context) {}
std::string context;
};
29 changes: 29 additions & 0 deletions tests/test_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from unittest import mock

import pytest
Expand Down Expand Up @@ -508,3 +509,31 @@ def test_pr4220_tripped_over_this():
m.Empty0().get_msg()
== "This is really only meant to exercise successful compilation."
)


@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
def test_all_type_info_multithreaded():
# See PR #5419 for background.
import threading

from pybind11_tests import TestContext

class Context(TestContext):
pass

num_runs = 10
num_threads = 4
barrier = threading.Barrier(num_threads)

def func():
barrier.wait()
with Context():
pass

for _ in range(num_runs):
threads = [threading.Thread(target=func) for _ in range(num_threads)]
for thread in threads:
thread.start()

for thread in threads:
thread.join()