Skip to content

Commit

Permalink
add feature: compress_json (#120)
Browse files Browse the repository at this point in the history
* fix sentence-filter adding separator bug and add document for structured_compress_prompt
* update unittest data
* fix one-sentence bug
* update readme
* update compress_json feature
* update compress_json document
* change format
  • Loading branch information
SiyunZhao authored Mar 26, 2024
1 parent c3c7001 commit 60abc0f
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 4 deletions.
44 changes: 43 additions & 1 deletion DOCUMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"re

### Structured Prompt Compression

Split text into sections, decide on whether to compress and its rate. Use `<llmlingua></llmlingua>` tags for context segmentation, with optional rate and compress parameters.
Split text into sections, decide on whether to compress and its rate. Use `<llmlingua></llmlingua>` tags for context segmentation, with optional `rate` and `compress` parameters.

```python
structured_prompt = """<llmlingua, compress=False>Speaker 4:</llmlingua><llmlingua, rate=0.4> Thank you. And can we do the functions for content? Items I believe are 11, three, 14, 16 and 28, I believe.</llmlingua><llmlingua, compress=False>
Expand All @@ -117,6 +117,48 @@ print(compressed_prompt['compressed_prompt'])
# Speaker 4: We have a promotion and a second time as councilman served Councilman Ringa and customers and they have any comments.
```

### Compress Json data

You can specify the compression method for each key and value by passing a config or a yaml config file. Each key must include four parameters: `rate` indicates the compression ratio for the corresponding value, `compress` indicates whether the corresponding value is compressed, `value_type` indicates the data type of the value, and `pair_remove` indicates whether the key-value pair can be completely deleted.

```python
json_data = {
"id": 987654,
"name": "John Doe",
"skills": ["Java","Python","Machine Learning","Cloud Computing","AI Development"],
"biography": "John Doe, born in New York in 1985, is a renowned software engineer with over 10 years of experience in the field. John graduated from MIT with a degree in Computer Science and has since worked with several Fortune 500 companies. He has a passion for developing innovative software solutions and has contributed to numerous open source projects. John is also an avid writer and speaker at tech conferences, sharing his insights on emerging technologies and their impact on the business world. In his free time, John enjoys hiking, reading science fiction novels, and playing the piano. At TechCorp, John was responsible for leading a team of software engineers and overseeing the development of scalable web applications. He played a key role in driving the adoption of cloud technologies within the company, significantly enhancing the efficiency of their digital operations. In his John on developingedge AI and implementing machine learning solutions for various business applications. He was instrumental in developing a predictive analytics tool that transformed the company's approach to data-driven decision making."
}
json_config = {
"id": {
"rate": 1,
"compress": False,
"value_type": "int",
"pair_remove": True
},
"name": {
"rate": 0.7,
"compress": False,
"value_type": "str",
"pair_remove": False
},
"skills": {
"rate": 0.2,
"compress": True,
"value_type": "list",
"pair_remove": True
},
"biography": {
"rate": 0.3,
"compress": True,
"value_type": "str",
"pair_remove": True
}
}
compressed_prompt = llm_lingua.compress_json(json_data, json_config, use_keyvalue_level_filter=True)
print(compressed_prompt['compressed_prompt'])
# > {'id': 987654, 'name': 'John Doe', 'skills': ['', '', '', '', 'AI'], 'biography': ",York in a has several for developing has avid and speaker at,on and enjoys reading fiction playing. At Tech John for and of scalable He in the of cloud technologies,significantly enhancing the efficiency of their digital operations. In his John on developingedge AI and implementing machine learning solutions for various business applications. He was instrumental in developing a predictive analytics tool that transformed the company's approach to data-driven decision making."}
```

### Integration with LangChain

Thanks to the contributions of Ayo Ayibiowu (@thehapyone), (Long)LLMLingua can be seamlessly integrated into LangChain. Here's an example of how to initialize (Long)LLMLingua within LangChain:
Expand Down
73 changes: 70 additions & 3 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import bisect
import copy
import json
import re
import string
from collections import defaultdict
from typing import List
from typing import List, Union

import nltk
import numpy as np
Expand All @@ -25,6 +26,8 @@
TokenClfDataset,
get_pure_token,
is_begin_of_new_word,
process_structured_json_data,
remove_consecutive_commas,
replace_added_token,
seed_everything,
)
Expand Down Expand Up @@ -207,6 +210,65 @@ def get_ppl(
def __call__(self, *args, **kwargs):
return self.compress_prompt(*args, **kwargs)

def compress_json(
self,
json_data: dict,
json_config: Union[str, dict],
instruction: str = "",
question: str = "",
rate: float = 0.5,
target_token: float = -1,
iterative_size: int = 200,
use_sentence_level_filter: bool = False,
use_keyvalue_level_filter: bool = False,
use_token_level_filter: bool = True,
keep_split: bool = False,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
context_budget: str = "+100",
token_budget_ratio: float = 1.4,
condition_in_question: str = "none",
reorder_keyvalue: str = "original",
condition_compare: bool = False,
rank_method: str = "llmlingua",
):
context, force_context_ids = process_structured_json_data(
json_data, json_config
)
compressed_res = self.structured_compress_prompt(
context=context,
instruction=instruction,
question=question,
rate=rate,
target_token=target_token,
iterative_size=iterative_size,
force_context_ids=force_context_ids,
use_sentence_level_filter=use_sentence_level_filter,
use_context_level_filter=use_keyvalue_level_filter,
use_token_level_filter=use_token_level_filter,
keep_split=keep_split,
keep_first_sentence=keep_first_sentence,
keep_last_sentence=keep_last_sentence,
keep_sentence_number=keep_sentence_number,
high_priority_bonus=high_priority_bonus,
context_budget=context_budget,
token_budget_ratio=token_budget_ratio,
condition_in_question=condition_in_question,
reorder_context=reorder_keyvalue,
condition_compare=condition_compare,
add_instruction=False,
rank_method=rank_method,
concate_question=False,
strict_preserve_uncompressed=False,
)
compressed_json_text = remove_consecutive_commas(
compressed_res["compressed_prompt"]
)
compressed_res["compressed_prompt"] = json.loads(compressed_json_text)
return compressed_res

def structured_compress_prompt(
self,
context: List[str],
Expand Down Expand Up @@ -234,6 +296,7 @@ def structured_compress_prompt(
add_instruction: bool = False,
rank_method: str = "llmlingua",
concate_question: bool = True,
strict_preserve_uncompressed: bool = True,
):
"""
Compresses the given prompt context based on a specified structure.
Expand Down Expand Up @@ -355,6 +418,7 @@ def structured_compress_prompt(
context_segs=context_segs,
context_segs_rate=context_segs_rate,
context_segs_compress=context_segs_compress,
strict_preserve_uncompressed=strict_preserve_uncompressed,
)

def compress_prompt(
Expand Down Expand Up @@ -398,6 +462,7 @@ def compress_prompt(
force_reserve_digit: bool = False,
drop_consecutive: bool = False,
chunk_end_tokens: List[str] = [".", "\n"],
strict_preserve_uncompressed: bool = True,
):
"""
Compresses the given context.
Expand Down Expand Up @@ -547,6 +612,7 @@ def compress_prompt(
context_segs=context_segs,
context_segs_rate=context_segs_rate,
context_segs_compress=context_segs_compress,
strict_preserve_uncompressed=strict_preserve_uncompressed,
)
if context_segs is not None:
context_segs = [context_segs[idx] for idx in context_used]
Expand Down Expand Up @@ -1119,6 +1185,7 @@ def control_context_budget(
context_segs: List[List[str]] = None,
context_segs_rate: List[List[float]] = None,
context_segs_compress: List[List[bool]] = None,
strict_preserve_uncompressed: bool = True,
):
demostrations_sort = self.get_rank_results(
context,
Expand All @@ -1133,9 +1200,9 @@ def control_context_budget(
target_token = eval("target_token" + context_budget)
res = []
used = force_context_ids if force_context_ids is not None else []
if context_segs is not None:
if context_segs is not None and strict_preserve_uncompressed:
for idx, _ in enumerate(context):
if False in context_segs_compress[idx]:
if False in context_segs_compress[idx] and idx not in used:
used.append(idx)

self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
Expand Down
127 changes: 127 additions & 0 deletions llmlingua/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import os
import random
import re
import string

import numpy as np
import torch
import yaml
from torch.utils.data import Dataset


Expand Down Expand Up @@ -107,3 +110,127 @@ def get_pure_token(token, model_name):
return token.lstrip("▁")
else:
raise NotImplementedError()


def process_structured_json_data(json_data, json_config):
if isinstance(json_config, str):
with open(json_config, "r") as file:
json_config = yaml.safe_load(file)
elif not isinstance(json_config, dict):
raise ValueError(
"Invalid json config file. It should be a dictionary or a path to a yaml file."
)
assert set(json_data.keys()) == set(
json_config.keys()
), "Keys in json data and json config file do not match."
context = ["<llmlingua, compress=False>{</llmlingua>"]
forced_context_ids = [0]
for i, (k, v) in enumerate(json_data.items()):
if not json_config[k]["pair_remove"]:
forced_context_ids.append(i + 1)
rate, compress, value_type = (
json_config[k]["rate"],
json_config[k]["compress"],
json_config[k]["value_type"],
)
if not compress:
rate = 1
context.append(precess_jsonKVpair(k, v, value_type, rate))
context[-1] = context[-1][:-14] + "</llmlingua>"
context.append("<llmlingua, compress=False>}</llmlingua>")
forced_context_ids.append(len(json_data) + 1)

return context, forced_context_ids


def precess_jsonKVpair(k, v, value_type, rate):
if rate == 1:
return (
"<llmlingua, compress=False>"
+ f"{json.dumps({k:v})[1:-1]}, "
+ "</llmlingua>"
)
if value_type == "str" or value_type == "string":
v = str(v)
new_v = (
f"</llmlingua><llmlingua, rate={rate}>"
+ v
+ "</llmlingua><llmlingua, compress=False>"
)
return (
"<llmlingua, compress=False>"
+ f"{json.dumps({k:new_v})[1:-1]}, "
+ "</llmlingua>"
)
elif value_type in ["int", "float", "integer", "number"]:
if value_type in ["int", "integer"]:
v = int(v)
if value_type in ["float", "number"]:
v = float(v)
return (
"<llmlingua, compress=False>"
+ f'"{k}": </llmlingua><llmlingua, rate={rate}>{v}</llmlingua><llmlingua, compress=False>, </llmlingua>'
)
elif value_type == "bool" or value_type == "boolean":
if v in ["True", "true", "TRUE", True]:
v = "true"
elif v in ["False", "false", "FALSE", False]:
v = "false"
else:
raise ValueError(f"Invalid boolean value: {v}")
new_v = (
f"</llmlingua><llmlingua, rate={rate}>"
+ v
+ "</llmlingua><llmlingua, compress=False>"
)
return (
"<llmlingua, compress=False>"
+ f"{json.dumps({k:new_v})[1:-1]}, "
+ "</llmlingua>"
)
elif value_type == "list" or value_type == "List":
return (
"<llmlingua, compress=False>"
+ f'"{k}": {process_sequence_data(rate, "[", "]", v)}'
)
elif value_type == "dict" or value_type == "dictionary":
return (
"<llmlingua, compress=False>"
+ f'"{k}": {process_sequence_data(rate, "[", "]", v, is_dict=True)}'
)
elif value_type == "set":
raise ValueError(f"Invalid value type: {value_type}")
# return '<llmlingua, compress=False>' + f'"{k}": {process_sequence_data(rate, "{", "}", v)}'
elif value_type == "tuple":
return (
"<llmlingua, compress=False>"
+ f'"{k}": {process_sequence_data(rate, "(", ")", v)}'
)
else:
raise ValueError(f"Invalid value type: {value_type}")


def process_sequence_data(rate, start, end, sequence, is_dict=False):
res = f'{start}"'
n = len(sequence)
if not is_dict:
for i, item in enumerate(sequence):
item = str(item)
res += f"</llmlingua><llmlingua, rate={rate}>{item}</llmlingua><llmlingua, compress=False>"
if i != n - 1:
res += '", "'
else:
for i, (k, v) in enumerate(sequence.items()):
item = f"{k}: {v}"
item.replace('"', "'")
res += f"</llmlingua><llmlingua, rate={rate}>{item}</llmlingua><llmlingua, compress=False>"
if i != n - 1:
res += '", "'
res += f'"{end}, </llmlingua>'
return res


def remove_consecutive_commas(text):
text = re.sub(r",\s*", ",", text)
text = re.sub(r",+", ",", text)
return text

0 comments on commit 60abc0f

Please sign in to comment.