Skip to content

Commit

Permalink
WIP: Avoid blocking on read_async, fixes issue where an exception has…
Browse files Browse the repository at this point in the history
… been raised in process_one, but AsyncioRunnable is blocked on read_async in the situation where the source isn't emitting any values

TODO: Remove debug logging
TODO: Remove static_pointer_cast
  • Loading branch information
dagardner-nv committed Dec 19, 2024
1 parent 672097f commit 5ce43a3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
6 changes: 6 additions & 0 deletions cpp/mrc/include/mrc/edge/edge_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge_readable.hpp"
#include "mrc/edge/edge_writable.hpp"
#include "mrc/edge/forward.hpp"
Expand Down Expand Up @@ -45,6 +46,11 @@ class EdgeChannelReader : public IEdgeReadable<T>
return m_channel->await_read(t);
}

virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp)
{
return m_channel->await_read_until(t, tp);
}

private:
EdgeChannelReader(std::shared_ptr<mrc::channel::Channel<T>> channel) : m_channel(std::move(channel)) {}

Expand Down
2 changes: 2 additions & 0 deletions cpp/mrc/include/mrc/edge/edge_readable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mrc/channel/channel.hpp"
#include "mrc/channel/egress.hpp"
#include "mrc/channel/ingress.hpp"
#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/forward.hpp"
Expand Down Expand Up @@ -62,6 +63,7 @@ class IEdgeReadable : public virtual Edge<T>, public IEdgeReadableBase
}

virtual channel::Status await_read(T& t) = 0;
// virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) = 0;
};

template <typename InputT, typename OutputT = InputT>
Expand Down
26 changes: 19 additions & 7 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <mrc/coroutines/closable_ring_buffer.hpp>
#include <mrc/coroutines/task.hpp>
#include <mrc/coroutines/task_container.hpp>
#include <mrc/edge/edge_channel.hpp> // for EdgeChannelReader
#include <mrc/exceptions/exception_catcher.hpp>
#include <mrc/node/sink_properties.hpp>
#include <mrc/runnable/forward.hpp>
Expand Down Expand Up @@ -118,8 +119,17 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
{
protected:
AsyncSink() :
m_read_async([this](T& value) {
return this->get_readable_edge()->await_read(value);
m_read_async([this](T& value, std::stop_source& stop_source) {
using namespace std::chrono_literals;
auto edge = std::static_pointer_cast<edge::EdgeChannelReader<T>>(this->get_readable_edge());
channel::Status status = channel::Status::timeout;
while ((status == channel::Status::timeout || status == channel::Status::empty) &&
not stop_source.stop_requested())
{
status = edge->await_read_until(value, std::chrono::system_clock::now() + 10ms);
}

return status;
})
{
// Set the default channel
Expand All @@ -129,13 +139,13 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
/**
* @brief Asynchronously reads a value from the sink's channel
*/
coroutines::Task<mrc::channel::Status> read_async(T& value)
coroutines::Task<mrc::channel::Status> read_async(T& value, std::stop_source& stop_source)
{
co_return co_await m_read_async(std::ref(value));
co_return co_await m_read_async(std::ref(value), std::ref(stop_source));
}

private:
BoostFutureAwaitableOperation<mrc::channel::Status(T&)> m_read_async;
BoostFutureAwaitableOperation<mrc::channel::Status(T&, std::stop_source&)> m_read_async;
};

/**
Expand Down Expand Up @@ -297,8 +307,8 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
{
InputT data;

auto read_status = co_await this->read_async(data);

mrc::channel::Status read_status = mrc::channel::Status::success;
read_status = co_await this->read_async(data, m_stop_source);
if (read_status != mrc::channel::Status::success)
{
break;
Expand All @@ -309,6 +319,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m

co_await outstanding_tasks.garbage_collect_and_yield_until_empty();

// this is a no-op if there are no exceptions
catcher.rethrow_next_exception();
}

Expand Down Expand Up @@ -339,6 +350,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
} catch (...)
{
catcher.push_exception(std::current_exception());
on_state_update(state_t::Kill);
}
}

Expand Down

0 comments on commit 5ce43a3

Please sign in to comment.