diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index 077174e3ee..0128817e97 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -221,10 +221,12 @@ def _get_session(self, config: RepoConfig): raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS) if self._session: - print("Reusing existing session..") - return self._session - else: - print("Creating a new session..") + if not self._session.is_shutdown: + print("Reusing existing session..") + return self._session + else: + self._session = None + print("Creating a new session..") if not self._session: # configuration consistency checks hosts = online_store_config.hosts @@ -348,7 +350,9 @@ def online_write_batch( session: Session = self._get_session(config) keyspace: str = self._keyspace fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) - insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable) + insert_cql = self._get_cql_statement( + config, "insert4", fqtable=fqtable, session=session + ) futures = [] for entity_key, values, timestamp, created_ts in data: @@ -390,6 +394,7 @@ def online_write_batch( logger.error(f"Error writing a batch: {exc}") print(f"Error writing a batch: {exc}") raise Exception("Error writing a batch") from exc + session.shutdown() # correction for the last missing call to `progress`: if progress: progress(1) @@ -582,7 +587,12 @@ def _get_cql_statement( This additional layer makes it easy to control whether to use prepared statements and, if so, on which database operations. """ - session: Session = self._get_session(config) + session: Session = None + if "session" in kwargs: + session = kwargs["session"] + else: + session = self._get_session(config) + template, prepare = CQL_TEMPLATE_MAP[op_name] statement = template.format( fqtable=fqtable,