Skip to content

Commit

Permalink
feaute@支持多模态RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqiangmiffy committed Dec 27, 2024
1 parent f6d07c7 commit 67c01a6
Show file tree
Hide file tree
Showing 3 changed files with 2,333 additions and 0 deletions.
1,963 changes: 1,963 additions & 0 deletions examples/rag/multimodal_rag.ipynb

Large diffs are not rendered by default.

128 changes: 128 additions & 0 deletions trustrag/applications/rag_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import base64
from typing import List, Dict, Any
from zhipuai import ZhipuAI
from PIL import Image
from trustrag.modules.retrieval.multimodal_retriever import MultimodalRetriever,MultimodalRetrieverConfig

class MultimodalRAG:
def __init__(
self,
api_key: str,
retriever_config: MultimodalRetrieverConfig,
model_name: str = "glm-4v-plus",
top_k: int = 3
):
self.client = ZhipuAI(api_key=api_key)
self.retriever = MultimodalRetriever(retriever_config)
# self.retriever.load_index()
self.model_name = model_name
self.top_k = top_k

def _prepare_context(self, results: List[Dict[str, Any]]) -> str:
context = "基于以下相似图片信息:\n"
for idx, result in enumerate(results, 1):
context += f"{idx}. {result['text']} (相似度: {result['score']:.2f})\n"
return context

def _image_to_base64(self, image: Image) -> str:
# Convert the image to RGB mode if it's in RGBA mode
if image.mode == 'RGBA':
image = image.convert('RGB')

# Save the image to a BytesIO buffer in JPEG format
buffered = io.BytesIO()
image.save(buffered, format="JPEG")

# Encode the image data to base64 and return it as a string
return base64.b64encode(buffered.getvalue()).decode('utf-8')

def chat(self, query: str, include_images: bool = True) -> str:
# 1. 检索相似内容
results = self.retriever.retrieve(query, top_k=self.top_k)

# 2. 准备提示信息
context = self._prepare_context(results)
full_prompt = f"{context}\n用户问题: {query}\n请基于用户提供的图片和上述图片信息回答问题。"

# 3. 准备消息内容
messages = [{"role": "user", "content": []}]

# 4. 如果需要,添加检索到的图片
if include_images:
for result in results:
img_base64 = self._image_to_base64(result['image'])
messages[0]["content"].append({
"type": "image_url",
"image_url": {"url": img_base64}
})

messages[0]["content"].append({"type": "text", "text": full_prompt})
# 5. 调用API获取回答
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages
)

return results, response.choices[0].message.content

def chat_with_image(self, query: str, image_path: str) -> str:
# 1. 读取和编码用户提供的图片
with open(image_path, 'rb') as img_file:
user_img_base64 = base64.b64encode(img_file.read()).decode('utf-8')

# 2. 检索相似内容
results = self.retriever.retrieve(query, top_k=self.top_k)

# 3. 准备提示信息
context = self._prepare_context(results)
full_prompt = f"{context}\n用户问题: {query}\n请基于用户提供的图片和上述相似图片信息回答问题。"

# 4. 准备消息内容,首先添加用户的图片
messages = [{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": user_img_base64}
},
{
"type": "text",
"text": full_prompt
}
]
}]

# 5. 添加检索到的相似图片
for result in results:
img_base64 = self._image_to_base64(result['image'])
messages[0]["content"].append({
"type": "image_url",
"image_url": {"url": img_base64}
})

# 6. 调用API获取回答
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages
)

return response.choices[0].message.content

if __name__ == '__main__':
# 初始化配置
car_retriever_config = MultimodalRetrieverConfig(
model_name='ViT-B-16',
index_path='./index_car',
batch_size=32,
dim=512,
download_root="data/chinese-clip-vit-base-patch16/"
)
# 初始化
car_rag = MultimodalRAG(
api_key="xxx",
retriever_config=car_retriever_config,
top_k=1
)
query_text = "冷却系统检查"
retrieve_results, response = car_rag.chat(query_text)
retrieve_results, response
242 changes: 242 additions & 0 deletions trustrag/modules/retrieval/multimodal_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import torch
import numpy as np
import faiss
import os
import gc
from PIL import Image
import base64
from io import BytesIO
import cn_clip.clip as clip
from typing import List, Tuple, Union, Dict
from tqdm import tqdm
import matplotlib.pyplot as plt


class MultimodalRetrieverConfig():
"""
Configuration class for Multimodal Retriever.
Attributes:
model_name (str): Name of the CLIP model variant (e.g., 'ViT-B-16').
dim (int): Dimension of the CLIP embeddings (768 for ViT-B-16).
index_path (str): Path to save or load the FAISS index.
download_root (str): Directory for downloading CLIP models.
batch_size (int): Batch size for processing multiple documents.
"""

def __init__(
self,
model_name='ViT-B-16',
dim=768,
index_path='./index',
download_root='./',
batch_size=32
):
self.model_name = model_name
self.dim = dim
self.index_path = index_path
self.download_root = download_root
self.batch_size = batch_size

def validate(self):
"""Validate Multimodal configuration parameters."""
if not isinstance(self.model_name, str) or not self.model_name:
raise ValueError("Model name must be a non-empty string.")
if not isinstance(self.dim, int) or self.dim <= 0:
raise ValueError("Dimension must be a positive integer.")
if not isinstance(self.index_path, str):
raise ValueError("Index path must be a string.")
if not isinstance(self.download_root, str):
raise ValueError("Download root must be a string.")
if not isinstance(self.batch_size, int) or self.batch_size <= 0:
raise ValueError("Batch size must be a positive integer.")
print("Multimodal configuration is valid.")


