Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline #91

Merged
merged 2 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/engine/qrant_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from trustrag.modules.engine.qdrant import QdrantEngine
from trustrag.modules.engine.qdrant import SentenceTransformerEmbedding
if __name__ == "__main__":
# Initialize embedding generators
local_embedding_generator = SentenceTransformerEmbedding(model_name_or_path="all-MiniLM-L6-v2", device="cpu")
# openai_embedding_generator = OpenAIEmbedding(api_key="your_key", base_url="https://ark.cn-beijing.volces.com/api/v3", model="your_model_id")

# Initialize QdrantEngine with local embedding generator
qdrant_engine = QdrantEngine(
collection_name="startups",
embedding_generator=local_embedding_generator,
qdrant_client_params={"host": "192.168.1.5", "port": 6333},
)

documents=[
{"name": "SaferCodes", "images": "https:\/\/safer.codes\/img\/brand\/logo-icon.png",
"alt": "SaferCodes Logo QR codes generator system forms for COVID-19",
"description": "QR codes systems for COVID-19.\nSimple tools for bars, restaurants, offices, and other small proximity businesses.",
"link": "https:\/\/safer.codes", "city": "Chicago"},
{"name": "Human Practice",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/373036-94d1e190f12f2c919c3566ecaecbda68-thumb_jpg.jpg?buster=1396498835",
"alt": "Human Practice - health care information technology",
"description": "Point-of-care word of mouth\nPreferral is a mobile platform that channels physicians\u2019 interest in networking with their peers to build referrals within a hospital system.\nHospitals are in a race to employ physicians, even though they lose billions each year ($40B in 2014) on employment. Why ...",
"link": "http:\/\/humanpractice.com", "city": "Chicago"},
{"name": "StyleSeek",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/3747-bb0338d641617b54f5234a1d3bfc6fd0-thumb_jpg.jpg?buster=1329158692",
"alt": "StyleSeek - e-commerce fashion mass customization online shopping",
"description": "Personalized e-commerce for lifestyle products\nStyleSeek is a personalized e-commerce site for lifestyle products.\nIt works across the style spectrum by enabling users (both men and women) to create and refine their unique StyleDNA.\nStyleSeek also promotes new products via its email newsletter, 100% personalized ...",
"link": "http:\/\/styleseek.com", "city": "Chicago"},
{"name": "Scout",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/190790-dbe27fe8cda0614d644431f853b64e8f-thumb_jpg.jpg?buster=1389652078",
"alt": "Scout - security consumer electronics internet of things",
"description": "Hassle-free Home Security\nScout is a self-installed, wireless home security system. We've created a more open, affordable and modern system than what is available on the market today. With month-to-month contracts and portable devices, Scout is a renter-friendly solution for the other ...",
"link": "http:\/\/www.scoutalarm.com", "city": "Chicago"},
{"name": "Invitation codes", "images": "https:\/\/invitation.codes\/img\/inv-brand-fb3.png",
"alt": "Invitation App - Share referral codes community ",
"description": "The referral community\nInvitation App is a social network where people post their referral codes and collect rewards on autopilot.",
"link": "https:\/\/invitation.codes", "city": "Chicago"},
{"name": "Hyde Park Angels",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/61114-35cd9d9689b70b4dc1d0b3c5f11c26e7-thumb_jpg.jpg?buster=1427395222",
"alt": "Hyde Park Angels - ",
"description": "Hyde Park Angels is the largest and most active angel group in the Midwest. With a membership of over 100 successful entrepreneurs, executives, and venture capitalists, the organization prides itself on providing critical strategic expertise to entrepreneurs and ...",
"link": "http:\/\/hydeparkangels.com", "city": "Chicago"},
{"name": "GiveForward",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/1374-e472ccec267bef9432a459784455c133-thumb_jpg.jpg?buster=1397666635",
"alt": "GiveForward - health care startups crowdfunding",
"description": "Crowdfunding for medical and life events\nGiveForward lets anyone to create a free fundraising page for a friend or loved one's uncovered medical bills, memorial fund, adoptions or any other life events in five minutes or less. Millions of families have used GiveForward to raise more than $165M to let ...",
"link": "http:\/\/giveforward.com", "city": "Chicago"},
{"name": "MentorMob",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/19374-3b63fcf38efde624dd79c5cbd96161db-thumb_jpg.jpg?buster=1315734490",
"alt": "MentorMob - digital media education ventures for good crowdsourcing",
"description": "Google of Learning, indexed by experts\nProblem: Google doesn't index for learning. Nearly 1 billion Google searches are done for \"how to\" learn various topics every month, from photography to entrepreneurship, forcing learners to waste their time sifting through the millions of results.\nMentorMob is ...",
"link": "http:\/\/www.mentormob.com", "city": "Chicago"},
{"name": "The Boeing Company",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/49394-df6be7a1eca80e8e73cc6699fee4f772-thumb_jpg.jpg?buster=1406172049",
"alt": "The Boeing Company - manufacturing transportation", "description": "",
"link": "http:\/\/www.boeing.com", "city": "Berlin"},
{"name": "NowBoarding \u2708\ufe0f",
"images": "https:\/\/static.above.flights\/img\/lowcost\/envelope_blue.png",
"alt": "Lowcost Email cheap flights alerts",
"description": "Invite-only mailing list.\n\nWe search the best weekend and long-haul flight deals\nso you can book before everyone else.",
"link": "https:\/\/nowboarding.club\/", "city": "Berlin"},
{"name": "Rocketmiles",
"images": "https:\/\/d1qb2nb5cznatu.cloudfront.net\/startups\/i\/158571-e53ddffe9fb3ed5e57080db7134117d0-thumb_jpg.jpg?buster=1361371304",
"alt": "Rocketmiles - e-commerce online travel loyalty programs hotels",
"description": "Fueling more vacations\nWe enable our customers to travel more, travel better and travel further. 20M+ consumers stock away miles & points to satisfy their wanderlust.\nFlying around or using credit cards are the only good ways to fill the stockpile today. We've built the third way. Customers ...",
"link": "http:\/\/www.Rocketmiles.com", "city": "Berlin"}

]
vectors = qdrant_engine.embedding_generator.generate_embedding([doc["description"] for doc in documents])
print(vectors.shape)
payload = [doc for doc in documents]

