Skip to content

Commit

Permalink
Define a await_read_until method for IEdgeReadable
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Dec 19, 2024
1 parent 5ce43a3 commit 5c2d482
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
33 changes: 32 additions & 1 deletion cpp/mrc/include/mrc/edge/edge_readable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ class IEdgeReadable : public virtual Edge<T>, public IEdgeReadableBase
}

virtual channel::Status await_read(T& t) = 0;
// virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) = 0;
virtual channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp)
{
throw std::runtime_error("Not implemented");
};
};

template <typename InputT, typename OutputT = InputT>
Expand Down Expand Up @@ -112,6 +115,20 @@ class ConvertingEdgeReadable<InputT, OutputT, std::enable_if_t<std::is_convertib

return ret_val;
}

channel::Status await_read_until(OutputT& data, const mrc::channel::time_point_t& tp) override
{
InputT source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// Convert to the sink type
data = std::move(source_data);
}

return status;
}
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -139,6 +156,20 @@ class LambdaConvertingEdgeReadable : public ConvertingEdgeReadableBase<InputT, O
return ret_val;
}

channel::Status await_read_until(output_t& data, const mrc::channel::time_point_t& tp) override
{
input_t source_data;
auto status = this->upstream().await_read_until(source_data, tp);

if (status == channel::Status::success)
{
// Convert to the sink type
data = m_lambda_fn(std::move(source_data));
}

return status;
}

private:
lambda_fn_t m_lambda_fn{};
};
Expand Down
7 changes: 7 additions & 0 deletions cpp/mrc/include/mrc/node/sink_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ class NullReadableEdge : public edge::IEdgeReadable<T>

return channel::Status::error;
}

channel::Status await_read_until(T& t, const mrc::channel::time_point_t& tp) override
{
throw std::runtime_error("Attempting to read from a null edge. Ensure an edge was established for all sinks.");

return channel::Status::error;
}
};

/**
Expand Down
2 changes: 1 addition & 1 deletion python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class AsyncSink : public mrc::node::WritableProvider<T>,
AsyncSink() :
m_read_async([this](T& value, std::stop_source& stop_source) {
using namespace std::chrono_literals;
auto edge = std::static_pointer_cast<edge::EdgeChannelReader<T>>(this->get_readable_edge());
auto edge = this->get_readable_edge();
channel::Status status = channel::Status::timeout;
while ((status == channel::Status::timeout || status == channel::Status::empty) &&
not stop_source.stop_requested())
Expand Down

0 comments on commit 5c2d482

Please sign in to comment.