Skip to content

Commit

Permalink
Add ShardedQuantManagedCollisionEmbeddingCollection (pytorch#2649)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2649

Sharded MCEC is extended from Sharded EC to reuse the lookups of sharded embeddings.

Reviewed By: emlin

Differential Revision: D61388755

fbshipit-source-id: d222a9db8842ab3c5adc568d0083c53e768683ce
  • Loading branch information
kausv authored and facebook-github-bot committed Dec 23, 2024
1 parent 5f607ff commit 464a0e9
Show file tree
Hide file tree
Showing 8 changed files with 1,064 additions and 89 deletions.
18 changes: 18 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable


torch.fx.wrap("len")

CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"
Expand All @@ -61,6 +62,15 @@ def _fx_wrap_tensor_to_device_dtype(
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_optional_tensor_to_device_dtype(
t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor
) -> Optional[torch.Tensor]:
if t is None:
return None
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]:
return (
Expand Down Expand Up @@ -121,6 +131,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
block_sizes: torch.Tensor,
bucketize_pos: bool = False,
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
total_num_blocks: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -142,6 +153,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
bucketize_pos=bucketize_pos,
sequence=True,
block_sizes=block_sizes,
total_num_blocks=total_num_blocks,
my_size=num_buckets,
weights=kjt.weights_or_none(),
max_B=_fx_wrap_max_B(kjt),
Expand Down Expand Up @@ -289,6 +301,7 @@ def bucketize_kjt_inference(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
total_num_buckets: Optional[torch.Tensor] = None,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
is_sequence: bool = False,
Expand All @@ -303,6 +316,7 @@ def bucketize_kjt_inference(
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
bucketize_pos (bool): output the changed position of the bucketized values or
not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
Expand All @@ -318,6 +332,9 @@ def bucketize_kjt_inference(
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype(
total_num_buckets, kjt.values()
)
unbucketize_permute = None
bucket_mapping = None
if is_sequence:
Expand All @@ -332,6 +349,7 @@ def bucketize_kjt_inference(
kjt,
num_buckets=num_buckets,
block_sizes=block_sizes_new_type,
total_num_blocks=total_num_buckets_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos,
)
Expand Down
Loading

0 comments on commit 464a0e9

Please sign in to comment.