Skip to content

Commit

Permalink
Refactor LanceDB related code and increase type hint coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 4, 2024
1 parent fb7565e commit 4c73541
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 84 deletions.
2 changes: 1 addition & 1 deletion dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LanceDBCredentials(CredentialsConfiguration):
__config_gen_annotations__: ClassVar[List[str]] = [
"uri",
"api_key",
"embedding_model_provider_api_key"
"embedding_model_provider_api_key",
]


Expand Down
66 changes: 19 additions & 47 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
Sequence,
)

import lancedb
import lancedb # type: ignore
import pyarrow as pa
from lancedb import DBConnection
from lancedb.common import DATA
from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
from lancedb.pydantic import LanceModel
from lancedb.query import LanceQueryBuilder
from lancedb.common import DATA # type: ignore
from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore
from lancedb.pydantic import LanceModel # type: ignore
from lancedb.query import LanceQueryBuilder # type: ignore
from numpy import ndarray
from pyarrow import Array, ChunkedArray

Expand Down Expand Up @@ -132,19 +132,15 @@ def _make_qualified_table_name(self, table_name: str) -> str:
def get_table_schema(self, table_name: str) -> pa.Schema:
return cast(pa.Schema, self.db_client[table_name].schema)

def _create_table(
self, table_name: str, schema: Union[pa.Schema, LanceModel]
) -> None:
def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
schema: The table schema to create.
table_name: The name of the table to create.
"""

self.db_client.create_table(
table_name, schema=schema, embedding_functions=self.model_func
)
self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func)

def delete_table(self, table_name: str) -> None:
"""Delete a LanceDB table.
Expand Down Expand Up @@ -209,9 +205,7 @@ def add_to_table(
Returns:
None
"""
self.db_client.open_table(table_name).add(
data, mode, on_bad_vectors, fill_value
)
self.db_client.open_table(table_name).add(data, mode, on_bad_vectors, fill_value)

def drop_storage(self) -> None:
"""Drop the dataset from the LanceDB instance.
Expand Down Expand Up @@ -254,9 +248,7 @@ def is_storage_initialized(self) -> bool:

def _create_sentinel_table(self) -> None:
"""Create an empty table to indicate that the storage is initialized."""
self._create_table(
schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table
)
self._create_table(schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table)

def _delete_sentinel_table(self) -> None:
"""Delete the sentinel table."""
Expand Down Expand Up @@ -293,9 +285,7 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
"inserted_at": str(pendulum.now()),
"schema": json.dumps(schema.to_dict()),
}
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
self._create_record(properties, VersionSchema, version_table_name)

