diff --git a/.coverage.wenshandeMacBook-Air.local.33294.XMCccIZx b/.coverage.wenshandeMacBook-Air.local.33294.XMCccIZx deleted file mode 100644 index 92c9a68..0000000 Binary files a/.coverage.wenshandeMacBook-Air.local.33294.XMCccIZx and /dev/null differ diff --git a/README.md b/README.md index 7d25ef2..d7c5383 100644 --- a/README.md +++ b/README.md @@ -18,3 +18,6 @@ git clone https://github.com/gomate-community/GoMate.git ``` git install ... ``` + +## Demo +![demo](resources/demo.png) \ No newline at end of file diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000..e85d971 --- /dev/null +++ b/README_zh.md @@ -0,0 +1,124 @@ +# GoMate + +可配置的模块化RAG框架。 + +[![Python](https://img.shields.io/badge/Python-3.10.0-3776AB.svg?style=flat)](https://www.python.org) +![workflow status](https://github.com/gomate-community/rageval/actions/workflows/makefile.yml/badge.svg) +[![codecov](https://codecov.io/gh/gomate-community/GoMate/graph/badge.svg?token=eG99uSM8mC)](https://codecov.io/gh/gomate-community/GoMate) +[![pydocstyle](https://img.shields.io/badge/pydocstyle-enabled-AD4CD3)](http://www.pydocstyle.org/en/stable/) +[![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/) + + +## 🔥Gomate 简介 +GoMate是一款配置化模块化的Retrieval-Augmented Generation (RAG) 框架,旨在提供**可靠的输入与可信的输出**,确保用户在检索问答场景中能够获得高质量且可信赖的结果。 + +GoMate框架的设计核心在于其**高度的可配置性和模块化**,使得用户可以根据具体需求灵活调整和优化各个组件,以满足各种应用场景的要求。 + +## 🔨Gomate框架 +![framework.png](resources%2Fframework.png) +## ✨主要特色 + +**“Reliable input,Trusted output”** + +可靠的输入,可信的输出 + +## 🚀快速上手 + +### 安装环境 +```shell +pip install -r requirements.txt +``` +### 1 文档解析 + +```python +from gomate.modules.document.reader import ReadFiles +from gomate.modules.store import VectorStore + +docs = ReadFiles('./data/docs').get_content(max_token_len=600, cover_content=150) +vector = VectorStore(docs) +``` + +### 2 提取向量 + +```python +from gomate.modules.retrieval.embedding import BgeEmbedding +embedding = BgeEmbedding("BAAI/bge-large-zh-v1.5") # 创建EmbeddingModel +vector.get_vector(EmbeddingModel=embedding) +vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库 +vector.load_vector(path='storage') # 加载本地的数据库 +``` + +### 3 检索文档 + +```python +question = '伊朗坠机事故原因是什么?' +contents = vector.query(question, EmbeddingModel=embedding, k=1) +content = '\n'.join(contents[:5]) +print(contents) +``` + +### 4 大模型问答 +```python +from gomate.modules.generator.llm import GLMChat +chat = GLMChat(path='THUDM/chatglm3-6b') +print(chat.chat(question, [], content)) +``` + +### 5 添加文档 +```python +docs = ReadFiles('').get_content_by_file(file='data/add/伊朗问题.txt', max_token_len=600, cover_content=150) +vector.add_documents('storage', docs, embedding) +question = '如今伊朗人的经济生活状况如何?' +contents = vector.query(question, EmbeddingModel=embedding, k=1) +content = '\n'.join(contents[:5]) +print(contents) +print(chat.chat(question, [], content)) +``` + +## 🔧定制化RAG + +> 构建自定义的RAG应用 + +```python +from gomate.modules.document.reader import ReadFiles +from gomate.modules.generator.llm import GLMChat +from gomate.modules.retrieval.embedding import BgeEmbedding +from gomate.modules.store import VectorStore + + +class RagApplication(): + def __init__(self, config): + pass + + def init_vector_store(self): + pass + + def load_vector_store(self): + pass + + def add_document(self, file_path): + pass + + def chat(self, question: str = '', topk: int = 5): + pass +``` + +模块可见[rag.py](gomate/applications/rag.py) + + +### 🌐体验RAG效果 +可以配置本地模型路径 +```text +class ApplicationConfig: + llm_model_name = '/data/users/searchgpt/pretrained_models/chatglm3-6b' # 本地模型文件 or huggingface远程仓库 + embedding_model_name = '/data/users/searchgpt/pretrained_models/bge-reranker-large' # 检索模型文件 or huggingface远程仓库 + vector_store_path = './storage' + docs_path = './data/docs' + +``` + +```shell +python app.py +``` +浏览器访问:[127.0.0.1:7860](127.0.0.1:7860) +![demo.png](resources%2Fdemo.png) \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..f3b5049 --- /dev/null +++ b/app.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: app.py +@time: 2024/05/21 +@contact: yanqiangmiffy@gamil.com +""" +import os +import shutil +import gradio as gr +from gomate.applications.rag import RagApplication + + +# 修改成自己的配置!!! +class ApplicationConfig: + llm_model_name = '/data/users/searchgpt/pretrained_models/chatglm3-6b' # 本地模型文件 or huggingface远程仓库 + embedding_model_name = '/data/users/searchgpt/pretrained_models/bge-reranker-large' # 检索模型文件 or huggingface远程仓库 + vector_store_path = './storage' + docs_path = './data/docs' + + +config = ApplicationConfig() +application = RagApplication(config) +application.init_vector_store() + + +def get_file_list(): + if not os.path.exists("./docs"): + return [] + return [f for f in os.listdir("./docs")] + + +file_list = get_file_list() + + +def upload_file(file): + cache_base_dir = './docs/' + if not os.path.exists(cache_base_dir): + os.mkdir(cache_base_dir) + filename = os.path.basename(file.name) + shutil.move(file.name, cache_base_dir + filename) + # file_list首位插入新上传的文件 + file_list.insert(0, filename) + application.add_document("./docs/" + filename) + return gr.Dropdown.update(choices=file_list, value=filename) + + +def set_knowledge(kg_name, history): + try: + application.load_vector_store() + msg_status = f'{kg_name}知识库已成功加载' + except Exception as e: + print(e) + msg_status = f'{kg_name}知识库未成功加载' + return history + [[None, msg_status]] + + +def clear_session(): + return '', None + + +def predict(input, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, + history=None): + # print(large_language_model, embedding_model) + print(input) + if history == None: + history = [] + + if use_web == '使用': + web_content = application.source_service.search_web(query=input) + else: + web_content = '' + search_text = '' + if use_pattern == '模型问答': + result = application.get_llm_answer(query=input, web_content=web_content) + history.append((input, result)) + search_text += web_content + return '', history, history, search_text + + else: + response, _,contents = application.chat( + question=input, + topk=top_k, + ) + history.append((input, response)) + for idx, source in enumerate(contents[:5]): + sep = f'----------【搜索结果{idx + 1}:】---------------\n' + search_text += f'{sep}\n{source}\n\n' + print(search_text) + search_text += "----------【网络检索内容】-----------\n" + search_text += web_content + return '', history, history, search_text + + +with gr.Blocks(theme="soft") as demo: + gr.Markdown("""

Gomate Application

