Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where AsyncioRunnable hangs if process_one throws and the source is not emitting new values #523

Merged
merged 10 commits into from
Jan 13, 2025
6 changes: 6 additions & 0 deletions cpp/mrc/include/mrc/edge/edge_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge_readable.hpp"
#include "mrc/edge/edge_writable.hpp"
#include "mrc/edge/forward.hpp"
Expand Down Expand Up @@ -45,6 +46,11 @@ class EdgeChannelReader : public IEdgeReadable<T>
return m_channel->await_read(t);
}

virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp)
{
return m_channel->await_read_until(t, tp);
}

private:
EdgeChannelReader(std::shared_ptr<mrc::channel::Channel<T>> channel) : m_channel(std::move(channel)) {}

Expand Down
32 changes: 31 additions & 1 deletion cpp/mrc/include/mrc/edge/edge_readable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mrc/channel/channel.hpp"
#include "mrc/channel/egress.hpp"
#include "mrc/channel/ingress.hpp"
#include "mrc/channel/types.hpp" // for time_point_t
#include "mrc/edge/edge.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/forward.hpp"
Expand Down Expand Up @@ -61,7 +62,8 @@
return EdgeTypeInfo::create<T>();
}

virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read(T& t) = 0;
virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) = 0;
};

template <typename InputT, typename OutputT = InputT>
Expand Down Expand Up @@ -110,6 +112,20 @@

return ret_val;
}

channel::Status await_read_until(OutputT& data, const mrc::channel::time_point_t& tp) override

Check warning on line 116 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L116

Added line #L116 was not covered by tests
{
InputT source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 119 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L118-L119

Added lines #L118 - L119 were not covered by tests

if (status == channel::Status::success)

Check warning on line 121 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L121

Added line #L121 was not covered by tests
{
// Convert to the sink type
data = std::move(source_data);

Check warning on line 124 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L124

Added line #L124 was not covered by tests
}

return status;

Check warning on line 127 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L127

Added line #L127 was not covered by tests
}
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -137,6 +153,20 @@
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override

Check warning on line 156 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L156

Added line #L156 was not covered by tests
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 159 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L158-L159

Added lines #L158 - L159 were not covered by tests

if (status == channel::Status::success)

Check warning on line 161 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L161

Added line #L161 was not covered by tests
{
// Convert to the sink type
data = m_lambda_fn(std::move(source_data));

Check warning on line 164 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L164

Added line #L164 was not covered by tests
}

return status;

Check warning on line 167 in cpp/mrc/include/mrc/edge/edge_readable.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/edge/edge_readable.hpp#L167

Added line #L167 was not covered by tests
}

private:
lambda_fn_t m_lambda_fn{};
};
Expand Down
5 changes: 4 additions & 1 deletion cpp/mrc/include/mrc/node/sink_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@
channel::Status await_read(T& t) override
{
throw std::runtime_error("Attempting to read from a null edge. Ensure an edge was established for all sinks.");
}

return channel::Status::error;
channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override

Check warning on line 45 in cpp/mrc/include/mrc/node/sink_properties.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/node/sink_properties.hpp#L45

Added line #L45 was not covered by tests
{
throw std::runtime_error("Attempting to read from a null edge. Ensure an edge was established for all sinks.");

Check warning on line 47 in cpp/mrc/include/mrc/node/sink_properties.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/node/sink_properties.hpp#L47

Added line #L47 was not covered by tests
dagardner-nv marked this conversation as resolved.
Show resolved Hide resolved
}
};

Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/include/mrc/node/source_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@
return m_parent.get_next(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override

Check warning on line 291 in cpp/mrc/include/mrc/node/source_properties.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/node/source_properties.hpp#L291

Added line #L291 was not covered by tests
{
throw std::runtime_error("Not implemented");

Check warning on line 293 in cpp/mrc/include/mrc/node/source_properties.hpp

View check run for this annotation

Codecov / codecov/patch

cpp/mrc/include/mrc/node/source_properties.hpp#L293

Added line #L293 was not covered by tests
}

private:
ForwardingReadableProvider<T>& m_parent;
};
Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/tests/node/test_nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class EdgeReadableLambda : public edge::IEdgeReadable<T>
return m_on_await_read(t);
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Not implemented");
}

