From 6370fa0c124eff28ce177b660edb6e899c90d993 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 2 Jan 2025 21:47:15 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2H/N) Summary: - Update benchmark test for `int32_t` Indicies Differential Revision: D67784746 --- .../split_table_batched_embeddings_benchmark.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index f439ed678..68da7659c 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -125,6 +125,7 @@ def cli() -> None: @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--dense", is_flag=True, default=False) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) +@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64") @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--export-trace", is_flag=True, default=False) @@ -166,6 +167,7 @@ def device( # noqa C901 flush_gpu_cache_size_mb: int, dense: bool, output_dtype: SparseType, + indices_dtype: int, requests_data_file: Optional[str], tables: Optional[str], export_trace: bool, @@ -176,6 +178,9 @@ def device( # noqa C901 cache_load_factor: float, ) -> None: assert not ssd or not dense, "--ssd cannot be used together with --dense" + indices_dtype_torch: torch.dtype = ( + torch.int32 if indices_dtype == 32 else torch.int64 + ) np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( - indices.long(), - offsets.long(), + indices.to(dtype=indices_dtype_torch), + offsets.to(dtype=indices_dtype_torch), per_sample_weights, feature_requires_grad=feature_requires_grad, ), @@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( - indices.long(), - offsets.long(), + indices.to(dtype=indices_dtype_torch), + offsets.to(dtype=indices_dtype_torch), per_sample_weights, feature_requires_grad=feature_requires_grad, ),