Skip to content

Commit

Permalink
Merge branch 'main' into pickle_scaling_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn authored Aug 30, 2023
2 parents a750b62 + aed29d6 commit 0367bdc
Show file tree
Hide file tree
Showing 40 changed files with 2,195 additions and 451 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ inline-quotes = single
multiline-quotes = double
docstring-quotes = double
docstring-convention = google
per-file-ignores = *.py:D202
13 changes: 13 additions & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,19 @@ jobs:
steps:
- name: Checkout msamp
uses: actions/checkout@v2
with:
submodules: true
- name: Install MSCCL
run: |
cd third_party/msccl
make -j src.build NVCC_GENCODE="\
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_80,code=sm_80 \
-gencode=arch=compute_90,code=sm_90"
make install
- name: Install dependencies
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
python3 -m pip install --upgrade pip
python3 -m pip install .[test]
make postinstall
Expand All @@ -38,6 +49,8 @@ jobs:
python3 setup.py lint
- name: Run unit tests
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
python3 setup.py test
# - name: Report coverage results
# run: |
Expand Down
8 changes: 4 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "third_party/nccl"]
path = third_party/nccl
url = https://github.com/yzygitzh/nccl.git
branch = ziyyang/fp8-support
[submodule "third_party/msccl"]
path = third_party/msccl
url = https://github.com/Azure/msccl-executor-nccl
branch = msccl-v2.17
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ lint: cpplint mdlint
python3 setup.py lint

postinstall:
cd msamp/operators/dist_op && pip install -v . && cd -
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/optim && pip install -v . && cd -
48 changes: 35 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Features:

- Support O1 optimization: Apply FP8 to weights and weight gradients and support FP8 in communication.
- Support O2 optimization: Support FP8 for two optimizers(Adam and AdamW).
- Provide three training examples using FP8: Swin-Transformer, DeiT and RoBERTa.
- Support O3 optimization: Support FP8 in DeepSpeed ZeRO optimizer.
- Provide four training examples using FP8: Swin-Transformer, DeiT, RoBERTa and GPT-3.

MS-AMP has the following benefit comparing with Transformer Engine:

Expand All @@ -28,10 +29,10 @@ MS-AMP has the following benefit comparing with Transformer Engine:
- CUDA version 11 or later (which can be checked by running `nvcc --version`).
- PyTorch version 1.13 or later (which can be checked by running `python -c "import torch; print(torch.__version__)"`).

We strongly recommend using [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). For example, to start PyTorch 1.13 container, run the following command:
We strongly recommend using [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). For example, to start PyTorch 1.14 container, run the following command:

```
sudo docker run -it -d --name=msamp --privileged --net=host --ipc=host --gpus=all nvcr.io/nvidia/pytorch:22.09-py3 bash
sudo docker run -it -d --name=msamp --privileged --net=host --ipc=host --gpus=all nvcr.io/nvidia/pytorch:22.12-py3 bash
sudo docker exec -it msamp bash
```

Expand All @@ -45,10 +46,10 @@ cd MS-AMP
git submodule update --init --recursive
```

If you want to train model with multiple GPU, you need to install specific nccl to support FP8.
If you want to train model with multiple GPU, you need to install MSCCL to support FP8.

```bash
cd third_party/nccl
cd third_party/msccl

# V100
make -j src.build NVCC_GENCODE="-gencode=arch=compute_70,code=sm_70"
Expand All @@ -61,18 +62,26 @@ apt-get update
apt install build-essential devscripts debhelper fakeroot
make pkg.debian.build
dpkg -i build/pkg/deb/libnccl2_*.deb
dpkg -i build/pkg/deb/libnccl-dev_2*.deb

cd -
```

Then, you can install MS-AMP from source.

```
```bash
python3 -m pip install --upgrade pip
python3 -m pip install .
make postinstall
```

Before using MS-AMP, you need to preload msampfp8 library and it's depdencies:

```bash
NCCL_LIBRARY=/usr/lib/x86_64-linux-gnu/libnccl.so # Change as needed
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:${NCCL_LIBRARY}:${LD_PRELOAD}"
```

After that, you can verify the installation by running:

```bash
Expand Down Expand Up @@ -113,7 +122,17 @@ for batch_idx, (data, target) in enumerate(train_loader):
scaler.step(optimizer)
```