class MultimodalRetriever():
def __init__(self, config):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = clip.load_from_name(
config.model_name,
device=self.device,
download_root=config.download_root
)
self.model.eval()
self.dim = config.dim # CLIP embedding dimension
self.index = faiss.IndexFlatIP(self.dim)
self.embeddings = []
self.documents = [] # List to store (image_path, text) pairs
self.num_documents = 0
self.index_path = config.index_path
self.batch_size = config.batch_size

def convert_base642image(self, image_base64):
image_data = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_data))
return image

def merge_mm_embeddings(self, img_emb=None, text_emb=None):
if text_emb is not None and img_emb is not None:
return np.mean([img_emb, text_emb], axis=0)
elif text_emb is not None:
return text_emb
elif img_emb is not None:
return img_emb
raise ValueError("Must specify one of `img_emb` or `text_emb`")

def _embed(self, image=None, text=None) -> np.ndarray:
if image is None and text is None:
raise ValueError("Must specify one of image or text")

img_emb = None
text_emb = None

if image is not None:
image = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
img_emb = self.model.encode_image(image)
img_emb /= img_emb.norm(dim=-1, keepdim=True)
img_emb = img_emb.cpu().numpy()

if text is not None:
text = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_emb = self.model.encode_text(text)
text_emb /= text_emb.norm(dim=-1, keepdim=True)
text_emb = text_emb.cpu().numpy()

return self.merge_mm_embeddings(img_emb, text_emb)

def add_image_text(self, image: Union[str, Image.Image], text: str):
"""Add a single image-text pair to the index."""
if isinstance(image, str):
image = self.convert_base642image(image_base64=image)

emb = self._embed(image=image, text=text).astype('float32')
self.index.add(emb)
self.embeddings.append(emb)
self.documents.append((image, text))
self.num_documents += 1

def build_from_pairs(self, img_text_pairs: List[Tuple[Union[str, Image.Image], str]]):
"""Build index from image-text pairs in batches."""
if not img_text_pairs:
raise ValueError("Image-text pairs list is empty")

for i in tqdm(range(0, len(img_text_pairs), self.batch_size), desc="Building index"):
batch = img_text_pairs[i:i + self.batch_size]
for img, text in batch:
self.add_image_text(img, text)

def save_index(self, index_path: str = None):
"""Save the index, embeddings, and document pairs."""
if not (self.index and self.embeddings and self.documents):
raise ValueError("No data to save")

if index_path is None:
index_path = self.index_path

os.makedirs(index_path, exist_ok=True)

# Save embeddings and document information
np.savez(
os.path.join(index_path, 'multimodal.vecstore'),
embeddings=np.array(self.embeddings),
documents=np.array(self.documents, dtype=object)
)

# Save FAISS index
faiss.write_index(self.index, os.path.join(index_path, 'multimodal.index'))
print(f"Index saved successfully to {index_path}")

def load_index(self, index_path: str = None):
"""Load the index, embeddings, and document pairs."""
if index_path is None:
index_path = self.index_path

# Load document data
data = np.load(os.path.join(index_path, 'multimodal.vecstore.npz'),
allow_pickle=True)
self.documents = data['documents'].tolist()
self.embeddings = data['embeddings'].tolist()

# Load FAISS index
self.index = faiss.read_index(os.path.join(index_path, 'multimodal.index'))
self.num_documents = len(self.documents)

print(f"Index loaded successfully from {index_path}")
del data
gc.collect()

def retrieve(self, query: Union[str, Image.Image], top_k: int = 5) -> List[Dict]:
"""Retrieve top_k most relevant image-text pairs."""
if self.index is None or self.num_documents == 0:
raise ValueError("Index is empty or not initialized")

# Generate query embedding
query_embedding = self._embed(
image=query if isinstance(query, Image.Image) else None,
text=query if isinstance(query, str) else None
).astype('float32')

# Search index
D, I = self.index.search(query_embedding, min(top_k, self.num_documents))

# Return results with scores
results = []
for idx, score in zip(I[0], D[0]):
image, text = self.documents[idx]
results.append({
'image': image,
'text': text,
'score': float(score)
})

return results

def plot_results(self, query: Union[str, Image.Image], results: List[Dict], font_path: str = None):
"""
Plot query and retrieval results with dynamic sizing and font support.
Args:
query: Text string or PIL Image
results: List of retrieval results
font_path: Path to font file for Chinese text support
"""
# plt.close('all') # Close any existing figures
n_results = len(results)
# Dynamic figure size: base width (3) for query + width for each result (3)
figsize = (3 * (n_results + 1), 4)

fig = plt.figure(figsize=figsize)

# Set font for Chinese characters if provided
if font_path:
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=font_path)
else:
font = None

# Plot query
if isinstance(query, str):
ax = plt.subplot(1, n_results + 1, 1)
ax.text(0.5, 0.5, f"Query Text:\n{query}",
ha='center', va='center', wrap=True,
fontproperties=font)
ax.axis('off')
else:
plt.subplot(1, n_results + 1, 1)
plt.imshow(query)
plt.title("Query Image", fontproperties=font)
plt.axis('off')

# Plot results
for idx, result in enumerate(results, 1):
plt.subplot(1, n_results + 1, idx + 1)
plt.imshow(result['image'])
plt.title(f"Score: {result['score']:.3f}\n{result['text']}",
pad=10, fontproperties=font)
plt.axis('off')

plt.tight_layout()
# return fig

0 comments on commit 67c01a6

Please sign in to comment.