Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add llama.cpp backend #94

Merged
merged 3 commits into from
May 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@ env:
RUSTFLAGS: "-D warnings -W unreachable-pub"
RUSTUP_MAX_RETRIES: 10
FETCH_DEPTH: 0 # pull in the tags for the version string
MACOSX_DEPLOYMENT_TARGET: 10.15
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc

5 changes: 1 addition & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -40,10 +40,7 @@ jobs:
DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran

- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
rustflags: ''
toolchain: nightly
uses: dtolnay/rust-toolchain@stable

- name: Install Python 3.10
uses: actions/setup-python@v5
7 changes: 3 additions & 4 deletions crates/custom-types/src/llm_ls.rs
Original file line number Diff line number Diff line change
@@ -67,10 +67,9 @@ pub enum Backend {
#[serde(default = "hf_default_url", deserialize_with = "parse_url")]
url: String,
},
// TODO:
// LlamaCpp {
// url: String,
// },
LlamaCpp {
url: String,
},
Ollama {
url: String,
},
36 changes: 36 additions & 0 deletions crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
@@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppGeneration {
content: String,
}

impl From<LlamaCppGeneration> for Generation {
fn from(value: LlamaCppGeneration) -> Self {
Generation {
generated_text: value.content,
}
}
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LlamaCppAPIResponse {
Generation(LlamaCppGeneration),
Error(APIError),
}

fn build_llamacpp_headers() -> HeaderMap {
HeaderMap::new()
}

fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)),
}
}

#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
response: String,
@@ -192,6 +223,9 @@ pub(crate) fn build_body(
request_body.insert("parameters".to_owned(), params);
}
}
Backend::LlamaCpp { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
}
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
@@ -208,6 +242,7 @@ pub(crate) fn build_headers(
) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()),
Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
@@ -217,6 +252,7 @@ pub(crate) fn build_headers(
pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace { .. } => parse_api_text(text),
Backend::LlamaCpp { .. } => parse_llamacpp_text(text),
Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text),
34 changes: 24 additions & 10 deletions crates/llm-ls/src/document.rs
Original file line number Diff line number Diff line change
@@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
}

impl PositionEncodingKind {
pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind {
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
@@ -205,9 +205,10 @@ impl Document {
) -> Result<()> {
match change.range {
Some(range) => {
if range.start.line < range.end.line
if range.start.line > range.end.line
|| (range.start.line == range.end.line
&& range.start.character <= range.end.character) {
&& range.start.character > range.end.character)
{
return Err(Error::InvalidRange(range));
}

@@ -219,7 +220,10 @@ impl Document {

// 1. Get the line at which the change starts.
let change_start_line_idx = range.start.line as usize;
let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?;
let change_start_line =
self.text.get_line(change_start_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines())
})?;

// 2. Get the line at which the change ends. (Small optimization
// where we first check whether start and end line are the
@@ -228,7 +232,9 @@ impl Document {
let change_end_line_idx = range.end.line as usize;
let change_end_line = match same_line {
true => change_start_line,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines())
})?,
};

fn compute_char_idx(
@@ -330,7 +336,7 @@ impl Document {
self.tree = Some(new_tree);
}
None => {
return Err(Error::TreeSitterParseError);
return Err(Error::TreeSitterParsing);
}
}
}
@@ -416,7 +422,9 @@ mod test {
let mut rope = Rope::from_str(
"let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';",
);
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string())
.await
.unwrap();
let mut parser = Parser::new();

parser
@@ -464,7 +472,9 @@ mod test {
#[tokio::test]
async fn test_text_document_apply_content_change_bounds() {
let rope = Rope::from_str("");
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string())
.await
.unwrap();

assert!(doc
.apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16)
@@ -513,7 +523,9 @@ mod test {
async fn test_document_update_tree_consistency_easy() {
let a = "let a = '你好';\rlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16)
@@ -541,7 +553,9 @@ mod test {
async fn test_document_update_tree_consistency_medium() {
let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 14, 2, 13, ","), PositionEncodingKind::Utf16)
6 changes: 4 additions & 2 deletions crates/llm-ls/src/error.rs
Original file line number Diff line number Diff line change
@@ -33,6 +33,8 @@ pub enum Error {
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("llama.cpp error: {0}")]
LlamaCpp(crate::backend::APIError),
#[error("ollama error: {0}")]
Ollama(crate::backend::APIError),
#[error("openai error: {0}")]
@@ -50,7 +52,7 @@ pub enum Error {
#[error("tgi error: {0}")]
Tgi(crate::backend::APIError),
#[error("tree-sitter parse error: timeout possibly exceeded")]
TreeSitterParseError,
TreeSitterParsing,
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
@@ -60,7 +62,7 @@ pub enum Error {
#[error("unknown backend: {0}")]
UnknownBackend(String),
#[error("unknown encoding kind: {0}")]
UnknownEncodingKind(String)
UnknownEncodingKind(String),
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
14 changes: 13 additions & 1 deletion crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
@@ -417,6 +417,17 @@ async fn get_tokenizer(
fn build_url(backend: Backend, model: &str) -> String {
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
if url.ends_with("/completions") {
url
} else if url.ends_with('/') {
url.push_str("completions");
url
} else {
url.push_str("/completions");
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
@@ -540,7 +551,8 @@ impl LanguageServer for LlmService {
general_capabilities
.position_encodings
.map(TryFrom::try_from)
}).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
})
.unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;

*self.position_encoding.write().await = position_encoding;

12 changes: 6 additions & 6 deletions crates/testbed/repositories-ci.yaml
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@
context_window: 2000
fim:
enabled: true
prefix: <fim_prefix>
middle: <fim_middle>
suffix: <fim_suffix>
model: bigcode/starcoder
prefix: "<PRE> "
middle: " <MID>"
suffix: " <SUF>"
model: codellama/CodeLlama-13b-hf
backend: huggingface
request_body:
max_new_tokens: 150
@@ -14,8 +14,8 @@ request_body:
top_p: 0.95
tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<EOT>"]
repositories:
- source:
type: local