Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TRT-LLM Multigpu Compatibility #2837

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import random
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union

import torch
Expand All @@ -22,6 +23,15 @@
import transformers
from datasets import Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues]

try:
import tensorrt_llm
if tensorrt_llm.mpi_world_size() > 1:
TRTLLM_MULTIGPU = True
else:
TRTLLM_MULTIGPU = False
except:
TRTLLM_MULTIGPU = False

# Allow models to have slightly more tokens than were used in the most verbose CoT in the dataset
_MAX_ANSWER_BUFFER_LENGTH = 10

Expand Down Expand Up @@ -208,6 +218,19 @@ def _get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, example_idx: i
return fewshot_idxs


def _rank_zero_download(dataset_uri, destination_path):
if TRTLLM_MULTIGPU == True:
if tensorrt_llm.mpi_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
else:
while not os.path.exists(destination_path):
time.sleep(0.1)
else:
with dist.local_rank_zero_download_and_wait(destination_path):
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)


class InContextLearningDataset(Dataset):
"""
A base dataset that constructs batches for in-context learning task evaluations.
Expand Down Expand Up @@ -403,9 +426,7 @@ def read_dataset(
assert isinstance(dataset, HFDataset)
dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names)
else:
with dist.local_rank_zero_download_and_wait(destination_path):
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
_rank_zero_download(dataset_uri, destination_path)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
assert isinstance(dataset, HFDataset)
return dataset
Expand Down Expand Up @@ -625,7 +646,6 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
batch[batch_key].append(data_pair[data_key])
if 'continuation_indices' in data_pair:
batch['continuation_indices'].append(data_pair['continuation_indices'])

batch = convert_tokens_to_tensors(batch, self.tokenize_labels)
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch
Expand Down Expand Up @@ -704,7 +724,7 @@ def __init__(
tensor_keys = ['input_ids', 'attention_mask']
list_keys = ['labels']
super().__init__(
padding_side='left',
padding_side='right',
tokenize_labels=False,
static_keys=static_keys,
list_keys=list_keys,
Expand Down Expand Up @@ -1171,7 +1191,6 @@ def _prep_example(
) -> Dict[str, Any]:
"""
Prepares a single example from a HF Dataset into tokenized format with prompt and fewshot examples.

Each task consists of multiple contexts and a single, correct continuation. Will preprend fewshot examples and
prompt if present.

Expand Down Expand Up @@ -1643,9 +1662,7 @@ def partition_dataset_by_category(
assert hasattr(dataset, 'column_names')
dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names)
else:
with dist.local_rank_zero_download_and_wait(destination_path):
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
_rank_zero_download(dataset_uri, destination_path)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
assert isinstance(dataset, HFDataset) or isinstance(dataset, IterableDataset)
assert hasattr(dataset, 'features')
Expand Down
8 changes: 7 additions & 1 deletion composer/trainer/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import Optional, Union

import torch
from torch.cuda.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state
from torch.cuda.amp.grad_scaler import GradScaler, OptState
from torch.optim import Optimizer

from packaging import version
if version.parse(torch.__version__) >= version.parse('2.2.9'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from composer.utils import dist

__all__ = ['ClosureGradScaler']
Expand Down
8 changes: 7 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
import torch.nn as nn
import torch.utils.data
from torch._dynamo import OptimizedModule
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state

from packaging import version
if version.parse(torch.__version__) >= version.parse('2.2.9'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
Expand Down
Loading