Skip to content

Commit

Permalink
Merge pull request #1147 from exograph/exograph
Browse files Browse the repository at this point in the history
Work with pools that don't support prepared statements
  • Loading branch information
sfackler authored Jul 13, 2024
2 parents 647a925 + 0fa3247 commit 257bcfd
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 4 deletions.
48 changes: 48 additions & 0 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,54 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Like `query`, but requires the types of query parameters to be explicitly specified.
///
/// Compared to `query`, this method allows performing queries without three round trips (for
/// prepare, execute, and close) by requiring the caller to specify parameter values along with
/// their Postgres type. Thus, this is suitable in environments where prepared statements aren't
/// supported (such as Cloudflare Workers with Hyperdrive).
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the
/// parameter of the list provided, 1-indexed.
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let rows = client.query_typed(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// for row in rows {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
fn slice_iter<'a>(
s: &'a [(&'a (dyn ToSql + Sync), Type)],
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
s.iter()
.map(|(param, param_type)| (*param as _, param_type.clone()))
}

query::query_typed(&self.inner, statement, slice_iter(params))
.await?
.try_collect()
.await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
23 changes: 23 additions & 0 deletions tokio-postgres/src/generic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed {
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator;

/// Like [`Client::query_typed`]
async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error>;

/// Like [`Client::prepare`].
async fn prepare(&self, query: &str) -> Result<Statement, Error>;

Expand Down Expand Up @@ -139,6 +146,14 @@ impl GenericClient for Client {
self.query_raw(statement, params).await
}

async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_typed(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down Expand Up @@ -229,6 +244,14 @@ impl GenericClient for Transaction<'_> {
self.query_raw(statement, params).await
}

async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_typed(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
})
}

async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
Expand Down
95 changes: 92 additions & 3 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::prepare::get_type;
use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement};
use crate::{Column, Error, Portal, Row, Statement};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::frontend;
use postgres_types::Type;
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
Expand Down Expand Up @@ -57,6 +61,71 @@ where
})
}

pub async fn query_typed<'a, P, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = (P, Type)>,
I::IntoIter: ExactSizeIterator,
{
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();

let params = params.into_iter();

let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();

let params = params.into_iter();

let buf = client.with_buf(|buf| {
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;

encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;

frontend::describe(b'S', "", buf).map_err(Error::encode)?;

frontend::execute("", 0, buf).map_err(Error::encode)?;

frontend::sync(buf);

Ok(buf.split().freeze())
})?;

let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;

loop {
match responses.next().await? {
Message::ParseComplete
| Message::BindComplete
| Message::ParameterDescription(_)
| Message::NoData => {}
Message::RowDescription(row_description) => {
let mut columns: Vec<Column> = vec![];
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let column = Column {
name: field.name().to_string(),
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
column_id: Some(field.column_id()).filter(|n| *n != 0),
r#type: type_,
};
columns.push(column);
}
return Ok(RowStream {
statement: Statement::unnamed(vec![], columns),
responses,
rows_affected: None,
_p: PhantomPinned,
});
}
_ => return Err(Error::unexpected_message()),
}
}
}

pub async fn query_portal(
client: &InnerClient,
portal: &Portal,
Expand Down Expand Up @@ -164,7 +233,27 @@ where
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
encode_bind_with_statement_name_and_param_types(
statement.name(),
statement.params(),
params,
portal,
buf,
)
}

fn encode_bind_with_statement_name_and_param_types<P, I>(
statement_name: &str,
param_types: &[Type],
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();

if param_types.len() != params.len() {
Expand All @@ -181,7 +270,7 @@ where
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
statement_name,
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
Expand Down
13 changes: 13 additions & 0 deletions tokio-postgres/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ struct StatementInner {

impl Drop for StatementInner {
fn drop(&mut self) {
if self.name.is_empty() {
// Unnamed statements don't need to be closed
return;
}
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
Expand Down Expand Up @@ -46,6 +50,15 @@ impl Statement {
}))
}

pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: String::new(),
params,
columns,
}))
}

pub(crate) fn name(&self) -> &str {
&self.0.name
}
Expand Down
9 changes: 9 additions & 0 deletions tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ impl<'a> Transaction<'a> {
query::query_portal(self.client.inner(), portal, max_rows).await
}

/// Like `Client::query_typed`.
pub async fn query_typed(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.client.query_typed(statement, params).await
}

/// Like `Client::copy_in`.
pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
where
Expand Down
106 changes: 106 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,109 @@ async fn deferred_constraint() {
.await
.unwrap_err();
}

#[tokio::test]
async fn query_typed_no_transaction() {
let client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
name TEXT,
age INT
);
INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40);
",
)
.await
.unwrap();

let rows: Vec<tokio_postgres::Row> = client
.query_typed(
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
)
.await
.unwrap();

assert_eq!(rows.len(), 2);
let first_row = &rows[0];
assert_eq!(first_row.get::<_, &str>(0), "bob");
assert_eq!(first_row.get::<_, i32>(1), 30);
assert_eq!(first_row.get::<_, &str>(2), "literal");
assert_eq!(first_row.get::<_, i32>(3), 5);

let second_row = &rows[1];
assert_eq!(second_row.get::<_, &str>(0), "carol");
assert_eq!(second_row.get::<_, i32>(1), 40);
assert_eq!(second_row.get::<_, &str>(2), "literal");
assert_eq!(second_row.get::<_, i32>(3), 5);
}

#[tokio::test]
async fn query_typed_with_transaction() {
let mut client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
name TEXT,
age INT
);
",
)
.await
.unwrap();

let transaction = client.transaction().await.unwrap();

let rows: Vec<tokio_postgres::Row> = transaction
.query_typed(
"INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age",
&[
(&"alice", Type::TEXT),
(&20i32, Type::INT4),
(&"bob", Type::TEXT),
(&30i32, Type::INT4),
(&"carol", Type::TEXT),
(&40i32, Type::INT4),
],
)
.await
.unwrap();
let inserted_values: Vec<(String, i32)> = rows
.iter()
.map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1)))
.collect();
assert_eq!(
inserted_values,
[
("alice".to_string(), 20),
("bob".to_string(), 30),
("carol".to_string(), 40)
]
);

let rows: Vec<tokio_postgres::Row> = transaction
.query_typed(
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
)
.await
.unwrap();

assert_eq!(rows.len(), 2);
let first_row = &rows[0];
assert_eq!(first_row.get::<_, &str>(0), "bob");
assert_eq!(first_row.get::<_, i32>(1), 30);
assert_eq!(first_row.get::<_, &str>(2), "literal");
assert_eq!(first_row.get::<_, i32>(3), 5);

let second_row = &rows[1];
assert_eq!(second_row.get::<_, &str>(0), "carol");
assert_eq!(second_row.get::<_, i32>(1), 40);
assert_eq!(second_row.get::<_, &str>(2), "literal");
assert_eq!(second_row.get::<_, i32>(3), 5);
}

0 comments on commit 257bcfd

Please sign in to comment.