Skip to content

Commit

Permalink
Merge pull request #29 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
Pipeline
  • Loading branch information
yanqiangmiffy authored Jun 25, 2024
2 parents 8751428 + 841cac3 commit b750d2d
Show file tree
Hide file tree
Showing 20 changed files with 463 additions and 620 deletions.
99 changes: 69 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,54 @@ pip install -r requirements.txt
```
### 1 文档解析

目前支持解析的文件类型包括:`text`,`docx`,`ppt`,`excel`,`html`,`pdf`,`md`

```python
from gomate.modules.document.parset import TextParser
from gomate.modules.store import VectorStore
from gomate.modules.document.common_parser import CommonParser

docs = TextParser('./data/docs').get_content(max_token_len=600, cover_content=150)
vector = VectorStore(docs)
parser = CommonParser()
document_path = 'docs/夏至各地习俗.docx'
chunks = parser.parse(document_path)
print(chunks)
```

### 2 提取向量
### 2 构建检索器

```python
from gomate.modules.retrieval.embedding import BgeEmbedding
embedding = BgeEmbedding("BAAI/bge-large-zh-v1.5") # 创建EmbeddingModel
vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库
vector.load_vector(path='storage') # 加载本地的数据库
import pandas as pd
from tqdm import tqdm

from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig

retriever_config = DenseRetrieverConfig(
model_name_or_path="bge-large-zh-v1.5",
dim=1024,
index_dir='dense_cache'
)
config_info = retriever_config.log_config()
print(config_info)

retriever = DenseRetriever(config=retriever_config)

data = pd.read_json('docs/zh_refine.json', lines=True)[:5]
print(data)
print(data.columns)