A runnable, comprehensive MNIST example demonstrating good practices can be found [here](./examples). For more examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).
For applying MS-AMP to DeepSpeed ZeRO, add a "msamp" section in deepspeed config file:

```json
"msamp": {
"enabled": true,
"opt_level": "O3"
}
```

Runnable, comprehensive examples demonstrating good practices can be found [here](./examples).
For more examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).

### Optimization Level

Expand All @@ -123,13 +142,16 @@ Currently MS-AMP supports two optimization levels: O1 and O2. Try both, and see

- O2: From O1 to O2, our main focus is on enabling the use of low-bit data formats for auxiliary tensors in the Adam/AdamW optimizer without any loss in accuracy. Specifically, we are able to maintain accuracy by representing the first-order optimizer state in FP8 and the second-order state in FP16. This optimization has the potential to save up to 62.5% of GPU memory for the optimizer when the model size is particularly large.

- O3: This optimization level is specifically designed for ZeRO-optimizer in advanced distributed traning framework DeepSpeed. ZeRO separates model weights into regular weights and master weights, with the former used for network forward/backward on each GPU, and the latter used for model updating in the optimizer. This separation allows us to use 8-bit data precision for regular weights and weight broadcasting, which reduces GPU memory and bandwidth usage even further.

Here are details of different MS-AMP optimization levels:
| Optimization Level | Computation(GEMM) | Comm | Weight | Weight Gradient | Optimizer States |
| ------------------- | ----------- | ----- | ------ | --------------- | ---------------- |
| FP16 AMP | FP16 | FP32 | FP32 | FP32 | FP32+FP32 |
| Nvidia TE | FP8 | FP32 | FP32 | FP32 | FP32+FP32 |
| MS-AMP O1 | FP8 | FP8 | FP16 | FP8 | FP32+FP32 |
| MS-AMP O2 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |
| Optimization Level | Computation(GEMM) | Comm | Weight | Master Weight | Weight Gradient | Optimizer States |
| ------------------- | ----------- | ----- | ------ | ------------- | --------------- | ---------------- |
| FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
| Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
| MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32 |
| MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16 |
| MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |

## Performance

Expand Down
7 changes: 4 additions & 3 deletions dockerfile/torch1.14-cuda11.8.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ RUN apt-get update && \

ARG NUM_MAKE_JOBS=
ENV PATH="${PATH}" \
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" \
PYTHONOPTIMIZE=1
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}"

WORKDIR /opt/msamp

ADD third_party third_party
RUN cd third_party/nccl && \
RUN cd third_party/msccl && \
make -j ${NUM_MAKE_JOBS} src.build NVCC_GENCODE="\
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_80,code=sm_80 \
Expand All @@ -55,3 +54,5 @@ RUN python3 -m pip install --upgrade pip && \
ADD . .
RUN python3 -m pip install . && \
make postinstall

ENV LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
7 changes: 4 additions & 3 deletions dockerfile/torch2.1-cuda12.1.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ RUN apt-get update && \

ARG NUM_MAKE_JOBS=
ENV PATH="${PATH}" \
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}" \
PYTHONOPTIMIZE=1
LD_LIBRARY_PATH="/usr/local/lib:${LD_LIBRARY_PATH}"

WORKDIR /opt/msamp

ADD third_party third_party
RUN cd third_party/nccl && \
RUN cd third_party/msccl && \
make -j ${NUM_MAKE_JOBS} src.build NVCC_GENCODE="\
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_80,code=sm_80 \
Expand All @@ -55,3 +54,5 @@ RUN python3 -m pip install --upgrade pip && \
ADD . .
RUN python3 -m pip install . && \
make postinstall

ENV LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
2 changes: 1 addition & 1 deletion msamp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def initialize(model, optimizer=None, opt_level='O1'): # noqa: C901
return cast_model, cast_optimizer