+
+
+ """) + state = gr.State() + + with gr.Row(): + with gr.Column(scale=1): + embedding_model = gr.Dropdown([ + "text2vec-base", + "bge-large-v1.5", + "bge-base-v1.5", + ], + label="Embedding model", + value="bge-large-v1.5") + + large_language_model = gr.Dropdown( + [ + "ChatGLM-6B-int4", + ], + label="large language model", + value="ChatGLM-6B-int4") + + top_k = gr.Slider(1, + 20, + value=4, + step=1, + label="检索top-k文档", + interactive=True) + + use_web = gr.Radio(["使用", "不使用"], label="web search", + info="是否使用网络搜索,使用时确保网络通常", + value="不使用", interactive=False + ) + use_pattern = gr.Radio( + [ + '模型问答', + '知识库问答', + ], + label="模式", + value='知识库问答', + interactive=False) + + kg_name = gr.Radio(["伊朗新闻"], + label="知识库", + value=None, + info="使用知识库问答,请加载知识库", + interactive=True) + set_kg_btn = gr.Button("加载知识库") + + file = gr.File(label="将文件上传到知识库库,内容要尽量匹配", + visible=True, + file_types=['.txt', '.md', '.docx', '.pdf'] + ) + + with gr.Column(scale=4): + with gr.Row(): + chatbot = gr.Chatbot(label='Gomate Application').style(height=400) + with gr.Row(): + message = gr.Textbox(label='请输入问题') + with gr.Row(): + clear_history = gr.Button("🧹 清除历史对话") + send = gr.Button("🚀 发送") + with gr.Row(): + gr.Markdown("""提醒:
+ [Gomate Application](https://github.com/gomate-community/GoMate)
+ 有任何使用问题[Github Issue区](https://github.com/gomate-community/GoMate)进行反馈. +
+ """) + with gr.Column(scale=2): + search = gr.Textbox(label='搜索结果') + + # ============= 触发动作============= + file.upload(upload_file, + inputs=file, + outputs=None) + set_kg_btn.click( + set_knowledge, + show_progress=True, + inputs=[kg_name, chatbot], + outputs=chatbot + ) + # 发送按钮 提交 + send.click(predict, + inputs=[ + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, + state + ], + outputs=[message, chatbot, state, search]) + + # 清空历史对话按钮 提交 + clear_history.click(fn=clear_session, + inputs=[], + outputs=[chatbot, state], + queue=False) + + # 输入框 回车 + message.submit(predict, + inputs=[ + message, + large_language_model, + embedding_model, + top_k, + use_web, + use_pattern, + state + ], + outputs=[message, chatbot, state, search]) + +demo.queue(concurrency_count=2).launch( + server_name='0.0.0.0', + server_port=7860, + share=True, + show_error=True, + debug=True, + enable_queue=True, + inbrowser=False, +) diff --git "a/data/add/\344\274\212\346\234\227.txt" "b/data/add/\344\274\212\346\234\227.txt" new file mode 100644 index 0000000..5d19e18 --- /dev/null +++ "b/data/add/\344\274\212\346\234\227.txt" @@ -0,0 +1,124 @@ +伊朗总统莱希坠机事件:目前为止知道哪些细节?如何影响政局? +2024年5月21日 +Rescuers found the crash site after daybreak on Monday图像来源,EPA +图像加注文字,救援人员在事发第二天赶到直升机坠毁现场。 +伊朗官方证实,总统莱希和同行七人乘坐的直升机在阿塞拜疆边境附近坠毁,全员遇难。 + +伊朗官方尚未公布具体事故原因,但描述了过程中的恶劣天气。前外长指责美国负有间接责任,因为美方长期制裁伊朗,使该国无法购买新的飞机。 + +终年63岁的莱希属于强硬派,被视为85岁的最高精神领袖哈梅内伊的潜在接班人。外界关注他遇难对伊朗、中东乃至全球政局的影响。 + +以下是BBC了解到的细节。 + +伊朗总统莱希直升机坠毁遇难,他被视为哈梅内伊接班人 +伊朗强硬派领军人物之一莱西当选总统 +伊朗为什么要袭击以色列? +总统为何乘坐直升机? +5月19日上午,莱希飞往伊朗境内西北部的东阿塞拜疆省,参加齐兹加拉西(Qiz Qalasi)和霍达阿法林(Khoda Afarin)两座大坝的落成典礼,阿塞拜疆总统阿利耶夫也一同出席。这是两国在阿拉斯河(Aras river)上合作的水电项目。 + +阿利耶夫说,在直升机离开大坝地区飞往南面约130公里的大不里士市(Tabriz)之前,他已向莱希“友好道别”。 + +莱希原定随后前往大不里士炼油厂,参加一个项目的落成典礼。 + +机上还有谁? +President Ebrahim Raisi (3rd L) was in north-western Iran for the inauguration of a dam, along with Foreign Minister Hossein Amir-Abdollahian (2nd R), East Azerbaijan Governor Malek Rahmati (2nd L) and Ayatollah Mohammad Ali Al-e Hashem (5th R)图像来源,EPA +图像加注文字,莱希(前排左三)前往东阿塞拜疆省参加大坝落成典礼,随机人员包括外长阿卜杜拉希扬(右二)、东阿塞拜疆省省长拉赫马蒂(前排左二)和最高领袖派驻东阿塞拜疆省代表阿勒哈希姆。 +同机的伊朗外长也已遇难图像来源,EPA +图像加注文字,同机的伊朗外长阿卜杜拉希扬也已遇难。 +伊斯兰革命卫队(IRGC)总司令萨拉米(Hossein Salami)少将表示,七名与总统同行的人也在空难中丧生,包括: + +伊朗外长阿卜杜拉希扬(Hossein Amir-Abdollahian) +东阿塞拜疆省省长拉赫马蒂(Malek Rahmati) +最高领袖派驻东阿塞拜疆省代表阿勒哈希姆(Ayatollah Mohammad Ali Ale-Hashem) +伊斯兰革命卫队准将兼总统安全小组组长穆萨维(Mohammad Mehdi Mousavi) +机师Mohsen Daryanush上校 +机师Seyyed Taher Mostafavi上校 +技术员Behrouz Qadimi少校 +直升机在哪里坠毁? +坠机地点 +图像加注文字,坠机地点靠近伊朗和阿塞拜疆的边境。 +伊朗官媒发布的照片显示,事发在当地时间周日约13:30(格林威治标准时间10:00),地点位于齐兹加拉西(Qiz-Qalasi)大坝以南约58公里、乌兹村(Uzi)西南2公里处的偏远山区。 + +但直到当地时间16:00(格林威治标准时间12:45)之后,伊朗国家电视台才报道称,载有总统的直升机在浓雾和大雨中飞往大不里士(Tabriz)“硬着陆”。 + +伊朗内政部长瓦希迪(Ahmad Vahidi )其后证实,总统代表团乘坐三架直升机出行,总统乘坐的飞机“因恶劣天气和该区浓雾而被迫硬着陆”。 + +他表示,多个救援队尝试前往事发地点,但浓雾、雨雪以及当地地形阻碍了搜索行动。 + +另外两架直升机也一度失联,当局后曾启动15至20分钟的搜索,它们随后紧急着陆。行动一直持续到深夜。 + +负责行政事务的副总统曼苏里(Mohsen Mansouri)当时曾表示,与总统直升机上的一名机组人员及另一人取得联系,“这表明事件的严重性不是很高,因为直升机内的两个人设法与我们的团队多次沟通”,但他没有提供更多细节。 + +然而,到了周一黎明,机上有人生还的希望破灭了。 + +Photographs from the scene on Monday showed rescuers climbing a steep mountainside, shrouded in fog图像来源,REUTERS +图像加注文字,现场照片显示,搜救人员周一爬过陡峭的山峰,四周雾气弥漫。 +现场发现了什么? +伊朗国家电视台引述伊朗红新月会负责人科利万德 (Pirhossein Kolivand) 表示,救援人员到达海拔约2,200米的坠机地点后,发现“没有生命迹象”。 + +国家电视台发布从山谷对面看到的坠机现场模糊画面,显示一架直升机的蓝白色尾部,旁边是一些烧焦的灌木丛。 + +在莱希死讯确认后,该台播放一段影片,显示一名记者站在机尾和残骸的其它部分面前。 + +官媒伊通社(IRNA)影片显示,救援人员用担架抬着一具裹着毯子的遗体。 + +国家电视台称,遗体已转移到大不里士的一个墓地。 + +半官方的塔斯尼姆通讯社(Tasnim)引述伊朗危机管理机构负责人纳米 (Mohammad Nami) 称,所有遗体都可辨认,“无需进行DNA检查”。 + +他还表示,最高领袖派驻东阿塞拜疆省代表阿勒哈希姆(Ayatollah Mohammad Ali Ale-Hashem)在坠机后一小时内还活着,并在去世前与总统办公室负责人取得了联系。 + +事故原因是什么? +伊朗政府官员描述了直升机在大雾和大雨中陷入困境后坠毁的情况。 + +但当局迄今尚未公布具体的事故原因。 + +直升机是什么型号? +伊朗官媒确认失事直升机机型为贝尔212(Bell-212),是一家美国公司在上世纪60年代为加拿大军方开发的型号。 + +The Bell 212 carrying President Raisi was filmed taking off from the Qiz-Qalasi Dam before the crash图像来源,REUTERS +图像加注文字,这架贝尔212直升机事故发生前从Qiz-Qalasi 大坝起飞的画面 +根据FlightGlobal的2024年世界空军名录,伊朗海军和空军共有10架贝尔212,但尚不清楚伊朗政府拥有该型号飞机的数目。 + +官媒伊通社(IRNA)称,总统乘坐的直升机可载六名乘客和两名机组人员。 + +据飞行安全基金会称,上一宗涉及贝尔212的致命事件发生在2018年4月一次医疗运送过程中。 + +伊朗国内有何反应? +伊朗最高精神领袖哈梅内伊对“惨痛悲剧”致哀,宣布全国哀悼五天。 + +他说:“带着深切的悲痛和遗憾,我收到了人民总统、能干、勤奋的莱希及其尊敬的随行者殉难的悲痛消息。” + +伊朗内阁发表声明称,总统“在为国家服务的道路上作出了最终牺牲”,并向伊朗人民承诺将追随莱希的榜样,“国家治理不会有任何问题”。 + +莱希的温和派竞争对手、前任总统鲁哈尼(Hassan Rouhani)也表达哀悼,表示“伊斯兰革命的历史翻开了悲痛的一页”。 + +前外交部长扎里夫(Mohammad Javad Zarif)对国家电视台表示,美国对事故负有间接责任,因为美方多年来一直维持制裁,阻止伊朗购买新飞机。 + +谁接替莱希担任总统? +Mohammad Mokhber, Ebrahim Raisi's deputy, has been named acting president图像来源,REUTERS +图像加注文字,现任副总统穆赫贝尔(Mohammad Mokhber)已被任命为代理总统。 +哈梅内伊确认,根据伊朗宪法第131条,副总统穆赫贝尔(Mohammad Mokhber)已被任命为代理总统。接着,伊朗必须在50天内举行选举选出新总统。 + +伊朗外长伊朗外长阿卜杜拉希扬也在坠机事故中遇难,他的职位由副外长、伊朗资深核谈判代表巴盖里·卡尼(Ali Baqeri Kani)代理。在议会提名并批准继任人之前,他最多可以代理三个月。 + +BBC首席国际事务记者丽斯·杜塞特(Lyse Doucet)分析,即使在莱希的团队中,似乎也没有明显的总统继任者。 + +柏林智库SWP的客座研究员哈米德雷扎·阿齐兹 (Hamidreza Azizi) 指出,“这个保守派内部有不同阵营,包括立场更强硬的,以及被认为更务实的。” + +他认为,这将加剧目前新议会和地方各级职位争夺。 + +如何影响伊朗政局? +在伊朗,最高精神领袖拥有最终决策权,也控制了伊斯兰革命卫队(IRGC)。 + +因此,无论谁担任总统都权力有限,莱希遇难不会影响伊朗的政策方向或冲击内政。 + +但分析认为,这将考验保守强硬派目前主导的权力体制。 + +查塔姆研究所智库中东和北非项目主任萨纳姆·瓦基尔(Sanam Vakil)博士说:“伊朗政府将大肆宣扬莱希的死讯,并坚持宪法程序以显示其职能,同时寻找一位能维持保守派团结并忠于哈梅内伊的新成员。” + +莱希被指在上世纪80年代末大规模处决数千名政治犯,他的反对者庆祝其死讯,希望可以加速终结这个政权。 + +对于伊朗的执政保守派来说,他们知道全世界都在关注,将利用国葬释放出继续执政的信号。 + +德黑兰大学的穆罕默德·马兰迪(Mohammed Marandi)教授对BBC说:“40多年来在西方论述中,伊朗本应崩溃瓦解。但不知何故,它依然奇迹般地存在,而且我预测它在未来几年依然存在。” \ No newline at end of file diff --git "a/data/add/\344\274\212\346\234\227\346\200\273\347\273\237\350\216\261\345\270\214\345\217\212\345\244\232\344\275\215\351\253\230\347\272\247\345\256\230\345\221\230\351\201\207\351\232\276\347\232\204\347\233\264\345\215\207\346\234\272\344\272\213\346\225\205.txt" "b/data/add/\344\274\212\346\234\227\346\200\273\347\273\237\350\216\261\345\270\214\345\217\212\345\244\232\344\275\215\351\253\230\347\272\247\345\256\230\345\221\230\351\201\207\351\232\276\347\232\204\347\233\264\345\215\207\346\234\272\344\272\213\346\225\205.txt" new file mode 100644 index 0000000..00b7dbb --- /dev/null +++ "b/data/add/\344\274\212\346\234\227\346\200\273\347\273\237\350\216\261\345\270\214\345\217\212\345\244\232\344\275\215\351\253\230\347\272\247\345\256\230\345\221\230\351\201\207\351\232\276\347\232\204\347\233\264\345\215\207\346\234\272\344\272\213\346\225\205.txt" @@ -0,0 +1,21 @@ +越来越多的信息显示,5月19日导致伊朗总统莱希及多位高级官员遇难的直升机事故,可能是一场多重因素叠加的意外悲剧。据伊朗国家通讯社报道,伊朗武装部队总参谋长巴盖里已指派一高级代表团,调查导致莱希及其随行人员遇难的直升机坠毁事件。 + +图片 + +5月20日,伊朗首都德黑兰,一些民众在广场聚集,对伊朗总统莱希、外长阿卜杜拉希扬等高级官员在直升机事故中罹难表示哀悼。图/视觉中国5月20日,伊朗国家通讯社在报道中首次提到,这次直升机失事的原因可能是技术故障(technical failure)。同一天,伊朗前外交部长扎里夫在接受国家电视台采访时,指责美国对伊朗的航空零部件制裁损害了伊朗的航空能力,美国“是悲剧的幕后罪魁祸首之一”。这些报道是否意味着,伊朗官方已经初步排除了事故是爆炸或外部袭击的可能性?资深空难调查专家丹尼尔·阿杰库姆对《中国新闻周刊》指出,这是有可能的,因为调查人员可以通过残骸的分布情况得出初步结论。“就我们目前看到的影像来说,残骸都分布在一个较小的区域内。如果飞机在空中爆炸,则碎片的分布范围会更广,也会更零碎。”阿杰库姆是国际民航组织指定的飞行安全主题专家(ICAO SME),拥有25年军机及大型民航客机驾驶经验,曾担任多国空军、联合国维护部队及大型航空公司的飞行安全官员或顾问,多次参与空难调查。2007年,他曾作为空难调查员主持了加纳空军米-17运输直升机坠毁事件的调查工作。5月21日,阿杰库姆就伊朗总统坠机事故接受了《中国新闻周刊》专访。他坦言,直升机空难调查和一般的空难调查相比,难度更大;而美国对伊朗的航空部件制裁,确实会让调查工作变得“更麻烦”。 + +图片 + +丹尼尔·阿杰库姆。图/受访者提供 + +导弹袭击为何能被初步排除? + +中国新闻周刊:基于目前的已知信息和影像证据,关于坠机事故的成因,你从空难调查的角度能做出哪些判断?阿杰库姆:首先要强调的是,在本次事故中,确认坠机原因会是一件很有挑战的工作。这是一架老式直升机,所以我怀疑机上可能没有装有飞行数据记录仪和语音记录仪。其次,美国对伊朗实施关于飞机零部件的制裁已有很长一段时间。这意味着维护这架美制贝尔-212直升机时,伊朗无法从原始制造商那儿获得零件。不论是仿制还是黑市采购,这都意味着替换过的零部件上不会有原厂认证的序列号。这会让调查变得很麻烦,你将很难确认某块物体具体是什么部件。所以,我们只能基于目前的事实信息进行判断。这主要是伊朗媒体发布的现场影像。我认为这些影像总体上很像空难调查中所说的“可控飞行撞地”(飞机在由飞行员控制的情况下撞上地面、阻碍物或水面坠毁),但还不能简单地归因为自然因素还是机械因素。首先,我们可以看到,事发山区的能见度很差,有雾,且海拔较高。直升机不像固定翼飞机那样可以飞得很高,贝尔-212的最大飞行高度不足4000米。而且,在高海拔山区,由于气压、气流、空气稀薄等原因,直升机发动机的功率输出可能受到影响。在这种情况下,如果叠加一个机械故障,很容易产生严重的安全威胁。其次,我们还不清楚这架直升机上的导航系统如何、是否装备了地形防撞预警系统。当能见度不足以进行目视飞行时,飞行员需要依赖仪表飞行。这时,如果地形防撞预警系统探测到直升机正在接近障碍物,它将发出警报声,提醒飞行员避开面前的大山。问题是,这是一架老式直升机,而且即使装备了地形预警系统,该系统也需要搭配精确的GPS导航系统共同工作,才能发出准确的警报。我有一些朋友曾在伊朗、阿富汗的山区飞行,这些地区被称为GPS定位系统的“暗点”(dark spot),也就是说这里的定位信号不太可靠,而且有时候,即使很高频的信号也会被山脉阻断。我们还要考虑飞行员的心理压力。为非常尊贵的乘客(VVIP)服务时,飞行员的压力会很大,这将直接影响飞行决策。这并不是说来自总统“要求继续飞行”之类的直接压力,也包括无形、间接的压力。综合这些可能性,我们大致可以推测出,飞行员在恶劣的气象条件下迂回前进,可能迷失方向,可能存在机械故障,也可能存在一些人为处理不当,总之最终撞上了山坡。这是一个相对合理的解释。中国新闻周刊:现有证据能排除飞机遭遇直接袭击,比如炸弹或导弹攻击的可能性吗?阿杰库姆:调查人员可以通过残骸的分布情况得出初步结论,就我们目前看到的影像来说,残骸都分布在一个较小的区域内。如果飞机在空中爆炸,则碎片的分布范围会更广,也会更零碎。而本次事故的残骸显示出,飞机是在飞行的状态下撞击了山坡,坠机后燃起大火,烧毁了大部分残骸。当然,后续调查人员还可以通过对残留物进行检测等方式排除爆炸物或导弹袭击的可能性。 + +图片 + +5月20日,伊朗瓦尔扎汗,在浓雾笼罩的山区,救援人员在直升机坠毁现场搜寻。图/视觉中国 + +为何只有贝尔-212发生事故? + +中国新闻周刊:你谈到了恶劣的气象条件。在起飞前制订飞行计划、进行安全评估时,这种气象条件是否会被预测到?同行的两架米-171直升机顺利抵达目的地,是否意味着贝尔-212原本也有“逃出生天”的可能?阿杰库姆:通常来说,在进行任何一次飞行之前,飞行员都会做好飞行计划,包括考虑任务的可行性。对于直升机飞行员来说,在面对高海拔山区飞行时,他们应当对于使用仪表飞行进行评估。贝尔-212直升机也具备进行仪表飞行的条件。我推测,这架直升机上还加装了手持GPS之类的定位导航系统,以提升仪表飞行的精度。从这个角度说,能见度低并不意味着不能飞。在本次事件中,我们看到,和总统座机同期出发的另外两架米-171直升机安全抵达了目的地。我们还不清楚为什么它们可以顺利完成飞行,但据我所知,这些苏制直升机就是为高山极端条件而设计的,它们拥有强大的引擎;而美国、意大利等国制造的直升机,由于性能不足,在高海拔地区动力会有些“挣扎”。还有其他一些因素需要加以考虑。让我们假设这些飞行员都接受过遭遇恶劣天气时进行仪表飞行的训练。但是,这种训练需要每隔半年就进行一次检验,以保证飞行员对仪表飞行的熟练程度,随时做好准备。如果过去6个月,我没有在模拟器或飞行中进行过仪表飞行,那么突然进入一个山区大雾的环境,就会变得很麻烦。要知道,仪表飞行意味着,你失去了外部视觉参考,必须完全依赖仪表。没有目视意味着你可能会产生“空间迷向”,也就是你的身体感觉告诉你的姿态、方位,和仪表告诉你的信息不一样。你明明在转弯,但你的体感是在直着飞。所以,我认为仪表飞行的关键是一种“纪律训练”,你要足够约束自己、信赖仪表。这带来了另一个问题:如果这架贝尔-212上的仪表不值得信赖呢?在老式直升机上,如果姿态仪没有校准或者出现故障导致失灵,飞机又处于没有目视条件的情况,那么仪表就会将飞行员引导向灾难。作为空难调查员,我肯定会检查零部件的维修记录,首先检查这些关键设备是否在事发时正常工作。中国新闻周刊:你参与过直升机和其他类型飞机坠毁事件的调查。与其他类型航空器的空难相比,调查直升机坠毁事件有什么特别之处吗?阿杰库姆:直升机调查更加复杂,因为我们必须了解直升机的空气动力学。坦白说,作为一名固定翼飞机飞行员,当我调查直升机坠机事故时,我在很大程度上依赖于直升机飞行员们提供的专业知识。举例而言,直升机的结构布局和固定翼飞机完全不同,我们称之为“头重脚轻”。包括贝尔-212在内,大多数直升机的发动机、变速箱和传动装置都在机舱的顶部。这意味着,如果直升机撞向地面,这些重物都会因为重力而下坠到客舱中,直接砸向乘员。这意味着机舱里的人几乎没有逃生的机会。而当代民航客机的发动机主要分布于两翼,遇到极端情况时,它们会在遭受撞击后脱离机身,从而带走大部分动能,这在一定程度上有助于增强乘员的生存可能。总的来说,关于直升机的空难调查,一定需要相关机型飞行员和工程师的专业知识辅助。考虑到伊朗和美国的关系,我并不指望美国交通运输安全委员会(NTSB)会受邀参与到本次空难调查中。但即使排除美国人,加拿大、意大利、法国的空难调查机构中也有熟悉贝尔-212的专家。当然,是否邀请第三方专家参与调查工作,是由伊朗方面自己决定的。 \ No newline at end of file diff --git "a/data/add/\344\274\212\346\234\227\351\227\256\351\242\230.txt" "b/data/add/\344\274\212\346\234\227\351\227\256\351\242\230.txt" new file mode 100644 index 0000000..50bd24a --- /dev/null +++ "b/data/add/\344\274\212\346\234\227\351\227\256\351\242\230.txt" @@ -0,0 +1,36 @@ +伊朗问题专家指:莱西总统之死对该政权是一巨大战略政变 +伊朗总统易卜拉欣·莱西 (Ebrahim Raïssi )近日坠机去世。 绰号“德黑兰屠夫”的莱西的死讯受到伊朗反对派人士满意的欢迎,他们现在希望能够利用“伊朗即将开始的这段不确定和内部冲突时期”。 法国南方快讯大媒体(La Dépêche du Midi )采访了伊朗民主反对派 (CNRI) 的伊朗问题专家的政治学者哈米德·埃纳亚特 (Hamid Enayat)。 + +发表时间: 22/05/2024 - 19:45 + +9 分钟 +伊朗哀悼者参加在德黑兰举行的坠机身亡总统易卜拉欣·莱西 (Ebrahim Raïssi )的葬礼 2024 年 5 月 22 日 +伊朗哀悼者参加在德黑兰举行的坠机身亡总统易卜拉欣·莱西 (Ebrahim Raïssi )的葬礼 2024 年 5 月 22 日 © Atta Kenare / AFP +作者: +珍妮特 +广告 +南方快讯:对许多人来说,易卜拉欣·莱西的死不会改变任何事情,因为伊朗的真正主人是最高领袖阿里·哈梅内伊。 你的观点是什么? + +埃纳亚特 : 确实,易卜拉欣·莱西是刽子手,是伊朗政权的肮脏勾当的执行者。 领导人一直是毛拉的最高指导者阿里·哈梅内伊。但据此推断该政权将继续存在,那就是把自己当成假先知了。 我同意伊朗抵抗运动民意代表主席拉贾维女士的观点,她宣称莱西的去世是对伊朗政权的巨大打击,这可能具有战略意义。 哈梅内伊在政府发现自己与内部隔绝的危机背景下刚刚失去了他的代替扑克牌。卷入中东冲突,加上50%的通胀率消耗,伊朗政府也成为两年来民众抗议运动抵制的对象,这从其上次立法选举的第二轮投票时创纪录的弃权率(7%)得到了印证。当然,无论谁接替莱西,都必须遵循哈梅内伊制定的路线。 但伊朗将开始一段充满不确定性和内部冲突的时期。 + +南方快讯:2022年的民众起义遭到血腥镇压。 我们能否期待莱西总统去世后第二天出现新的革命动乱? + +埃纳亚特 :抗议活动并未消失。 它甚至通过采取更结构化和更有效的形式(特别是使用社交网络)得到了加强。 自2022年以来,我们目睹了伊朗年轻抵抗战士的一系列行动,他们攻击政权的象征,焚烧哈梅内伊的肖像,向革命卫队总部投掷燃烧弹。 + +这些勇敢的行动在国内层出不穷,并在政权内部播下了恐惧的种子,而政权无法遏制它们 + +南方快讯:“女性 ! 生活 ! 自由!”这是2022年学生马赫萨·阿米尼在维护道德警察检查中死亡后诞生的民众起义的口号。这场大火是突发的:起义之火因此已经燃烧了这么久? + +埃纳亚特 :四十年来,伊朗妇女一直站在反对伊斯兰共和国这个宗教和厌恶女性的独裁政权的斗争的最前线。 + +在 20 世纪 80 年代,成千上万的人为伊朗伊斯兰共和国的反对运动伊朗人民圣战者组织 (PMOI) 献出了生命。 + +伊朗前总统易卜拉欣·莱西在1988年担任死亡委员会成员时将他们送上了绞刑架……今天,他们鼓励年轻一代不要向政权低头 + +南方快讯:伊朗女性的生活状况如何? + +埃纳亚特 :在伊朗,作为一名女性就意味着成为二等公民。 任何时候,道德警察都可以逮捕你,因为你的一绺头发从面纱中露出来,或者你的衣服被认为是不雅的。 没有男人的建议,你没有行动的自由。 此外,大学的许多学科都禁止女性入学。 大多数伊朗女毕业生找不到工作。 总共有超过 2700 万人没有工作。 简而言之,这些妇女生活在天空敞开的监狱里 + +南方快讯:根据官方数据显示,2023年伊朗通胀率达到49.7%。 如今伊朗人的经济生活状况如何? + +埃纳亚特 :目前,75%的人口生活在贫穷线以下。人们遭受物资短缺和缺乏购买力的困扰。2015年,《维也纳协议》的签署使伊朗解除了对其核计划的控制权,从而解除了外国对其的经济制裁。 随之而来的是外国投资的增加。然而,该政权利用它们并不是为了提振国家经济,而是为了投资其导弹计划以及对中东的干扰行动。 结果,伊朗失业率上升。 大家都知道,这不是2018年美国禁运造成的,而是伊朗政府对经济管理不力造成的 \ No newline at end of file diff --git "a/data/docs/\344\274\212\346\234\227\346\200\273\347\273\237\347\275\271\351\232\276\344\272\213\344\273\266.txt" "b/data/docs/\344\274\212\346\234\227\346\200\273\347\273\237\347\275\271\351\232\276\344\272\213\344\273\266.txt" new file mode 100644 index 0000000..02ddd82 --- /dev/null +++ "b/data/docs/\344\274\212\346\234\227\346\200\273\347\273\237\347\275\271\351\232\276\344\272\213\344\273\266.txt" @@ -0,0 +1,32 @@ +伊朗总统在直升机事故中罹难 事件全过程梳理 +当地时间20日,伊朗总统莱希、外长阿卜杜拉希扬等高级官员确认在一起直升机事故中罹难。看下文,披露事故细节,详解莱希及阿卜杜拉希扬生平↓ + +“硬着陆事故" +当地时间19日,伊朗总统莱希与阿塞拜疆总统阿利耶夫一起参加了霍达阿法林县的水库大坝落成仪式。随后,莱希所乘直升机与另外两架直升机从霍达阿法林县前往伊朗东阿塞拜疆省省会大不里士,计划参加一个石化综合设施的落成典礼。 +莱希与伊朗外长阿卜杜拉希扬、东阿塞拜疆省省长拉赫马提等多名高级官员同乘一架直升机。直升机飞至伊朗东阿塞拜疆省瓦尔扎甘地区上空时,由于不利的天气条件发生硬着陆事故。 + +全力搜救 +事故地点附近为山区,地形复杂,且事发时天气恶劣,多雾寒冷,能见度低,搜救工作难度很大。 +事故发生后,伊朗第一副总统穆赫贝尔召开内阁会议,确定伊朗副总统曼苏尔和伊朗卫生部长前往事故现场,同时调集一切力量来跟进事件进展。伊朗军队、伊朗伊斯兰革命卫队及警察投入全部设施和力量进行搜寻工作。 +曼苏尔表示,此前曾与事故直升机上的人员建立沟通,但信号随后中断。 +伊朗红新月会应急管理指挥中心的消息显示,共有73支救援队参与到事故直升机的搜救工作,其中23支配备专业设备的救援队已经从首都德黑兰以及周边省份派出。 +直升机事故的消息传出后,土耳其、伊拉克、俄罗斯均发布声明,派遣救援人员和设备,努力为搜救行动提供各种支持。 +其中,土耳其国防部派遣了一架“游骑兵”无人机。这架无人机在当地时间20日早识别出热源,很可能是直升机发生事故的地点。该地确切坐标被提供给伊朗方面。 + + +“没有人员生还的迹象” +20日早,莱希乘坐的直升机残骸已找到,救援团队接近事故现场。 +伊朗救援团队使用无人机对直升机事故现场进行探测。救援人员说,直升机的螺旋桨以及驾驶舱被烧毁,通过热成像,救援人员无法探测到温度。通过对事故区域的搜索,没有人员生还的迹象。 +随后,迈赫尔通讯社、伊朗官方通讯社、伊朗国家电视台等伊朗媒体以及伊朗副总统确认了伊朗总统莱希及其他机上人员罹难的消息。 + +伊朗总统易卜拉欣·莱希和伊朗外交部长阿卜杜拉希扬生平 +伊朗伊斯兰共和国总统赛义德·易卜拉欣·莱希1960年出生在位于伊朗东北部的马什哈德。他曾担任伊朗国家监察组织主席、国家总检察长、司法总监等职。2021年6月19日,莱希当选伊朗总统,同年8月5日宣誓就职。莱希在就职演说中表示,人民希望新一届政府兑现承诺,伸张正义,打击腐败和反对歧视。新政府将努力通过解决经济问题改善民生。 +对外方面,莱希主张同邻国以及友好国家展开合作,以抵消美国制裁给伊朗带来的影响。 +此次发生事故的直升机除了载有伊朗总统莱希,还载有伊朗外交部长阿卜杜拉希扬等多名高级官员。伊朗外长阿卜杜拉希扬生于1964年,毕业于德黑兰大学国际关系专业。除了母语波斯语外,他还精通阿拉伯语和英语。阿卜杜拉希扬曾于2007年至2010年担任伊朗驻巴林大使,于2010年至2011年担任伊朗外交部波斯湾和中东司司长,于2011年至2016年任伊朗外交部阿拉伯和非洲事务部副部长,并于2021年起担任伊朗外交部长。 + + +伊朗第一副总统将经最高领袖批准暂代总统职责 +坠机事故发生后,伊朗最高领袖哈梅内伊表示,希望莱希及随行官员平安回家,同时国家和政府工作不会受影响。 +莱希等人确定罹难后,伊朗政府内阁举行了特别会议,将会宣布莱希等人的葬礼安排。 +伊朗政府内阁发布声明,“向国家和人民保证,将继续莱希总统的前进道路,国家的治理不会受到干扰”。 +根据伊朗宪法第131条,总统死亡、被免职、辞职、缺席或患病时间超过两个月,或总统任期结束由于某些原因而未选出新总统时,伊朗第一副总统应在最高领袖的批准下承担总统的权力和职责。由伊朗伊斯兰议会议长、司法总监和第一副总统组成的委员会有义务在最多五十天的时间内安排选举新总统。 diff --git a/exmaple.py b/exmaple.py new file mode 100644 index 0000000..74c55bf --- /dev/null +++ b/exmaple.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: exmaple.py +@time: 2024/05/22 +@contact: yanqiangmiffy@gamil.com +""" +from gomate.modules.document.reader import ReadFiles +from gomate.modules.generator.llm import GLMChat +from gomate.modules.retrieval.embedding import BgeEmbedding +from gomate.modules.store import VectorStore + +# step1:Document +docs = ReadFiles('./data/docs').get_content(max_token_len=600, cover_content=150) +vector = VectorStore(docs) + +# step2:Extract Embedding +embedding = BgeEmbedding("/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5") # 创建EmbeddingModel +vector.get_vector(EmbeddingModel=embedding) +vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库 +vector.load_vector(path='storage') # 加载本地的数据库 + +# step3:retrieval +question = '伊朗坠机事故原因是什么?' +contents = vector.query(question, EmbeddingModel=embedding, k=1) +content = '\n'.join(contents[:5]) +print(contents) + +# step4:QA +chat = GLMChat(path='/data/users/searchgpt/pretrained_models/chatglm3-6b') +print(chat.chat(question, [], content)) + +# step5 追加文档 +docs = ReadFiles('').get_content_by_file(file='data/add/伊朗问题.txt', max_token_len=600, cover_content=150) +vector.add_documents('storage', docs, embedding) +question = '如今伊朗人的经济生活状况如何?' +contents = vector.query(question, EmbeddingModel=embedding, k=1) +content = '\n'.join(contents[:5]) +print(contents) +print(chat.chat(question, [], content)) + diff --git a/gomate/applications/__init__.py b/gomate/applications/__init__.py index ad60ae1..e2a3511 100644 --- a/gomate/applications/__init__.py +++ b/gomate/applications/__init__.py @@ -1,2 +1,2 @@ -from .RewriterApp import RewriterApp -from .RerankerApp import RerankerApp \ No newline at end of file +from .rewriter import RewriterApp +from .rerank import RerankerApp \ No newline at end of file diff --git a/gomate/applications/rag.py b/gomate/applications/rag.py new file mode 100644 index 0000000..f84ead6 --- /dev/null +++ b/gomate/applications/rag.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: RagApplication.py +@time: 2024/05/20 +@contact: yanqiangmiffy@gamil.com +""" +from gomate.modules.document.reader import ReadFiles +from gomate.modules.generator.llm import GLMChat +from gomate.modules.retrieval.embedding import BgeEmbedding +from gomate.modules.store import VectorStore + +class RagApplication(): + def __init__(self, config): + self.config=config + self.vector_store = VectorStore([]) + self.llm = GLMChat(config.llm_model_name) + self.reader = ReadFiles(config.docs_path) + self.embedding_model = BgeEmbedding(config.embedding_model_name) + def init_vector_store(self): + docs=self.reader.get_content(max_token_len=600, cover_content=150) + self.vector_store.document=docs + self.vector_store.get_vector(EmbeddingModel=self.embedding_model) + self.vector_store.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库 + self.vector_store.load_vector(path='storage') # 加 + def load_vector_store(self): + self.vector_store.load_vector(path=self.config.vector_store_path) # 加载本地的数据库 + + def add_document(self, file_path): + docs = self.reader.get_content_by_file(file=file_path, max_token_len=512, cover_content=60) + self.vector_store.add_documents(self.config.vector_store_path, docs, self.embedding_model) + + def chat(self, question: str = '', topk: int = 5): + contents = self.vector_store.query(question, EmbeddingModel=self.embedding_model, k=topk) + content = '\n'.join(contents[:5]) + print(contents) + response, history = self.llm.chat(question, [], content) + return response, history,contents diff --git a/gomate/applications/RerankerApp.py b/gomate/applications/rerank.py similarity index 100% rename from gomate/applications/RerankerApp.py rename to gomate/applications/rerank.py diff --git a/gomate/applications/RewriterApp.py b/gomate/applications/rewriter.py similarity index 100% rename from gomate/applications/RewriterApp.py rename to gomate/applications/rewriter.py diff --git a/gomate/modules/citation/__init__.py b/gomate/modules/citation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/citation/match_citation.py b/gomate/modules/citation/match_citation.py new file mode 100644 index 0000000..058f44c --- /dev/null +++ b/gomate/modules/citation/match_citation.py @@ -0,0 +1,162 @@ +import re +from typing import List + +import jieba + + +class MatchCitation(object): + + def __init__(self): + self.stopwords = [ + line.strip() for line in open('./config/stopwords.txt').readlines() + ] + + def cut(self, para: str): + """""" + pattern = [ + '([。!?\?])([^”’])', # 单字符断句符 + '(\.{6})([^”’])', # 英文省略号 + '(\…{2})([^”’])', # 中文省略号 + '([。!?\?][”’])([^,。!?\?])' + ] + for i in pattern: + para = re.sub(i, r"\1\n\2", para) + para = para.rstrip() + return para.split("\n") + + def remove_stopwords(self, query: str): + for word in self.stopwords: + query = query.replace(word, " ") + return query + + def ground_response(self, + response: str, evidences: List[str], + selected_idx: List[int] = None, markdown: bool = False + ) -> List[dict]: + # {'type': 'default', 'texts': ['xxx', 'xxx']} + # {'type': 'quote', 'texts': ['1', '2']} + # if field == 'video': + # return [{'type': 'default', 'texts': [response]}] + + # Step 1: cut response into sentences, line break is removed + # print(response) + sentences = self.cut(response) + # print(sentences) + # get removed line break position + line_breaks = [] + sentences = [s for s in sentences if s] + for i in range(len(sentences) - 1): + current_index = response.index(sentences[i]) + next_sentence_index = response.index(sentences[i + 1]) + dummy_next_sentence_index = current_index + len(sentences[i]) + line_breaks.append(response[dummy_next_sentence_index:next_sentence_index]) + line_breaks.append('') + final_response = [] + + citations = [i + 1 for i in selected_idx] + paragraph_have_citation = False + paragraph = "" + for sentence, line_break in zip(sentences, line_breaks): + origin_sentence = sentence + paragraph += origin_sentence + sentence = self.remove_stopwords(sentence) + sentence_seg_cut = set(jieba.lcut(sentence)) + sentence_seg_cut_length = len(sentence_seg_cut) + if sentence_seg_cut_length <= 0: + continue + topk_evidences = [] + + for evidence, idx in zip(evidences, selected_idx): + evidence_cuts = self.cut(evidence) + for j in range(len(evidence_cuts)): + evidence_cuts[j] = self.remove_stopwords(evidence_cuts[j]) + evidence_seg_cut = set(jieba.lcut(evidence_cuts[j])) + overlap = sentence_seg_cut.intersection(evidence_seg_cut) + topk_evidences.append((len(overlap) / sentence_seg_cut_length, idx)) + + topk_evidences.sort(key=lambda x: x[0], reverse=True) + + idx = 0 + sentence_citations = [] + if len(sentence) > 20: + threshold = 0.4 + else: + threshold = 0.5 + + while (idx < len(topk_evidences)) and (topk_evidences[idx][0] > threshold): + paragraph_have_citation = True + sentence_citations.append(topk_evidences[idx][1] + 1) + if topk_evidences[idx][1] + 1 in citations: + citations.remove(topk_evidences[idx][1] + 1) + idx += 1 + + if sentence != sentences[-1] and line_break and line_break[0] == '\n' or sentence == sentences[-1] and len( + citations) == 0: + if not paragraph_have_citation and len(selected_idx) > 0: + topk_evidences = [] + for evidence, idx in zip(evidences, selected_idx): + evidence = self.remove_stopwords(evidence) + paragraph_seg = set(jieba.lcut(paragraph)) + evidence_seg = set(jieba.lcut(evidence)) + overlap = paragraph_seg.intersection(evidence_seg) + paragraph_seg_length = len(paragraph_seg) + topk_evidences.append((len(overlap) / paragraph_seg_length, idx)) + topk_evidences.sort(key=lambda x: x[0], reverse=True) + if len(paragraph) > 60: + threshold = 0.2 + else: + threshold = 0.3 + if topk_evidences[0][0] > threshold: + sentence_citations.append(topk_evidences[0][1] + 1) + if topk_evidences[0][1] + 1 in citations: + citations.remove(topk_evidences[0][1] + 1) + paragraph_have_citation = False + paragraph = "" + + # Add citation to response, need to consider the punctuation and line break + if origin_sentence[-1] not in [':', ':'] and len(origin_sentence) > 10 and len(sentence_citations) > 0: + sentence_citations = list(set(sentence_citations)) + if origin_sentence[-1] in ['。', ',', '!', '?', '.', ',', '!', '?', ':', ':']: + if markdown: + final_response.append( + origin_sentence[:-1] + ''.join([f'[{c}]' for c in sentence_citations]) + origin_sentence[ + -1]) + else: + final_response.append({'type': 'default', 'texts': [origin_sentence[:-1]]}) + final_response.append({'type': 'quote', 'texts': [str(c) for c in sentence_citations]}) + final_response.append({'type': 'default', 'texts': [origin_sentence[-1]]}) + else: + if markdown: + final_response.append(origin_sentence + ''.join([f'[{c}]' for c in sentence_citations])) + else: + final_response.append({'type': 'default', 'texts': [origin_sentence]}) + final_response.append({'type': 'quote', 'texts': [str(c) for c in sentence_citations]}) + else: + if markdown: + final_response.append(origin_sentence) + else: + final_response.append({'type': 'default', 'texts': [origin_sentence]}) + + if line_break: + if markdown: + final_response.append(line_break) + else: + final_response.append({'type': 'default', 'texts': [line_break]}) + if markdown: + final_response = ''.join(final_response) + return final_response + + +if __name__ == '__main__': + mc = MatchCitation() + + result = mc.ground_response( + response="巨齿鲨2是一部科幻冒险电影,由本·维特利执导,杰森·斯坦森、吴京、蔡书雅和克利夫·柯蒂斯主演。电影讲述了海洋霸主巨齿鲨,今夏再掀狂澜!乔纳斯·泰勒(杰森·斯坦森饰)与科学家张九溟(吴京饰)双雄联手,进入海底7000米深渊执行探索任务。他们意外遭遇史前巨兽海洋霸主巨齿鲨群的攻击,还将对战凶猛危险的远古怪兽群。惊心动魄的深渊冒险,巨燃巨爽的深海大战一触即发。", + evidences=[ + "海洋霸主巨齿鲨,今夏再掀狂澜!乔纳斯·泰勒(杰森·斯坦森 饰)与科学家张九溟(吴京 饰)双雄联手,进入海底7000米深渊执行探索任务。他们意外遭遇史前巨兽海洋霸主巨齿鲨群的攻击,还将对战凶猛危险的远古怪兽群。惊心动魄的深渊冒险,巨燃巨爽的深海大战一触即发", + "本·维特利 编剧:乔·霍贝尔埃里希·霍贝尔迪恩·乔格瑞斯 国家地区:中国 | 美国 发行公司:上海华人影业有限公司五洲电影发行有限公司中国电影股份有限公司北京电影发行分公司 出品公司:上海华人影业有限公司华纳兄弟影片公司北京登峰国际文化传播有限公司 更多片名:巨齿鲨2 剧情:海洋霸主巨齿鲨,今夏再掀狂澜!乔纳斯·泰勒(杰森·斯坦森 饰)与科学家张九溟(吴京 饰)双雄联手,进入海底7000米深渊执行探索任务。他们意外遭遇史前巨兽海洋霸主巨齿鲨群的攻击,还将对战凶猛危险的远古怪兽群。惊心动魄的深渊冒险,巨燃巨爽的深海大战一触即发……" + ], + selected_idx=[1, 2] + ) + + print(result) diff --git a/gomate/modules/document/__init__.py b/gomate/modules/document/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/document/parser.py b/gomate/modules/document/parser.py new file mode 100644 index 0000000..5b61f7e --- /dev/null +++ b/gomate/modules/document/parser.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: parser.py +@time: 2024/05/24 +@contact: yanqiangmiffy@gamil.com +""" diff --git a/gomate/modules/document/reader.py b/gomate/modules/document/reader.py new file mode 100644 index 0000000..5aacb4e --- /dev/null +++ b/gomate/modules/document/reader.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +document +""" + +import os +from typing import Dict, List, Optional, Tuple, Union + +import PyPDF2 +import markdown +import html2text +import json +from tqdm import tqdm +import tiktoken +from bs4 import BeautifulSoup +import re + + + +class ReadFiles: + """ + class to read files + """ + + def __init__(self, path: str='storage') -> None: + self._path = path + self.file_list = self.get_files() + + def get_files(self): + # args:dir_path,目标文件夹路径 + file_list = [] + for filepath, dirnames, filenames in os.walk(self._path): + # os.walk 函数将递归遍历指定文件夹 + for filename in filenames: + # 通过后缀名判断文件类型是否满足要求 + if filename.endswith(".md"): + # 如果满足要求,将其绝对路径加入到结果列表 + file_list.append(os.path.join(filepath, filename)) + elif filename.endswith(".txt"): + file_list.append(os.path.join(filepath, filename)) + elif filename.endswith(".pdf"): + file_list.append(os.path.join(filepath, filename)) + return file_list + + def get_content(self, max_token_len: int = 600, cover_content: int = 150): + docs = [] + # 读取文件内容 + for file in self.file_list: + content = self.read_file_content(file) + chunk_content = self.get_chunk( + content, max_token_len=max_token_len, cover_content=cover_content) + docs.extend(chunk_content) + return docs + + def get_content_by_file(self,file:str,max_token_len: int = 600, cover_content: int = 150): + docs = [] + + content = self.read_file_content(file) + chunk_content = self.get_chunk( + content, max_token_len=max_token_len, cover_content=cover_content) + docs.extend(chunk_content) + return docs + + @classmethod + def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150): + chunk_text = [] + + curr_len = 0 + curr_chunk = '' + + token_len = max_token_len - cover_content + lines = text.splitlines() # 假设以换行符分割文本为行 + + for line in lines: + line = line.replace(' ', '') + line_len = len(line) + if line_len > max_token_len: + # 如果单行长度就超过限制,则将其分割成多个块 + num_chunks = (line_len + token_len - 1) // token_len + for i in range(num_chunks): + start = i * token_len + end = start + token_len + # 避免跨单词分割 + while not line[start:end].rstrip().isspace(): + start += 1 + end += 1 + if start >= line_len: + break + curr_chunk = curr_chunk[-cover_content:] + line[start:end] + chunk_text.append(curr_chunk) + # 处理最后一个块 + start = (num_chunks - 1) * token_len + curr_chunk = curr_chunk[-cover_content:] + line[start:end] + chunk_text.append(curr_chunk) + + if curr_len + line_len <= token_len: + curr_chunk += line + curr_chunk += '\n' + curr_len += line_len + curr_len += 1 + else: + chunk_text.append(curr_chunk) + curr_chunk = curr_chunk[-cover_content:] + line + curr_len = line_len + cover_content + + if curr_chunk: + chunk_text.append(curr_chunk) + + return chunk_text + + @classmethod + def read_file_content(cls, file_path: str): + # 根据文件扩展名选择读取方法 + if file_path.endswith('.pdf'): + return cls.read_pdf(file_path) + elif file_path.endswith('.md'): + return cls.read_markdown(file_path) + elif file_path.endswith('.txt'): + return cls.read_text(file_path) + else: + raise ValueError("Unsupported file type") + + @classmethod + def read_pdf(cls, file_path: str): + # 读取PDF文件 + with open(file_path, 'rb') as file: + reader = PyPDF2.PdfReader(file) + text = "" + for page_num in range(len(reader.pages)): + text += reader.pages[page_num].extract_text() + return text + + @classmethod + def read_markdown(cls, file_path: str): + # 读取Markdown文件 + with open(file_path, 'r', encoding='utf-8') as file: + md_text = file.read() + html_text = markdown.markdown(md_text) + # 使用BeautifulSoup从HTML中提取纯文本 + soup = BeautifulSoup(html_text, 'html.parser') + plain_text = soup.get_text() + # 使用正则表达式移除网址链接 + text = re.sub(r'http\S+', '', plain_text) + return text + + @classmethod + def read_text(cls, file_path: str): + # 读取文本文件 + with open(file_path, 'r', encoding='utf-8') as file: + return file.read() + + +class Documents: + """ + 获取已分好类的json格式文档 + """ + + def __init__(self, path: str = '') -> None: + self.path = path + + def get_content(self): + with open(self.path, mode='r', encoding='utf-8') as f: + content = json.load(f) + return content \ No newline at end of file diff --git a/gomate/modules/generator/__init__.py b/gomate/modules/generator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/generator/base.py b/gomate/modules/generator/base.py new file mode 100644 index 0000000..46b7bf8 --- /dev/null +++ b/gomate/modules/generator/base.py @@ -0,0 +1,8 @@ +from abc import abstractmethod + + +class BaseLLM: + @abstractmethod + def generate(self, prompt: str) -> str: + """Generate text from a prompt using the given LLM backend.""" + # TODO \ No newline at end of file diff --git a/gomate/modules/generator/huggingface.py b/gomate/modules/generator/huggingface.py new file mode 100644 index 0000000..1d18e79 --- /dev/null +++ b/gomate/modules/generator/huggingface.py @@ -0,0 +1,27 @@ +from typing import Optional, TypedDict, cast + +from transformers import pipeline + +from gomate.modules.generator.base import BaseLLM + + +class Response(TypedDict): + """Typed description of the response from the Transformers model""" + + generated_text: str + + +class HuggingFaceLLM(BaseLLM): + def __init__( + self, model: str, do_sample: bool = False, token: Optional[str] = None + ): + self.model = model + self.do_sample = do_sample + self.pipeline = pipeline( + "text-generation", model=self.model, device_map="auto", token=token + ) + + def generate(self, prompt: str) -> str: + """Generate text from a prompt using the OpenAI API.""" + response = cast(Response, self.pipeline(prompt)[0]) + return response["generated_text"] diff --git a/gomate/modules/generator/llm.py b/gomate/modules/generator/llm.py new file mode 100644 index 0000000..1ad21eb --- /dev/null +++ b/gomate/modules/generator/llm.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: llm.py +@time: 2024/05/16 +@contact: yanqiangmiffy@gamil.com +@software: PyCharm +@description: coding.. +""" +import os +from typing import Dict, List, Optional, Tuple, Union, Any +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from openai import OpenAI + +PROMPT_TEMPLATE = dict( + RAG_PROMPT_TEMPALTE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 + 问题: {question} + 可参考的上下文: + ··· + {context} + ··· + 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 + 有用的回答:""", + InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 + 问题: {question} + 可参考的上下文: + ··· + {context} + ··· + 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 + 有用的回答:""", + GLM_PROMPT_TEMPALTE="""请结合参考的上下文内容回答用户问题,如果上下文不能支撑用户问题,那么回答不知道或者我无法根据参考信息回答。 + 问题: {question} + 可参考的上下文: + ··· + {context} + ··· + 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 + 有用的回答:""" +) + + +class BaseModel: + def __init__(self, path: str = '') -> None: + self.path = path + + def chat(self, prompt: str, history: List[dict], content: str) -> str: + pass + + def load_model(self): + pass + + +class OpenAIChat(BaseModel): + def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None: + super().__init__(path) + self.model = model + + def chat(self, prompt: str, history: List[dict], content: str) -> str: + client = OpenAI() + client.api_key = os.getenv("OPENAI_API_KEY") + client.base_url = os.getenv("OPENAI_BASE_URL") + history.append({'role': 'user', + 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)}) + response = client.chat.completions.create( + model=self.model, + messages=history, + max_tokens=150, + temperature=0.1 + ) + return response.choices[0].message.content + + +class InternLMChat(BaseModel): + def __init__(self, path: str = '') -> None: + super().__init__(path) + self.load_model() + + def chat(self, prompt: str, history: List = [], content: str = '') -> str: + prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content) + response, history = self.model.chat(self.tokenizer, prompt, history) + return response + + def load_model(self): + + self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, + trust_remote_code=True).cuda() +class GLMChat(BaseModel): + def __init__(self, path: str = '') -> None: + super().__init__(path) + self.load_model() + + def chat(self, prompt: str, history: List = [], content: str = '') -> tuple[Any, Any]: + prompt = PROMPT_TEMPLATE['GLM_PROMPT_TEMPALTE'].format(question=prompt, context=content) + response, history = self.model.chat(self.tokenizer, prompt, history) + return response, history + + def load_model(self): + + self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, + trust_remote_code=True).cuda() + +class DashscopeChat(BaseModel): + def __init__(self, path: str = '', model: str = "qwen-turbo") -> None: + super().__init__(path) + self.model = model + + def chat(self, prompt: str, history: List[Dict], content: str) -> str: + import dashscope + dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") + history.append({'role': 'user', + 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)}) + response = dashscope.Generation.call( + model=self.model, + messages=history, + result_format='message', + max_tokens=150, + temperature=0.1 + ) + return response.output.choices[0].message.content + + +class ZhipuChat(BaseModel): + def __init__(self, path: str = '', model: str = "glm-4") -> None: + super().__init__(path) + from zhipuai import ZhipuAI + self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) + self.model = model + + def chat(self, prompt: str, history: List[Dict], content: str) -> str: + history.append({'role': 'user', + 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)}) + response = self.client.chat.completions.create( + model=self.model, + messages=history, + max_tokens=150, + temperature=0.1 + ) + return response.choices[0].message \ No newline at end of file diff --git a/gomate/modules/generator/openai.py b/gomate/modules/generator/openai.py new file mode 100644 index 0000000..35a1c7e --- /dev/null +++ b/gomate/modules/generator/openai.py @@ -0,0 +1,47 @@ +from enum import Enum +from typing import List, TypedDict, Union + +from gomate.modules.generator.base import BaseLLM +from openai import OpenAI + + +class ModelType(str, Enum): + GPT_3_5_TURBO = "gpt-3.5-turbo" + GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106" + GPT_4_TURBO_1106 = "gpt-4-1106-preview" + + +class Message(TypedDict): + role: str + content: str + + +class Choice(TypedDict): + finish_reason: str + index: int + message: Message + + +class Response(TypedDict): + """Typed description of the response from the OpenAI API""" + + # TODO: Add other response fields. See: + choices: List[Choice] + + +class OpenAILLM(BaseLLM): + def __init__(self, model: Union[ModelType, str], OPENAI_API_KEY: str): + if isinstance(model, str): + model = ModelType(model) + self.model = model + self.client = OpenAI(api_key=OPENAI_API_KEY) + + def generate(self, prompt: str) -> str: + """Generate text from a prompt using the OpenAI API.""" + openai_response = self.client.chat.completions.create( + model=self.model.value, + messages=[{"role": "user", "content": prompt}], + ) + text = openai_response.choices[0].message.content + if text is None: + raise ValueError("OpenAI response was empty") diff --git a/gomate/modules/postprocess/__init__.py b/gomate/modules/postprocess/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/processor/__init__.py b/gomate/modules/processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/prompt/__init__.py b/gomate/modules/prompt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gomate/modules/reranker/reranker.py b/gomate/modules/reranker/reranker.py new file mode 100644 index 0000000..4b6e357 --- /dev/null +++ b/gomate/modules/reranker/reranker.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: reranker.py +@time: 2024/05/22 +@contact: yanqiangmiffy@gamil.com +@software: PyCharm +@description: coding.. +""" +from typing import List +import numpy as np +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +class BaseReranker: + """ + Base class for reranker + """ + + def __init__(self, path: str) -> None: + self.path = path + + def rerank(self, text: str, content: List[str], k: int) -> List[str]: + raise NotImplementedError + + +class BgeReranker(BaseReranker): + """ + class for Bge reranker + """ + + def __init__(self, path: str = 'BAAI/bge-reranker-base') -> None: + super().__init__(path) + self._model, self._tokenizer = self.load_model(path) + + def rerank(self, text: str, content: List[str], k: int) -> List[str]: + import torch + pairs = [(text, c) for c in content] + with torch.no_grad(): + inputs = self._tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512) + inputs = {k: v.to(self._model.device) for k, v in inputs.items()} + scores = self._model(**inputs, return_dict=True).logits.view(-1, ).float() + index = np.argsort(scores.tolist())[-k:][::-1] + return [content[i] for i in index] + + def load_model(self, path: str): + + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + tokenizer = AutoTokenizer.from_pretrained(path) + model = AutoModelForSequenceClassification.from_pretrained(path).to(device) + model.eval() + return model, tokenizer \ No newline at end of file diff --git a/gomate/modules/retrieval/__init__.py b/gomate/modules/retrieval/__init__.py new file mode 100644 index 0000000..53157e2 --- /dev/null +++ b/gomate/modules/retrieval/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: __init__.py.py +@time: 2024/05/27 +@contact: yanqiangmiffy@gamil.com +@software: PyCharm +@description: coding.. +""" diff --git a/gomate/modules/retrieval/embedding.py b/gomate/modules/retrieval/embedding.py new file mode 100644 index 0000000..c921d0f --- /dev/null +++ b/gomate/modules/retrieval/embedding.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: embedding.py +@contact: yanqiangmiffy@gamil.com +@software: PyCharm +""" +import os +from copy import copy +from typing import Dict, List, Optional, Tuple, Union +import numpy as np + +os.environ['CURL_CA_BUNDLE'] = '' +from dotenv import load_dotenv, find_dotenv +_ = load_dotenv(find_dotenv()) + + +class BaseEmbeddings: + """ + Base class for embeddings + """ + def __init__(self, path: str, is_api: bool) -> None: + self.path = path + self.is_api = is_api + + def get_embedding(self, text: str, model: str) -> List[float]: + raise NotImplementedError + + @classmethod + def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float: + """ + calculate cosine similarity between two vectors + """ + dot_product = np.dot(vector1, vector2) + magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) + if not magnitude: + return 0 + return dot_product / magnitude + + +class OpenAIEmbedding(BaseEmbeddings): + """ + class for OpenAI embeddings + """ + def __init__(self, path: str = '', is_api: bool = True) -> None: + super().__init__(path, is_api) + if self.is_api: + from openai import OpenAI + self.client = OpenAI() + self.client.api_key = os.getenv("OPENAI_API_KEY") + self.client.base_url = os.getenv("OPENAI_BASE_URL") + + def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]: + if self.is_api: + text = text.replace("\n", " ") + return self.client.embeddings.create(input=[text], model=model).data[0].embedding + else: + raise NotImplementedError + +class JinaEmbedding(BaseEmbeddings): + """ + class for Jina embeddings + """ + def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None: + super().__init__(path, is_api) + self._model = self.load_model() + + def get_embedding(self, text: str) -> List[float]: + return self._model.encode([text])[0].tolist() + + def load_model(self): + import torch + from transformers import AutoModel + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device) + return model + +class ZhipuEmbedding(BaseEmbeddings): + """ + class for Zhipu embeddings + """ + def __init__(self, path: str = '', is_api: bool = True) -> None: + super().__init__(path, is_api) + if self.is_api: + from zhipuai import ZhipuAI + self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) + + def get_embedding(self, text: str) -> List[float]: + response = self.client.embeddings.create( + model="embedding-2", + input=text, + ) + return response.data[0].embedding + +class DashscopeEmbedding(BaseEmbeddings): + """ + class for Dashscope embeddings + """ + def __init__(self, path: str = '', is_api: bool = True) -> None: + super().__init__(path, is_api) + if self.is_api: + import dashscope + dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") + self.client = dashscope.TextEmbedding + + def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]: + response = self.client.call( + model=model, + input=text + ) + return response.output['embeddings'][0]['embedding'] + + +class BgeEmbedding(BaseEmbeddings): + """ + class for BGE embeddings + """ + + def __init__(self, path: str = 'BAAI/bge-base-zh-v1.5', is_api: bool = False) -> None: + super().__init__(path, is_api) + self._model, self._tokenizer = self.load_model(path) + + def get_embedding(self, text: str) -> List[float]: + import torch + encoded_input = self._tokenizer([text], padding=True, truncation=True, return_tensors='pt') + encoded_input = {k: v.to(self._model.device) for k, v in encoded_input.items()} + with torch.no_grad(): + model_output = self._model(**encoded_input) + sentence_embeddings = model_output[0][:, 0] + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings[0].tolist() + + def load_model(self, path: str): + import torch + from transformers import AutoModel, AutoTokenizer + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + tokenizer = AutoTokenizer.from_pretrained(path) + model = AutoModel.from_pretrained(path).to(device) + model.eval() + return model, tokenizer \ No newline at end of file diff --git a/gomate/modules/retrieval/faiss.py b/gomate/modules/retrieval/faiss.py new file mode 100644 index 0000000..fa7f320 --- /dev/null +++ b/gomate/modules/retrieval/faiss.py @@ -0,0 +1,92 @@ +import logging +from pathlib import Path +from typing import Optional, Union + +from haystack.document_stores import FAISSDocumentStore +from haystack.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport( + "Install Faiss by running `pip install .[faiss-cpu]` or `faiss-gpu`" +) as faiss_import: + import faiss + + +class FastRAGFAISSStore(FAISSDocumentStore): + def __init__( + self, + sql_url: str = "sqlite:///faiss_document_store.db", + vector_dim: Optional[int] = None, + embedding_dim: int = 768, + faiss_index_factory_str: str = "Flat", + faiss_index=None, + return_embedding: bool = False, + index: str = "document", + similarity: str = "dot_product", + embedding_field: str = "embedding", + progress_bar: bool = True, + duplicate_documents: str = "overwrite", + faiss_index_path: Optional[Union[str, Path]] = None, + faiss_config_path: Optional[Union[str, Path]] = None, + isolation_level: Optional[str] = None, + n_links: int = 64, + ef_search: int = 20, + ef_construction: int = 80, + validate_index_sync: bool = True, + ): + faiss_import.check() + if faiss_index_path is None: + try: + file_path = sql_url.split("sqlite:///")[1] + logger.error(file_path) + except OSError: + pass + validate_index_sync = False + faiss_index_path = None + faiss_config_path = None + super().__init__( + sql_url=sql_url, + validate_index_sync=validate_index_sync, + vector_dim=vector_dim, + embedding_dim=embedding_dim, + faiss_index_factory_str=faiss_index_factory_str, + faiss_index=faiss_index, + return_embedding=return_embedding, + index=index, + similarity=similarity, + embedding_field=embedding_field, + progress_bar=progress_bar, + duplicate_documents=duplicate_documents, + isolation_level=isolation_level, + n_links=n_links, + ef_search=ef_search, + ef_construction=ef_construction, + ) + else: + validate_index_sync = True + sql_url = None + super().__init__( + faiss_index_path=faiss_index_path, + faiss_config_path=faiss_config_path, + ) + + total_gpus_to_use = faiss.get_num_gpus() + use_gpu = True if total_gpus_to_use > 0 else False + if use_gpu: + faiss_index_cpu = self.faiss_indexes["document"] + total_gpus = faiss.get_num_gpus() + if total_gpus_to_use is not None: + total_gpus_to_use = min(total_gpus, total_gpus_to_use) + faiss_index_gpu = faiss.index_cpu_to_gpus_list( + index=faiss_index_cpu, ngpu=total_gpus_to_use + ) + logger.info(f"Faiss index uses {total_gpus_to_use} gpus out of {total_gpus}") + else: + faiss_index_gpu = faiss.index_cpu_to_all_gpus(faiss_index_cpu) + logger.info(f"Faiss index uses all {total_gpus} gpus") + + if faiss_index_path is not None: + assert faiss_index_gpu.ntotal > 0 + logger.info(f"Faiss gpu index size: {faiss_index_gpu.ntotal}") + super().__init__(sql_url=sql_url, faiss_index=faiss_index_gpu) \ No newline at end of file diff --git a/gomate/modules/retrieval/vector.py b/gomate/modules/retrieval/vector.py new file mode 100644 index 0000000..341fc74 --- /dev/null +++ b/gomate/modules/retrieval/vector.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: vectorbase.py +@time: 2024/05/23 +@contact: yanqiangmiffy@gamil.com +""" +import json +import os +from typing import List + +import numpy as np +from tqdm import tqdm + +from gomate.modules.retrieval.embedding import BaseEmbeddings, BgeEmbedding + + +class VectorStore: + def __init__(self, document: List[str] = ['']) -> None: + self.document = document + + def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]: + + self.vectors = [] + for doc in tqdm(self.document, desc="Calculating embeddings"): + self.vectors.append(EmbeddingModel.get_embedding(doc)) + return self.vectors + + def persist(self, path: str = 'storage'): + if not os.path.exists(path): + os.makedirs(path) + with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f: + json.dump(self.document, f, indent=2, ensure_ascii=False) + if self.vectors: + with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f: + json.dump(self.vectors, f, indent=2, ensure_ascii=False) + + def load_vector(self, path: str = 'storage'): + with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f: + self.vectors = json.load(f) + with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f: + self.document = json.load(f) + + def get_similarity(self, vector1: List[float], vector2: List[float]) -> float: + return BaseEmbeddings.cosine_similarity(vector1, vector2) + + def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]: + query_vector = EmbeddingModel.get_embedding(query) + result = np.array([self.get_similarity(query_vector, vector) + for vector in self.vectors]) + return np.array(self.document)[result.argsort()[-k:][::-1]].tolist() + + def add_documents( + self, + path: str = 'storage', + documents: List[str] = [''], + EmbeddingModel: BaseEmbeddings = BgeEmbedding + ) -> List[List[float]]: + # load existed vector + self.load_vector(path) + for doc in documents: + self.document.append(doc) + self.vectors.append(EmbeddingModel.get_embedding(doc)) + print("len(self.document),len(self.vectors):", len(self.document), len(self.vectors)) + self.persist(path) diff --git a/requirements.txt b/requirements.txt index db5e972..cc19e3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,12 @@ pydocstyle == 2.1 openai == 1.10.0 datasets == 2.16.1 langchain == 0.1.4 -transformers == 4.37.2 +transformers == 4.41.1 torch == 2.2.0 pandas == 2.0.0 nltk == 3.8.1 +sentencepiece==0.2.0 +PyPDF2==3.0.1 +html2text +beautifulsoup4==4.12.3 +faiss-cpu \ No newline at end of file diff --git a/resources/demo.png b/resources/demo.png new file mode 100644 index 0000000..79a295b Binary files /dev/null and b/resources/demo.png differ diff --git a/resources/framework.png b/resources/framework.png new file mode 100644 index 0000000..ae6d9dc Binary files /dev/null and b/resources/framework.png differ