retriever.build_from_texts(documents)
```


保存索引
```python
retriever.save_index()
```


### 3 检索文档

```python
question = '伊朗坠机事故原因是什么?'
contents = vector.query(question, EmbeddingModel=embedding, k=1)
content = '\n'.join(contents[:5])
print(contents)
result = retriever.retrieve("RCEP具体包括哪些国家")
print(result)
```

### 4 大模型问答
Expand All @@ -70,24 +93,26 @@ print(chat.chat(question, [], content))
### 5 添加文档

```python
docs = TextParser.get_content_by_file(file='data/docs/伊朗问题.txt', max_token_len=600, cover_content=150)
vector.add_documents('storage', docs, embedding)
question = '如今伊朗人的经济生活状况如何?'
contents = vector.query(question, EmbeddingModel=embedding, k=1)
content = '\n'.join(contents[:5])
print(contents)
print(chat.chat(question, [], content))
for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
retriever.add_text(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
retriever.add_text(document)
```

## 🔧定制化RAG

> 构建自定义的RAG应用
```python
from gomate.modules.document.reader import ReadFiles
import os

from gomate.modules.document.common_parser import CommonParser
from gomate.modules.generator.llm import GLMChat
from gomate.modules.retrieval.embedding import BgeEmbedding
from gomate.modules.store import VectorStore
from gomate.modules.reranker.bge_reranker import BgeReranker
from gomate.modules.retrieval.dense_retriever import DenseRetriever



class RagApplication():
Expand All @@ -113,12 +138,24 @@ class RagApplication():
### 🌐体验RAG效果
可以配置本地模型路径
```text
class ApplicationConfig:
llm_model_name = '/data/users/searchgpt/pretrained_models/chatglm3-6b' # 本地模型文件 or huggingface远程仓库
embedding_model_name = '/data/users/searchgpt/pretrained_models/bge-reranker-large' # 检索模型文件 or huggingface远程仓库
vector_store_path = './storage'
docs_path = './data/docs'
# 修改成自己的配置!!!
app_config = ApplicationConfig()
app_config.docs_path = "./docs/"
app_config.llm_model_path = "/data/users/searchgpt/pretrained_models/chatglm3-6b/"
retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
rerank_config = BgeRerankerConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-reranker-large"
)
app_config.retriever_config = retriever_config
app_config.rerank_config = rerank_config
application = RagApplication(app_config)
application.init_vector_store()
```

```shell
Expand All @@ -127,7 +164,9 @@ python app.py
浏览器访问:[127.0.0.1:7860](127.0.0.1:7860)
![demo.png](resources%2Fdemo.png)

app后台日志:

![app_logging.png](resources%2Fapp_logging.png)
## ⭐️ Star History

[![Star History Chart](https://api.star-history.com/svg?repos=gomate-community/GoMate&type=Date)](https://star-history.com/#gomate-community/GoMate&Date)
Expand Down
59 changes: 38 additions & 21 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,55 @@
"""
import os
import shutil

import gradio as gr
from gomate.applications.rag import RagApplication

from gomate.applications.rag import RagApplication, ApplicationConfig
from gomate.modules.reranker.bge_reranker import BgeRerankerConfig
from gomate.modules.retrieval.dense_retriever import DenseRetrieverConfig

# 修改成自己的配置!!!
class ApplicationConfig:
llm_model_name = '/data/users/searchgpt/pretrained_models/chatglm3-6b' # 本地模型文件 or huggingface远程仓库
embedding_model_name = '/data/users/searchgpt/pretrained_models/bge-reranker-large' # 检索模型文件 or huggingface远程仓库
vector_store_path = './storage'
docs_path = './data/docs'

app_config = ApplicationConfig()
app_config.docs_path = "./docs/"
app_config.llm_model_path = "/data/users/searchgpt/pretrained_models/chatglm3-6b/"

retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
rerank_config = BgeRerankerConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-reranker-large"
)

config = ApplicationConfig()
application = RagApplication(config)
app_config.retriever_config = retriever_config
app_config.rerank_config = rerank_config
application = RagApplication(app_config)
application.init_vector_store()


def get_file_list():
if not os.path.exists("./docs"):
if not os.path.exists(app_config.docs_path):
return []
return [f for f in os.listdir("./docs")]
return [f for f in os.listdir(app_config.docs_path)]


file_list = get_file_list()

def info_fn(filename):
gr.Info(f"upload file:{filename} success!")

def upload_file(file):
cache_base_dir = './docs/'
cache_base_dir = app_config.docs_path
if not os.path.exists(cache_base_dir):
os.mkdir(cache_base_dir)
filename = os.path.basename(file.name)
shutil.move(file.name, cache_base_dir + filename)
# file_list首位插入新上传的文件
file_list.insert(0, filename)
application.add_document("./docs/" + filename)
return gr.Dropdown.update(choices=file_list, value=filename)

application.add_document(app_config.docs_path + filename)
info_fn(filename)
return gr.Dropdown(choices=file_list, value=filename,interactive=True)

def set_knowledge(kg_name, history):
try:
Expand Down Expand Up @@ -74,7 +86,7 @@ def predict(input,
history = []

if use_web == '使用':
web_content = application.source_service.search_web(query=input)
web_content = application.retriever.search_web(query=input)
else:
web_content = ''
search_text = ''
Expand All @@ -85,9 +97,9 @@ def predict(input,
return '', history, history, search_text

else:
response, _,contents = application.chat(
response, _, contents = application.chat(
question=input,
topk=top_k,
top_k=top_k,
)
history.append((input, response))
for idx, source in enumerate(contents[:5]):
Expand Down Expand Up @@ -118,10 +130,10 @@ def predict(input,

large_language_model = gr.Dropdown(
[
"ChatGLM-6B-int4",
"ChatGLM3-6B",
],
label="large language model",
value="ChatGLM-6B-int4")
value="ChatGLM3-6B")

top_k = gr.Slider(1,
20,
Expand Down Expand Up @@ -154,7 +166,12 @@ def predict(input,
visible=True,
file_types=['.txt', '.md', '.docx', '.pdf']
)

# uploaded_files = gr.Dropdown(
# file_list,
# label="已上传的文件列表",
# value=file_list[0] if len(file_list) > 0 else '',
# interactive=True
# )
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(label='Gomate Application').style(height=400)
Expand Down
7 changes: 7 additions & 0 deletions examples/parsers/parser_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from gomate.modules.document.common_parser import CommonParser

if __name__ == '__main__':
parser = CommonParser()
document_path = '/data/users/searchgpt/yq/GoMate_dev/docs/夏至各地习俗.docx'
chunks = parser.parse(document_path)
print(chunks)
24 changes: 11 additions & 13 deletions examples/retrievers/denseretriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,32 @@
@software: PyCharm
@description: coding..
"""
from tqdm import tqdm
import pandas as pd
from gomate.modules.retrieval.dense_retriever import DenseRetriever,DenseRetrieverConfig

from tqdm import tqdm

from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig

if __name__ == '__main__':
retriever_config=DenseRetrieverConfig(
model_name="/home/test/pretrained_models/bge-large-zh-v1.5",
retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
top_k=3
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info=retriever_config.log_config()
config_info = retriever_config.log_config()
print(config_info)

retriever=DenseRetriever(config=retriever_config)
retriever = DenseRetriever(config=retriever_config)

data = pd.read_json('../../data/zh_refine.json', lines=True)[:5]
data = pd.read_json('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', lines=True)[:5]
print(data)
print(data.columns)

for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
retriever.add_doc(document)

retriever.add_text(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
retriever.add_doc(document)

retriever.add_text(document)
result = retriever.retrieve("RCEP具体包括哪些国家")
print(result)
retriever.save_index()
22 changes: 5 additions & 17 deletions examples/retrievers/faissretriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,21 @@
from gomate.modules.retrieval.faiss_retriever import FaissRetriever, FaissRetrieverConfig

if __name__ == '__main__':
from transformers import AutoTokenizer

embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
embedding_model_path = "/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5"
embedding_model = SBertEmbeddingModel(embedding_model_path)
tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
retriever_config = FaissRetrieverConfig(
max_tokens=100,
max_context_tokens=3500,
use_top_k=True,
embedding_model=embedding_model,
top_k=5,
tokenizer=tokenizer,
embedding_model_string="bge-large-zh-v1.5",
index_path="faiss_index.bin",
index_path="/data/users/searchgpt/yq/GoMate/examples/retrievers/faiss_index.bin",
rebuild_index=True
)

faiss_retriever = FaissRetriever(config=retriever_config)

documents = []
with open('/home/test/codes/GoMate/data/zh_refine.json', 'r', encoding="utf-8") as f:
with open('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', 'r', encoding="utf-8") as f:
for line in f.readlines():
data = json.loads(line)
documents.extend(data['positive'])
documents.extend(data['negative'])
print(len(documents))
faiss_retriever.build_from_texts(documents[:200])

contexts = faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
print(contexts)
search_contexts = faiss_retriever.retrieve("2021年香港GDP增长了多少")
print(search_contexts)
Loading

0 comments on commit b750d2d

Please sign in to comment.