From d00b02397233f6118dbcb4f681cedd172ec9d062 Mon Sep 17 00:00:00 2001 From: yanqiangmiffy <1185918903@qq.com> Date: Tue, 14 Jan 2025 11:56:01 +0800 Subject: [PATCH] feature@update llmopenai demo --- examples/llms/deepseek_chat_example.py | 29 ++ trustrag/applications/rag_openai.py | 88 ++++++ trustrag/modules/generator/chat.py | 412 +------------------------ 3 files changed, 131 insertions(+), 398 deletions(-) create mode 100644 examples/llms/deepseek_chat_example.py create mode 100644 trustrag/applications/rag_openai.py diff --git a/examples/llms/deepseek_chat_example.py b/examples/llms/deepseek_chat_example.py new file mode 100644 index 0000000..cdae473 --- /dev/null +++ b/examples/llms/deepseek_chat_example.py @@ -0,0 +1,29 @@ +from trustrag.modules.generator.chat import DeepSeekChat + +if __name__ == '__main__': + + api_key = "替换为你的 DeepSeek API Key" # 替换为你的 DeepSeek API Key + deepseek_chat = DeepSeekChat(key=api_key) + + system_prompt = "You are a helpful assistant." + history = [ + {"role": "user", "content": "Hello"} + ] + gen_conf = { + "temperature": 0.7, + "max_tokens": 100 + } + + # 调用 chat 方法进行对话 + + response, total_tokens = deepseek_chat.chat(system=system_prompt, history=history, gen_conf=gen_conf) + print("Response:", response) + print("Total Tokens:", total_tokens) + + # 调用 chat_streamly 方法进行流式对话 + + for response in deepseek_chat.chat_streamly(system=system_prompt, history=history, gen_conf=gen_conf): + if isinstance(response, str): + print("Stream Response:", response) + else: + print("Total Tokens:", response) \ No newline at end of file diff --git a/trustrag/applications/rag_openai.py b/trustrag/applications/rag_openai.py new file mode 100644 index 0000000..104ee6c --- /dev/null +++ b/trustrag/applications/rag_openai.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding:utf-8 _*- +""" +@author:quincy qiang +@license: Apache Licence +@file: RagApplication.py +@time: 2024/05/20 +@contact: yanqiangmiffy@gamil.com +@description:use openai etc. llm service demo +""" +import os +from trustrag.modules.document.common_parser import CommonParser +from trustrag.modules.generator.chat import DeepSeekChat +from trustrag.modules.reranker.bge_reranker import BgeReranker +from trustrag.modules.retrieval.dense_retriever import DenseRetriever +from trustrag.modules.document.chunk import TextChunker + +class ApplicationConfig(): + def __init__(self): + self.retriever_config = None + self.rerank_config = None + + +class RagApplication(): + def __init__(self, config): + self.config = config + self.parser = CommonParser() + self.retriever = DenseRetriever(self.config.retriever_config) + self.reranker = BgeReranker(self.config.rerank_config) + self.llm = DeepSeekChat(key=self.config.your_key) + self.tc=TextChunker() + self.rag_prompt="""请结合参考的上下文内容回答用户问题,如果上下文不能支撑用户问题,那么回答不知道或者我无法根据参考信息回答。 + 问题: {question} + 可参考的上下文: + ··· + {context} + ··· + 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 + 有用的回答:""" + def init_vector_store(self): + """ + + """ + print("init_vector_store ... ") + all_paragraphs = [] + all_chunks = [] + for filename in os.listdir(self.config.docs_path): + file_path = os.path.join(self.config.docs_path, filename) + try: + paragraphs=self.parser.parse(file_path) + all_paragraphs.append(paragraphs) + except: + pass + print("chunking for paragraphs") + for paragraphs in all_paragraphs: + chunks=self.tc.get_chunks(paragraphs, 256) + all_chunks.extend(chunks) + self.retriever.build_from_texts(all_chunks) + print("init_vector_store done! ") + self.retriever.save_index(self.config.retriever_config.index_path) + + def load_vector_store(self): + self.retriever.load_index(self.config.retriever_config.index_path) + + def add_document(self, file_path): + chunks = self.parser.parse(file_path) + for chunk in chunks: + self.retriever.add_text(chunk) + print("add_document done!") + + def chat(self, question: str = '', top_k: int = 5): + contents = self.retriever.retrieve(query=question, top_k=top_k) + contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents]) + content = '\n'.join([content['text'] for content in contents]) + + system_prompt = "你是一个人工智能助手." + prompt = self.rag_prompt.format(system_prompt=system_prompt, question=question,context=content) + history = [ + {"role": "user", "content": prompt} + ] + gen_conf = { + "temperature": 0.7, + "max_tokens": 100 + } + + # 调用 chat 方法进行对话 + result = self.llm.chat(system=system_prompt, history=history, gen_conf=gen_conf) + return result, history, contents diff --git a/trustrag/modules/generator/chat.py b/trustrag/modules/generator/chat.py index f201e6b..db77735 100644 --- a/trustrag/modules/generator/chat.py +++ b/trustrag/modules/generator/chat.py @@ -9,30 +9,19 @@ @software: PyCharm @description: coding.. """ -from zhipuai import ZhipuAI -from dashscope import Generation from abc import ABC from openai import OpenAI import openai -from ollama import Client -from volcengine.maas.v2 import MaasService -import tiktoken import re +from langdetect import detect + def is_english(texts): - eng = 0 - if not texts: return False - for t in texts: - if re.match(r"[a-zA-Z]{2,}", t.strip()): - eng += 1 - if eng / len(texts) > 0.8: - return True - return False -encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") + try: + # 检测文本的语言 + return detect(str(texts[0])) == 'en' + except: + return False -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - num_tokens = len(encoder.encode(string)) - return num_tokens class Base(ABC): def __init__(self, key, model_name, base_url): self.client = OpenAI(api_key=key, base_url=base_url) @@ -80,6 +69,13 @@ def chat_streamly(self, system, history, gen_conf): yield total_tokens +class DeepSeekChat(Base): + def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com"): + if not base_url: base_url="https://api.deepseek.com" + super().__init__(key, model_name, base_url) + + + class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): if not base_url: base_url="https://api.openai.com/v1" @@ -98,386 +94,6 @@ def __init__(self, key=None, model_name="", base_url=""): super().__init__(key, model_name, base_url) -class DeepSeekChat(Base): - def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): - if not base_url: base_url="https://api.deepseek.com/v1" - super().__init__(key, model_name, base_url) - - -class BaiChuanChat(Base): - def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"): - if not base_url: - base_url = "https://api.baichuan-ai.com/v1" - super().__init__(key, model_name, base_url) - - @staticmethod - def _format_params(params): - return { - "temperature": params.get("temperature", 0.3), - "max_tokens": params.get("max_tokens", 2048), - "top_p": params.get("top_p", 0.85), - } - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={ - "tools": [{ - "type": "web_search", - "web_search": { - "enable": True, - "search_mode": "performance_first" - } - }] - }, - **self._format_params(gen_conf)) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except openai.APIError as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - ans = "" - total_tokens = 0 - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={ - "tools": [{ - "type": "web_search", - "web_search": { - "enable": True, - "search_mode": "performance_first" - } - }] - }, - stream=True, - **self._format_params(gen_conf)) - for resp in response: - if resp.choices[0].finish_reason == "stop": - if not resp.choices[0].delta.content: - continue - total_tokens = resp.usage.get('total_tokens', 0) - if not resp.choices[0].delta.content: - continue - ans += resp.choices[0].delta.content - if resp.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - yield ans - - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens - - -class QWenChat(Base): - def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): - import dashscope - dashscope.api_key = key - self.model_name = model_name - - def chat(self, system, history, gen_conf): - from http import HTTPStatus - if system: - history.insert(0, {"role": "system", "content": system}) - response = Generation.call( - self.model_name, - messages=history, - result_format='message', - **gen_conf - ) - ans = "" - tk_count = 0 - if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]['message']['content'] - tk_count += response.usage.total_tokens - if response.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, tk_count - - return "**ERROR**: " + response.message, tk_count - - def chat_streamly(self, system, history, gen_conf): - from http import HTTPStatus - if system: - history.insert(0, {"role": "system", "content": system}) - ans = "" - tk_count = 0 - try: - response = Generation.call( - self.model_name, - messages=history, - result_format='message', - stream=True, - **gen_conf - ) - for resp in response: - if resp.status_code == HTTPStatus.OK: - ans = resp.output.choices[0]['message']['content'] - tk_count = resp.usage.total_tokens - if resp.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - yield ans - else: - yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - -class ZhipuChat(Base): - def __init__(self, key, model_name="glm-3-turbo", **kwargs): - self.client = ZhipuAI(api_key=key) - self.model_name = model_name - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf - ) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] - ans = "" - tk_count = 0 - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - stream=True, - **gen_conf - ) - for resp in response: - if not resp.choices[0].delta.content:continue - delta = resp.choices[0].delta.content - ans += delta - if resp.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - tk_count = resp.usage.total_tokens - if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - -class OllamaChat(Base): - def __init__(self, key, model_name, **kwargs): - self.client = Client(host=kwargs["base_url"]) - self.model_name = model_name - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - options = {} - if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] - if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] - response = self.client.chat( - model=self.model_name, - messages=history, - options=options, - keep_alive=-1 - ) - ans = response["message"]["content"].strip() - return ans, response["eval_count"] + response.get("prompt_eval_count", 0) - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - options = {} - if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] - if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] - if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] - ans = "" - try: - response = self.client.chat( - model=self.model_name, - messages=history, - stream=True, - options=options, - keep_alive=-1 - ) - for resp in response: - if resp["done"]: - yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) - ans += resp["message"]["content"] - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - yield 0 - - -class LocalLLM(Base): - class RPCProxy: - def __init__(self, host, port): - self.host = host - self.port = int(port) - self.__conn() - - def __conn(self): - from multiprocessing.connection import Client - self._connection = Client( - (self.host, self.port), authkey=b'infiniflow-token4kevinhu') - - def __getattr__(self, name): - import pickle - - def do_rpc(*args, **kwargs): - for _ in range(3): - try: - self._connection.send( - pickle.dumps((name, args, kwargs))) - return pickle.loads(self._connection.recv()) - except Exception as e: - self.__conn() - raise Exception("RPC connection lost!") - - return do_rpc - - def __init__(self, key, model_name="glm-3-turbo"): - self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - ans = self.client.chat( - history, - gen_conf - ) - return ans, num_tokens_from_string(ans) - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - token_count = 0 - answer = "" - try: - for ans in self.client.chat_streamly(history, gen_conf): - answer += ans - token_count += 1 - yield answer - except Exception as e: - yield answer + "\n**ERROR**: " + str(e) - - yield token_count - -class VolcEngineChat(Base): - def __init__(self, key, model_name, base_url): - """ - Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, - Assemble ak, sk, ep_id into api_key, store it as a dictionary type, and parse it for use - model_name is for display only - """ - self.client = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing') - self.volc_ak = eval(key).get('volc_ak', '') - self.volc_sk = eval(key).get('volc_sk', '') - self.client.set_ak(self.volc_ak) - self.client.set_sk(self.volc_sk) - self.model_name = eval(key).get('ep_id', '') - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - req = { - "parameters": { - "min_new_tokens": gen_conf.get("min_new_tokens", 1), - "top_k": gen_conf.get("top_k", 0), - "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), - "temperature": gen_conf.get("temperature", 0.1), - "max_new_tokens": gen_conf.get("max_tokens", 1000), - "top_p": gen_conf.get("top_p", 0.3), - }, - "messages": history - } - response = self.client.chat(self.model_name, req) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.total_tokens - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - ans = "" - tk_count = 0 - try: - req = { - "parameters": { - "min_new_tokens": gen_conf.get("min_new_tokens", 1), - "top_k": gen_conf.get("top_k", 0), - "max_prompt_tokens": gen_conf.get("max_prompt_tokens", 30000), - "temperature": gen_conf.get("temperature", 0.1), - "max_new_tokens": gen_conf.get("max_tokens", 1000), - "top_p": gen_conf.get("top_p", 0.3), - }, - "messages": history - } - stream = self.client.stream_chat(self.model_name, req) - for resp in stream: - if not resp.choices[0].message.content: - continue - ans += resp.choices[0].message.content - if resp.choices[0].finish_reason == "stop": - tk_count = resp.usage.total_tokens - yield ans - - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - yield tk_count -class MiniMaxChat(Base): - def __init__(self, key, model_name="abab6.5s-chat", - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"): - if not base_url: - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" - super().__init__(key, model_name, base_url) \ No newline at end of file