From 2bd6da228d3e3f0f258c982b7a2a3571718d3688 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Wed, 8 Jan 2025 01:11:25 +0800 Subject: [PATCH] chore: chat stream timeout (#1137) * chore: increase timeout * chore: keepalive event --- libs/appflowy-ai-client/src/client.rs | 5 +++-- libs/appflowy-ai-client/src/dto.rs | 1 + libs/client-api-test/src/test_client.rs | 1 + libs/client-api/src/http_chat.rs | 7 ++++++- libs/infra/src/reqwest.rs | 2 -- src/api/chat.rs | 2 +- tests/ai_test/chat_test.rs | 1 + 7 files changed, 13 insertions(+), 6 deletions(-) diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index edf9dfb66..ceb98c986 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -263,20 +263,21 @@ impl AppFlowyAIClient { rag_ids, }, }; - self.stream_question_v3(model, json).await + self.stream_question_v3(model, json, Some(30)).await } pub async fn stream_question_v3( &self, model: &AIModel, question: ChatQuestion, + timeout_secs: Option, ) -> Result>, AIError> { let url = format!("{}/v2/chat/message/stream", self.url); let resp = self .async_http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) .json(&question) - .timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(timeout_secs.unwrap_or(30))) .send() .await?; AIResponse::<()>::stream_response(resp).await diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index f7f6dea40..4eb5a1d39 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -8,6 +8,7 @@ use std::str::FromStr; pub const STREAM_METADATA_KEY: &str = "0"; pub const STREAM_ANSWER_KEY: &str = "1"; pub const STREAM_IMAGE_KEY: &str = "2"; +pub const STREAM_KEEP_ALIVE_KEY: &str = "3"; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SummarizeRowResponse { pub text: String, diff --git a/libs/client-api-test/src/test_client.rs b/libs/client-api-test/src/test_client.rs index e8c30ac01..e6474a568 100644 --- a/libs/client-api-test/src/test_client.rs +++ b/libs/client-api-test/src/test_client.rs @@ -1271,6 +1271,7 @@ pub async fn collect_answer(mut stream: QuestionStream) -> String { answer.push_str(&value); }, QuestionStreamValue::Metadata { .. } => {}, + QuestionStreamValue::KeepAlive => {}, } } answer diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index b9b5c575b..26b332ce3 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -12,7 +12,7 @@ use reqwest::Method; use serde_json::Value; use shared_entity::dto::ai_dto::{ CalculateSimilarityParams, ChatQuestionQuery, RepeatedRelatedQuestion, SimilarityResponse, - STREAM_ANSWER_KEY, STREAM_IMAGE_KEY, STREAM_METADATA_KEY, + STREAM_ANSWER_KEY, STREAM_IMAGE_KEY, STREAM_KEEP_ALIVE_KEY, STREAM_METADATA_KEY, }; use shared_entity::dto::chat_dto::{ChatSettings, UpdateChatParams}; use shared_entity::response::{AppResponse, AppResponseError}; @@ -366,6 +366,7 @@ pub enum QuestionStreamValue { Metadata { value: serde_json::Value, }, + KeepAlive, } impl Stream for QuestionStream { type Item = Result; @@ -394,6 +395,10 @@ impl Stream for QuestionStream { return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: image }))); } + if value.remove(STREAM_KEEP_ALIVE_KEY).is_some() { + return Poll::Ready(Some(Ok(QuestionStreamValue::KeepAlive))); + } + error!("Invalid streaming value: {:?}", value); Poll::Ready(None) }, diff --git a/libs/infra/src/reqwest.rs b/libs/infra/src/reqwest.rs index 4a3e635ab..65365baab 100644 --- a/libs/infra/src/reqwest.rs +++ b/libs/infra/src/reqwest.rs @@ -92,7 +92,6 @@ where // Poll for the next chunk of data from the underlying stream match ready!(this.stream.as_mut().poll_next(cx)) { Some(Ok(bytes)) => { - // Append the new bytes to the buffer this.buffer.extend_from_slice(&bytes); // Create a StreamDeserializer to deserialize the bytes into T @@ -112,7 +111,6 @@ where return Poll::Pending; }, Some(Err(err)) => { - // Return other deserialization errors wrapped in SE return Poll::Ready(Some(Err(err.into()))); }, None => { diff --git a/src/api/chat.rs b/src/api/chat.rs index ebc10853e..fca0dc8c0 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -380,7 +380,7 @@ async fn answer_stream_v3_handler( trace!("[Chat] stream v3 {:?}", question); match state .ai_client - .stream_question_v3(&ai_model, question) + .stream_question_v3(&ai_model, question, Some(60)) .await { Ok(answer_stream) => { diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index 1006aa33c..f2bf54916 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -450,6 +450,7 @@ async fn collect_answer(mut stream: QuestionStream) -> String { answer.push_str(&value); }, QuestionStreamValue::Metadata { .. } => {}, + QuestionStreamValue::KeepAlive => {}, } } answer