From dec6718662f13f610f74df1afb0d2c7750591d5a Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 29 Aug 2024 07:49:50 -0700 Subject: [PATCH] Init Dist Default None (#3585) --- composer/utils/dist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 0515828a10..26b135217f 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -498,7 +498,7 @@ def is_initialized(): return dist.is_initialized() -def initialize_dist(device: Union[str, Device], timeout: float = 300.0) -> None: +def initialize_dist(device: Optional[Union[str, Device]] = None, timeout: float = 300.0) -> None: """Initialize the default PyTorch distributed process group. This function assumes that the following environment variables are set: @@ -517,9 +517,9 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0) -> None: .. seealso:: :func:`torch.distributed.init_process_group` Args: - device (str | Device): The device from which the distributed backend is + device (Optional[str | Device] ): The device from which the distributed backend is interpreted. Either a string corresponding to a device (one of ``'cpu'``, - ``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`. + ``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`. (default: ``None``) timeout (float, optional): The timeout for operations executed against the process group, expressed in seconds. (default: ``300.0``). """