# Upload vectors and payload
qdrant_engine.upload_vectors(vectors=vectors, payload=payload)

# Build a filter for city and category
conditions = [
{"key": "city", "match": "Berlin"},
]
custom_filter = qdrant_engine.build_filter(conditions)

# Search for startups related to "vacations" in Berlin
results = qdrant_engine.search(text="vacations", query_filter=custom_filter, limit=5)
for result in results:
print(result)
Original file line number Diff line number Diff line change
@@ -1,58 +1,99 @@
import json
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, Filter, FieldCondition, MatchValue
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional

from typing import List, Dict, Any, Optional, Union
from abc import ABC, abstractmethod
import numpy as np
from openai import OpenAI

class EmbeddingGenerator(ABC):
@abstractmethod
def generate_embedding(self, text: List[str]) -> np.ndarray:
pass

class SentenceTransformerEmbedding(EmbeddingGenerator):
def __init__(self, model_name_or_path: str = "all-MiniLM-L6-v2", device: str = "cpu"):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name_or_path, device=device)

def generate_embedding(self, text: List[str]) -> np.ndarray:
return self.model.encode(text)

class OpenAIEmbedding(EmbeddingGenerator):
def __init__(self, api_key: str, base_url: str, model: str):
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model

def generate_embedding(self, text: List[str]) -> np.ndarray:
resp = self.client.embeddings.create(
model=self.model,
input=text,
encoding_format="float"
)
return np.array([data.embedding for data in resp.data])

