From b4cd60483cb8bab1f2e12d6f6f9f1fd9e155b831 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Sun, 12 Jan 2025 20:57:43 +0530 Subject: [PATCH] fix: Enhance cluster connection management and session handling in CassandraOnlineStore --- .../cassandra_online_store.py | 177 ++++++++++++++---- 1 file changed, 142 insertions(+), 35 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index a32bb3834b..ce2c67c6ef 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -208,6 +208,111 @@ class CassandraOnlineStore(OnlineStore): _keyspace: str = "feast_keyspace" _prepared_statements: Dict[str, PreparedStatement] = {} + def _get_cluster(self, config: RepoConfig): + """ + Establish the database connection, if not yet created, + and return it. + + Also perform basic config validation checks. + """ + + online_store_config = config.online_store + if not isinstance(online_store_config, CassandraOnlineStoreConfig): + raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS) + + if self._cluster: + if not self._cluster.is_shutdown: + print("Reusing existing cluster..") + return self._cluster + else: + self._cluster = None + print("Creating a new cluster..") + if not self._cluster: + # configuration consistency checks + hosts = online_store_config.hosts + secure_bundle_path = online_store_config.secure_bundle_path + port = online_store_config.port or 9042 + keyspace = online_store_config.keyspace + username = online_store_config.username + password = online_store_config.password + protocol_version = online_store_config.protocol_version + + db_directions = hosts or secure_bundle_path + if not db_directions or not keyspace: + raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED) + if hosts and secure_bundle_path: + raise CassandraInvalidConfig(E_CASSANDRA_MISCONFIGURED) + if (username is None) ^ (password is None): + raise CassandraInvalidConfig(E_CASSANDRA_INCONSISTENT_AUTH) + + if username is not None: + auth_provider = PlainTextAuthProvider( + username=username, + password=password, + ) + else: + auth_provider = None + + # handling of load-balancing policy (optional) + if online_store_config.load_balancing: + # construct a proper execution profile embedding + # the configured LB policy + _lbp_name = online_store_config.load_balancing.load_balancing_policy + if _lbp_name == "DCAwareRoundRobinPolicy": + lb_policy = DCAwareRoundRobinPolicy( + local_dc=online_store_config.load_balancing.local_dc, + ) + elif _lbp_name == "TokenAwarePolicy(DCAwareRoundRobinPolicy)": + lb_policy = TokenAwarePolicy( + DCAwareRoundRobinPolicy( + local_dc=online_store_config.load_balancing.local_dc, + ) + ) + else: + raise CassandraInvalidConfig(E_CASSANDRA_UNKNOWN_LB_POLICY) + + # wrap it up in a map of ex.profiles with a default + exe_profile = ExecutionProfile( + request_timeout=online_store_config.request_timeout, + load_balancing_policy=lb_policy, + ) + execution_profiles = {EXEC_PROFILE_DEFAULT: exe_profile} + else: + execution_profiles = None + + # additional optional keyword args to Cluster + cluster_kwargs = { + k: v + for k, v in { + "protocol_version": protocol_version, + "execution_profiles": execution_profiles, + "idle_heartbeat_interval": None, + }.items() + if v is not None + } + + # creation of Cluster (Cassandra vs. Astra) + if hosts: + self._cluster = Cluster( + hosts, + port=port, + auth_provider=auth_provider, + **cluster_kwargs, + ) + else: + # we use 'secure_bundle_path' + self._cluster = Cluster( + cloud={"secure_connect_bundle": secure_bundle_path}, + auth_provider=auth_provider, + **cluster_kwargs, + ) + + # creation of Session + self._keyspace = keyspace + # self._session = self._cluster.connect(self._keyspace) + + return self._cluster + def _get_session(self, config: RepoConfig): """ Establish the database connection, if not yet created, @@ -350,35 +455,47 @@ def online_write_batch( display progress. """ project = config.project - session: Session = self._get_session(config) + cluster: Cluster = self._get_cluster(config) + # session: Session = self._get_session(config) keyspace: str = self._keyspace fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) - insert_cql = self._get_cql_statement( - config, "insert4", fqtable=fqtable, session=session - ) futures = [] - for entity_key, values, timestamp, created_ts in data: - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - entity_key_bin = serialize_entity_key( - entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, - ).hex() - for feature_name, val in values.items(): - params: Tuple[str, bytes, str, datetime] = ( - feature_name, - val.SerializeToString(), - entity_key_bin, - timestamp, - ) - batch.add(insert_cql, params) - # this happens N-1 times, will be corrected outside: - if progress: - progress(1) - - futures.append(session.execute_async(batch)) - if len(futures) >= config.online_store.write_concurrency: - # Raises exception if at least one of the batch fails + with cluster.connect(keyspace) as session: + insert_cql = self._get_cql_statement( + config, "insert4", fqtable=fqtable, session=session + ) + for entity_key, values, timestamp, created_ts in data: + batch = BatchStatement(batch_type=BatchType.UNLOGGED) + entity_key_bin = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ).hex() + for feature_name, val in values.items(): + params: Tuple[str, bytes, str, datetime] = ( + feature_name, + val.SerializeToString(), + entity_key_bin, + timestamp, + ) + batch.add(insert_cql, params) + # this happens N-1 times, will be corrected outside: + if progress: + progress(1) + + futures.append(session.execute_async(batch)) + if len(futures) >= config.online_store.write_concurrency: + # Raises exception if at least one of the batch fails + try: + for future in futures: + future.result() + futures = [] + except Exception as exc: + logger.error(f"Error writing a batch: {exc}") + print(f"Error writing a batch: {exc}") + raise Exception("Error writing a batch") from exc + + if len(futures) > 0: try: for future in futures: future.result() @@ -387,16 +504,6 @@ def online_write_batch( logger.error(f"Error writing a batch: {exc}") print(f"Error writing a batch: {exc}") raise Exception("Error writing a batch") from exc - - if len(futures) > 0: - try: - for future in futures: - future.result() - futures = [] - except Exception as exc: - logger.error(f"Error writing a batch: {exc}") - print(f"Error writing a batch: {exc}") - raise Exception("Error writing a batch") from exc # correction for the last missing call to `progress`: if progress: progress(1)