diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index 9d5947e7f5..d7c0672f9c 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -8,6 +8,7 @@ import json import os import random +import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union import torch @@ -22,6 +23,15 @@ import transformers from datasets import Dataset as HFDataset # pyright: ignore[reportGeneralTypeIssues] +try: + import tensorrt_llm + if tensorrt_llm.mpi_world_size() > 1: + TRTLLM_MULTIGPU = True + else: + TRTLLM_MULTIGPU = False +except: + TRTLLM_MULTIGPU = False + # Allow models to have slightly more tokens than were used in the most verbose CoT in the dataset _MAX_ANSWER_BUFFER_LENGTH = 10 @@ -208,6 +218,19 @@ def _get_fewshot_sample_idxs(dataset_size: int, num_fewshot: int, example_idx: i return fewshot_idxs +def _rank_zero_download(dataset_uri, destination_path): + if TRTLLM_MULTIGPU == True: + if tensorrt_llm.mpi_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + else: + while not os.path.exists(destination_path): + time.sleep(0.1) + else: + with dist.local_rank_zero_download_and_wait(destination_path): + if dist.get_local_rank() == 0: + get_file(dataset_uri, destination_path, overwrite=True) + + class InContextLearningDataset(Dataset): """ A base dataset that constructs batches for in-context learning task evaluations. @@ -403,9 +426,7 @@ def read_dataset( assert isinstance(dataset, HFDataset) dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names) else: - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) + _rank_zero_download(dataset_uri, destination_path) dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) assert isinstance(dataset, HFDataset) return dataset @@ -625,7 +646,6 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: batch[batch_key].append(data_pair[data_key]) if 'continuation_indices' in data_pair: batch['continuation_indices'].append(data_pair['continuation_indices']) - batch = convert_tokens_to_tensors(batch, self.tokenize_labels) batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id) return batch @@ -704,7 +724,7 @@ def __init__( tensor_keys = ['input_ids', 'attention_mask'] list_keys = ['labels'] super().__init__( - padding_side='left', + padding_side='right', tokenize_labels=False, static_keys=static_keys, list_keys=list_keys, @@ -1171,7 +1191,6 @@ def _prep_example( ) -> Dict[str, Any]: """ Prepares a single example from a HF Dataset into tokenized format with prompt and fewshot examples. - Each task consists of multiple contexts and a single, correct continuation. Will preprend fewshot examples and prompt if present. @@ -1643,9 +1662,7 @@ def partition_dataset_by_category( assert hasattr(dataset, 'column_names') dataset = dataset.map(dataset_parsing_func, remove_columns=dataset.column_names) else: - with dist.local_rank_zero_download_and_wait(destination_path): - if dist.get_local_rank() == 0: - get_file(dataset_uri, destination_path, overwrite=True) + _rank_zero_download(dataset_uri, destination_path) dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False) assert isinstance(dataset, HFDataset) or isinstance(dataset, IterableDataset) assert hasattr(dataset, 'features') diff --git a/composer/trainer/_scaler.py b/composer/trainer/_scaler.py index e36057b1b5..eb45443d2b 100644 --- a/composer/trainer/_scaler.py +++ b/composer/trainer/_scaler.py @@ -5,9 +5,15 @@ from typing import Optional, Union import torch -from torch.cuda.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state +from torch.cuda.amp.grad_scaler import GradScaler, OptState from torch.optim import Optimizer +from packaging import version +if version.parse(torch.__version__) >= version.parse('2.2.9'): + from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore +else: + from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore + from composer.utils import dist __all__ = ['ClosureGradScaler'] diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 46843efa50..915e76d95a 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -42,7 +42,13 @@ import torch.nn as nn import torch.utils.data from torch._dynamo import OptimizedModule -from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state + +from packaging import version +if version.parse(torch.__version__) >= version.parse('2.2.9'): + from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore +else: + from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore + from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler