Skip to content

Commit

Permalink
feat: add strategies for building query embedding vector
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Feb 28, 2024
1 parent baedf85 commit 58f814d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
20 changes: 18 additions & 2 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use uuid::Uuid;
use crate::backend::{build_body, build_headers, parse_generations};
use crate::document::Document;
use crate::error::{internal_error, Error, Result};
use crate::retrieval::BuildFrom;

mod backend;
mod config;
Expand Down Expand Up @@ -238,11 +239,21 @@ async fn build_prompt(
after_line = after_iter.next();
}
let before = before.into_iter().rev().collect::<Vec<_>>().join("");
let query = snippet_retriever
.read()
.await
.build_query(
format!("{before}{after}"),
BuildFrom::Cursor {
cursor_position: before.len(),
},
)
.await?;
let snippets = snippet_retriever
.read()
.await
.search(
format!("{before}{after}"),
&query,
Some(FilterBuilder::new().comparison(
"file_url".to_owned(),
Compare::Neq,
Expand Down Expand Up @@ -281,11 +292,16 @@ async fn build_prompt(
before.push(line);
}
let prompt = before.into_iter().rev().collect::<Vec<_>>().join("");
let query = snippet_retriever
.read()
.await
.build_query(prompt.clone(), BuildFrom::End)
.await?;
let snippets = snippet_retriever
.read()
.await
.search(
prompt.clone(),
&query,
Some(FilterBuilder::new().comparison(
"file_url".to_owned(),
Compare::Neq,
Expand Down
62 changes: 51 additions & 11 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,29 +377,60 @@ impl SnippetRetriever {
Ok(())
}

pub(crate) async fn search(
pub(crate) async fn build_query(
&self,
snippet: String,
strategy: BuildFrom,
) -> Result<Vec<f32>> {
match strategy {
BuildFrom::Start => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
self.generate_embedding(encoding, self.model.clone()).await
}
BuildFrom::Cursor { cursor_position } => {
let (before, after) = snippet.split_at(cursor_position);
let mut before_encoding = self.tokenizer.encode(before, true)?;
let mut after_encoding = self.tokenizer.encode(after, true)?;
let share = self.model_config.max_input_size / 2;
before_encoding.truncate(share, 1, TruncationDirection::Left);
after_encoding.truncate(share, 1, TruncationDirection::Right);
before_encoding.take_overflowing();
after_encoding.take_overflowing();
before_encoding.merge_with(after_encoding, false);
self.generate_embedding(before_encoding, self.model.clone())
.await
}
BuildFrom::End => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Left,
);
self.generate_embedding(encoding, self.model.clone()).await
}
}
}

pub(crate) async fn search(
&self,
query: &[f32],
filter: Option<FilterBuilder>,
) -> Result<Vec<Snippet>> {
let db = match self.db.as_ref() {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection(&self.collection_name).await?;
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let query = self
.generate_embedding(encoding, self.model.clone())
.await?;
let result = col
.read()
.await
.get(&query, 5, filter)
.get(query, 5, filter)
.await?
.iter()
.map(TryInto::try_into)
Expand Down Expand Up @@ -537,3 +568,12 @@ impl SnippetRetriever {
Ok(())
}
}

pub(crate) enum BuildFrom {
Cursor {
cursor_position: usize,
},
End,
#[allow(dead_code)]
Start,
}

0 comments on commit 58f814d

Please sign in to comment.