Skip to content

Commit

Permalink
Merge pull request #64 from ExpediaGroup/feature/elasticsearch_write
Browse files Browse the repository at this point in the history
implement write batch method for elasticsearch online store
  • Loading branch information
piket authored Oct 25, 2023
2 parents 1ddee5f + f592055 commit a8eb0a5
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from bidict import bidict
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, helpers
from pydantic.typing import Literal

from feast import Entity, FeatureView, RepoConfig
Expand Down Expand Up @@ -96,8 +97,24 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
with ElasticsearchConnectionManager(config):
pass
with ElasticsearchConnectionManager(config) as es:
resp = es.indices.exists(index=table.name)
if not resp.body:
self._create_index(es, table)
bulk_documents = []
for entity_key, values, timestamp, created_ts in data:
id_val = self._get_value_from_value_proto(entity_key.entity_values[0])
document = {entity_key.join_keys[0]: id_val}
for feature_name, val in values.items():
document[feature_name] = self._get_value_from_value_proto(val)
bulk_documents.append(
{"_index": table.name, "_id": id_val, "doc": document}
)

successes, errors = helpers.bulk(client=es, actions=bulk_documents)
logger.info(f"bulk write completed with {successes} successes")
if errors:
logger.error(f"bulk write return errors: {errors}")

def online_read(
self,
Expand Down Expand Up @@ -163,8 +180,28 @@ def _create_index(self, es, fv):
if is_primary:
index_mapping["properties"][feature.name]["index"] = True
es.indices.create(index=fv.name, mappings=index_mapping)
logger.info(f"Index {fv.name} created")

def _get_data_type(self, t: FeastType) -> str:
if isinstance(t, ComplexFeastType):
return "text"
return TYPE_MAPPING.get(t, "text")

def _get_value_from_value_proto(self, proto: ValueProto):
"""
Get the raw value from a value proto.
Parameters:
value (ValueProto): the value proto that contains the data.
Returns:
value (Any): the extracted value.
"""
val_type = proto.WhichOneof("val")
value = getattr(proto, val_type) # type: ignore
if val_type == "bytes_val":
value = base64.b64encode(value).decode()
if val_type == "float_list_val":
value = list(value.val)

return value
127 changes: 105 additions & 22 deletions sdk/python/tests/expediagroup/test_elasticsearch_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,6 @@ def setup_method(self, repo_config):

yield

def create_n_customer_test_samples_elasticsearch_online_read(self, n=10):
return [
(
EntityKeyProto(
join_keys=["film_id"],
entity_values=[ValueProto(int64_val=i)],
),
{
"films": ValueProto(
float_list_val=FloatList(
val=[random.random() for _ in range(2)]
)
),
"film_date": ValueProto(int64_val=n),
"film_id": ValueProto(int64_val=n),
},
datetime.utcnow(),
None,
)
for i in range(n)
]

@pytest.mark.parametrize("index_params", index_param_list)
def test_elasticsearch_update_add_index(self, repo_config, caplog, index_params):
dimensions = 16
Expand Down Expand Up @@ -307,6 +285,29 @@ def test_elasticsearch_update_delete_unavailable_index(self, repo_config, caplog
with ElasticsearchConnectionManager(repo_config.online_store) as es:
assert es.indices.exists(index=self.index_to_delete).body is False

def test_elasticsearch_online_write_batch(self, repo_config, caplog):
total_rows_to_write = 100
(
feature_view,
data,
) = self._create_n_customer_test_samples_elasticsearch_online_read(
n=total_rows_to_write
)
ElasticsearchOnlineStore().online_write_batch(
config=repo_config.online_store,
table=feature_view,
data=data,
progress=None,
)

with ElasticsearchConnectionManager(repo_config.online_store) as es:
es.indices.refresh(index=self.index_to_write)
res = es.cat.count(index=self.index_to_write, params={"format": "json"})
assert res[0]["count"] == "100"
doc = es.get(index=self.index_to_write, id="0")["_source"]["doc"]
for feature in feature_view.schema:
assert feature.name in doc

def _create_index_in_es(self, index_name, repo_config):
with ElasticsearchConnectionManager(repo_config.online_store) as es:
mapping = {
Expand All @@ -321,3 +322,85 @@ def _create_index_in_es(self, index_name, repo_config):
}
}
es.indices.create(index=index_name, mappings=mapping)

def _create_n_customer_test_samples_elasticsearch_online_read(self, n=10):
fv = FeatureView(
name=self.index_to_write,
source=SOURCE,
entities=[Entity(name="id")],
schema=[
Field(
name="vector",
dtype=Array(Float32),
tags={
"description": "float32",
"dimensions": "10",
"index_type": "HNSW",
},
),
Field(
name="id",
dtype=String,
),
Field(
name="text",
dtype=String,
),
Field(
name="int",
dtype=Int32,
),
Field(
name="long",
dtype=Int64,
),
Field(
name="float",
dtype=Float32,
),
Field(
name="double",
dtype=Float64,
),
Field(
name="binary",
dtype=Bytes,
),
Field(
name="bool",
dtype=Bool,
),
Field(
name="timestamp",
dtype=UnixTimestamp,
),
],
)
return fv, [
(
EntityKeyProto(
join_keys=["id"],
entity_values=[ValueProto(string_val=str(i))],
),
{
"vector": ValueProto(
float_list_val=FloatList(
val=[random.random() for _ in range(10)]
)
),
"text": ValueProto(string_val="text"),
"int": ValueProto(int32_val=n),
"long": ValueProto(int64_val=n),
"float": ValueProto(float_val=n * 0.3),
"double": ValueProto(double_val=n * 1.2),
"binary": ValueProto(bytes_val=b"binary"),
"bool": ValueProto(bool_val=True),
"timestamp": ValueProto(
unix_timestamp_val=int(datetime.utcnow().timestamp() * 1000)
),
},
datetime.utcnow(),
None,
)
for i in range(n)
]

0 comments on commit a8eb0a5

Please sign in to comment.