Skip to content

Commit

Permalink
Cleanup the stackable resource implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
caugonnet committed Nov 4, 2024
1 parent d3e70d0 commit ff0ba38
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 103 deletions.
2 changes: 1 addition & 1 deletion cudax/examples/stf/binary_fhe_stackable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ int main()
auto eA = pA.encrypt();
auto eB = pB.encrypt();

ctx.push_graph();
ctx.push();

eA.push(access_mode::read);
eB.push(access_mode::read);
Expand Down
206 changes: 107 additions & 99 deletions cudax/include/cuda/experimental/__stf/utility/stackable_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,139 +34,155 @@ namespace cuda::experimental::stf
template <typename T>
class stackable_logical_data;

/**
* @brief This class defines a context that behaves as a context which can have nested subcontexts (implemented as local CUDA graphs)
*/
class stackable_ctx
{
public:
class impl
{
private:
/*
* State of each nested context
*/
struct per_level {
per_level(context ctx, cudaStream_t support_stream, ::std::optional<stream_adapter> alloc_adapters) : ctx(mv(ctx)), support_stream(mv(support_stream)), alloc_adapters(mv(alloc_adapters)) {}

context ctx;
cudaStream_t support_stream;
// A wrapper to forward allocations from a level to the previous one (none is used at the root level)
::std::optional<stream_adapter> alloc_adapters;
};

public:
impl()
{
push(stream_ctx(), nullptr);
push();
}

~impl() = default;

void push(context ctx, cudaStream_t stream)
/**
* @brief Create a new nested level
*/
void push()
{
s.push_back(mv(ctx));
s_stream.push_back(stream);
}
// These resources are not destroyed when we pop, so we create it only if needed
if (async_handles.size() < levels.size())
{
async_handles.emplace_back();
}

void pop()
{
s.back().finalize();
if (levels.size() == 0) {
levels.emplace_back(stream_ctx(), nullptr, ::std::nullopt);
}
else {
// Get a stream from previous context (we haven't pushed the new one yet)
cudaStream_t stream = levels[depth()].ctx.pick_stream();

s.pop_back();
auto gctx = graph_ctx(stream, async_handles.back());

s_stream.pop_back();
auto wrapper = stream_adapter(gctx, stream);
// FIXME : issue with the deinit phase
// gctx.update_uncached_allocator(wrapper.allocator());

_CCCL_ASSERT(alloc_adapters.size() > 0, "Calling pop from an empty container");
alloc_adapters.back().clear();
alloc_adapters.pop_back();
levels.emplace_back(gctx, stream, wrapper);
}
}

size_t depth() const
/**
* @brief Terminate the current nested level and get back to the previous one
*/
void pop()
{
return s.size() - 1;
}
_CCCL_ASSERT(levels.size() > 0, "Calling pop while no context was pushed");

auto& get()
{
return s.back();
}
auto &current_level = levels.back();

const auto& get() const
{
return s.back();
// Ensure everything is finished in the context
current_level.ctx.finalize();

// Destroy the resources used in the wrapper allocator (if any)
if (current_level.alloc_adapters.has_value())
{
current_level.alloc_adapters.value().clear();
}

// Destroy the current level state
levels.pop_back();
}

auto& operator[](size_t level)
/**
* @brief Get the nesting depth
*/
size_t depth() const
{
_CCCL_ASSERT(level < s.size(), "Out of bound access");
return s[level];
return levels.size() - 1;
}

const auto& operator[](size_t level) const
/**
* @brief Returns a reference to the context at a specific level
*/
auto& get_ctx(size_t level)
{
_CCCL_ASSERT(level < s.size(), "Out of bound access");
return s[level];
return levels[level].ctx;
}

cudaStream_t stream_at(size_t level) const
/**
* @brief Returns a const reference to the context at a specific level
*/
const auto& get_ctx(size_t level) const
{
return s_stream[level];
return levels[level].ctx;
}

void push_graph()
cudaStream_t get_stream(size_t level) const
{
cudaStream_t stream = get().pick_stream();

// These resources are not destroyed when we pop, so we create it only if needed
if (async_handles.size() < s_stream.size())
{
async_handles.emplace_back();
}

auto gctx = graph_ctx(stream, async_handles.back());

auto wrapper = stream_adapter(gctx, stream);
// FIXME : issue with the deinit phase
// gctx.update_uncached_allocator(wrapper.allocator());

alloc_adapters.push_back(wrapper);

push(mv(gctx), stream);
return levels[level].support_stream;
}

private:
::std::vector<context> s;
::std::vector<cudaStream_t> s_stream;
// State for each nested level
::std::vector<per_level> levels;

// Handles to retain some asynchronous states, we maintain it separately
// from levels because we keep its entries even when we pop a level
::std::vector<async_resources_handle> async_handles;
::std::vector<stream_adapter> alloc_adapters;
};

stackable_ctx()
: pimpl(::std::make_shared<impl>())
{}

const auto& get() const
cudaStream_t get_stream(size_t level) const
{
return pimpl->get();
}
auto& get()
{
return pimpl->get();
}

auto& operator[](size_t level)
{
return pimpl->operator[](level);
return pimpl->get_stream(level);
}

const auto& operator[](size_t level) const
const auto& get_ctx(size_t level) const
{
return pimpl->operator[](level);
return pimpl->get_ctx(level);
}

cudaStream_t stream_at(size_t level) const
auto& get_ctx(size_t level)
{
return pimpl->stream_at(level);
return pimpl->get_ctx(level);
}

const auto& operator()() const
{
return get();
return get_ctx(depth());
}

auto& operator()()
{
return get();
return get_ctx(depth());
}

void push_graph()
void push()
{
pimpl->push_graph();
pimpl->push();
}

void pop()
Expand All @@ -182,33 +198,33 @@ public:
template <typename... Pack>
auto logical_data(Pack&&... pack)
{
return stackable_logical_data(*this, depth(), get().logical_data(::std::forward<Pack>(pack)...));
return stackable_logical_data(*this, depth(), get_ctx(depth()).logical_data(::std::forward<Pack>(pack)...));
}

template <typename... Pack>
auto task(Pack&&... pack)
{
return get().task(::std::forward<Pack>(pack)...);
return get_ctx(depth()).task(::std::forward<Pack>(pack)...);
}

template <typename... Pack>
auto parallel_for(Pack&&... pack)
{
return get().parallel_for(::std::forward<Pack>(pack)...);
return get_ctx(depth()).parallel_for(::std::forward<Pack>(pack)...);
}

template <typename... Pack>
auto host_launch(Pack&&... pack)
{
return get().host_launch(::std::forward<Pack>(pack)...);
return get_ctx(depth()).host_launch(::std::forward<Pack>(pack)...);
}

void finalize()
{
// There must be only one level left
_CCCL_ASSERT(depth() == 0, "All nested levels must have been popped");

get().finalize();
get_ctx(depth()).finalize();
}

public:
Expand All @@ -229,12 +245,12 @@ class stackable_logical_data
s.push_back(ld);
}

const auto& get() const
const auto& get_ld() const
{
check_level_mismatch();
return s.back();
}
auto& get()
auto& get_ld()
{
check_level_mismatch();
return s.back();
Expand All @@ -243,8 +259,8 @@ class stackable_logical_data
void push(access_mode m, data_place where = data_place::invalid)
{
// We have not pushed yet, so the current depth is the one before pushing
context& from_ctx = sctx[depth()];
context& to_ctx = sctx[depth() + 1];
context& from_ctx = sctx.get_ctx(depth());
context& to_ctx = sctx.get_ctx(depth() + 1);

// Ensure this will match the depth of the context after pushing
_CCCL_ASSERT(sctx.depth() == depth() + 1, "Invalid depth");
Expand All @@ -265,7 +281,7 @@ class stackable_logical_data
frozen_s.push_back(f);

// FAKE IMPORT : use the stream needed to support the (graph) ctx
cudaStream_t stream = sctx.stream_at(depth());
cudaStream_t stream = sctx.get_stream(depth());

T inst = f.get(where, stream);
auto ld = to_ctx.logical_data(inst, where);
Expand All @@ -282,7 +298,7 @@ class stackable_logical_data
{
// We are going to unfreeze the data, which is currently being used
// in a (graph) ctx that uses this stream to launch the graph
cudaStream_t stream = sctx.stream_at(depth());
cudaStream_t stream = sctx.get_stream(depth());

frozen_logical_data<T>& f = frozen_s.back();
f.unfreeze(stream);
Expand Down Expand Up @@ -336,22 +352,13 @@ public:
: pimpl(::std::make_shared<impl>(mv(sctx), depth, mv(ld)))
{}

const auto& get() const
{
return pimpl->get();
}
auto& get()
{
return pimpl->get();
}

const auto& operator()() const
const auto& get_ld() const
{
return get();
return pimpl->get_ld();
}
auto& operator()()
auto& get_ld()
{
return get();
return pimpl->get_ld();
}

size_t depth() const
Expand All @@ -363,6 +370,7 @@ public:
{
pimpl->push(m, mv(where));
}

void pop()
{
pimpl->pop();
Expand All @@ -372,24 +380,24 @@ public:
template <typename... Pack>
auto read(Pack&&... pack) const
{
return get().read(::std::forward<Pack>(pack)...);
return get_ld().read(::std::forward<Pack>(pack)...);
}

template <typename... Pack>
auto write(Pack&&... pack)
{
return get().write(::std::forward<Pack>(pack)...);
return get_ld().write(::std::forward<Pack>(pack)...);
}

template <typename... Pack>
auto rw(Pack&&... pack)
{
return get().rw(::std::forward<Pack>(pack)...);
return get_ld().rw(::std::forward<Pack>(pack)...);
}

auto shape() const
{
return get().shape();
return get_ld().shape();
}

auto& set_symbol(::std::string symbol)
Expand Down
Loading

0 comments on commit ff0ba38

Please sign in to comment.