Skip to content

Commit

Permalink
Merge pull request #1997 from fermyon/link-llm-v2
Browse files Browse the repository at this point in the history
Link with llm v2
  • Loading branch information
rylev authored Oct 31, 2023
2 parents 561b3e1 + c1ef2db commit fa975e8
Show file tree
Hide file tree
Showing 16 changed files with 787 additions and 49 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions crates/e2e-testing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ doctest = false
[dependencies]
anyhow = "1.0"
async-trait = "0.1"
tokio = { version = "1.23", features = [ "full" ] }
hyper = { version = "0.14", features = [ "full" ] }
tokio = { version = "1.23", features = ["full"] }
regex = "1.5.5"
reqwest = { version = "0.11", features = ["blocking"] }
nix = "0.26.1"
url = "2.2.2"
derive_builder = "0.12.0"
hyper-tls = "0.5.0"
tempfile = "3.3.0"
tempfile = "3.3.0"
37 changes: 12 additions & 25 deletions crates/e2e-testing/src/http_asserts.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::ensure_eq;
use anyhow::Result;
use hyper::client::HttpConnector;
use hyper::{body, Body, Client, Method, Request, Response};
use hyper_tls::HttpsConnector;
use reqwest::{Method, Request, Response};
use std::str;

pub async fn assert_status(url: &str, expected: u16) -> Result<()> {
let resp = make_request(Method::GET, url, "").await?;
let status = resp.status();

let response = body::to_bytes(resp.into_body()).await.unwrap().to_vec();
let actual_body = str::from_utf8(&response).unwrap().to_string();
let body = resp.bytes().await?;
let actual_body = str::from_utf8(&body).unwrap().to_string();

ensure_eq!(status, expected, "{}", actual_body);

Expand All @@ -29,8 +27,8 @@ pub async fn assert_http_response(

let status = res.status();
let headers = res.headers().clone();
let response = body::to_bytes(res.into_body()).await.unwrap().to_vec();
let actual_body = str::from_utf8(&response).unwrap().to_string();
let body = res.bytes().await?;
let actual_body = str::from_utf8(&body).unwrap().to_string();

ensure_eq!(
expected,
Expand All @@ -55,26 +53,15 @@ pub async fn assert_http_response(
Ok(())
}

pub async fn create_request(method: Method, url: &str, body: &str) -> Result<Request<Body>> {
let req = Request::builder()
.method(method)
.uri(url)
.body(Body::from(body.to_string()))
.expect("request builder");
pub async fn create_request(method: Method, url: &str, body: &str) -> Result<Request> {
let mut req = reqwest::Request::new(method, url.try_into()?);
*req.body_mut() = Some(body.to_owned().into());

Ok(req)
}

pub fn create_client() -> Client<HttpsConnector<HttpConnector>> {
let connector = HttpsConnector::new();
Client::builder()
.pool_max_idle_per_host(0)
.build::<_, hyper::Body>(connector)
}

pub async fn make_request(method: Method, path: &str, body: &str) -> Result<Response<Body>> {
let c = create_client();
let req = create_request(method, path, body);
let resp = c.request(req.await?).await.unwrap();
Ok(resp)
pub async fn make_request(method: Method, path: &str, body: &str) -> Result<Response> {
let req = create_request(method, path, body).await?;
let client = reqwest::Client::new();
Ok(client.execute(req).await?)
}
2 changes: 1 addition & 1 deletion crates/llm-local/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use llm::{
use rand::SeedableRng;
use spin_core::async_trait;
use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2};
use spin_world::v1::llm::{self as wasi_llm};
use spin_world::v2::llm::{self as wasi_llm};
use std::{
collections::hash_map::Entry,
collections::HashMap,
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-remote-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use spin_core::async_trait;
use spin_llm::LlmEngine;
use spin_world::v1::llm::{self as wasi_llm};
use spin_world::v2::llm::{self as wasi_llm};

#[derive(Clone)]
pub struct RemoteHttpLlmEngine {
Expand Down
3 changes: 2 additions & 1 deletion crates/llm/src/host_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ impl HostComponent for LlmComponent {
linker: &mut spin_core::Linker<T>,
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
spin_world::v1::llm::add_to_linker(linker, get)
spin_world::v1::llm::add_to_linker(linker, get)?;
spin_world::v2::llm::add_to_linker(linker, get)
}

fn build_data(&self) -> Self::Data {
Expand Down
59 changes: 44 additions & 15 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ pub mod host_component;

use spin_app::MetadataKey;
use spin_core::async_trait;
use spin_world::v1::llm::{self as wasi_llm};
use spin_world::v1::llm::{self as v1};
use spin_world::v2::llm::{self as v2};
use std::collections::HashSet;

pub use crate::host_component::LlmComponent;
Expand All @@ -14,16 +15,16 @@ pub const AI_MODELS_KEY: MetadataKey<HashSet<String>> = MetadataKey::new("ai_mod
pub trait LlmEngine: Send + Sync {
async fn infer(
&mut self,
model: wasi_llm::InferencingModel,
model: v1::InferencingModel,
prompt: String,
params: wasi_llm::InferencingParams,
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error>;

async fn generate_embeddings(
&mut self,
model: wasi_llm::EmbeddingModel,
model: v2::EmbeddingModel,
data: Vec<String>,
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
) -> Result<v2::EmbeddingsResult, v2::Error>;
}

pub struct LlmDispatch {
Expand All @@ -32,13 +33,13 @@ pub struct LlmDispatch {
}

#[async_trait]
impl wasi_llm::Host for LlmDispatch {
impl v2::Host for LlmDispatch {
async fn infer(
&mut self,
model: wasi_llm::InferencingModel,
model: v2::InferencingModel,
prompt: String,
params: Option<wasi_llm::InferencingParams>,
) -> anyhow::Result<Result<wasi_llm::InferencingResult, wasi_llm::Error>> {
params: Option<v2::InferencingParams>,
) -> anyhow::Result<Result<v2::InferencingResult, v2::Error>> {
if !self.allowed_models.contains(&model) {
return Ok(Err(access_denied_error(&model)));
}
Expand All @@ -47,7 +48,7 @@ impl wasi_llm::Host for LlmDispatch {
.infer(
model,
prompt,
params.unwrap_or(wasi_llm::InferencingParams {
params.unwrap_or(v2::InferencingParams {
max_tokens: 100,
repeat_penalty: 1.1,
repeat_penalty_last_n_token_count: 64,
Expand All @@ -61,18 +62,46 @@ impl wasi_llm::Host for LlmDispatch {

async fn generate_embeddings(
&mut self,
m: wasi_llm::EmbeddingModel,
m: v1::EmbeddingModel,
data: Vec<String>,
) -> anyhow::Result<Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>> {
) -> anyhow::Result<Result<v2::EmbeddingsResult, v2::Error>> {
if !self.allowed_models.contains(&m) {
return Ok(Err(access_denied_error(&m)));
}
Ok(self.engine.generate_embeddings(m, data).await)
}
}

fn access_denied_error(model: &str) -> wasi_llm::Error {
wasi_llm::Error::InvalidInput(format!(
#[async_trait]
impl v1::Host for LlmDispatch {
async fn infer(
&mut self,
model: v1::InferencingModel,
prompt: String,
params: Option<v1::InferencingParams>,
) -> anyhow::Result<Result<v1::InferencingResult, v1::Error>> {
Ok(
<Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
.await?
.map(Into::into)
.map_err(Into::into),
)
}

async fn generate_embeddings(
&mut self,
model: v1::EmbeddingModel,
data: Vec<String>,
) -> anyhow::Result<Result<v1::EmbeddingsResult, v1::Error>> {
Ok(<Self as v2::Host>::generate_embeddings(self, model, data)
.await?
.map(Into::into)
.map_err(Into::into))
}
}

fn access_denied_error(model: &str) -> v2::Error {
v2::Error::InvalidInput(format!(
"The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
))
}
2 changes: 1 addition & 1 deletion crates/trigger/src/runtime_config/llm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use spin_llm::LlmEngine;
use spin_llm_remote_http::RemoteHttpLlmEngine;
use spin_world::v1::llm as wasi_llm;
use spin_world::v2::llm as wasi_llm;
use url::Url;

#[derive(Default)]
Expand Down
50 changes: 50 additions & 0 deletions crates/world/src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,53 @@ mod redis {
}
}
}

mod llm {
use super::*;

impl From<v1::llm::InferencingParams> for v2::llm::InferencingParams {
fn from(value: v1::llm::InferencingParams) -> Self {
Self {
max_tokens: value.max_tokens,
repeat_penalty: value.repeat_penalty,
repeat_penalty_last_n_token_count: value.repeat_penalty_last_n_token_count,
temperature: value.temperature,
top_k: value.top_k,
top_p: value.top_p,
}
}
}

impl From<v2::llm::InferencingResult> for v1::llm::InferencingResult {
fn from(value: v2::llm::InferencingResult) -> Self {
Self {
text: value.text,
usage: v1::llm::InferencingUsage {
prompt_token_count: value.usage.prompt_token_count,
generated_token_count: value.usage.prompt_token_count,
},
}
}
}

impl From<v2::llm::EmbeddingsResult> for v1::llm::EmbeddingsResult {
fn from(value: v2::llm::EmbeddingsResult) -> Self {
Self {
embeddings: value.embeddings,
usage: v1::llm::EmbeddingsUsage {
prompt_token_count: value.usage.prompt_token_count,
},
}
}
}

impl From<v2::llm::Error> for v1::llm::Error {
fn from(value: v2::llm::Error) -> Self {
match value {
v2::llm::Error::ModelNotSupported => Self::ModelNotSupported,
v2::llm::Error::RuntimeError(s) => Self::RuntimeError(s),
v2::llm::Error::InvalidInput(s) => Self::InvalidInput(s),
}
}
}
}
3 changes: 2 additions & 1 deletion sdk/rust/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ impl Response {
ResponseBuilder { response: self }
}

fn builder() -> ResponseBuilder {
/// Creates a [`ResponseBuilder`]
pub fn builder() -> ResponseBuilder {
ResponseBuilder::new(200)
}
}
Expand Down
5 changes: 5 additions & 0 deletions tests/spinup_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ mod spinup_tests {
testcases::head_rust_sdk_redis(CONTROLLER).await
}

#[tokio::test]
async fn llm_works() {
testcases::llm_works(CONTROLLER).await
}

#[tokio::test]
async fn header_env_routes_works() {
testcases::header_env_routes_works(CONTROLLER).await
Expand Down
Loading

0 comments on commit fa975e8

Please sign in to comment.