Skip to content

Commit

Permalink
fix elasticsearch connection manager calls (#70)
Browse files Browse the repository at this point in the history
* fix elasticsearch connection manager calls

* fix tests
  • Loading branch information
piket authored Nov 15, 2023
1 parent 5eb27da commit 9ee716f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel):
password: str
""" password to connect to Elasticsearch """

token: str
""" bearer token for authentication """


class ElasticsearchConnectionManager:
def __init__(self, online_config: RepoConfig):
Expand All @@ -85,15 +82,10 @@ def __enter__(self):
logger.info(
f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}"
)
if len(self.online_config.token) > 0:
self.client = Elasticsearch(
self.online_config.endpoint, bearer_auth=self.online_config.token
)
else:
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
)
self.client = Elasticsearch(
self.online_config.endpoint,
basic_auth=(self.online_config.username, self.online_config.password),
)
return self.client

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -112,7 +104,7 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
with ElasticsearchConnectionManager(config) as es:
with ElasticsearchConnectionManager(config.online_store) as es:
resp = es.indices.exists(index=table.name)
if not resp.body:
self._create_index(es, table)
Expand All @@ -138,7 +130,7 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
with ElasticsearchConnectionManager(config) as es:
with ElasticsearchConnectionManager(config.online_store) as es:
id_list = []
for entity in entity_keys:
for val in entity.entity_values:
Expand Down Expand Up @@ -191,7 +183,7 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
with ElasticsearchConnectionManager(config) as es:
with ElasticsearchConnectionManager(config.online_store) as es:
for fv in tables_to_delete:
resp = es.indices.exists(index=fv.name)
if resp.body:
Expand Down
19 changes: 9 additions & 10 deletions sdk/python/tests/expediagroup/test_elasticsearch_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def repo_config(embedded_elasticsearch):
endpoint=f"http://{embedded_elasticsearch['host']}:{embedded_elasticsearch['port']}",
username=embedded_elasticsearch["username"],
password=embedded_elasticsearch["password"],
token=embedded_elasticsearch["token"],
),
offline_store=FileOfflineStoreConfig(),
entity_key_serialization_version=2,
Expand Down Expand Up @@ -132,7 +131,7 @@ def test_elasticsearch_update_add_index(self, repo_config, caplog, index_params)
Field(name="feature10", dtype=UnixTimestamp),
]
ElasticsearchOnlineStore().update(
config=repo_config.online_store,
config=repo_config,
tables_to_delete=[],
tables_to_keep=[
FeatureView(
Expand Down Expand Up @@ -195,7 +194,7 @@ def test_elasticsearch_update_add_existing_index(self, repo_config, caplog):
]
self._create_index_in_es(self.index_to_write, repo_config)
ElasticsearchOnlineStore().update(
config=repo_config.online_store,
config=repo_config,
tables_to_delete=[],
tables_to_keep=[
FeatureView(
Expand Down Expand Up @@ -236,7 +235,7 @@ def test_elasticsearch_update_delete_index(self, repo_config, caplog):
assert es.indices.exists(index=self.index_to_delete).body is True

ElasticsearchOnlineStore().update(
config=repo_config.online_store,
config=repo_config,
tables_to_delete=[
FeatureView(
name=self.index_to_delete,
Expand Down Expand Up @@ -276,7 +275,7 @@ def test_elasticsearch_update_delete_unavailable_index(self, repo_config, caplog
assert es.indices.exists(index=self.index_to_delete).body is False

ElasticsearchOnlineStore().update(
config=repo_config.online_store,
config=repo_config,
tables_to_delete=[
FeatureView(
name=self.index_to_delete,
Expand Down Expand Up @@ -304,7 +303,7 @@ def test_elasticsearch_online_write_batch(self, repo_config, caplog):
n=total_rows_to_write,
)
ElasticsearchOnlineStore().online_write_batch(
config=repo_config.online_store,
config=repo_config,
table=feature_view,
data=data,
progress=None,
Expand Down Expand Up @@ -334,7 +333,7 @@ def test_elasticsearch_online_read(self, repo_config, caplog):
]
store = ElasticsearchOnlineStore()
store.online_write_batch(
config=repo_config.online_store,
config=repo_config,
table=feature_view,
data=data,
progress=None,
Expand All @@ -344,7 +343,7 @@ def test_elasticsearch_online_read(self, repo_config, caplog):
es.indices.refresh(index=self.index_to_read)

result = store.online_read(
config=repo_config.online_store,
config=repo_config,
table=feature_view,
entity_keys=ids,
)
Expand Down Expand Up @@ -376,7 +375,7 @@ def test_elasticsearch_online_read_with_requested_features(
]
store = ElasticsearchOnlineStore()
store.online_write_batch(
config=repo_config.online_store,
config=repo_config,
table=feature_view,
data=data,
progress=None,
Expand All @@ -386,7 +385,7 @@ def test_elasticsearch_online_read_with_requested_features(
es.indices.refresh(index=self.index_to_read)

result = store.online_read(
config=repo_config.online_store,
config=repo_config,
table=feature_view,
entity_keys=ids,
requested_features=requested_features,
Expand Down

0 comments on commit 9ee716f

Please sign in to comment.