def _create_record(
Expand All @@ -311,9 +301,7 @@ def _create_record(
try:
tbl = self.db_client.open_table(self._make_qualified_table_name(table_name))
except FileNotFoundError:
tbl = self.db_client.create_table(
self._make_qualified_table_name(table_name)
)
tbl = self.db_client.create_table(self._make_qualified_table_name(table_name))
except Exception:
raise

Expand All @@ -333,9 +321,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
while True:
try:
state_table_name = self._make_qualified_table_name(
self.schema.state_table_name
)
state_table_name = self._make_qualified_table_name(self.schema.state_table_name)
state_records = (
self.db_client.open_table(state_table_name)
.search()
Expand All @@ -347,9 +333,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
loads_table_name = self._make_qualified_table_name(
self.schema.loads_table_name
)
loads_table_name = self._make_qualified_table_name(self.schema.loads_table_name)
load_records = (
self.db_client.open_table(loads_table_name)
.search()
Expand Down Expand Up @@ -381,9 +365,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
try:
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
response = (
self.db_client[version_table_name]
.search()
Expand Down Expand Up @@ -420,9 +402,7 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -458,9 +438,7 @@ def __init__(
self.table_name = table_name
self.table_schema: TTableSchema = table_schema
self.unique_identifiers = self._list_unique_identifiers(table_schema)
self.embedding_fields = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_model_func = model_func
self.embedding_model_dimensions = client_config.embedding_model_dimensions

Expand Down Expand Up @@ -509,21 +487,15 @@ def _upload_data(
except Exception:
raise

parsed_records: List[LanceModel] = [
lancedb_model(**record) for record in records
]
parsed_records: List[LanceModel] = [lancedb_model(**record) for record in records]

# Upsert using reserved ID as the key.
tbl.merge_insert(
self.id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(
parsed_records
)
).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records)

@staticmethod
def _generate_uuid(
data: DictStrAny, unique_identifiers: Sequence[str], table_name: str
) -> str:
def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str:
"""Generates deterministic UUID - used for deduplication.
Args:
Expand Down
32 changes: 15 additions & 17 deletions dlt/destinations/impl/lancedb/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Type, Optional

from lancedb.embeddings import TextEmbeddingFunction
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import TextEmbeddingFunction # type: ignore[import-untyped]
from lancedb.pydantic import LanceModel, Vector # type: ignore[import-untyped]
from pydantic import create_model

from dlt.common.typing import DictStrAny
Expand Down Expand Up @@ -29,35 +29,33 @@ def infer_lancedb_model_from_data(
Type[LanceModel]: The inferred LanceModel.
"""

template_schema: Type[LanceModel] = create_model(
template_schema: Type[LanceModel] = create_model( # type: ignore[call-overload]
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**{
id_field_name: (str, ...),
vector_field_name: (
Vector(
embedding_model_dimensions
if embedding_model_dimensions
else embedding_model_func.ndims()
),
Vector(embedding_model_dimensions or embedding_model_func.ndims()),
...,
),
},
)

field_types = {}
for field_name in data[0].keys():
if field_name != id_field_name and field_name != vector_field_name:
field_types[field_name] = (
str, # Infer all fields as str
field_types = {
field_name: (
str, # Infer all fields temporarily as str
(
embedding_model_func.SourceField()
if field_name in embedding_fields
else None, # Set default to None to make fields optional
)

inferred_schema: Type[LanceModel] = create_model(
else None # Set default to None to make fields optional
),
)
for field_name in data[0].keys()
if field_name not in [id_field_name, vector_field_name]
}
inferred_schema: Type[LanceModel] = create_model( # type: ignore[call-overload]
"InferredSchema",
__base__=template_schema,
__module__=__name__,
Expand Down
10 changes: 4 additions & 6 deletions tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_adapter_and_hints() -> None:
generator_instance1 = sequence_generator()

@dlt.resource(columns=[{"name": "content", "data_type": "text"}])
def some_data(): # type: ignore[no-untyped-def]
def some_data():
yield from next(generator_instance1)

assert some_data.columns["content"] == {"name": "content", "data_type": "text"} # type: ignore[index]
Expand All @@ -50,7 +50,7 @@ def test_basic_state_and_schema() -> None:
generator_instance1 = sequence_generator()

@dlt.resource
def some_data(): # type: ignore[no-untyped-def]
def some_data():
yield from next(generator_instance1)

lancedb_adapter(
Expand Down Expand Up @@ -361,12 +361,10 @@ def test_empty_dataset_allowed() -> None:
client: LanceDBClient = p.destination_client() # type: ignore[assignment]

assert p.dataset_name is None
info = p.run(
lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])
)
info = p.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]))
# dataset in load info is empty
assert info.dataset_name is None
client = p.destination_client() # type: ignore[assignment]
assert client.dataset_name is None
assert client.sentinel_collection == "DltSentinelCollection"
assert client.sentinel_table == "DltSentinelCollection"
assert_table(p, "content", expected_items_count=3)
20 changes: 8 additions & 12 deletions tests/load/lancedb/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Union, List, Any

import numpy as np
from lancedb.embeddings import TextEmbeddingFunction
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import TextEmbeddingFunction # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore

import dlt
from dlt.common.configuration.container import Container
Expand Down Expand Up @@ -30,12 +30,7 @@ def assert_table(
assert exists

qualified_collection_name = client._make_qualified_table_name(collection_name)
records = (
client.db_client.open_table(qualified_collection_name)
.search()
.limit(50)
.to_list()
)
records = client.db_client.open_table(qualified_collection_name).search().limit(50).to_list()

if expected_items_count is not None:
assert expected_items_count == len(records)
Expand Down Expand Up @@ -72,9 +67,10 @@ def has_tables(client: LanceDBClient) -> bool:


class MockEmbeddingFunc(TextEmbeddingFunction):
def generate_embeddings( # type: ignore
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
def generate_embeddings(
self,
texts: Union[List[str], np.ndarray], # type: ignore[type-arg]
) -> List[np.array]: # type: ignore[valid-type]
return [np.array(None)]

def ndims(self) -> int:
Expand Down Expand Up @@ -104,5 +100,5 @@ def test_infer_lancedb_model_from_data() -> None:
}

assert issubclass(inferred_model, LanceModel)
for field_name, (field_type, field_default) in expected_fields.items():
for field_name in expected_fields:
assert field_name in inferred_model.model_fields
10 changes: 9 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@
"clickhouse",
"dremio",
}
NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "lancedb", "destination"}
NON_SQL_DESTINATIONS = {
"filesystem",
"weaviate",
"dummy",
"motherduck",
"qdrant",
"lancedb",
"destination",
}
SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS

# exclude destination configs (for now used for athena and athena iceberg separation)
Expand Down

0 comments on commit 4c73541

Please sign in to comment.