Skip to content

Commit

Permalink
feat(router): add base64 encoding_format for OpenAI API
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 21, 2024
1 parent ce2f210 commit f668ae8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 10 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ debug = 1
lto = "thin"
codegen-units = 16
strip = "none"
incremental = true
3 changes: 2 additions & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ mimalloc = { version = "*", default-features = false }
# HTTP dependencies
axum = { version = "0.7.4", features = ["json"], optional = true }
axum-tracing-opentelemetry = { version = "0.17.0", optional = true }
base64 = { version = "0.21.4", optional = true }
tower-http = { version = "0.5.1", features = ["cors"], optional = true }
utoipa = { version = "4.2", features = ["axum_extras"], optional = true }
utoipa-swagger-ui = { version = "6.0", features = ["axum"], optional = true }
Expand All @@ -66,7 +67,7 @@ tonic-build = { version = "0.10.2", optional = true }

[features]
default = ["candle", "http"]
http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"]
http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:base64", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"]
grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"]
metal = ["text-embeddings-backend/metal"]
mkl = ["text-embeddings-backend/mkl"]
Expand Down
33 changes: 26 additions & 7 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/// HTTP Server logic
use crate::http::types::{
DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse,
EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, InputType, OpenAICompatEmbedding,
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
Sequence, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse,
VertexPrediction, VertexRequest, VertexResponse,
EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType,
OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse,
OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank,
RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput,
TokenizeRequest, TokenizeResponse, VertexPrediction, VertexRequest, VertexResponse,
};
use crate::{
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
Expand All @@ -19,6 +19,8 @@ use axum::http::{Method, StatusCode};
use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::future::join_all;
use futures::FutureExt;
use http::header::AUTHORIZATION;
Expand Down Expand Up @@ -938,6 +940,21 @@ async fn openai_embed(
Json(req): Json<OpenAICompatRequest>,
) -> Result<(HeaderMap, Json<OpenAICompatResponse>), (StatusCode, Json<OpenAICompatErrorResponse>)>
{
let encode_embedding = |array: Vec<f32>| {
match req.encoding_format {
EncodingFormat::Float => Embedding::Float(array),
EncodingFormat::Base64 => {
// Unsafe is fine here since we do not violate memory ownership: bytes
// is only used in this scope and we return an owned string
let bytes = unsafe {
std::slice::from_raw_parts(array.as_ptr() as *const u8, array.len() * 4)
};

Embedding::Base64(BASE64_STANDARD.encode(bytes))
}
}
};

let span = tracing::Span::current();
let start_time = Instant::now();

Expand All @@ -957,10 +974,11 @@ async fn openai_embed(

metrics::increment_counter!("te_request_success", "method" => "single");

let embedding = encode_embedding(response.results);
(
vec![OpenAICompatEmbedding {
object: "embedding",
embedding: response.results,
embedding,
index: 0,
}],
ResponseMetadata::new(
Expand Down Expand Up @@ -1033,9 +1051,10 @@ async fn openai_embed(
total_queue_time += r.metadata.queue.as_nanos() as u64;
total_inference_time += r.metadata.inference.as_nanos() as u64;
total_compute_tokens += r.metadata.prompt_tokens;
let embedding = encode_embedding(r.results);
embeddings.push(OpenAICompatEmbedding {
object: "embedding",
embedding: r.results,
embedding,
index: i,
});
}
Expand Down
20 changes: 19 additions & 1 deletion router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ pub(crate) enum Input {
Batch(Vec<InputType>),
}

#[derive(Deserialize, ToSchema, Default)]
#[serde(rename_all = "snake_case")]
pub(crate) enum EncodingFormat {
#[default]
Float,
Base64,
}

#[derive(Deserialize, ToSchema)]
pub(crate) struct OpenAICompatRequest {
pub input: Input,
Expand All @@ -294,14 +302,24 @@ pub(crate) struct OpenAICompatRequest {
#[allow(dead_code)]
#[schema(nullable = true, example = "null")]
pub user: Option<String>,
#[schema(default = "float", example = "float")]
#[serde(default)]
pub encoding_format: EncodingFormat,
}

#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum Embedding {
Float(Vec<f32>),
Base64(String),
}

#[derive(Serialize, ToSchema)]
pub(crate) struct OpenAICompatEmbedding {
#[schema(example = "embedding")]
pub object: &'static str,
#[schema(example = json!([0.0, 1.0, 2.0]))]
pub embedding: Vec<f32>,
pub embedding: Embedding,
#[schema(example = "0")]
pub index: usize,
}
Expand Down

0 comments on commit f668ae8

Please sign in to comment.