__version__ = '0.1.0'
__version__ = '0.2.0'
__author__ = 'Microsoft'
__all__ = ['clip_grad_norm_', 'initialize']
2 changes: 2 additions & 0 deletions msamp/common/dtype/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ def _get_fp_max(exp, man, inf_existed=True):
Dtypes.kfloat8_e4m3: Floating._get_fp_max(exp=4, man=3, inf_existed=False),
Dtypes.kfloat8_e5m2: Floating._get_fp_max(exp=5, man=2),
Dtypes.kfloat16: Floating._get_fp_max(exp=5, man=10), # E5M10
Dtypes.kbfloat16: Floating._get_fp_max(exp=8, man=7), # E8M7
Dtypes.kfloat32: Floating._get_fp_max(exp=8, man=23), # E8M23
}
28 changes: 10 additions & 18 deletions msamp/common/tensor/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.distributed as dist

from msamp.common.dtype import Dtypes, Floating
from msamp.common.dtype import Dtypes
from msamp.common.utils import DistUtil
from msamp.common.utils import TransformerEngineWrapper

Expand Down Expand Up @@ -76,8 +76,9 @@ def cast_to_fp16(input, meta, sync=False):
Return:
torch.Tensor: tensor whose dtype is torch.float16.
"""
meta.amax[0] = input.abs().max()
in_time = meta.is_in_time_scaling()
if in_time or sync:
meta.amax[0] = input.abs().max()
if sync:
# convert NAN to INF since NCCL-ReduceMax ignores NAN
# notice: nan and posinf must be INF
Expand All @@ -86,13 +87,12 @@ def cast_to_fp16(input, meta, sync=False):
if world_size > 1:
dist.all_reduce(meta.amax[0], op=dist.ReduceOp.MAX, group=meta.group)
if in_time or sync:
# notice: we scale the tensor with qtype FP8-E4M3.
meta.reset_scaling_factor(qtype=Dtypes.kfloat8_e4m3)
meta.scale.clamp_(max=Floating.qfp_max[meta.qtype])

meta.reset_scaling_factor()
meta.scale_inv.data.copy_(torch.reciprocal(meta.scale)) # scale_inv = 1 / scale
input_fp16 = (input * meta.scale).to(torch.float16)
return input_fp16
dtype = Dtypes.get_dtype_from_qtype(meta.qtype)
# reshape scale to the tensor with the shape of (1,)
# to avoid overflow when scale is larger than the maximum of qtype
return (input * meta.scale.view((1, ))).to(dtype)

@staticmethod
def cast_from_fp8(input, meta, otype):
Expand All @@ -111,13 +111,12 @@ def cast_from_fp8(input, meta, otype):
if input.dtype != torch.uint8:
raise ValueError('The dtype of input tensor is not torch.uint8.')

shape = input.shape
return TransformerEngineWrapper.cast_from_fp8(
input.view(1, -1),
meta.scale_inv,
meta.qtype,
otype,
).view(shape)
).view_as(input)

@staticmethod
def cast_from_fp16(input, meta, otype):
Expand All @@ -132,13 +131,6 @@ def cast_from_fp16(input, meta, otype):
torch.Tensor: tensor whose type is otype.
"""
dtype = Dtypes.get_dtype_from_qtype(otype)
if input.dtype == dtype:
# return a copy
input = input.clone()
else:
input = input.to(dtype)
if meta.scale_inv != 1:
input.mul_(meta.scale_inv)
return input
return (input * meta.scale_inv.view((1, ))).to(dtype)

cast_from_fp32 = cast_from_fp16
24 changes: 24 additions & 0 deletions msamp/common/tensor/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ def is_in_time_scaling(self):
"""
return ScalingMeta.in_time_scaling and (self.window_size == 1 or self.is_warmup())

@staticmethod
def in_time_scaling_context(enabled):
"""A context manager to set in_time_scaling flag.
Args:
bool: in_time_scaling flag.
Returns:
InTimeScalingContext: A context manager to set in_time_scaling flag.
"""
class InTimeScalingContext:
def __init__(self, enabled):
self.enabled = enabled
self.in_time_scaling = ScalingMeta.in_time_scaling

def __enter__(self):
self.in_time_scaling = ScalingMeta.in_time_scaling
ScalingMeta.in_time_scaling = self.enabled

def __exit__(self, exc_type, exc_val, exc_tb):
ScalingMeta.in_time_scaling = self.in_time_scaling

return InTimeScalingContext(enabled=enabled)

def reset_scaling_factor(self, qtype=None):
"""Reset scaling factor.
Expand Down
Loading

0 comments on commit 0367bdc

Please sign in to comment.