class QdrantEngine:
def __init__(self, collection_name: str, vector_size: int = 384, distance: Distance = Distance.COSINE):
def __init__(
self,
collection_name: str,
embedding_generator: EmbeddingGenerator,
qdrant_client_params: Dict[str, Any] = {"host": "localhost", "port": 6333},
vector_size: int = 384,
distance: Distance = Distance.COSINE,
):
"""
Initialize the Qdrant vector store.

:param collection_name: Name of the Qdrant collection.
:param vector_size: Size of the vectors (default is 384 for all-MiniLM-L6-v2).
:param embedding_generator: An instance of EmbeddingGenerator to generate embeddings.
:param qdrant_client_params: Dictionary of parameters to pass to QdrantClient.
:param vector_size: Size of the vectors.
:param distance: Distance metric for vector comparison (default is cosine similarity).
"""
self.collection_name = collection_name
self.vector_size = vector_size
self.distance = distance
self.client = QdrantClient("http://localhost:6333")
self.model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
self.embedding_generator = embedding_generator

# Initialize QdrantClient with provided parameters
self.client = QdrantClient(**qdrant_client_params)

# Create collection if it doesn't exist
if not self.client.collection_exists(self.collection_name):
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=self.distance),
timeout=500,
)

def upload_vectors(self, vectors_path: str, payload_path: str, batch_size: int = 256):
def upload_vectors(
self, vectors: Union[np.ndarray, List[List[float]]],
payload: List[Dict[str, Any]],
batch_size: int = 256
):
"""
Upload vectors and payload to the Qdrant collection.

:param vectors_path: Path to the .npy file containing the vectors.
:param payload_path: Path to the .json file containing the payload.
:param vectors: A numpy array or list of vectors to upload.
:param payload: A list of dictionaries containing the payload for each vector.
:param batch_size: Number of vectors to upload in a single batch.
"""
# Load vectors from .npy file
vectors = np.load(vectors_path)

# Load payload from .json file
with open(payload_path) as fd:
payload = map(json.loads, fd)

# Upload vectors and payload to Qdrant
if not isinstance(vectors, np.ndarray):
vectors = np.array(vectors)
if len(vectors) != len(payload):
raise ValueError("Vectors and payload must have the same length.")
self.client.upload_collection(
collection_name=self.collection_name,
vectors=vectors,
payload=payload,
ids=None, # Vector ids will be assigned automatically
ids=None,
batch_size=batch_size,
)

def search(self, text: str, query_filter: Optional[Filter] = None, limit: int = 5) -> List[Dict[str, Any]]:
def search(
self, text: str,
query_filter: Optional[Filter] = None,
limit: int = 5
) -> List[Dict[str, Any]]:
"""
Search for the closest vectors in the collection based on the input text.

Expand All @@ -61,13 +102,13 @@ def search(self, text: str, query_filter: Optional[Filter] = None, limit: int =
:param limit: Number of closest results to return.
:return: List of payloads from the closest vectors.
"""
# Convert text query into vector
vector = self.model.encode(text).tolist()
# Generate embedding using the provided embedding generator
vector = self.embedding_generator.generate_embedding([text])

# Search for closest vectors in the collection
search_result = self.client.query_points(
collection_name=self.collection_name,
query=vector,
query=vector[0], # Use the first (and only) embedding
query_filter=query_filter,
limit=limit,
).points
Expand Down Expand Up @@ -99,23 +140,3 @@ def build_filter(self, conditions: List[Dict[str, Any]]) -> Filter:

return Filter(must=filter_conditions)


# Example usage
if __name__ == "__main__":
# Initialize the QdrantVectorStore
vector_store = QdrantVectorStore(collection_name="startups")

# Upload vectors and payload
vector_store.upload_vectors(vectors_path="./startup_vectors.npy", payload_path="./startups_demo.json")

# Build a filter for city and category
conditions = [
{"key": "city", "match": "Berlin"},
{"key": "category", "match": "AI"},
]
custom_filter = vector_store.build_filter(conditions)

# Search for startups related to "AI" in Berlin
results = vector_store.search(text="AI", query_filter=custom_filter, limit=5)
for result in results:
print(result)
Loading