private:
std::function<channel::Status(T&)> m_on_await_read;
std::function<void()> m_on_complete;
Expand Down
26 changes: 19 additions & 7 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <mrc/coroutines/closable_ring_buffer.hpp>
#include <mrc/coroutines/task.hpp>
#include <mrc/coroutines/task_container.hpp>
#include <mrc/edge/edge_channel.hpp> // for EdgeChannelReader
#include <mrc/exceptions/exception_catcher.hpp>
#include <mrc/node/sink_properties.hpp>
#include <mrc/runnable/forward.hpp>
Expand Down Expand Up @@ -118,8 +119,17 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
{
protected:
AsyncSink() :
m_read_async([this](T& value) {
return this->get_readable_edge()->await_read(value);
m_read_async([this](T& value, std::stop_source& stop_source) {
using namespace std::chrono_literals;
auto edge = this->get_readable_edge();
channel::Status status = channel::Status::timeout;
while ((status == channel::Status::timeout || status == channel::Status::empty) &&
not stop_source.stop_requested())
{
status = edge->await_read_until(value, std::chrono::system_clock::now() + 10ms);
}

return status;
})
{
// Set the default channel
Expand All @@ -129,13 +139,13 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
/**
* @brief Asynchronously reads a value from the sink's channel
*/
coroutines::Task<mrc::channel::Status> read_async(T& value)
coroutines::Task<mrc::channel::Status> read_async(T& value, std::stop_source& stop_source)
{
co_return co_await m_read_async(std::ref(value));
co_return co_await m_read_async(std::ref(value), std::ref(stop_source));
}

private:
BoostFutureAwaitableOperation<mrc::channel::Status(T&)> m_read_async;
BoostFutureAwaitableOperation<mrc::channel::Status(T&, std::stop_source&)> m_read_async;
};

/**
Expand Down Expand Up @@ -297,8 +307,8 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
{
InputT data;

auto read_status = co_await this->read_async(data);

mrc::channel::Status read_status = mrc::channel::Status::success;
read_status = co_await this->read_async(data, m_stop_source);
if (read_status != mrc::channel::Status::success)
{
break;
Expand All @@ -309,6 +319,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m

co_await outstanding_tasks.garbage_collect_and_yield_until_empty();

// this is a no-op if there are no exceptions
catcher.rethrow_next_exception();
}

Expand Down Expand Up @@ -339,6 +350,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
} catch (...)
{
catcher.push_exception(std::current_exception());
on_state_update(state_t::Kill);
}
}

Expand Down
60 changes: 60 additions & 0 deletions python/mrc/_pymrc/include/pymrc/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override

Check warning on line 198 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L198

Added line #L198 was not covered by tests
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 201 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L200-L201

Added lines #L200 - L201 were not covered by tests

if (status == channel::Status::success)

Check warning on line 203 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L203

Added line #L203 was not covered by tests
{
// We need to hold the GIL here, because casting from c++ -> pybind11::object allocates memory with
// Py_Malloc.
// Its also important to note that you do not want to hold the GIL when calling m_output->await_write, as
// that can trigger a deadlock with another fiber reading from the end of the channel
pymrc::AcquireGIL gil;

Check warning on line 209 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L209

Added line #L209 was not covered by tests

data = pybind11::cast(std::move(source_data));

Check warning on line 211 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L211

Added line #L211 was not covered by tests
}

return status;

Check warning on line 214 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L214

Added line #L214 was not covered by tests
}
};

template <typename OutputT>
Expand Down Expand Up @@ -224,6 +243,21 @@
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override

Check warning on line 246 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L246

Added line #L246 was not covered by tests
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 249 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L248-L249

Added lines #L248 - L249 were not covered by tests

if (status == channel::Status::success)

Check warning on line 251 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L251

Added line #L251 was not covered by tests
{
pymrc::AcquireGIL gil;

Check warning on line 253 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L253

Added line #L253 was not covered by tests

data = pybind11::cast<output_t>(pybind11::object(std::move(source_data)));

Check warning on line 255 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L255

Added line #L255 was not covered by tests
}

return status;

Check warning on line 258 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L258

Added line #L258 was not covered by tests
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand All @@ -249,6 +283,19 @@

return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override

Check warning on line 287 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L287

Added line #L287 was not covered by tests
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 290 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L289-L290

Added lines #L289 - L290 were not covered by tests

if (status == channel::Status::success)

Check warning on line 292 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L292

Added line #L292 was not covered by tests
{
data = std::move(source_data);

Check warning on line 294 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L294

Added line #L294 was not covered by tests
}

return status;

Check warning on line 297 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L297

Added line #L297 was not covered by tests
}
};

template <>
Expand All @@ -271,6 +318,19 @@
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override

Check warning on line 321 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L321

Added line #L321 was not covered by tests
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

Check warning on line 324 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L323-L324

Added lines #L323 - L324 were not covered by tests

if (status == channel::Status::success)

Check warning on line 326 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L326

Added line #L326 was not covered by tests
{
data = pymrc::PyObjectHolder(std::move(source_data));

Check warning on line 328 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L328

Added line #L328 was not covered by tests
}

return status;

Check warning on line 331 in python/mrc/_pymrc/include/pymrc/node.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/node.hpp#L331

Added line #L331 was not covered by tests
}

static void register_converter()
{
EdgeConnector<input_t, output_t>::register_converter();
Expand Down
66 changes: 65 additions & 1 deletion python/mrc/_pymrc/tests/test_asyncio_runnable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <atomic>
#include <chrono>
#include <coroutine>
#include <cstddef> // for size_t
#include <functional>
#include <memory>
#include <stdexcept>
Expand All @@ -60,6 +61,7 @@ class Scheduler;

namespace py = pybind11;
namespace pymrc = mrc::pymrc;
using namespace std::chrono_literals;
using namespace std::string_literals;
using namespace py::literals;

Expand Down Expand Up @@ -102,6 +104,11 @@ class __attribute__((visibility("default"))) PythonCallbackAsyncioRunnable : pub
result = co_await pymrc::coro::PyTaskToCppAwaitable(std::move(coroutine));
}

if (result.is_none())
{
co_return;
}

auto result_casted = py::cast<int>(result);

py::gil_scoped_release release;
Expand Down Expand Up @@ -316,7 +323,6 @@ auto run_operation(OperationT& operation) -> mrc::coroutines::Task<int>
TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanReturn)
{
auto operation = mrc::pymrc::BoostFutureAwaitableOperation<int()>([]() {
using namespace std::chrono_literals;
boost::this_fiber::sleep_for(10ms);
return 5;
});
Expand All @@ -333,3 +339,61 @@ TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanThrow)

ASSERT_THROW(mrc::coroutines::sync_wait(run_operation(operation)), std::runtime_error);
}

TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows2086)
{
// Reproduces Morpheus issue #2086 where an exception is thrown in Async Python code, and the source does not emit
// any additional values. When the source emits an additional value or calls on_completed, the pipeline completes
// and the exception is thrown to the caller.
pymrc::Pipeline p;

py::object globals = py::globals();
py::exec(
R"(
async def fn(value):
print(f"Sink received value={value}")
if value == 1:
print("Sink raising exception", flush=True)
raise RuntimeError("oops")
)",
globals);

pymrc::PyObjectHolder fn = static_cast<py::object>(globals["fn"]);

auto init = [&fn](mrc::segment::IBuilder& seg) {
auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
std::size_t i = 0;
while (s.is_subscribed())
{
if (i < 2)
{
s.on_next(i);
}

boost::this_fiber::sleep_for(10ms);

++i;
}

s.on_completed();
});

auto sink = seg.construct_object<PythonCallbackAsyncioRunnable>("sink", fn);

seg.make_edge(src, sink);
};

p.make_segment("seg1"s, init);

auto options = std::make_shared<mrc::Options>();

// AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific.
options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread);

pymrc::Executor exec{options};
exec.register_pipeline(p);

exec.start();

ASSERT_THROW(exec.join(), std::runtime_error);
}
Loading