From 764dcf5989f754bb8d004de237fa152a166f2958 Mon Sep 17 00:00:00 2001 From: Aloysius Lim Date: Fri, 3 Jan 2025 13:55:49 +0800 Subject: [PATCH 1/3] Map join key to original column name in field mapping. Signed-off-by: Aloysius Lim --- .../feast/infra/offline_stores/offline_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 2d4fa268e4..e66be472fa 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -118,6 +118,10 @@ def get_feature_view_query_context( query_context = [] for feature_view, features in feature_views_to_feature_map.items(): + reverse_field_mapping = { + v: k for k, v in feature_view.batch_source.field_mapping.items() + } + join_keys: List[str] = [] entity_selections: List[str] = [] for entity_column in feature_view.entity_columns: @@ -125,16 +129,16 @@ def get_feature_view_query_context( entity_column.name, entity_column.name ) join_keys.append(join_key) - entity_selections.append(f"{entity_column.name} AS {join_key}") + entity_selections.append( + f"{reverse_field_mapping.get(entity_column.name, entity_column.name)} " + f"AS {join_key}" + ) if isinstance(feature_view.ttl, timedelta): ttl_seconds = int(feature_view.ttl.total_seconds()) else: ttl_seconds = 0 - reverse_field_mapping = { - v: k for k, v in feature_view.batch_source.field_mapping.items() - } features = [reverse_field_mapping.get(feature, feature) for feature in features] timestamp_field = reverse_field_mapping.get( feature_view.batch_source.timestamp_field, From 23a7dc59a61bb27de3dce8c0cf598e983029763a Mon Sep 17 00:00:00 2001 From: Aloysius Lim Date: Fri, 3 Jan 2025 14:56:28 +0800 Subject: [PATCH 2/3] Add test. Signed-off-by: Aloysius Lim --- .../test_universal_historical_retrieval.py | 99 ++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 3f28245f3c..e968580e8e 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -14,7 +14,7 @@ from feast.infra.offline_stores.offline_utils import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, ) -from feast.types import Float32, Int32 +from feast.types import Float32, Int32, String from feast.utils import _utc_now from tests.integration.feature_repos.repo_configuration import ( construct_universal_feature_views, @@ -639,3 +639,100 @@ def test_historical_features_containing_backfills(environment): actual_df, sort_by=["driver_id"], ) + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_historical_features_field_mapping( + environment, universal_data_sources, full_feature_names +): + store = environment.feature_store + + # (entities, datasets, data_sources) = universal_data_sources + # feature_views = construct_universal_feature_views(data_sources) + + now = datetime.now().replace(microsecond=0, second=0, minute=0) + tomorrow = now + timedelta(days=1) + day_after_tomorrow = now + timedelta(days=2) + + entity_df = pd.DataFrame( + data=[ + {"driver_id": 1001, "event_timestamp": day_after_tomorrow}, + {"driver_id": 1002, "event_timestamp": day_after_tomorrow}, + ] + ) + + driver_stats_df = pd.DataFrame( + data=[ + { + "id": 1001, + "avg_daily_trips": 20, + "event_timestamp": now, + "created": tomorrow, + }, + { + "id": 1002, + "avg_daily_trips": 40, + "event_timestamp": tomorrow, + "created": now, + }, + ] + ) + + expected_df = pd.DataFrame( + data=[ + { + "driver_id": 1001, + "event_timestamp": day_after_tomorrow, + "avg_daily_trips": 20, + }, + { + "driver_id": 1002, + "event_timestamp": day_after_tomorrow, + "avg_daily_trips": 40, + }, + ] + ) + + driver_stats_data_source = environment.data_source_creator.create_data_source( + df=driver_stats_df, + destination_name=f"test_driver_stats_{int(time.time_ns())}_{random.randint(1000, 9999)}", + timestamp_field="event_timestamp", + created_timestamp_column="created", + # Map original "id" column to "driver_id" join key + field_mapping={"id": "driver_id"} + ) + + driver = Entity(name="driver", join_keys=["driver_id"]) + driver_fv = FeatureView( + name="driver_stats", + entities=[driver], + schema=[ + Field(name="driver_id", dtype=String), + Field(name="avg_daily_trips", dtype=Int32) + ], + source=driver_stats_data_source, + ) + + store.apply([driver, driver_fv]) + + offline_job = store.get_historical_features( + entity_df=entity_df, + features=["driver_stats:avg_daily_trips"], + full_feature_names=False, + ) + + start_time = _utc_now() + actual_df = offline_job.to_df() + + print(f"actual_df shape: {actual_df.shape}") + end_time = _utc_now() + print(str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n")) + + assert sorted(expected_df.columns) == sorted(actual_df.columns) + validate_dataframes( + expected_df, + actual_df, + sort_by=["driver_id"], + ) From 9b60d4398bd460c645c92178fdc333f7dd869de6 Mon Sep 17 00:00:00 2001 From: Aloysius Lim Date: Fri, 3 Jan 2025 15:00:11 +0800 Subject: [PATCH 3/3] Format. Signed-off-by: Aloysius Lim --- .../offline_store/test_universal_historical_retrieval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index e968580e8e..37df649386 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -701,7 +701,7 @@ def test_historical_features_field_mapping( timestamp_field="event_timestamp", created_timestamp_column="created", # Map original "id" column to "driver_id" join key - field_mapping={"id": "driver_id"} + field_mapping={"id": "driver_id"}, ) driver = Entity(name="driver", join_keys=["driver_id"]) @@ -710,7 +710,7 @@ def test_historical_features_field_mapping( entities=[driver], schema=[ Field(name="driver_id", dtype=String), - Field(name="avg_daily_trips", dtype=Int32) + Field(name="avg_daily_trips", dtype=Int32), ], source=driver_stats_data_source, )