From 56b5e9f15c2da5830ccf001315121466f55ad825 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Thu, 31 Oct 2024 09:58:24 +0100 Subject: [PATCH 01/18] Introduce "stackable" resources for to improve how we nest contexts --- cudax/examples/stf/CMakeLists.txt | 1 + cudax/examples/stf/binary_fhe_stackable.cu | 201 ++++++++++++ .../__stf/utility/stackable_ctx.cuh | 298 ++++++++++++++++++ cudax/include/cuda/experimental/stf.cuh | 12 + 4 files changed, 512 insertions(+) create mode 100644 cudax/examples/stf/binary_fhe_stackable.cu create mode 100644 cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh diff --git a/cudax/examples/stf/CMakeLists.txt b/cudax/examples/stf/CMakeLists.txt index 586a3fe4be4..bd2925ac2d9 100644 --- a/cudax/examples/stf/CMakeLists.txt +++ b/cudax/examples/stf/CMakeLists.txt @@ -11,6 +11,7 @@ set(stf_example_sources 08-cub-reduce.cu axpy-annotated.cu binary_fhe.cu + binary_fhe_stackable.cu cfd.cu custom_data_interface.cu void_data_interface.cu diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu new file mode 100644 index 00000000000..581e7983d6d --- /dev/null +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -0,0 +1,201 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief A toy example to illustrate how we can compose logical operations + * over encrypted data + */ + +#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" +#include "cuda/experimental/stf.cuh" + +using namespace cuda::experimental::stf; + +class ciphertext; + +class plaintext { +public: + plaintext(const stackable_ctx& ctx) : ctx(ctx) {} + + plaintext(stackable_ctx& ctx, std::vector v) : values(v), ctx(ctx) { + l = ctx.logical_data(&values[0], values.size()); + } + + void set_symbol(std::string s) { + l.set_symbol(s); + symbol = s; + } + + std::string get_symbol() const { return symbol; } + + std::string symbol; + + const stackable_logical_data>& data() const { return l; } + + stackable_logical_data>& data() { return l; } + + // This will asynchronously fill string s + void convert_to_vector(std::vector& v) { + ctx.host_launch(l.read()).set_symbol("to_vector")->*[&](auto dl) { + v.resize(dl.size()); + for (size_t i = 0; i < dl.size(); i++) { + v[i] = dl(i); + } + }; + } + + ciphertext encrypt() const; + + stackable_logical_data> l; + + template + void push(Pack&&... pack) { + l.push(::std::forward(pack)...); + } + + void pop() { l.pop(); } + +private: + std::vector values; + mutable stackable_ctx ctx; +}; + +class ciphertext { +public: + ciphertext() = default; + + ciphertext(const stackable_ctx& ctx) : ctx(ctx) {} + + plaintext decrypt() const { + plaintext p(ctx); + p.l = ctx.logical_data(shape_of>(l.shape().size())); + // fprintf(stderr, "Decrypting...\n"); + ctx.parallel_for(l.shape(), l.read(), p.l.write()).set_symbol("decrypt")->* + [] __device__ (size_t i, auto dctxt, auto dptxt) { + dptxt(i) = char((dctxt(i) >> 32)); + // printf("DECRYPT %ld : %lx -> %x\n", i, dctxt(i), (int) dptxt(i)); + }; + return p; + } + + // Copy assignment operator + ciphertext& operator=(const ciphertext& other) { + if (this != &other) { + fprintf(stderr, "COPY ASSIGNMENT OP... this->l.depth() %ld other.l.depth() %ld - ctx depth %ld other.ctx.depth %ld\n", l.depth(), other.l.depth(), ctx.depth(), other.ctx.depth()); + // l = ctx.logical_data(other.data().shape()); + assert(l.shape() == other.l.shape()); + other.ctx.parallel_for(l.shape(), other.l.read(), l.write()).set_symbol("copy")->* + [] __device__ (size_t i, auto other, auto result) { result(i) = other(i); }; + } + return *this; + } + + ciphertext operator|(const ciphertext& other) const { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + + ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("OR")->* + [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) | d_c2(i); }; + + return result; + } + + ciphertext operator&(const ciphertext& other) const { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + + ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("AND")->* + [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) & d_c2(i); }; + + return result; + } + + ciphertext operator~() const { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + ctx.parallel_for(data().shape(), data().read(), result.data().write()).set_symbol("NOT")->* + [] __device__(size_t i, auto d_c, auto d_res) { d_res(i) = ~d_c(i); }; + + return result; + } + + const stackable_logical_data>& data() const { return l; } + + stackable_logical_data>& data() { return l; } + + stackable_logical_data> l; + + template + void push(Pack&&... pack) { + l.push(::std::forward(pack)...); + } + + void pop() { l.pop(); } + +private: + mutable stackable_ctx ctx; +}; + +ciphertext plaintext::encrypt() const { + ciphertext c(ctx); + c.l = ctx.logical_data(shape_of>(l.shape().size())); + + ctx.parallel_for(l.shape(), l.read(), c.l.write()).set_symbol("encrypt")->* + [] __device__(size_t i, auto dptxt, auto dctxt) { + // A super safe encryption ! + dctxt(i) = ((uint64_t) (dptxt(i)) << 32 | 0x4); + }; + + return c; +} + +template +T circuit(const T& a, const T& b) { + return (~((a | ~b) & (~a | b))); +} + +int main() { + stackable_ctx ctx; + + std::vector vA { 3, 3, 2, 2, 17 }; + plaintext pA(ctx, vA); + pA.set_symbol("A"); + + std::vector vB { 1, 7, 7, 7, 49 }; + plaintext pB(ctx, vB); + pB.set_symbol("B"); + + auto eA = pA.encrypt(); + auto eB = pB.encrypt(); + + ctx.push_graph(); + + eA.push(access_mode::read); + eB.push(access_mode::read); + + // TODO find a way to get "out" outside of this scope to do decryption in the main ctx + auto out = circuit(eA, eB); + + std::vector v_out; + out.decrypt().convert_to_vector(v_out); + + eA.pop(); + eB.pop(); + + ctx.pop(); + + ctx.finalize(); + + for (size_t i = 0; i < v_out.size(); i++) { + char expected = circuit(vA[i], vB[i]); + EXPECT(expected == v_out[i]); + } +} diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh new file mode 100644 index 00000000000..33dd0186943 --- /dev/null +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -0,0 +1,298 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief Stackable context and logical data to nest contexts + */ + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include "cuda/experimental/__stf/allocators/adapters.cuh" +#include "cuda/experimental/stf.cuh" + +namespace cuda::experimental::stf { + +template +class stackable_logical_data; + +class stackable_ctx { +public: + class impl { + public: + impl() { push(stream_ctx(), nullptr); } + + ~impl() = default; + + void push(context ctx, cudaStream_t stream) { + s.push_back(mv(ctx)); + s_stream.push_back(stream); + } + + void pop() { + s.back().finalize(); + + s.pop_back(); + + s_stream.pop_back(); + + assert(alloc_adapters.size() > 0); + alloc_adapters.back().clear(); + alloc_adapters.pop_back(); + } + + size_t depth() const { return s.size() - 1; } + + auto& get() { return s.back(); } + + const auto& get() const { return s.back(); } + + auto& at(size_t level) { + assert(level < s.size()); + return s[level]; + } + + const auto& at(size_t level) const { + assert(level < s.size()); + + return s[level]; + } + + cudaStream_t stream_at(size_t level) const { return s_stream[level]; } + + void push_graph() { + cudaStream_t stream = get().pick_stream(); + + // These resources are not destroyed when we pop, so we create it only if needed + if (async_handles.size() < s_stream.size()) { + async_handles.emplace_back(); + } + + auto gctx = graph_ctx(stream, async_handles.back()); + + auto wrapper = stream_adapter(gctx, stream); + // FIXME : issue with the deinit phase + // gctx.update_uncached_allocator(wrapper.allocator()); + + alloc_adapters.push_back(wrapper); + + push(mv(gctx), stream); + } + + private: + ::std::vector s; + ::std::vector s_stream; + ::std::vector async_handles; + ::std::vector alloc_adapters; + }; + + stackable_ctx() : pimpl(::std::make_shared()) {} + + const auto& get() const { return pimpl->get(); } + auto& get() { return pimpl->get(); } + + const auto& at(size_t level) const { return pimpl->at(level); } + auto& at(size_t level) { return pimpl->at(level); } + + cudaStream_t stream_at(size_t level) const { return pimpl->stream_at(level); } + + const auto& operator()() const { return get(); } + + auto& operator()() { return get(); } + + void push_graph() { pimpl->push_graph(); } + + void pop() { pimpl->pop(); } + + size_t depth() const { return pimpl->depth(); } + + template + auto logical_data(Pack&&... pack) { + return stackable_logical_data(*this, depth(), get().logical_data(::std::forward(pack)...)); + } + + template + auto task(Pack&&... pack) { + return get().task(::std::forward(pack)...); + } + + template + auto parallel_for(Pack&&... pack) { + return get().parallel_for(::std::forward(pack)...); + } + + template + auto host_launch(Pack&&... pack) { + return get().host_launch(::std::forward(pack)...); + } + + void finalize() { + // There must be only one level left + assert(depth() == 0); + + get().finalize(); + } + +public: + ::std::shared_ptr pimpl; +}; + +template +class stackable_logical_data { + class impl { + public: + impl() = default; + impl(stackable_ctx sctx, size_t depth, logical_data ld) : base_depth(depth), sctx(mv(sctx)) { + s.push_back(ld); + } + + const auto& get() const { + check_level_mismatch(); + return s.back(); + } + auto& get() { + check_level_mismatch(); + return s.back(); + } + + void push(access_mode m, data_place where = data_place::invalid) { + // We have not pushed yet, so the current depth is the one before pushing + context& from_ctx = sctx.at(depth()); + context& to_ctx = sctx.at(depth() + 1); + + // Ensure this will match the depth of the context after pushing + assert(sctx.depth() == depth() + 1); + + if (where == data_place::invalid) { + // use the default place + where = from_ctx.default_exec_place().affine_data_place(); + } + + assert(where != data_place::invalid); + + // Freeze the logical data at the top + logical_data& from_data = s.back(); + frozen_logical_data f = from_ctx.freeze(from_data, m, mv(where)); + + // Save the frozen data in a separate vector + frozen_s.push_back(f); + + // FAKE IMPORT : use the stream needed to support the (graph) ctx + cudaStream_t stream = sctx.stream_at(depth()); + + T inst = f.get(where, stream); + auto ld = to_ctx.logical_data(inst, where); + + if (!symbol.empty()) { + ld.set_symbol(symbol + "." + ::std::to_string(depth() + 1)); + } + + s.push_back(mv(ld)); + } + + void pop() { + // We are going to unfreeze the data, which is currently being used + // in a (graph) ctx that uses this stream to launch the graph + cudaStream_t stream = sctx.stream_at(depth()); + + frozen_logical_data& f = frozen_s.back(); + f.unfreeze(stream); + + // Remove frozen logical data + frozen_s.pop_back(); + // Remove aliased logical data + s.pop_back(); + } + + size_t depth() const { return s.size() - 1 + base_depth; } + + void set_symbol(::std::string symbol_) { + symbol = mv(symbol_); + s.back().set_symbol(symbol + "." + ::std::to_string(depth())); + } + + private: + void check_level_mismatch() const { + if (depth() != sctx.depth()) { + fprintf(stderr, "Warning: mismatch between ctx level %ld and data level %ld\n", sctx.depth(), depth()); + } + } + + mutable ::std::vector> s; + + // When stacking data, we freeze data from the lower levels, these are + // their frozen counterparts. This vector has one item less than the + // vector of logical data. + mutable ::std::vector> frozen_s; + + // If the logical data was created at a level that is not directly the root of the context, we remember this + // offset + size_t base_depth = 0; + stackable_ctx sctx; // in which stackable context was this created ? + + ::std::string symbol; + }; + +public: + stackable_logical_data() = default; + + template + stackable_logical_data(stackable_ctx sctx, size_t depth, logical_data ld) + : pimpl(::std::make_shared(mv(sctx), depth, mv(ld))) {} + + const auto& get() const { return pimpl->get(); } + auto& get() { return pimpl->get(); } + + const auto& operator()() const { return get(); } + auto& operator()() { return get(); } + + size_t depth() const { return pimpl->depth(); } + + void push(access_mode m, data_place where = data_place::invalid) { pimpl->push(m, mv(where)); } + void pop() { pimpl->pop(); } + + // Helpers + template + auto read(Pack&&... pack) const { + return get().read(::std::forward(pack)...); + } + + template + auto write(Pack&&... pack) { + return get().write(::std::forward(pack)...); + } + + template + auto rw(Pack&&... pack) { + return get().rw(::std::forward(pack)...); + } + + auto shape() const { return get().shape(); } + + auto& set_symbol(::std::string symbol) { + pimpl->set_symbol(mv(symbol)); + return *this; + } + +private: + ::std::shared_ptr pimpl; +}; + +} // end namespace cuda::experimental::stf diff --git a/cudax/include/cuda/experimental/stf.cuh b/cudax/include/cuda/experimental/stf.cuh index f5ce8f2e4dc..81515ae2431 100644 --- a/cudax/include/cuda/experimental/stf.cuh +++ b/cudax/include/cuda/experimental/stf.cuh @@ -706,6 +706,18 @@ public: } } + auto pick_dstream() { + return ::std::visit([](auto& self) { return self.pick_dstream(); }, payload); + } + + /** + * @brief Get a stream from the stream pool(s) of the context + * + * This is a helper routine which can be used to launch graphs, for example. Using the stream after finalize() + * results in undefined behavior. + */ + cudaStream_t pick_stream() { return pick_dstream().stream; } + private: template auto visit(Fun&& fun) From 7b2bd58802987179c8427302f7f536ef4cad9b0c Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Thu, 31 Oct 2024 15:49:02 +0100 Subject: [PATCH 02/18] Add missing constructor --- cudax/examples/stf/binary_fhe_stackable.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index 581e7983d6d..5128d3784a3 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -71,9 +71,11 @@ private: class ciphertext { public: ciphertext() = default; + ciphertext(const ciphertext&) = default; ciphertext(const stackable_ctx& ctx) : ctx(ctx) {} + plaintext decrypt() const { plaintext p(ctx); p.l = ctx.logical_data(shape_of>(l.shape().size())); From 95869b08de89ff22114f53edf6e38f4c238a7d08 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 09:51:00 +0100 Subject: [PATCH 03/18] Add two more tests with stacked resources --- cudax/test/stf/CMakeLists.txt | 2 + cudax/test/stf/local_stf/stackable.cu | 87 ++++++++++++++++++++++++++ cudax/test/stf/local_stf/stackable2.cu | 63 +++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 cudax/test/stf/local_stf/stackable.cu create mode 100644 cudax/test/stf/local_stf/stackable2.cu diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 0b238da03d3..9a33f715af4 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -78,6 +78,8 @@ set(stf_test_sources interface/scalar_div.cu local_stf/interop_cuda.cu local_stf/legacy_to_stf.cu + local_stf/stackable.cu + local_stf/stackable2.cu loop_dispatch/dispatch_on_streams.cu loop_dispatch/nested_loop_dispatch.cu loop_dispatch/loop_dispatch.cu diff --git a/cudax/test/stf/local_stf/stackable.cu b/cudax/test/stf/local_stf/stackable.cu new file mode 100644 index 00000000000..d68bef76e98 --- /dev/null +++ b/cudax/test/stf/local_stf/stackable.cu @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * + * @brief Experiment with local context nesting + * + */ + +#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" +#include + +using namespace cuda::experimental::stf; + +int X0(int i) { + return 17 * i + 45; +} + +int main() { + stackable_ctx sctx; + + int array[1024]; + for (size_t i = 0; i < 1024; i++) { + array[i] = 1 + i * i; + } + + auto lC = sctx.logical_data(array); + + auto lA = sctx.logical_data(lC.shape()); + lA.set_symbol("A"); + + auto lA2 = sctx.logical_data(shape_of>(1024)); + lA2.set_symbol("A2"); + + sctx.parallel_for(lA.shape(), lA.write())->*[] __device__(size_t i, auto a) { a(i) = 42 + 2 * i; }; + + /* Start to use a graph */ + sctx.push_graph(); + + auto lB = sctx.logical_data(shape_of>(512)); + lB.set_symbol("B"); + + sctx.parallel_for(lB.shape(), lB.write())->*[] __device__(size_t i, auto b) { b(i) = 17 - 3 * i; }; + + lC.push(access_mode::rw); + lA.push(access_mode::read); + lA2.push(access_mode::write, data_place::current_device()); + + sctx.parallel_for(lA2.shape(), lA2.write())->*[] __device__(size_t i, auto a2) { a2(i) = 5 * i + 4; }; + + sctx.parallel_for(lB.shape(), lA.read(), lB.rw())->*[] __device__(size_t i, auto a, auto b) { b(i) += a(i); }; + + sctx.parallel_for(lB.shape(), lB.read(), lC.rw())->*[] __device__(size_t i, auto b, auto c) { c(i) += b(i); }; + + lA.pop(); + lA2.pop(); + lC.pop(); + + sctx.pop(); + + sctx.host_launch(lA2.read())->*[](auto a2) { + for (size_t i = 0; i < a2.size(); i++) { + EXPECT(a2(i) == 5 * i + 4); + } + }; + + // Do the same check in another graph + sctx.push_graph(); + lA2.push(access_mode::read); + sctx.host_launch(lA2.read())->*[](auto a2) { + for (size_t i = 0; i < a2.size(); i++) { + EXPECT(a2(i) == 5 * i + 4); + } + }; + lA2.pop(); + sctx.pop(); + + sctx.finalize(); +} diff --git a/cudax/test/stf/local_stf/stackable2.cu b/cudax/test/stf/local_stf/stackable2.cu new file mode 100644 index 00000000000..3213ffa487f --- /dev/null +++ b/cudax/test/stf/local_stf/stackable2.cu @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * + * @brief Experiment with local context nesting + * + */ + +#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" +#include + +using namespace cuda::experimental::stf; + +int X0(int i) { + return 17 * i + 45; +} + +int main() { + stackable_ctx ctx; + + int array[1024]; + for (size_t i = 0; i < 1024; i++) { + array[i] = 1 + i * i; + } + + auto lA = ctx.logical_data(array).set_symbol("A"); + + // repeat : {tmp = a, a++; tmp*=2; a+=tmp} + for (size_t iter = 0; iter < 10; iter++) { + ctx.push_graph(); + + lA.push(access_mode::rw); + + auto tmp = ctx.logical_data(lA.shape()).set_symbol("tmp"); + + ctx.parallel_for(tmp.shape(), tmp.write(), lA.read())->*[] __device__(size_t i, auto tmp, auto a) { + tmp(i) = a(i); + }; + + ctx.parallel_for(lA.shape(), lA.rw())->*[] __device__(size_t i, auto a) { a(i) += 1; }; + + ctx.parallel_for(tmp.shape(), tmp.rw())->*[] __device__(size_t i, auto tmp) { tmp(i) *= 2; }; + + ctx.parallel_for(lA.shape(), tmp.read(), lA.rw())->*[] __device__(size_t i, auto tmp, auto a) { + a(i) += tmp(i); + }; + + lA.pop(); + + ctx.pop(); + } + + ctx.finalize(); +} From 771bcb0b4bb21f5365c5f0fed93c550b0dcf0e46 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:03:25 +0100 Subject: [PATCH 04/18] fix copyright notice --- cudax/examples/stf/binary_fhe_stackable.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index 5128d3784a3..e0bef6a79c3 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -4,7 +4,7 @@ // under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. +// SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. // //===----------------------------------------------------------------------===// From 438bfa3fbd28da1b3a8dbc63fa9bc43547c4959f Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:05:03 +0100 Subject: [PATCH 05/18] fix formatting issue in the doxygen comment --- cudax/examples/stf/binary_fhe_stackable.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index e0bef6a79c3..6e2de8b5fd0 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -10,8 +10,7 @@ /** * @file - * @brief A toy example to illustrate how we can compose logical operations - * over encrypted data + * @brief A toy example to illustrate how we can compose logical operations over encrypted data */ #include "cuda/experimental/__stf/utility/stackable_ctx.cuh" From 22ba9d29add1c5e3bf8d1c9920c3e226dc8a52b6 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:05:24 +0100 Subject: [PATCH 06/18] remove dead code and debug printfs --- cudax/examples/stf/binary_fhe_stackable.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index 6e2de8b5fd0..d88e733ffec 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -78,11 +78,9 @@ public: plaintext decrypt() const { plaintext p(ctx); p.l = ctx.logical_data(shape_of>(l.shape().size())); - // fprintf(stderr, "Decrypting...\n"); ctx.parallel_for(l.shape(), l.read(), p.l.write()).set_symbol("decrypt")->* [] __device__ (size_t i, auto dctxt, auto dptxt) { dptxt(i) = char((dctxt(i) >> 32)); - // printf("DECRYPT %ld : %lx -> %x\n", i, dctxt(i), (int) dptxt(i)); }; return p; } @@ -90,8 +88,6 @@ public: // Copy assignment operator ciphertext& operator=(const ciphertext& other) { if (this != &other) { - fprintf(stderr, "COPY ASSIGNMENT OP... this->l.depth() %ld other.l.depth() %ld - ctx depth %ld other.ctx.depth %ld\n", l.depth(), other.l.depth(), ctx.depth(), other.ctx.depth()); - // l = ctx.logical_data(other.data().shape()); assert(l.shape() == other.l.shape()); other.ctx.parallel_for(l.shape(), other.l.read(), l.write()).set_symbol("copy")->* [] __device__ (size_t i, auto other, auto result) { result(i) = other(i); }; From 0f9425bc38992c0151bf3a97801e67c5bf58f3c2 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:13:42 +0100 Subject: [PATCH 07/18] replace assert by _CCCL_ASSERT --- .../experimental/__stf/utility/stackable_ctx.cuh | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 33dd0186943..0422a8992a0 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -53,7 +53,7 @@ public: s_stream.pop_back(); - assert(alloc_adapters.size() > 0); + _CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container"); alloc_adapters.back().clear(); alloc_adapters.pop_back(); } @@ -65,13 +65,12 @@ public: const auto& get() const { return s.back(); } auto& at(size_t level) { - assert(level < s.size()); + _CCCL_ASSERT(level < s.size(), "Out of bound access"); return s[level]; } const auto& at(size_t level) const { - assert(level < s.size()); - + _CCCL_ASSERT(level < s.size(), "Out of bound access"); return s[level]; } @@ -145,7 +144,7 @@ public: void finalize() { // There must be only one level left - assert(depth() == 0); + _CCCL_ASSERT(depth() == 0, "All nested levels must have been popped"); get().finalize(); } @@ -178,14 +177,14 @@ class stackable_logical_data { context& to_ctx = sctx.at(depth() + 1); // Ensure this will match the depth of the context after pushing - assert(sctx.depth() == depth() + 1); + _CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth"); if (where == data_place::invalid) { // use the default place where = from_ctx.default_exec_place().affine_data_place(); } - assert(where != data_place::invalid); + _CCCL_ASSERT(where != data_place::invalid, "Invalid data place"); // Freeze the logical data at the top logical_data& from_data = s.back(); From c68c5fb0e408277e0ab1ae8543d4c53596ba39aa Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:19:24 +0100 Subject: [PATCH 08/18] Replace at by a [] operator --- .../__stf/utility/stackable_ctx.cuh | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 0422a8992a0..2bc8890df6e 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -64,12 +64,12 @@ public: const auto& get() const { return s.back(); } - auto& at(size_t level) { + auto& operator[](size_t level) { _CCCL_ASSERT(level < s.size(), "Out of bound access"); return s[level]; } - const auto& at(size_t level) const { + const auto& operator[](size_t level) const { _CCCL_ASSERT(level < s.size(), "Out of bound access"); return s[level]; } @@ -107,8 +107,13 @@ public: const auto& get() const { return pimpl->get(); } auto& get() { return pimpl->get(); } - const auto& at(size_t level) const { return pimpl->at(level); } - auto& at(size_t level) { return pimpl->at(level); } + auto& operator[](size_t level) { + return pimpl->operator[](level); + } + + const auto& operator[](size_t level) const { + return pimpl->operator[](level); + } cudaStream_t stream_at(size_t level) const { return pimpl->stream_at(level); } @@ -173,8 +178,8 @@ class stackable_logical_data { void push(access_mode m, data_place where = data_place::invalid) { // We have not pushed yet, so the current depth is the one before pushing - context& from_ctx = sctx.at(depth()); - context& to_ctx = sctx.at(depth() + 1); + context& from_ctx = sctx[depth()]; + context& to_ctx = sctx[depth() + 1]; // Ensure this will match the depth of the context after pushing _CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth"); From dd8b008291be0ca50ea535d033ce925f81155c81 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:20:08 +0100 Subject: [PATCH 09/18] clang-format --- cudax/examples/stf/binary_fhe_stackable.cu | 344 ++++++----- .../__stf/utility/stackable_ctx.cuh | 537 +++++++++++------- cudax/include/cuda/experimental/stf.cuh | 28 +- cudax/test/stf/local_stf/stackable.cu | 110 ++-- cudax/test/stf/local_stf/stackable2.cu | 63 +- 5 files changed, 635 insertions(+), 447 deletions(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index d88e733ffec..f090b3d14d6 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -20,179 +20,231 @@ using namespace cuda::experimental::stf; class ciphertext; -class plaintext { +class plaintext +{ public: - plaintext(const stackable_ctx& ctx) : ctx(ctx) {} - - plaintext(stackable_ctx& ctx, std::vector v) : values(v), ctx(ctx) { - l = ctx.logical_data(&values[0], values.size()); - } - - void set_symbol(std::string s) { - l.set_symbol(s); - symbol = s; - } - - std::string get_symbol() const { return symbol; } - - std::string symbol; - - const stackable_logical_data>& data() const { return l; } - - stackable_logical_data>& data() { return l; } - - // This will asynchronously fill string s - void convert_to_vector(std::vector& v) { - ctx.host_launch(l.read()).set_symbol("to_vector")->*[&](auto dl) { - v.resize(dl.size()); - for (size_t i = 0; i < dl.size(); i++) { - v[i] = dl(i); - } - }; - } - - ciphertext encrypt() const; - - stackable_logical_data> l; - - template - void push(Pack&&... pack) { - l.push(::std::forward(pack)...); - } - - void pop() { l.pop(); } + plaintext(const stackable_ctx& ctx) + : ctx(ctx) + {} + + plaintext(stackable_ctx& ctx, std::vector v) + : values(v) + , ctx(ctx) + { + l = ctx.logical_data(&values[0], values.size()); + } + + void set_symbol(std::string s) + { + l.set_symbol(s); + symbol = s; + } + + std::string get_symbol() const + { + return symbol; + } + + std::string symbol; + + const stackable_logical_data>& data() const + { + return l; + } + + stackable_logical_data>& data() + { + return l; + } + + // This will asynchronously fill string s + void convert_to_vector(std::vector& v) + { + ctx.host_launch(l.read()).set_symbol("to_vector")->*[&](auto dl) { + v.resize(dl.size()); + for (size_t i = 0; i < dl.size(); i++) + { + v[i] = dl(i); + } + }; + } + + ciphertext encrypt() const; + + stackable_logical_data> l; + + template + void push(Pack&&... pack) + { + l.push(::std::forward(pack)...); + } + + void pop() + { + l.pop(); + } private: - std::vector values; - mutable stackable_ctx ctx; + std::vector values; + mutable stackable_ctx ctx; }; -class ciphertext { +class ciphertext +{ public: - ciphertext() = default; - ciphertext(const ciphertext&) = default; - - ciphertext(const stackable_ctx& ctx) : ctx(ctx) {} - - - plaintext decrypt() const { - plaintext p(ctx); - p.l = ctx.logical_data(shape_of>(l.shape().size())); - ctx.parallel_for(l.shape(), l.read(), p.l.write()).set_symbol("decrypt")->* - [] __device__ (size_t i, auto dctxt, auto dptxt) { - dptxt(i) = char((dctxt(i) >> 32)); - }; - return p; - } - - // Copy assignment operator - ciphertext& operator=(const ciphertext& other) { - if (this != &other) { - assert(l.shape() == other.l.shape()); - other.ctx.parallel_for(l.shape(), other.l.read(), l.write()).set_symbol("copy")->* - [] __device__ (size_t i, auto other, auto result) { result(i) = other(i); }; - } - return *this; - } - - ciphertext operator|(const ciphertext& other) const { - ciphertext result(ctx); - result.l = ctx.logical_data(data().shape()); - - ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("OR")->* - [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) | d_c2(i); }; - - return result; - } - - ciphertext operator&(const ciphertext& other) const { - ciphertext result(ctx); - result.l = ctx.logical_data(data().shape()); - - ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("AND")->* - [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) & d_c2(i); }; - - return result; - } - - ciphertext operator~() const { - ciphertext result(ctx); - result.l = ctx.logical_data(data().shape()); - ctx.parallel_for(data().shape(), data().read(), result.data().write()).set_symbol("NOT")->* - [] __device__(size_t i, auto d_c, auto d_res) { d_res(i) = ~d_c(i); }; - - return result; - } - - const stackable_logical_data>& data() const { return l; } - - stackable_logical_data>& data() { return l; } - - stackable_logical_data> l; - - template - void push(Pack&&... pack) { - l.push(::std::forward(pack)...); + ciphertext() = default; + ciphertext(const ciphertext&) = default; + + ciphertext(const stackable_ctx& ctx) + : ctx(ctx) + {} + + plaintext decrypt() const + { + plaintext p(ctx); + p.l = ctx.logical_data(shape_of>(l.shape().size())); + ctx.parallel_for(l.shape(), l.read(), p.l.write()).set_symbol("decrypt")->* + [] __device__(size_t i, auto dctxt, auto dptxt) { + dptxt(i) = char((dctxt(i) >> 32)); + }; + return p; + } + + // Copy assignment operator + ciphertext& operator=(const ciphertext& other) + { + if (this != &other) + { + assert(l.shape() == other.l.shape()); + other.ctx.parallel_for(l.shape(), other.l.read(), l.write()).set_symbol("copy")->* + [] __device__(size_t i, auto other, auto result) { + result(i) = other(i); + }; } - - void pop() { l.pop(); } + return *this; + } + + ciphertext operator|(const ciphertext& other) const + { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + + ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("OR")->* + [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { + d_res(i) = d_c1(i) | d_c2(i); + }; + + return result; + } + + ciphertext operator&(const ciphertext& other) const + { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + + ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("AND")->* + [] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { + d_res(i) = d_c1(i) & d_c2(i); + }; + + return result; + } + + ciphertext operator~() const + { + ciphertext result(ctx); + result.l = ctx.logical_data(data().shape()); + ctx.parallel_for(data().shape(), data().read(), result.data().write()).set_symbol("NOT")->* + [] __device__(size_t i, auto d_c, auto d_res) { + d_res(i) = ~d_c(i); + }; + + return result; + } + + const stackable_logical_data>& data() const + { + return l; + } + + stackable_logical_data>& data() + { + return l; + } + + stackable_logical_data> l; + + template + void push(Pack&&... pack) + { + l.push(::std::forward(pack)...); + } + + void pop() + { + l.pop(); + } private: - mutable stackable_ctx ctx; + mutable stackable_ctx ctx; }; -ciphertext plaintext::encrypt() const { - ciphertext c(ctx); - c.l = ctx.logical_data(shape_of>(l.shape().size())); +ciphertext plaintext::encrypt() const +{ + ciphertext c(ctx); + c.l = ctx.logical_data(shape_of>(l.shape().size())); - ctx.parallel_for(l.shape(), l.read(), c.l.write()).set_symbol("encrypt")->* - [] __device__(size_t i, auto dptxt, auto dctxt) { - // A super safe encryption ! - dctxt(i) = ((uint64_t) (dptxt(i)) << 32 | 0x4); - }; + ctx.parallel_for(l.shape(), l.read(), c.l.write()).set_symbol("encrypt")->* + [] __device__(size_t i, auto dptxt, auto dctxt) { + // A super safe encryption ! + dctxt(i) = ((uint64_t) (dptxt(i)) << 32 | 0x4); + }; - return c; + return c; } template -T circuit(const T& a, const T& b) { - return (~((a | ~b) & (~a | b))); +T circuit(const T& a, const T& b) +{ + return (~((a | ~b) & (~a | b))); } -int main() { - stackable_ctx ctx; +int main() +{ + stackable_ctx ctx; - std::vector vA { 3, 3, 2, 2, 17 }; - plaintext pA(ctx, vA); - pA.set_symbol("A"); + std::vector vA{3, 3, 2, 2, 17}; + plaintext pA(ctx, vA); + pA.set_symbol("A"); - std::vector vB { 1, 7, 7, 7, 49 }; - plaintext pB(ctx, vB); - pB.set_symbol("B"); + std::vector vB{1, 7, 7, 7, 49}; + plaintext pB(ctx, vB); + pB.set_symbol("B"); - auto eA = pA.encrypt(); - auto eB = pB.encrypt(); + auto eA = pA.encrypt(); + auto eB = pB.encrypt(); - ctx.push_graph(); + ctx.push_graph(); - eA.push(access_mode::read); - eB.push(access_mode::read); + eA.push(access_mode::read); + eB.push(access_mode::read); - // TODO find a way to get "out" outside of this scope to do decryption in the main ctx - auto out = circuit(eA, eB); + // TODO find a way to get "out" outside of this scope to do decryption in the main ctx + auto out = circuit(eA, eB); - std::vector v_out; - out.decrypt().convert_to_vector(v_out); + std::vector v_out; + out.decrypt().convert_to_vector(v_out); - eA.pop(); - eB.pop(); + eA.pop(); + eB.pop(); - ctx.pop(); + ctx.pop(); - ctx.finalize(); + ctx.finalize(); - for (size_t i = 0; i < v_out.size(); i++) { - char expected = circuit(vA[i], vB[i]); - EXPECT(expected == v_out[i]); - } + for (size_t i = 0; i < v_out.size(); i++) + { + char expected = circuit(vA[i], vB[i]); + EXPECT(expected == v_out[i]); + } } diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 2bc8890df6e..765542033af 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -28,275 +28,378 @@ #include "cuda/experimental/__stf/allocators/adapters.cuh" #include "cuda/experimental/stf.cuh" -namespace cuda::experimental::stf { +namespace cuda::experimental::stf +{ template class stackable_logical_data; -class stackable_ctx { +class stackable_ctx +{ public: - class impl { - public: - impl() { push(stream_ctx(), nullptr); } - - ~impl() = default; + class impl + { + public: + impl() + { + push(stream_ctx(), nullptr); + } - void push(context ctx, cudaStream_t stream) { - s.push_back(mv(ctx)); - s_stream.push_back(stream); - } + ~impl() = default; - void pop() { - s.back().finalize(); + void push(context ctx, cudaStream_t stream) + { + s.push_back(mv(ctx)); + s_stream.push_back(stream); + } - s.pop_back(); + void pop() + { + s.back().finalize(); - s_stream.pop_back(); + s.pop_back(); - _CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container"); - alloc_adapters.back().clear(); - alloc_adapters.pop_back(); - } + s_stream.pop_back(); - size_t depth() const { return s.size() - 1; } + _CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container"); + alloc_adapters.back().clear(); + alloc_adapters.pop_back(); + } - auto& get() { return s.back(); } + size_t depth() const + { + return s.size() - 1; + } - const auto& get() const { return s.back(); } + auto& get() + { + return s.back(); + } - auto& operator[](size_t level) { - _CCCL_ASSERT(level < s.size(), "Out of bound access"); - return s[level]; - } + const auto& get() const + { + return s.back(); + } - const auto& operator[](size_t level) const { - _CCCL_ASSERT(level < s.size(), "Out of bound access"); - return s[level]; - } + auto& operator[](size_t level) + { + _CCCL_ASSERT(level < s.size(), "Out of bound access"); + return s[level]; + } - cudaStream_t stream_at(size_t level) const { return s_stream[level]; } + const auto& operator[](size_t level) const + { + _CCCL_ASSERT(level < s.size(), "Out of bound access"); + return s[level]; + } - void push_graph() { - cudaStream_t stream = get().pick_stream(); + cudaStream_t stream_at(size_t level) const + { + return s_stream[level]; + } - // These resources are not destroyed when we pop, so we create it only if needed - if (async_handles.size() < s_stream.size()) { - async_handles.emplace_back(); - } + void push_graph() + { + cudaStream_t stream = get().pick_stream(); - auto gctx = graph_ctx(stream, async_handles.back()); + // These resources are not destroyed when we pop, so we create it only if needed + if (async_handles.size() < s_stream.size()) + { + async_handles.emplace_back(); + } - auto wrapper = stream_adapter(gctx, stream); - // FIXME : issue with the deinit phase - // gctx.update_uncached_allocator(wrapper.allocator()); + auto gctx = graph_ctx(stream, async_handles.back()); - alloc_adapters.push_back(wrapper); + auto wrapper = stream_adapter(gctx, stream); + // FIXME : issue with the deinit phase + // gctx.update_uncached_allocator(wrapper.allocator()); - push(mv(gctx), stream); - } + alloc_adapters.push_back(wrapper); - private: - ::std::vector s; - ::std::vector s_stream; - ::std::vector async_handles; - ::std::vector alloc_adapters; - }; + push(mv(gctx), stream); + } - stackable_ctx() : pimpl(::std::make_shared()) {} + private: + ::std::vector s; + ::std::vector s_stream; + ::std::vector async_handles; + ::std::vector alloc_adapters; + }; + + stackable_ctx() + : pimpl(::std::make_shared()) + {} + + const auto& get() const + { + return pimpl->get(); + } + auto& get() + { + return pimpl->get(); + } + + auto& operator[](size_t level) + { + return pimpl->operator[](level); + } + + const auto& operator[](size_t level) const + { + return pimpl->operator[](level); + } + + cudaStream_t stream_at(size_t level) const + { + return pimpl->stream_at(level); + } + + const auto& operator()() const + { + return get(); + } + + auto& operator()() + { + return get(); + } + + void push_graph() + { + pimpl->push_graph(); + } + + void pop() + { + pimpl->pop(); + } + + size_t depth() const + { + return pimpl->depth(); + } + + template + auto logical_data(Pack&&... pack) + { + return stackable_logical_data(*this, depth(), get().logical_data(::std::forward(pack)...)); + } + + template + auto task(Pack&&... pack) + { + return get().task(::std::forward(pack)...); + } + + template + auto parallel_for(Pack&&... pack) + { + return get().parallel_for(::std::forward(pack)...); + } + + template + auto host_launch(Pack&&... pack) + { + return get().host_launch(::std::forward(pack)...); + } + + void finalize() + { + // There must be only one level left + _CCCL_ASSERT(depth() == 0, "All nested levels must have been popped"); + + get().finalize(); + } - const auto& get() const { return pimpl->get(); } - auto& get() { return pimpl->get(); } +public: + ::std::shared_ptr pimpl; +}; - auto& operator[](size_t level) { - return pimpl->operator[](level); +template +class stackable_logical_data +{ + class impl + { + public: + impl() = default; + impl(stackable_ctx sctx, size_t depth, logical_data ld) + : base_depth(depth) + , sctx(mv(sctx)) + { + s.push_back(ld); } - const auto& operator[](size_t level) const { - return pimpl->operator[](level); + const auto& get() const + { + check_level_mismatch(); + return s.back(); } - - cudaStream_t stream_at(size_t level) const { return pimpl->stream_at(level); } - - const auto& operator()() const { return get(); } - - auto& operator()() { return get(); } - - void push_graph() { pimpl->push_graph(); } - - void pop() { pimpl->pop(); } - - size_t depth() const { return pimpl->depth(); } - - template - auto logical_data(Pack&&... pack) { - return stackable_logical_data(*this, depth(), get().logical_data(::std::forward(pack)...)); + auto& get() + { + check_level_mismatch(); + return s.back(); } - template - auto task(Pack&&... pack) { - return get().task(::std::forward(pack)...); - } + void push(access_mode m, data_place where = data_place::invalid) + { + // We have not pushed yet, so the current depth is the one before pushing + context& from_ctx = sctx[depth()]; + context& to_ctx = sctx[depth() + 1]; - template - auto parallel_for(Pack&&... pack) { - return get().parallel_for(::std::forward(pack)...); - } + // Ensure this will match the depth of the context after pushing + _CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth"); - template - auto host_launch(Pack&&... pack) { - return get().host_launch(::std::forward(pack)...); - } + if (where == data_place::invalid) + { + // use the default place + where = from_ctx.default_exec_place().affine_data_place(); + } - void finalize() { - // There must be only one level left - _CCCL_ASSERT(depth() == 0, "All nested levels must have been popped"); + _CCCL_ASSERT(where != data_place::invalid, "Invalid data place"); - get().finalize(); - } + // Freeze the logical data at the top + logical_data& from_data = s.back(); + frozen_logical_data f = from_ctx.freeze(from_data, m, mv(where)); -public: - ::std::shared_ptr pimpl; -}; + // Save the frozen data in a separate vector + frozen_s.push_back(f); -template -class stackable_logical_data { - class impl { - public: - impl() = default; - impl(stackable_ctx sctx, size_t depth, logical_data ld) : base_depth(depth), sctx(mv(sctx)) { - s.push_back(ld); - } - - const auto& get() const { - check_level_mismatch(); - return s.back(); - } - auto& get() { - check_level_mismatch(); - return s.back(); - } - - void push(access_mode m, data_place where = data_place::invalid) { - // We have not pushed yet, so the current depth is the one before pushing - context& from_ctx = sctx[depth()]; - context& to_ctx = sctx[depth() + 1]; - - // Ensure this will match the depth of the context after pushing - _CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth"); - - if (where == data_place::invalid) { - // use the default place - where = from_ctx.default_exec_place().affine_data_place(); - } - - _CCCL_ASSERT(where != data_place::invalid, "Invalid data place"); - - // Freeze the logical data at the top - logical_data& from_data = s.back(); - frozen_logical_data f = from_ctx.freeze(from_data, m, mv(where)); - - // Save the frozen data in a separate vector - frozen_s.push_back(f); - - // FAKE IMPORT : use the stream needed to support the (graph) ctx - cudaStream_t stream = sctx.stream_at(depth()); - - T inst = f.get(where, stream); - auto ld = to_ctx.logical_data(inst, where); - - if (!symbol.empty()) { - ld.set_symbol(symbol + "." + ::std::to_string(depth() + 1)); - } - - s.push_back(mv(ld)); - } - - void pop() { - // We are going to unfreeze the data, which is currently being used - // in a (graph) ctx that uses this stream to launch the graph - cudaStream_t stream = sctx.stream_at(depth()); - - frozen_logical_data& f = frozen_s.back(); - f.unfreeze(stream); - - // Remove frozen logical data - frozen_s.pop_back(); - // Remove aliased logical data - s.pop_back(); - } - - size_t depth() const { return s.size() - 1 + base_depth; } - - void set_symbol(::std::string symbol_) { - symbol = mv(symbol_); - s.back().set_symbol(symbol + "." + ::std::to_string(depth())); - } - - private: - void check_level_mismatch() const { - if (depth() != sctx.depth()) { - fprintf(stderr, "Warning: mismatch between ctx level %ld and data level %ld\n", sctx.depth(), depth()); - } - } - - mutable ::std::vector> s; - - // When stacking data, we freeze data from the lower levels, these are - // their frozen counterparts. This vector has one item less than the - // vector of logical data. - mutable ::std::vector> frozen_s; - - // If the logical data was created at a level that is not directly the root of the context, we remember this - // offset - size_t base_depth = 0; - stackable_ctx sctx; // in which stackable context was this created ? - - ::std::string symbol; - }; + // FAKE IMPORT : use the stream needed to support the (graph) ctx + cudaStream_t stream = sctx.stream_at(depth()); -public: - stackable_logical_data() = default; + T inst = f.get(where, stream); + auto ld = to_ctx.logical_data(inst, where); - template - stackable_logical_data(stackable_ctx sctx, size_t depth, logical_data ld) - : pimpl(::std::make_shared(mv(sctx), depth, mv(ld))) {} + if (!symbol.empty()) + { + ld.set_symbol(symbol + "." + ::std::to_string(depth() + 1)); + } - const auto& get() const { return pimpl->get(); } - auto& get() { return pimpl->get(); } + s.push_back(mv(ld)); + } - const auto& operator()() const { return get(); } - auto& operator()() { return get(); } + void pop() + { + // We are going to unfreeze the data, which is currently being used + // in a (graph) ctx that uses this stream to launch the graph + cudaStream_t stream = sctx.stream_at(depth()); - size_t depth() const { return pimpl->depth(); } + frozen_logical_data& f = frozen_s.back(); + f.unfreeze(stream); - void push(access_mode m, data_place where = data_place::invalid) { pimpl->push(m, mv(where)); } - void pop() { pimpl->pop(); } + // Remove frozen logical data + frozen_s.pop_back(); + // Remove aliased logical data + s.pop_back(); + } - // Helpers - template - auto read(Pack&&... pack) const { - return get().read(::std::forward(pack)...); + size_t depth() const + { + return s.size() - 1 + base_depth; } - template - auto write(Pack&&... pack) { - return get().write(::std::forward(pack)...); + void set_symbol(::std::string symbol_) + { + symbol = mv(symbol_); + s.back().set_symbol(symbol + "." + ::std::to_string(depth())); } - template - auto rw(Pack&&... pack) { - return get().rw(::std::forward(pack)...); + private: + void check_level_mismatch() const + { + if (depth() != sctx.depth()) + { + fprintf(stderr, "Warning: mismatch between ctx level %ld and data level %ld\n", sctx.depth(), depth()); + } } - auto shape() const { return get().shape(); } + mutable ::std::vector> s; - auto& set_symbol(::std::string symbol) { - pimpl->set_symbol(mv(symbol)); - return *this; - } + // When stacking data, we freeze data from the lower levels, these are + // their frozen counterparts. This vector has one item less than the + // vector of logical data. + mutable ::std::vector> frozen_s; + + // If the logical data was created at a level that is not directly the root of the context, we remember this + // offset + size_t base_depth = 0; + stackable_ctx sctx; // in which stackable context was this created ? + + ::std::string symbol; + }; + +public: + stackable_logical_data() = default; + + template + stackable_logical_data(stackable_ctx sctx, size_t depth, logical_data ld) + : pimpl(::std::make_shared(mv(sctx), depth, mv(ld))) + {} + + const auto& get() const + { + return pimpl->get(); + } + auto& get() + { + return pimpl->get(); + } + + const auto& operator()() const + { + return get(); + } + auto& operator()() + { + return get(); + } + + size_t depth() const + { + return pimpl->depth(); + } + + void push(access_mode m, data_place where = data_place::invalid) + { + pimpl->push(m, mv(where)); + } + void pop() + { + pimpl->pop(); + } + + // Helpers + template + auto read(Pack&&... pack) const + { + return get().read(::std::forward(pack)...); + } + + template + auto write(Pack&&... pack) + { + return get().write(::std::forward(pack)...); + } + + template + auto rw(Pack&&... pack) + { + return get().rw(::std::forward(pack)...); + } + + auto shape() const + { + return get().shape(); + } + + auto& set_symbol(::std::string symbol) + { + pimpl->set_symbol(mv(symbol)); + return *this; + } private: - ::std::shared_ptr pimpl; + ::std::shared_ptr pimpl; }; -} // end namespace cuda::experimental::stf +} // end namespace cuda::experimental::stf diff --git a/cudax/include/cuda/experimental/stf.cuh b/cudax/include/cuda/experimental/stf.cuh index 81515ae2431..0ae01206c47 100644 --- a/cudax/include/cuda/experimental/stf.cuh +++ b/cudax/include/cuda/experimental/stf.cuh @@ -706,17 +706,25 @@ public: } } - auto pick_dstream() { - return ::std::visit([](auto& self) { return self.pick_dstream(); }, payload); - } + auto pick_dstream() + { + return ::std::visit( + [](auto& self) { + return self.pick_dstream(); + }, + payload); + } - /** - * @brief Get a stream from the stream pool(s) of the context - * - * This is a helper routine which can be used to launch graphs, for example. Using the stream after finalize() - * results in undefined behavior. - */ - cudaStream_t pick_stream() { return pick_dstream().stream; } + /** + * @brief Get a stream from the stream pool(s) of the context + * + * This is a helper routine which can be used to launch graphs, for example. Using the stream after finalize() + * results in undefined behavior. + */ + cudaStream_t pick_stream() + { + return pick_dstream().stream; + } private: template diff --git a/cudax/test/stf/local_stf/stackable.cu b/cudax/test/stf/local_stf/stackable.cu index d68bef76e98..d1ec6a91953 100644 --- a/cudax/test/stf/local_stf/stackable.cu +++ b/cudax/test/stf/local_stf/stackable.cu @@ -15,73 +15,89 @@ * */ -#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" #include +#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" + using namespace cuda::experimental::stf; -int X0(int i) { - return 17 * i + 45; +int X0(int i) +{ + return 17 * i + 45; } -int main() { - stackable_ctx sctx; +int main() +{ + stackable_ctx sctx; - int array[1024]; - for (size_t i = 0; i < 1024; i++) { - array[i] = 1 + i * i; - } + int array[1024]; + for (size_t i = 0; i < 1024; i++) + { + array[i] = 1 + i * i; + } - auto lC = sctx.logical_data(array); + auto lC = sctx.logical_data(array); - auto lA = sctx.logical_data(lC.shape()); - lA.set_symbol("A"); + auto lA = sctx.logical_data(lC.shape()); + lA.set_symbol("A"); - auto lA2 = sctx.logical_data(shape_of>(1024)); - lA2.set_symbol("A2"); + auto lA2 = sctx.logical_data(shape_of>(1024)); + lA2.set_symbol("A2"); - sctx.parallel_for(lA.shape(), lA.write())->*[] __device__(size_t i, auto a) { a(i) = 42 + 2 * i; }; + sctx.parallel_for(lA.shape(), lA.write())->*[] __device__(size_t i, auto a) { + a(i) = 42 + 2 * i; + }; - /* Start to use a graph */ - sctx.push_graph(); + /* Start to use a graph */ + sctx.push_graph(); - auto lB = sctx.logical_data(shape_of>(512)); - lB.set_symbol("B"); + auto lB = sctx.logical_data(shape_of>(512)); + lB.set_symbol("B"); - sctx.parallel_for(lB.shape(), lB.write())->*[] __device__(size_t i, auto b) { b(i) = 17 - 3 * i; }; + sctx.parallel_for(lB.shape(), lB.write())->*[] __device__(size_t i, auto b) { + b(i) = 17 - 3 * i; + }; - lC.push(access_mode::rw); - lA.push(access_mode::read); - lA2.push(access_mode::write, data_place::current_device()); + lC.push(access_mode::rw); + lA.push(access_mode::read); + lA2.push(access_mode::write, data_place::current_device()); - sctx.parallel_for(lA2.shape(), lA2.write())->*[] __device__(size_t i, auto a2) { a2(i) = 5 * i + 4; }; + sctx.parallel_for(lA2.shape(), lA2.write())->*[] __device__(size_t i, auto a2) { + a2(i) = 5 * i + 4; + }; - sctx.parallel_for(lB.shape(), lA.read(), lB.rw())->*[] __device__(size_t i, auto a, auto b) { b(i) += a(i); }; + sctx.parallel_for(lB.shape(), lA.read(), lB.rw())->*[] __device__(size_t i, auto a, auto b) { + b(i) += a(i); + }; - sctx.parallel_for(lB.shape(), lB.read(), lC.rw())->*[] __device__(size_t i, auto b, auto c) { c(i) += b(i); }; + sctx.parallel_for(lB.shape(), lB.read(), lC.rw())->*[] __device__(size_t i, auto b, auto c) { + c(i) += b(i); + }; - lA.pop(); - lA2.pop(); - lC.pop(); + lA.pop(); + lA2.pop(); + lC.pop(); - sctx.pop(); + sctx.pop(); - sctx.host_launch(lA2.read())->*[](auto a2) { - for (size_t i = 0; i < a2.size(); i++) { - EXPECT(a2(i) == 5 * i + 4); - } - }; - - // Do the same check in another graph - sctx.push_graph(); - lA2.push(access_mode::read); - sctx.host_launch(lA2.read())->*[](auto a2) { - for (size_t i = 0; i < a2.size(); i++) { - EXPECT(a2(i) == 5 * i + 4); - } - }; - lA2.pop(); - sctx.pop(); + sctx.host_launch(lA2.read())->*[](auto a2) { + for (size_t i = 0; i < a2.size(); i++) + { + EXPECT(a2(i) == 5 * i + 4); + } + }; + + // Do the same check in another graph + sctx.push_graph(); + lA2.push(access_mode::read); + sctx.host_launch(lA2.read())->*[](auto a2) { + for (size_t i = 0; i < a2.size(); i++) + { + EXPECT(a2(i) == 5 * i + 4); + } + }; + lA2.pop(); + sctx.pop(); - sctx.finalize(); + sctx.finalize(); } diff --git a/cudax/test/stf/local_stf/stackable2.cu b/cudax/test/stf/local_stf/stackable2.cu index 3213ffa487f..b8ba1e7fc36 100644 --- a/cudax/test/stf/local_stf/stackable2.cu +++ b/cudax/test/stf/local_stf/stackable2.cu @@ -15,49 +15,58 @@ * */ -#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" #include +#include "cuda/experimental/__stf/utility/stackable_ctx.cuh" + using namespace cuda::experimental::stf; -int X0(int i) { - return 17 * i + 45; +int X0(int i) +{ + return 17 * i + 45; } -int main() { - stackable_ctx ctx; +int main() +{ + stackable_ctx ctx; - int array[1024]; - for (size_t i = 0; i < 1024; i++) { - array[i] = 1 + i * i; - } + int array[1024]; + for (size_t i = 0; i < 1024; i++) + { + array[i] = 1 + i * i; + } - auto lA = ctx.logical_data(array).set_symbol("A"); + auto lA = ctx.logical_data(array).set_symbol("A"); - // repeat : {tmp = a, a++; tmp*=2; a+=tmp} - for (size_t iter = 0; iter < 10; iter++) { - ctx.push_graph(); + // repeat : {tmp = a, a++; tmp*=2; a+=tmp} + for (size_t iter = 0; iter < 10; iter++) + { + ctx.push_graph(); - lA.push(access_mode::rw); + lA.push(access_mode::rw); - auto tmp = ctx.logical_data(lA.shape()).set_symbol("tmp"); + auto tmp = ctx.logical_data(lA.shape()).set_symbol("tmp"); - ctx.parallel_for(tmp.shape(), tmp.write(), lA.read())->*[] __device__(size_t i, auto tmp, auto a) { - tmp(i) = a(i); - }; + ctx.parallel_for(tmp.shape(), tmp.write(), lA.read())->*[] __device__(size_t i, auto tmp, auto a) { + tmp(i) = a(i); + }; - ctx.parallel_for(lA.shape(), lA.rw())->*[] __device__(size_t i, auto a) { a(i) += 1; }; + ctx.parallel_for(lA.shape(), lA.rw())->*[] __device__(size_t i, auto a) { + a(i) += 1; + }; - ctx.parallel_for(tmp.shape(), tmp.rw())->*[] __device__(size_t i, auto tmp) { tmp(i) *= 2; }; + ctx.parallel_for(tmp.shape(), tmp.rw())->*[] __device__(size_t i, auto tmp) { + tmp(i) *= 2; + }; - ctx.parallel_for(lA.shape(), tmp.read(), lA.rw())->*[] __device__(size_t i, auto tmp, auto a) { - a(i) += tmp(i); - }; + ctx.parallel_for(lA.shape(), tmp.read(), lA.rw())->*[] __device__(size_t i, auto tmp, auto a) { + a(i) += tmp(i); + }; - lA.pop(); + lA.pop(); - ctx.pop(); - } + ctx.pop(); + } - ctx.finalize(); + ctx.finalize(); } From f78594f568d800c43b104b4b4ad506f910c8df88 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 14:28:51 +0100 Subject: [PATCH 10/18] avoid copies --- cudax/examples/stf/binary_fhe_stackable.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index f090b3d14d6..27916dcd17b 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -27,11 +27,11 @@ public: : ctx(ctx) {} - plaintext(stackable_ctx& ctx, std::vector v) - : values(v) + plaintext(stackable_ctx& ctx, ::std::vector v) + : values(mv(v)) , ctx(ctx) { - l = ctx.logical_data(&values[0], values.size()); + l = ctx.logical_data(values.data(), values.size()); } void set_symbol(std::string s) From fb3de987f4d39d79ea1847fee25ea1080bc9b35f Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Mon, 4 Nov 2024 22:03:36 +0100 Subject: [PATCH 11/18] Cleanup the stackable resource implementation --- cudax/examples/stf/binary_fhe_stackable.cu | 2 +- .../__stf/utility/stackable_ctx.cuh | 206 +++++++++--------- cudax/test/stf/local_stf/stackable.cu | 4 +- cudax/test/stf/local_stf/stackable2.cu | 2 +- 4 files changed, 111 insertions(+), 103 deletions(-) diff --git a/cudax/examples/stf/binary_fhe_stackable.cu b/cudax/examples/stf/binary_fhe_stackable.cu index 27916dcd17b..36f38c13b0f 100644 --- a/cudax/examples/stf/binary_fhe_stackable.cu +++ b/cudax/examples/stf/binary_fhe_stackable.cu @@ -224,7 +224,7 @@ int main() auto eA = pA.encrypt(); auto eB = pB.encrypt(); - ctx.push_graph(); + ctx.push(); eA.push(access_mode::read); eB.push(access_mode::read); diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 765542033af..957e33fb31d 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -34,139 +34,155 @@ namespace cuda::experimental::stf template class stackable_logical_data; +/** + * @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local CUDA graphs) + */ class stackable_ctx { public: class impl { + private: + /* + * State of each nested context + */ + struct per_level { + per_level(context ctx, cudaStream_t support_stream, ::std::optional alloc_adapters) : ctx(mv(ctx)), support_stream(mv(support_stream)), alloc_adapters(mv(alloc_adapters)) {} + + context ctx; + cudaStream_t support_stream; + // A wrapper to forward allocations from a level to the previous one (none is used at the root level) + ::std::optional alloc_adapters; + }; + public: impl() { - push(stream_ctx(), nullptr); + push(); } ~impl() = default; - void push(context ctx, cudaStream_t stream) + /** + * @brief Create a new nested level + */ + void push() { - s.push_back(mv(ctx)); - s_stream.push_back(stream); - } + // These resources are not destroyed when we pop, so we create it only if needed + if (async_handles.size() < levels.size()) + { + async_handles.emplace_back(); + } - void pop() - { - s.back().finalize(); + if (levels.size() == 0) { + levels.emplace_back(stream_ctx(), nullptr, ::std::nullopt); + } + else { + // Get a stream from previous context (we haven't pushed the new one yet) + cudaStream_t stream = levels[depth()].ctx.pick_stream(); - s.pop_back(); + auto gctx = graph_ctx(stream, async_handles.back()); - s_stream.pop_back(); + auto wrapper = stream_adapter(gctx, stream); + // FIXME : issue with the deinit phase + // gctx.update_uncached_allocator(wrapper.allocator()); - _CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container"); - alloc_adapters.back().clear(); - alloc_adapters.pop_back(); + levels.emplace_back(gctx, stream, wrapper); + } } - size_t depth() const + /** + * @brief Terminate the current nested level and get back to the previous one + */ + void pop() { - return s.size() - 1; - } + _CCCL_ASSERT(levels.size() > 0, "Calling pop while no context was pushed"); - auto& get() - { - return s.back(); - } + auto ¤t_level = levels.back(); - const auto& get() const - { - return s.back(); + // Ensure everything is finished in the context + current_level.ctx.finalize(); + + // Destroy the resources used in the wrapper allocator (if any) + if (current_level.alloc_adapters.has_value()) + { + current_level.alloc_adapters.value().clear(); + } + + // Destroy the current level state + levels.pop_back(); } - auto& operator[](size_t level) + /** + * @brief Get the nesting depth + */ + size_t depth() const { - _CCCL_ASSERT(level < s.size(), "Out of bound access"); - return s[level]; + return levels.size() - 1; } - const auto& operator[](size_t level) const + /** + * @brief Returns a reference to the context at a specific level + */ + auto& get_ctx(size_t level) { - _CCCL_ASSERT(level < s.size(), "Out of bound access"); - return s[level]; + return levels[level].ctx; } - cudaStream_t stream_at(size_t level) const + /** + * @brief Returns a const reference to the context at a specific level + */ + const auto& get_ctx(size_t level) const { - return s_stream[level]; + return levels[level].ctx; } - void push_graph() + cudaStream_t get_stream(size_t level) const { - cudaStream_t stream = get().pick_stream(); - - // These resources are not destroyed when we pop, so we create it only if needed - if (async_handles.size() < s_stream.size()) - { - async_handles.emplace_back(); - } - - auto gctx = graph_ctx(stream, async_handles.back()); - - auto wrapper = stream_adapter(gctx, stream); - // FIXME : issue with the deinit phase - // gctx.update_uncached_allocator(wrapper.allocator()); - - alloc_adapters.push_back(wrapper); - - push(mv(gctx), stream); + return levels[level].support_stream; } private: - ::std::vector s; - ::std::vector s_stream; + // State for each nested level + ::std::vector levels; + + // Handles to retain some asynchronous states, we maintain it separately + // from levels because we keep its entries even when we pop a level ::std::vector async_handles; - ::std::vector alloc_adapters; }; stackable_ctx() : pimpl(::std::make_shared()) {} - const auto& get() const + cudaStream_t get_stream(size_t level) const { - return pimpl->get(); - } - auto& get() - { - return pimpl->get(); - } - - auto& operator[](size_t level) - { - return pimpl->operator[](level); + return pimpl->get_stream(level); } - const auto& operator[](size_t level) const + const auto& get_ctx(size_t level) const { - return pimpl->operator[](level); + return pimpl->get_ctx(level); } - cudaStream_t stream_at(size_t level) const + auto& get_ctx(size_t level) { - return pimpl->stream_at(level); + return pimpl->get_ctx(level); } const auto& operator()() const { - return get(); + return get_ctx(depth()); } auto& operator()() { - return get(); + return get_ctx(depth()); } - void push_graph() + void push() { - pimpl->push_graph(); + pimpl->push(); } void pop() @@ -182,25 +198,25 @@ public: template auto logical_data(Pack&&... pack) { - return stackable_logical_data(*this, depth(), get().logical_data(::std::forward(pack)...)); + return stackable_logical_data(*this, depth(), get_ctx(depth()).logical_data(::std::forward(pack)...)); } template auto task(Pack&&... pack) { - return get().task(::std::forward(pack)...); + return get_ctx(depth()).task(::std::forward(pack)...); } template auto parallel_for(Pack&&... pack) { - return get().parallel_for(::std::forward(pack)...); + return get_ctx(depth()).parallel_for(::std::forward(pack)...); } template auto host_launch(Pack&&... pack) { - return get().host_launch(::std::forward(pack)...); + return get_ctx(depth()).host_launch(::std::forward(pack)...); } void finalize() @@ -208,7 +224,7 @@ public: // There must be only one level left _CCCL_ASSERT(depth() == 0, "All nested levels must have been popped"); - get().finalize(); + get_ctx(depth()).finalize(); } public: @@ -229,12 +245,12 @@ class stackable_logical_data s.push_back(ld); } - const auto& get() const + const auto& get_ld() const { check_level_mismatch(); return s.back(); } - auto& get() + auto& get_ld() { check_level_mismatch(); return s.back(); @@ -243,8 +259,8 @@ class stackable_logical_data void push(access_mode m, data_place where = data_place::invalid) { // We have not pushed yet, so the current depth is the one before pushing - context& from_ctx = sctx[depth()]; - context& to_ctx = sctx[depth() + 1]; + context& from_ctx = sctx.get_ctx(depth()); + context& to_ctx = sctx.get_ctx(depth() + 1); // Ensure this will match the depth of the context after pushing _CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth"); @@ -265,7 +281,7 @@ class stackable_logical_data frozen_s.push_back(f); // FAKE IMPORT : use the stream needed to support the (graph) ctx - cudaStream_t stream = sctx.stream_at(depth()); + cudaStream_t stream = sctx.get_stream(depth()); T inst = f.get(where, stream); auto ld = to_ctx.logical_data(inst, where); @@ -282,7 +298,7 @@ class stackable_logical_data { // We are going to unfreeze the data, which is currently being used // in a (graph) ctx that uses this stream to launch the graph - cudaStream_t stream = sctx.stream_at(depth()); + cudaStream_t stream = sctx.get_stream(depth()); frozen_logical_data& f = frozen_s.back(); f.unfreeze(stream); @@ -336,22 +352,13 @@ public: : pimpl(::std::make_shared(mv(sctx), depth, mv(ld))) {} - const auto& get() const - { - return pimpl->get(); - } - auto& get() - { - return pimpl->get(); - } - - const auto& operator()() const + const auto& get_ld() const { - return get(); + return pimpl->get_ld(); } - auto& operator()() + auto& get_ld() { - return get(); + return pimpl->get_ld(); } size_t depth() const @@ -363,6 +370,7 @@ public: { pimpl->push(m, mv(where)); } + void pop() { pimpl->pop(); @@ -372,24 +380,24 @@ public: template auto read(Pack&&... pack) const { - return get().read(::std::forward(pack)...); + return get_ld().read(::std::forward(pack)...); } template auto write(Pack&&... pack) { - return get().write(::std::forward(pack)...); + return get_ld().write(::std::forward(pack)...); } template auto rw(Pack&&... pack) { - return get().rw(::std::forward(pack)...); + return get_ld().rw(::std::forward(pack)...); } auto shape() const { - return get().shape(); + return get_ld().shape(); } auto& set_symbol(::std::string symbol) diff --git a/cudax/test/stf/local_stf/stackable.cu b/cudax/test/stf/local_stf/stackable.cu index d1ec6a91953..92989c5e758 100644 --- a/cudax/test/stf/local_stf/stackable.cu +++ b/cudax/test/stf/local_stf/stackable.cu @@ -49,7 +49,7 @@ int main() }; /* Start to use a graph */ - sctx.push_graph(); + sctx.push(); auto lB = sctx.logical_data(shape_of>(512)); lB.set_symbol("B"); @@ -88,7 +88,7 @@ int main() }; // Do the same check in another graph - sctx.push_graph(); + sctx.push(); lA2.push(access_mode::read); sctx.host_launch(lA2.read())->*[](auto a2) { for (size_t i = 0; i < a2.size(); i++) diff --git a/cudax/test/stf/local_stf/stackable2.cu b/cudax/test/stf/local_stf/stackable2.cu index b8ba1e7fc36..62b9ddd55a6 100644 --- a/cudax/test/stf/local_stf/stackable2.cu +++ b/cudax/test/stf/local_stf/stackable2.cu @@ -41,7 +41,7 @@ int main() // repeat : {tmp = a, a++; tmp*=2; a+=tmp} for (size_t iter = 0; iter < 10; iter++) { - ctx.push_graph(); + ctx.push(); lA.push(access_mode::rw); From cd63feba9ec7bc188e3b7203fb87132cb3a8ac60 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Tue, 14 Jan 2025 16:56:04 +0100 Subject: [PATCH 12/18] Automatically call pop() on (stackable) logical data when the pop() method of the context is called (if logical data's pop() was not called explicitly). --- .../__stf/internal/logical_data.cuh | 2 +- .../__stf/utility/stackable_ctx.cuh | 116 ++++++++++++++---- 2 files changed, 94 insertions(+), 24 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh b/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh index 1615ea3eb2f..6d511302b8a 100644 --- a/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/logical_data.cuh @@ -1695,12 +1695,12 @@ public: return pimpl->get_mutex(); } -private: int get_unique_id() const { return pimpl->get_unique_id(); } +private: ::std::shared_ptr pimpl; }; diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 957e33fb31d..6b6617b885c 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -35,7 +35,21 @@ template class stackable_logical_data; /** - * @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local CUDA graphs) + * @brief Base class with a virtual pop method to enable type erasure + * + * This is used to implement the automatic call to pop() on logical data when a + * context level is popped. + */ +class stackable_logical_data_impl_base +{ +public: + virtual ~stackable_logical_data_impl_base() = default; + virtual void pop() = 0; +}; + +/** + * @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local + * CUDA graphs) */ class stackable_ctx { @@ -46,18 +60,28 @@ public: /* * State of each nested context */ - struct per_level { - per_level(context ctx, cudaStream_t support_stream, ::std::optional alloc_adapters) : ctx(mv(ctx)), support_stream(mv(support_stream)), alloc_adapters(mv(alloc_adapters)) {} - - context ctx; - cudaStream_t support_stream; - // A wrapper to forward allocations from a level to the previous one (none is used at the root level) - ::std::optional alloc_adapters; + struct per_level + { + per_level(context ctx, cudaStream_t support_stream, ::std::optional alloc_adapters) + : ctx(mv(ctx)) + , support_stream(mv(support_stream)) + , alloc_adapters(mv(alloc_adapters)) + {} + + context ctx; + cudaStream_t support_stream; + // A wrapper to forward allocations from a level to the previous one (none is used at the root level) + ::std::optional alloc_adapters; + + // This map keeps track of the logical data that were pushed in this level + // key: logical data's unique id + ::std::unordered_map> pushed_data; }; public: impl() { + // Create the root level push(); } @@ -74,20 +98,22 @@ public: async_handles.emplace_back(); } - if (levels.size() == 0) { - levels.emplace_back(stream_ctx(), nullptr, ::std::nullopt); + if (levels.size() == 0) + { + levels.emplace_back(stream_ctx(), nullptr, ::std::nullopt); } - else { - // Get a stream from previous context (we haven't pushed the new one yet) - cudaStream_t stream = levels[depth()].ctx.pick_stream(); + else + { + // Get a stream from previous context (we haven't pushed the new one yet) + cudaStream_t stream = levels[depth()].ctx.pick_stream(); - auto gctx = graph_ctx(stream, async_handles.back()); + auto gctx = graph_ctx(stream, async_handles.back()); - auto wrapper = stream_adapter(gctx, stream); - // FIXME : issue with the deinit phase - // gctx.update_uncached_allocator(wrapper.allocator()); + auto wrapper = stream_adapter(gctx, stream); + // FIXME : issue with the deinit phase + // gctx.update_uncached_allocator(wrapper.allocator()); - levels.emplace_back(gctx, stream, wrapper); + levels.emplace_back(gctx, stream, wrapper); } } @@ -98,7 +124,13 @@ public: { _CCCL_ASSERT(levels.size() > 0, "Calling pop while no context was pushed"); - auto ¤t_level = levels.back(); + auto& current_level = levels.back(); + + // Automatically pop data if needed + for (auto& [key, d_impl] : current_level.pushed_data) + { + d_impl->pop(); + } // Ensure everything is finished in the context current_level.ctx.finalize(); @@ -106,7 +138,7 @@ public: // Destroy the resources used in the wrapper allocator (if any) if (current_level.alloc_adapters.has_value()) { - current_level.alloc_adapters.value().clear(); + current_level.alloc_adapters.value().clear(); } // Destroy the current level state @@ -142,6 +174,19 @@ public: return levels[level].support_stream; } + void track_pushed_data(int data_id, ::std::shared_ptr data_impl) + { + levels[depth()].pushed_data[data_id] = mv(data_impl); + } + + void untrack_pushed_data(int data_id) + { + size_t erased = levels[depth()].pushed_data.erase(data_id); + // We must have erased exactly one value (at least one otherwise it was already removed, and it must be pushed + // only once (TODO check)) + _CCCL_ASSERT(erased == 1, "invalid value"); + } + private: // State for each nested level ::std::vector levels; @@ -219,6 +264,16 @@ public: return get_ctx(depth()).host_launch(::std::forward(pack)...); } + void track_pushed_data(int data_id, ::std::shared_ptr data_impl) + { + pimpl->track_pushed_data(data_id, mv(data_impl)); + } + + void untrack_pushed_data(int data_id) + { + pimpl->untrack_pushed_data(data_id); + } + void finalize() { // There must be only one level left @@ -234,7 +289,7 @@ public: template class stackable_logical_data { - class impl + class impl : public stackable_logical_data_impl_base { public: impl() = default; @@ -258,7 +313,8 @@ class stackable_logical_data void push(access_mode m, data_place where = data_place::invalid) { - // We have not pushed yet, so the current depth is the one before pushing + // We have not pushed yet, so the current depth of the logical data is + // the one before pushing context& from_ctx = sctx.get_ctx(depth()); context& to_ctx = sctx.get_ctx(depth() + 1); @@ -294,7 +350,7 @@ class stackable_logical_data s.push_back(mv(ld)); } - void pop() + virtual void pop() override { // We are going to unfreeze the data, which is currently being used // in a (graph) ctx that uses this stream to launch the graph @@ -320,6 +376,11 @@ class stackable_logical_data s.back().set_symbol(symbol + "." + ::std::to_string(depth())); } + auto& get_sctx() + { + return sctx; + } + private: void check_level_mismatch() const { @@ -369,10 +430,19 @@ public: void push(access_mode m, data_place where = data_place::invalid) { pimpl->push(m, mv(where)); + + // Keep track of data that were pushed in this context. Note that the ID + // used is the ID of the logical data at this level. + pimpl->get_sctx().track_pushed_data(get_ld().get_unique_id(), pimpl); } void pop() { + // We remove the data from the map before popping it to have the id of the + // logical data. Doing so will prevent the automatic call to pop() when the + // context level gets popped. + pimpl->get_sctx().untrack_pushed_data(get_ld().get_unique_id()); + pimpl->pop(); } From dc7e04af6e38c026721f3a827e11b19c33575782 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Tue, 14 Jan 2025 16:58:18 +0100 Subject: [PATCH 13/18] change test to automatically pop data --- cudax/test/stf/local_stf/stackable.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cudax/test/stf/local_stf/stackable.cu b/cudax/test/stf/local_stf/stackable.cu index 92989c5e758..ead90512f23 100644 --- a/cudax/test/stf/local_stf/stackable.cu +++ b/cudax/test/stf/local_stf/stackable.cu @@ -74,9 +74,7 @@ int main() c(i) += b(i); }; - lA.pop(); - lA2.pop(); - lC.pop(); + // lA, lA2 and lC are automatically popped sctx.pop(); @@ -96,6 +94,7 @@ int main() EXPECT(a2(i) == 5 * i + 4); } }; + // explicit pop lA2.pop(); sctx.pop(); From bdddfe99f84b8b4392e175bb742a18835261817e Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Wed, 15 Jan 2025 08:33:13 +0100 Subject: [PATCH 14/18] stackable_ctx::impl is not copyable but movable --- .../cuda/experimental/__stf/utility/stackable_ctx.cuh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 6b6617b885c..ae170e6c2d7 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -87,6 +87,14 @@ public: ~impl() = default; + // Delete copy constructor and copy assignment operator + impl(const impl&) = delete; + impl& operator=(const impl&) = delete; + + // Define move constructor and move assignment operator + impl(impl&&) noexcept = default; + impl& operator=(impl&&) noexcept = default; + /** * @brief Create a new nested level */ From 8902d6ffa6e0f9413afefb69d183fba54f257edf Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Wed, 15 Jan 2025 08:34:00 +0100 Subject: [PATCH 15/18] clang-format --- .../include/cuda/experimental/__stf/utility/stackable_ctx.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index ae170e6c2d7..530405941ad 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -88,11 +88,11 @@ public: ~impl() = default; // Delete copy constructor and copy assignment operator - impl(const impl&) = delete; + impl(const impl&) = delete; impl& operator=(const impl&) = delete; // Define move constructor and move assignment operator - impl(impl&&) noexcept = default; + impl(impl&&) noexcept = default; impl& operator=(impl&&) noexcept = default; /** From 7c72f4bf6e5622f03f7ddb56f3dcb853a87316af Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Wed, 15 Jan 2025 08:37:13 +0100 Subject: [PATCH 16/18] stackable_logical_data::impl is movable but not copyable --- .../cuda/experimental/__stf/utility/stackable_ctx.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 530405941ad..5284be055fe 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -308,6 +308,16 @@ class stackable_logical_data s.push_back(ld); } + ~impl() = default; + + // Delete copy constructor and copy assignment operator + impl(const impl&) = delete; + impl& operator=(const impl&) = delete; + + // Define move constructor and move assignment operator + impl(impl&&) noexcept = default; + impl& operator=(impl&&) noexcept = default; + const auto& get_ld() const { check_level_mismatch(); From 8dce98578b8da5141a3b937b0a8bb6c4069fbd4e Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Wed, 15 Jan 2025 13:17:26 +0100 Subject: [PATCH 17/18] unittests are broken as they compare a fullpath with a relative path in the wrong directory : so we record the full path so that it works at runtime --- cudax/test/stf/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 8c7e988f03c..4acb08f5dc0 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -241,6 +241,10 @@ function(cudax_add_stf_unittest_header target_name_var source cn_target) string(REPLACE "/" "." test_label "${test_label}") set(test_target "${config_prefix}.test.stf.unittest_headers.${test_label}") + # Pass the full path to configure_file (this is configured from cudax/tests/stf/ + get_filename_component(source_full_path ../../../cudax/include/${source} ABSOLUTE) + set(source ${source_full_path}) + set(ut_template "${cudax_SOURCE_DIR}/cmake/stf_header_unittest.in.cu") set(ut_source "${cudax_BINARY_DIR}/unittest_headers/${test_target}.cu") configure_file(${ut_template} ${ut_source} @ONLY) From de5317b6b87d945b1a00f28e6ba55c5415262f09 Mon Sep 17 00:00:00 2001 From: Cedric Augonnet Date: Wed, 15 Jan 2025 17:42:07 +0100 Subject: [PATCH 18/18] Save WIP : add unittests in the stackable_ctx.cuh header --- .../__stf/utility/stackable_ctx.cuh | 91 +++++++++++++++++++ cudax/test/stf/CMakeLists.txt | 1 + 2 files changed, 92 insertions(+) diff --git a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh index 5284be055fe..16b81d44f68 100644 --- a/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh @@ -34,6 +34,9 @@ namespace cuda::experimental::stf template class stackable_logical_data; +template +class stackable_task_dep; + /** * @brief Base class with a virtual pop method to enable type erasure * @@ -272,6 +275,11 @@ public: return get_ctx(depth()).host_launch(::std::forward(pack)...); } + auto task_fence() + { + return get_ctx(depth()).task_fence(); + } + void track_pushed_data(int data_id, ::std::shared_ptr data_impl) { pimpl->track_pushed_data(data_id, mv(data_impl)); @@ -494,8 +502,91 @@ public: return *this; } + auto get_impl() + { + return pimpl; + } + private: ::std::shared_ptr pimpl; }; +template +class stackable_task_dep : public task_dep +{ +public: + stackable_task_dep(stackable_logical_data _d, access_mode m, data_place _dplace) + : task_dep(d.get_ld(), m, _dplace) + , d(mv(_d)) + , dplace(mv(_dplace)) + {} + +private: + stackable_logical_data d; + data_place dplace; +}; + +#ifdef UNITTESTED_FILE +# ifdef __CUDACC__ +namespace reserved +{ + +template +static __global__ void kernel_set(T *addr, T val) {printf("SETTING ADDR %p at %d\n",addr, val); *addr = val; } + +template +static __global__ void kernel_add(T *addr, T val) {*addr += val; } + +template +static __global__ void kernel_check_value(T *addr, T val) { printf("CHECK %d EXPECTED %d\n", *addr, val); if (*addr != val) ::cuda::std::terminate(); } + +} // namespace reserved + +UNITTEST("stackable task_fence") +{ + stackable_ctx ctx; + auto lA = ctx.logical_data(shape_of>(1024)); + ctx.push(); + lA.push(access_mode::write, data_place::current_device()); + ctx.task(lA.write())->*[](cudaStream_t stream, auto a) { reserved::kernel_set<<<1, 1, 0, stream>>>(a.data_handle(), 42); }; + ctx.task_fence(); + ctx.task(lA.read())->*[](cudaStream_t stream, auto a) { reserved::kernel_check_value<<<1, 1, 0, stream>>>(a.data_handle(), 44); }; + ctx.pop(); + ctx.finalize(); +}; + +UNITTEST("stackable host_launch") +{ + stackable_ctx ctx; + auto lA = ctx.logical_data(shape_of>(1024)); + ctx.push(); + lA.push(access_mode::write, data_place::current_device()); + ctx.task(lA.write())->*[](cudaStream_t stream, auto a) { reserved::kernel_set<<<1, 1, 0, stream>>>(a.data_handle(), 42); }; + ctx.host_launch(lA.read())->*[](auto a){ _CCCL_ASSERT(a(0) == 42, "invalid value"); }; + ctx.pop(); + ctx.finalize(); +}; + +UNITTEST("stackable promote mode") +{ + int A[1024]; + stackable_ctx ctx; + auto lA = ctx.logical_data(A); + ctx.push(); + + lA.push(access_mode::read, data_place::current_device()); + ctx.task(lA.read())->*[](cudaStream_t, auto) {}; + lA.pop(); + + lA.push(access_mode::rw, data_place::current_device()); + ctx.task(lA.rw())->*[](cudaStream_t, auto) {}; + lA.pop(); + + ctx.pop(); + ctx.finalize(); +}; + +#endif // __CUDACC__ +#endif // UNITTESTED_FILE + } // end namespace cuda::experimental::stf diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 4acb08f5dc0..ca47a06ca35 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -176,6 +176,7 @@ set(stf_unittested_headers cuda/experimental/__stf/utility/hash.cuh cuda/experimental/__stf/utility/memory.cuh cuda/experimental/__stf/utility/scope_guard.cuh + cuda/experimental/__stf/utility/stackable_ctx.cuh cuda/experimental/__stf/utility/stopwatch.cuh cuda/experimental/__stf/utility/unittest.cuh cuda/experimental/__stf/utility/unstable_unique.cuh