Skip to content

Commit

Permalink
Tensor Parallelism v2 (mosaicml#3335)
Browse files Browse the repository at this point in the history
* Revert "Revert TP integration (mosaicml#3328)"

This reverts commit f154ea6.

* fix device mesh

* fix backcompat

* lint

* fix

* elif

* fix test

* fix error

* lit

* filter warnings

* add filter

* link issue

* lint

* make version

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
mvpatel2000 and Your Name authored May 30, 2024
1 parent 54c2f88 commit c9a51d4
Show file tree
Hide file tree
Showing 32 changed files with 1,220 additions and 850 deletions.
5 changes: 3 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
249 changes: 183 additions & 66 deletions composer/core/state.py

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Distributed training."""

from composer.distributed.deepspeed import fix_batch_precision_for_deepspeed, parse_deepspeed_config
from composer.distributed.dist_strategy import (
DDPSyncStrategy,
ddp_sync_context,
prepare_ddp_module,
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
'parse_deepspeed_config',
'DDPSyncStrategy',
'ddp_sync_context',
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'set_fsdp_default',
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from composer.core import Batch, Precision, State
from composer.utils import dist, map_collection

__all__ = ['_fix_batch_precision_for_deepspeed', '_parse_deepspeed_config']
__all__ = ['fix_batch_precision_for_deepspeed', 'parse_deepspeed_config']


def _add_batch_config(config: Dict[str, Any], state: State):
Expand Down Expand Up @@ -105,7 +105,7 @@ def _add_precision_config(config: Dict[str, Any], state: State):
config['bf16'] = cast(Dict[str, Any], {'enabled': True})


def _parse_deepspeed_config(
def parse_deepspeed_config(
config: Dict[str, Any],
state: State,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -160,7 +160,7 @@ def _convert_fp32_tensor_to_bf16(tensor: torch.Tensor):
return tensor


def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
def fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
"""Ensures that a batch is properly formatted for DeepSpeed precisions, if active.
.. note:: Just because the precision is set to FP16 doesn't mean the entire batch can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

from composer.core import Precision, State
from composer.devices import Device
from composer.trainer.meta_safe_apply import meta_safe_apply
from composer.trainer.mosaic_fsdp import patch_pytorch
from composer.trainer.mosaic_fsdp_utils import (
from composer.distributed.meta_safe_apply import meta_safe_apply
from composer.distributed.mosaic_fsdp import (
BACKWARD_PREFETCH_MAP,
SHARDING_MAP,
_set_custom_fsdp_module_kwargs,
get_cpu_offload,
get_mixed_precision,
set_custom_fsdp_module_kwargs,
)
from composer.utils import StringEnum, dist, ensure_tuple

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module']
__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,35 +141,6 @@ def prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) ->
)


def set_fsdp_default(fsdp_config: Dict[str, Any]):
"""Modify fsdp_config to set default values for missing keys."""
fsdp_config.setdefault('activation_checkpointing', False)
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('te_checkpoint_wrapper', False)
fsdp_config.setdefault('te_shard_fp8_weight', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
fsdp_config.setdefault('forward_prefetch', False)
fsdp_config.setdefault('forward_prefetch_limit', 1)
fsdp_config.setdefault('ignored_modules', None)
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
fsdp_config.setdefault('load_monolith_rank0_only', False)
fsdp_config.setdefault('load_planner', None)
fsdp_config.setdefault('mixed_precision', 'DEFAULT')
fsdp_config.setdefault('process_group', None)
fsdp_config.setdefault('save_planner', None)
fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}')
fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD')
fsdp_config.setdefault('state_dict_type', 'full')
fsdp_config.setdefault('sync_module_states', False)
fsdp_config.setdefault('use_orig_params', True)
fsdp_config.setdefault('verbose', False)
return fsdp_config


def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
fsdp_wrapped_named_params: Iterator[Tuple[str, torch.nn.Parameter]],
non_wrapped_param_names_to_group_num: Dict[str, int],
Expand Down Expand Up @@ -209,6 +179,22 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())]


def prepare_tp_module(
model: torch.nn.Module,
tp_config: Dict[str, Any],
) -> None:
"""Prepare a module (assumed ComposerModel) for use with tensor parallel."""
from torch.distributed.tensor.parallel import parallelize_module

device_mesh = tp_config['device_mesh']
layer_plan = tp_config['layer_plan']
parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan=layer_plan,
)


def prepare_fsdp_module(
model: torch.nn.Module,
optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]],
Expand All @@ -229,10 +215,6 @@ def prepare_fsdp_module(
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
patch_pytorch()

set_fsdp_default(fsdp_config)

# Check sync_module_states is True for mixed initialization or HSDP
if fsdp_config['sync_module_states'] == False:
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
Expand Down Expand Up @@ -319,31 +301,26 @@ def sync_hook(*args):
sharding_strategy = SHARDING_MAP[sharding_map_key]

kwargs = {}
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
if 'device_mesh' in fsdp_config:
device_mesh_size = len(fsdp_config['device_mesh'])
if sharding_strategy in [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
] and device_mesh_size != 1:
raise ValueError(
f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh '
f'of size 1 but got device mesh size of {device_mesh_size}.',
)
elif sharding_strategy in [
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
] and device_mesh_size != 2:
raise ValueError(
f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh '
f'of size 2 but got device mesh size of {device_mesh_size}.',
)
from torch.distributed._tensor import init_device_mesh
kwargs['device_mesh'] = init_device_mesh(
'cuda',
tuple([int(x) for x in fsdp_config['device_mesh']]),
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0') and 'device_mesh' in fsdp_config:
if fsdp_config['process_group'] is not None:
warnings.warn(
'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.',
)
else:
ndim = fsdp_config['device_mesh'].ndim
if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD:
sharding_strategy = ShardingStrategy.FULL_SHARD
warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.')
elif ndim == 1 and sharding_strategy == ShardingStrategy._HYBRID_SHARD_ZERO2:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
warnings.warn('_HYBRID_SHARD_ZERO2 is not supported with 1D device mesh. Using SHARD_GRAD_OP instead.')
elif ndim == 2 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
warnings.warn('SHARD_GRAD_OP is not supported with 2D device mesh. Using _HYBRID_SHARD_ZERO2 instead.')
elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.')
kwargs['device_mesh'] = fsdp_config['device_mesh']

cpu_offload = get_cpu_offload(cpu_offload=fsdp_config['cpu_offload'])

Expand Down Expand Up @@ -382,7 +359,7 @@ def sync_hook(*args):
process_group = None
if fsdp_config['process_group'] is not None:
process_group_dict = {'process_group': fsdp_config['process_group']}
process_group = _set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()]
activation_checkpointing = fsdp_config['activation_checkpointing']
activation_cpu_offload = fsdp_config['activation_cpu_offload']
Expand Down Expand Up @@ -556,7 +533,7 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
ret = obj.fsdp_wrap_fn(module)
if isinstance(ret, dict):
ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache)
ret = set_custom_fsdp_module_kwargs(ret, process_group_cache)
if ret and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
Expand Down
File renamed without changes.
Loading

0 comments on commit c9a51d4

Please sign in to comment.