Skip to content

Commit

Permalink
Feature(LLMLingua): add KV-Cache Compression & HF Space Demo
Browse files Browse the repository at this point in the history
  • Loading branch information
iofu728 committed Oct 9, 2023
1 parent b44fd13 commit bd31e43
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 11 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
<img src="images/LLMLingua_logo.png" alt="LLMLingua" style="width: 20%; min-width: 100px; display: block; margin: auto;">
</p>

# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [[paper]()] & LongLLMLingua [[paper]()]
# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [paper] & LongLLMLingua [paper]

https://github.com/microsoft/LLMLingua/assets/30883354/ef52995c-ef3c-4eac-a9fd-1acb491c325b

You can try the LLMLingua demo in [HF Space](https://huggingface.co/spaces/microsoft/LLMLingua).

## Tl;DR

LLMLingua, that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to 20x compression with minimal performance loss.

[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() (EMNLP 2023).
LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models (EMNLP 2023).<br>
_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_

LongLLMLingua is a method that enhances LLMs' ability to perceive key information in long-context scenarios using prompt compression, achieveing up to $28.5 in cost savings per 1,000 samples while also improving performance.

[LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression]() (Under Review).
LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression (Under Review).<br>
_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_

## 🎥 Overview
Expand Down
70 changes: 63 additions & 7 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ def __init__(
self.load_model(model_name, device_map, use_auth_token)
self.sbert = None
self.open_api_config = open_api_config
self.cache_bos_num = 10

def load_model(
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
):
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
Expand All @@ -40,12 +41,11 @@ def load_model(
if "cuda" in device_map or "cpu" in device_map:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
torch_dtype="auto" if device_map == "cuda" else torch.float32,
config=config,
ignore_mismatched_sizes=True,
trust_remote_code=True,
).to(device_map)
if device_map == "cpu":
model = model.type(torch.float32)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
Expand All @@ -56,10 +56,12 @@ def load_model(
offload_state_dict=True,
cache_dir="/tmp/cache",
use_auth_token=use_auth_token,
trust_remote_code=True,
)
self.tokenizer = tokenizer
self.model = model
self.context_idxs = []
self.max_position_embeddings = config.max_position_embeddings

def get_ppl(
self,
Expand All @@ -83,7 +85,7 @@ def get_ppl(
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + 4096)
end = min(end, past_length + self.max_position_embeddings)
with torch.no_grad():
response = self.model(
input_ids[:, past_length:end],
Expand Down Expand Up @@ -145,11 +147,17 @@ def compress_prompt(
assert not (
rank_method == "longllmlingua" and not question
), "In the LongLLMLingua, it is necessary to set a question."
if condition_compare and "_condition" not in condition_in_question:
condition_in_question += "_condition"
if rank_method == "longllmlingua":
if condition_in_question == "none":
condition_in_question = "after"
elif rank_method == "llmlingua":
condition_in_question = "none"
condition_in_question = (
"none"
if "_condition" not in condition_in_question
else "none_condition"
)
origin_tokens = len(
encoding.encode("\n\n".join([instruction] + context + [question]).strip())
)
Expand Down Expand Up @@ -653,8 +661,52 @@ def iterative_compress_prompt(
keep_flag = torch.tensor(keep_flag).to(self.device)
past_key_values, past_loss, ready_end = None, None, 0
self_past_key_values, self_past_loss, self_ready_end = None, None, 0
pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
idx = 0
while end <= compressed_input_ids.shape[1]:
if end > self.max_position_embeddings and past_key_values is not None:
# KV-Cache Compression
e, s = end - self.max_position_embeddings, self.cache_bos_num
if pop_compressed_input_ids is None:
pop_compressed_input_ids = compressed_input_ids[:, :e]
else:
pop_compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
)
compressed_input_ids = compressed_input_ids[:, e:]
compressed_attention_mask = compressed_attention_mask[:, e:]
past_key_values = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in past_key_values
]
end, ready_end = end - e, ready_end - e
if condition_compare:
self_ready_end -= e
if pop_self_compressed_input_ids is None:
pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
else:
pop_self_compressed_input_ids = torch.cat(
[
pop_self_compressed_input_ids,
self_compressed_input_ids[:, :e],
],
dim=-1,
)
self_compressed_input_ids = self_compressed_input_ids[:, e:]
self_compressed_attention_mask = self_compressed_attention_mask[
:, e:
]
self_past_key_values = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in self_past_key_values
]

loss, past_key_values = self.get_ppl(
"",
"token",
Expand Down Expand Up @@ -762,6 +814,10 @@ def iterative_compress_prompt(
)
end += iterative_size
idx += 1
if pop_compressed_input_ids is not None:
compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids], dim=-1
)
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]

def recover(
Expand Down
2 changes: 1 addition & 1 deletion llmlingua/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "1"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "1"
_PATCH = "2"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down

0 comments on commit bd31e43

Please sign in to comment.