Skip to content

Commit

Permalink
Move sharding optimization flag to global_settings (pytorch#2665)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2665

As per title move the configuration flag to separate module for better abstraction and simpler rollout

Reviewed By: iamzainhuda

Differential Revision: D67777011

fbshipit-source-id: 8a659bee7b81d3181c4014fdf2678c69b306b8c1
  • Loading branch information
Boris Sarana authored and facebook-github-bot committed Jan 6, 2025
1 parent e1b96a6 commit 6f4bfe2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import abc
import copy
import os
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
Expand All @@ -21,6 +20,9 @@
from torch.distributed._tensor.placement_types import Placement
from torch.nn.modules.module import _addindent
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.global_settings import (
construct_sharded_tensor_from_metadata_enabled,
)
from torchrec.distributed.types import (
get_tensor_size_bytes,
ModuleSharder,
Expand Down Expand Up @@ -346,8 +348,7 @@ def __init__(

# option to construct ShardedTensor from metadata avoiding expensive all-gather
self._construct_sharded_tensor_from_metadata: bool = (
os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0")
== "1"
construct_sharded_tensor_from_metadata_enabled()
)

def prefetch(
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

# pyre-strict

import os

PROPOGATE_DEVICE: bool = False

TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
)


def set_propogate_device(val: bool) -> None:
global PROPOGATE_DEVICE
Expand All @@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None:
def get_propogate_device() -> bool:
global PROPOGATE_DEVICE
return PROPOGATE_DEVICE


def construct_sharded_tensor_from_metadata_enabled() -> bool:
return (
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
)

0 comments on commit 6f4bfe2

Please sign in to comment.