Skip to content

Commit

Permalink
fix a few stream/future issues (#1118)
Browse files Browse the repository at this point in the history
- The generated lift/lower code for stream/future payloads was not always calculating module paths correctly when generating type names.
- Also, we were moving raw pointers into `async move` blocks and returning them without capturing the pointed-to memory.  This would have been caught by runtime tests, but we don't have those yet since the Wasmtime async PR hasn't been merged yet.  Fortunately, it was easy enough to find and fix when I updated that PR to use the latest wit-bindgen.
- The generated lift/lower code for reading and writing streams needs to return a `Box<dyn Future>` that captures the lifetimes of the parameters.

Signed-off-by: Joel Dice <[email protected]>
  • Loading branch information
dicej authored Jan 9, 2025
1 parent e067c16 commit 4f52883
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 65 deletions.
4 changes: 2 additions & 2 deletions crates/guest-rust/rt/src/async_support/stream_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ fn ceiling(x: usize, y: usize) -> usize {

#[doc(hidden)]
pub struct StreamVtable<T> {
pub write: fn(future: u32, values: &[T]) -> Pin<Box<dyn Future<Output = Option<usize>>>>,
pub write: fn(future: u32, values: &[T]) -> Pin<Box<dyn Future<Output = Option<usize>> + '_>>,
pub read: fn(
future: u32,
values: &mut [MaybeUninit<T>],
) -> Pin<Box<dyn Future<Output = Option<usize>>>>,
) -> Pin<Box<dyn Future<Output = Option<usize>> + '_>>,
pub cancel_write: fn(future: u32),
pub cancel_read: fn(future: u32),
pub close_writable: fn(future: u32),
Expand Down
4 changes: 2 additions & 2 deletions crates/rust/src/bindgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
.as_ref()
.map(|ty| {
self.gen
.full_type_name_owned(ty, Identifier::StreamOrFuturePayload)
.type_name_owned_with_id(ty, Identifier::StreamOrFuturePayload)
})
.unwrap_or_else(|| "()".into());
let ordinal = self.gen.gen.future_payloads.get_index_of(&name).unwrap();
Expand All @@ -496,7 +496,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
let op = &operands[0];
let name = self
.gen
.full_type_name_owned(payload, Identifier::StreamOrFuturePayload);
.type_name_owned_with_id(payload, Identifier::StreamOrFuturePayload);
let ordinal = self.gen.gen.stream_payloads.get_index_of(&name).unwrap();
let path = self.gen.path_to_root();
results.push(format!(
Expand Down
116 changes: 55 additions & 61 deletions crates/rust/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@ macro_rules! {macro_name} {{
}

fn generate_payloads(&mut self, prefix: &str, func: &Function, interface: Option<&WorldKey>) {
let old_identifier = mem::replace(&mut self.identifier, Identifier::StreamOrFuturePayload);

for (index, ty) in func
.find_futures_and_streams(self.resolve)
.into_iter()
Expand All @@ -500,7 +502,7 @@ macro_rules! {macro_name} {{
match &self.resolve.types[ty].kind {
TypeDefKind::Future(payload_type) => {
let name = if let Some(payload_type) = payload_type {
self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload)
self.type_name_owned(payload_type)
} else {
"()".into()
};
Expand Down Expand Up @@ -533,7 +535,7 @@ macro_rules! {macro_name} {{
(String::new(), "let value = ();\n".into())
};

let box_ = format!("super::super::{}", self.path_to_box());
let box_ = self.path_to_box();
let code = format!(
r#"
#[doc(hidden)]
Expand All @@ -545,7 +547,7 @@ pub mod vtable{ordinal} {{
}}
#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
#[repr(align({align}))]
struct Buffer([::core::mem::MaybeUninit::<u8>; {size}]);
let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]);
Expand All @@ -558,10 +560,8 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8) -> u32;
}}
{box_}::pin(async move {{
unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }}
}})
}}
unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }}
}})
}}
fn read(future: u32) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<{name}>>>> {{
Expand All @@ -571,7 +571,7 @@ pub mod vtable{ordinal} {{
}}
#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
struct Buffer([::core::mem::MaybeUninit::<u8>; {size}]);
let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]);
let address = buffer.0.as_mut_ptr() as *mut u8;
Expand All @@ -582,15 +582,13 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8) -> u32;
}}
{box_}::pin(async move {{
if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{
{lift}
Some(value)
}} else {{
None
}}
}})
}}
if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{
{lift}
Some(value)
}} else {{
None
}}
}})
}}
fn cancel_write(writer: u32) {{
Expand Down Expand Up @@ -691,8 +689,7 @@ pub mod vtable{ordinal} {{
}
}
TypeDefKind::Stream(payload_type) => {
let name =
self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload);
let name = self.type_name_owned(payload_type);

if !self.gen.stream_payloads.contains_key(&name) {
let ordinal = self.gen.stream_payloads.len();
Expand Down Expand Up @@ -747,19 +744,19 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{
(address.clone(), lower, address, lift)
};

let box_ = format!("super::super::{}", self.path_to_box());
let box_ = self.path_to_box();
let code = format!(
r#"
#[doc(hidden)]
pub mod vtable{ordinal} {{
fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>>>> {{
fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>> + '_>> {{
#[cfg(not(target_arch = "wasm32"))]
{{
unreachable!();
}}
#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
{lower_address}
{lower}
Expand All @@ -769,27 +766,25 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8, _: u32) -> u32;
}}
{box_}::pin(async move {{
unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}}
}})
}}
unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}}
}})
}}
fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>>>> {{
fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>> + '_>> {{
#[cfg(not(target_arch = "wasm32"))]
{{
unreachable!();
}}
#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
{lift_address}
#[link(wasm_import_module = "{module}")]
Expand All @@ -798,22 +793,20 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8, _: u32) -> u32;
}}
{box_}::pin(async move {{
let count = unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}};
#[allow(unused)]
if let Some(count) = count {{
{lift}
}}
count
}})
}}
let count = unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}};
#[allow(unused)]
if let Some(count) = count {{
{lift}
}}
count
}})
}}
fn cancel_write(writer: u32) {{
Expand Down Expand Up @@ -916,6 +909,8 @@ pub mod vtable{ordinal} {{
_ => unreachable!(),
}
}

self.identifier = old_identifier;
}

fn generate_guest_import(&mut self, func: &Function, interface: Option<&WorldKey>) {
Expand Down Expand Up @@ -1699,25 +1694,24 @@ pub mod vtable{ordinal} {{
}
}

pub(crate) fn full_type_name_owned(&mut self, ty: &Type, id: Identifier<'i>) -> String {
self.full_type_name(
pub(crate) fn type_name_owned_with_id(&mut self, ty: &Type, id: Identifier<'i>) -> String {
let old_identifier = mem::replace(&mut self.identifier, id);
let name = self.type_name_owned(ty);
self.identifier = old_identifier;
name
}

fn type_name_owned(&mut self, ty: &Type) -> String {
self.type_name(
ty,
TypeMode {
lifetime: None,
lists_borrowed: false,
style: TypeOwnershipStyle::Owned,
},
id,
)
}

fn full_type_name(&mut self, ty: &Type, mode: TypeMode, id: Identifier<'i>) -> String {
let old_identifier = mem::replace(&mut self.identifier, id);
let name = self.type_name(ty, mode);
self.identifier = old_identifier;
name
}

fn type_name(&mut self, ty: &Type, mode: TypeMode) -> String {
let old = mem::take(&mut self.src);
self.print_ty(ty, mode);
Expand Down
15 changes: 15 additions & 0 deletions tests/codegen/streams.wit
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package foo:foo;

interface transmit {
variant control {
read-stream(string),
read-future(string),
write-stream(string),
write-future(string),
}

exchange: func(control: stream<control>,
caller-stream: stream<string>,
caller-future1: future<string>,
caller-future2: future<string>) -> tuple<stream<string>, future<string>, future<string>>;
}

interface streams {
stream-u8-param: func(x: stream<u8>);
stream-u16-param: func(x: stream<u16>);
Expand Down Expand Up @@ -82,4 +96,5 @@ interface streams {
world the-streams {
import streams;
export streams;
export transmit;
}

0 comments on commit 4f52883

Please sign in to comment.