Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
docs: Added QdrantKnowledgeBase docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 10, 2024
1 parent 48a01bc commit e343959
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 43 deletions.
9 changes: 4 additions & 5 deletions src/canopy/knowledge_base/qdrant/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ class QdrantConverter:
@staticmethod
def convert_id(_id: str) -> str:
"""
Converts any string into a UUID string in a deterministic way based on a seed.
Converts any string into a UUID string based on a seed.
Qdrant does not accept an arbitrary string as id, so an internal UUID has to be
generated for each point.
We generate deterministic UUIDs based on the original id.
Thereby enabling overwriting of the same point with the original id.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.UUID(UUID_NAMESPACE), _id))

Expand Down
319 changes: 315 additions & 4 deletions src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/canopy/knowledge_base/qdrant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def generate_clients(
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
**kwargs: Any,
) -> Tuple[QdrantClient, Union[AsyncQdrantClient, None]]:
sync_client = QdrantClient(
Expand All @@ -58,12 +59,13 @@ def generate_clients(
timeout=timeout,
host=host,
path=path,
force_disable_check_same_thread=force_disable_check_same_thread,
**kwargs,
)

if location == ":memory:" or path is not None:
# In-memory Qdrant doesn't interoperate with Sync and Async clients
# We fallback to sync operations in this case
# We fallback to sync operations in this case using @utils.sync_fallback
async_client = None
else:
async_client = AsyncQdrantClient(
Expand All @@ -78,6 +80,7 @@ def generate_clients(
timeout=timeout,
host=host,
path=path,
force_disable_check_same_thread=force_disable_check_same_thread,
**kwargs,
)

Expand Down
24 changes: 24 additions & 0 deletions tests/system/knowledge_base/qdrant/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import numpy as np
import requests
from canopy.knowledge_base.qdrant.constants import DENSE_VECTOR_NAME
from canopy.knowledge_base.qdrant.converter import QdrantConverter
from canopy.knowledge_base.qdrant.qdrant_knowledge_base import QdrantKnowledgeBase

import logging
from typing import List

logger = logging.getLogger(__name__)


def total_vectors_in_collection(knowledge_base: QdrantKnowledgeBase):
return knowledge_base._client.count(knowledge_base.collection_name).count
Expand Down Expand Up @@ -54,3 +60,21 @@ def assert_ids_not_in_collection(knowledge_base, ids):
ids=ids,
)
assert len(fetch_result) == 0, f"Found {len(fetch_result)} unexpected ids"


def qdrant_server_running() -> bool:
"""Check if Qdrant server is running."""

try:
response = requests.get("http://localhost:6333", timeout=10.0)
response_json = response.json()
return response_json.get("title") == "qdrant - vector search engine"
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
return False


def qdrant_locations() -> List[str]:
if not qdrant_server_running():
logger.warning("Running Qdrant tests in memory mode only.")
return [":memory:"]
return ["http://localhost:6333", ":memory:"]
7 changes: 4 additions & 3 deletions tests/system/knowledge_base/qdrant/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from canopy.knowledge_base.qdrant.constants import COLLECTION_NAME_PREFIX
from canopy.knowledge_base.qdrant.qdrant_knowledge_base import QdrantKnowledgeBase
from canopy.models.data_models import Document
from tests.system.knowledge_base.qdrant.common import qdrant_locations
from tests.system.knowledge_base.test_knowledge_base import _generate_text
from tests.unit.stubs.stub_chunker import StubChunker
from tests.unit.stubs.stub_dense_encoder import StubDenseEncoder
Expand Down Expand Up @@ -29,13 +30,13 @@ def encoder():
return StubRecordEncoder(StubDenseEncoder())


@pytest.fixture(scope="module", autouse=True)
def knowledge_base(collection_name, chunker, encoder):
@pytest.fixture(scope="module", autouse=True, params=qdrant_locations())
def knowledge_base(collection_name, chunker, encoder, request):
kb = QdrantKnowledgeBase(
collection_name=collection_name,
record_encoder=encoder,
chunker=chunker,
location=":memory:",
location=request.param,
)
kb.create_canopy_collection()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
)
from canopy.knowledge_base.models import DocumentWithScore
from canopy.models.data_models import Query
from qdrant_client.async_qdrant_remote import AsyncQdrantRemote
from tests.unit import random_words
from tests.unit.stubs.stub_chunker import StubChunker


load_dotenv()


Expand Down Expand Up @@ -57,6 +59,7 @@ async def assert_query_metadata_filter(
assert (
top_k > num_vectors_expected
), "the test might return false positive if top_k is not > num_vectors_expected"

query = Query(text="test", top_k=top_k, metadata_filter=metadata_filter)
query_results = await knowledge_base.aquery([query])
assert len(query_results) == 1
Expand Down Expand Up @@ -95,17 +98,25 @@ async def test_query(knowledge_base, encoded_chunks):
await execute_and_assert_queries(knowledge_base, encoded_chunks)


# @pytest.mark.asyncio
# async def test_query_with_metadata_filter(knowledge_base):
# await assert_query_metadata_filter(
# knowledge_base,
# {
# "must": [
# {"key": "my-key", "match": {"value": "value-1"}},
# ]
# },
# 2,
# )
@pytest.mark.asyncio
async def test_query_with_metadata_filter(knowledge_base):
if knowledge_base._async_client is None or not isinstance(
knowledge_base._async_client._client, AsyncQdrantRemote
):
pytest.skip(
"Dict filter is not supported for QdrantLocal"
"Use qdrant_client.models.Filter instead"
)

await assert_query_metadata_filter(
knowledge_base,
{
"must": [
{"key": "my-key", "match": {"value": "value-1"}},
]
},
2,
)


@pytest.mark.asyncio
Expand Down
47 changes: 28 additions & 19 deletions tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random

import pytest
Expand All @@ -16,6 +17,8 @@
from canopy.knowledge_base.record_encoder.base import RecordEncoder
from canopy.knowledge_base.reranker.reranker import Reranker
from canopy.models.data_models import Query

from qdrant_client.qdrant_remote import QdrantRemote
from tests.system.knowledge_base.qdrant.common import (
assert_chunks_in_collection,
assert_ids_in_collection,
Expand Down Expand Up @@ -68,7 +71,7 @@ def assert_query_metadata_filter(
assert len(query_results[0].documents) == num_vectors_expected


def test_create_index(collection_full_name, knowledge_base: QdrantKnowledgeBase):
def test_create_collection(collection_full_name, knowledge_base: QdrantKnowledgeBase):
assert knowledge_base.collection_name == collection_full_name
collection_info = knowledge_base._client.get_collection(collection_full_name)
assert (
Expand All @@ -77,7 +80,7 @@ def test_create_index(collection_full_name, knowledge_base: QdrantKnowledgeBase)
)


def test_list_indexes(collection_full_name, knowledge_base: QdrantKnowledgeBase):
def test_list_collections(collection_full_name, knowledge_base: QdrantKnowledgeBase):
collections_list = knowledge_base.list_canopy_collections()

assert len(collections_list) > 0
Expand Down Expand Up @@ -126,16 +129,22 @@ def test_query(knowledge_base, encoded_chunks):
execute_and_assert_queries(knowledge_base, encoded_chunks)


# def test_query_with_metadata_filter(knowledge_base):
# assert_query_metadata_filter(
# knowledge_base,
# {
# "must": [
# {"key": "my-key", "match": {"value": "value-1"}},
# ]
# },
# 2,
# )
def test_query_with_metadata_filter(knowledge_base):
if not isinstance(knowledge_base._client._client, QdrantRemote):
pytest.skip(
"Dict filter is not supported for QdrantLocal"
"Use qdrant_client.models.Filter instead"
)

assert_query_metadata_filter(
knowledge_base,
{
"must": [
{"key": "my-key", "match": {"value": "value-1"}},
]
},
2,
)


def test_delete_documents(knowledge_base: QdrantKnowledgeBase, encoded_chunks):
Expand Down Expand Up @@ -225,17 +234,17 @@ def test_query_edge_case_documents(knowledge_base, datetime_metadata_encoded_chu
execute_and_assert_queries(knowledge_base, datetime_metadata_encoded_chunks)


def test_create_existing_index(collection_full_name, knowledge_base):
def test_create_existing_collection(collection_full_name, knowledge_base):
with pytest.raises(RuntimeError) as e:
knowledge_base.create_canopy_collection()

assert f"Collection {collection_full_name} already exists" in str(e.value)


def test_kb_non_existing_index(chunker, encoder):
kb = QdrantKnowledgeBase(
"non-existing-collection", record_encoder=encoder, chunker=chunker
)
def test_kb_non_existing_collection(knowledge_base):
kb = copy.copy(knowledge_base)

kb._collection_name = f"{COLLECTION_NAME_PREFIX}non-existing-collection"

with pytest.raises(RuntimeError) as e:
kb.verify_index_connection()
Expand All @@ -245,7 +254,7 @@ def test_kb_non_existing_index(chunker, encoder):
assert expected_msg in str(e.value)


def test_init_defaults(knowledge_base, collection_name, collection_full_name):
def test_init_defaults(collection_name, collection_full_name):
new_kb = QdrantKnowledgeBase(collection_name)
assert isinstance(new_kb._client, QdrantClient)
assert new_kb.collection_name == collection_full_name
Expand Down Expand Up @@ -288,7 +297,7 @@ def test_init_raise_wrong_type(knowledge_base, chunker):
assert "record_encoder must be an instance of RecordEncoder" in str(e.value)


def test_create_with_index_encoder_dimension_none(collection_name, chunker):
def test_create_with_collection_encoder_dimension_none(collection_name, chunker):
encoder = StubRecordEncoder(StubDenseEncoder(dimension=3))
encoder._dense_encoder.dimension = None
with pytest.raises(RuntimeError) as e:
Expand Down

0 comments on commit e343959

Please sign in to comment.