Skip to content

Commit

Permalink
fix: Remove unused _get_cluster method and reverted to not use contex…
Browse files Browse the repository at this point in the history
…t manager
  • Loading branch information
Bhargav Dodla committed Jan 9, 2025
1 parent e84cb95 commit d7c7d13
Showing 1 changed file with 42 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,93 +208,6 @@ class CassandraOnlineStore(OnlineStore):
_keyspace: str = "feast_keyspace"
_prepared_statements: Dict[str, PreparedStatement] = {}

def _get_cluster(self, config: RepoConfig):
"""
Establish the database connection, if not yet created,
and return it.
Also perform basic config validation checks.
"""

online_store_config = config.online_store
if not isinstance(online_store_config, CassandraOnlineStoreConfig):
raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS)

# configuration consistency checks
hosts = online_store_config.hosts
secure_bundle_path = online_store_config.secure_bundle_path
port = online_store_config.port or 9042
keyspace = online_store_config.keyspace
username = online_store_config.username
password = online_store_config.password
protocol_version = online_store_config.protocol_version

db_directions = hosts or secure_bundle_path
if not db_directions or not keyspace:
raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED)
if hosts and secure_bundle_path:
raise CassandraInvalidConfig(E_CASSANDRA_MISCONFIGURED)
if (username is None) ^ (password is None):
raise CassandraInvalidConfig(E_CASSANDRA_INCONSISTENT_AUTH)

if username is not None:
auth_provider = PlainTextAuthProvider(
username=username,
password=password,
)
else:
auth_provider = None

# handling of load-balancing policy (optional)
if online_store_config.load_balancing:
# construct a proper execution profile embedding
# the configured LB policy
_lbp_name = online_store_config.load_balancing.load_balancing_policy
if _lbp_name == "DCAwareRoundRobinPolicy":
lb_policy = DCAwareRoundRobinPolicy(
local_dc=online_store_config.load_balancing.local_dc,
)
elif _lbp_name == "TokenAwarePolicy(DCAwareRoundRobinPolicy)":
lb_policy = TokenAwarePolicy(
DCAwareRoundRobinPolicy(
local_dc=online_store_config.load_balancing.local_dc,
)
)
else:
raise CassandraInvalidConfig(E_CASSANDRA_UNKNOWN_LB_POLICY)

# wrap it up in a map of ex.profiles with a default
exe_profile = ExecutionProfile(
request_timeout=online_store_config.request_timeout,
load_balancing_policy=lb_policy,
)
execution_profiles = {EXEC_PROFILE_DEFAULT: exe_profile}
else:
execution_profiles = None

# additional optional keyword args to Cluster
cluster_kwargs = {
k: v
for k, v in {
"protocol_version": protocol_version,
"execution_profiles": execution_profiles,
}.items()
if v is not None
}

# creation of Cluster (Cassandra vs. Astra)
if hosts:
return Cluster(
hosts, port=port, auth_provider=auth_provider, **cluster_kwargs
)
else:
# we use 'secure_bundle_path'
return Cluster(
cloud={"secure_connect_bundle": secure_bundle_path},
auth_provider=auth_provider,
**cluster_kwargs,
)

def _get_session(self, config: RepoConfig):
"""
Establish the database connection, if not yet created,
Expand Down Expand Up @@ -368,6 +281,7 @@ def _get_session(self, config: RepoConfig):
for k, v in {
"protocol_version": protocol_version,
"execution_profiles": execution_profiles,
"idle_heartbeat_interval": 5,
}.items()
if v is not None
}
Expand Down Expand Up @@ -428,53 +342,51 @@ def online_write_batch(
display progress.
"""
project = config.project

keyspace: str = config.online_store.keyspace
session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)
insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable)

futures = []
with self._get_cluster(config) as cluster:
with cluster.connect(keyspace) as session:
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

futures.append(session.execute_async(batch))
if len(futures) >= config.online_store.write_concurrency:
# Raises exception if at least one of the batch fails
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

if len(futures) > 0:
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

futures.append(session.execute_async(batch))
if len(futures) >= config.online_store.write_concurrency:
# Raises exception if at least one of the batch fails
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

if len(futures) > 0:
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc
# correction for the last missing call to `progress`:
if progress:
progress(1)
Expand Down

0 comments on commit d7c7d13

Please sign in to comment.