Skip to content

Commit

Permalink
Merge pull request #32 from YerevaNN/model_loading
Browse files Browse the repository at this point in the history
Model loading
  • Loading branch information
tigranfah authored Dec 17, 2024
2 parents 8e38015 + 3ead38a commit 48155f2
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 181 deletions.
59 changes: 38 additions & 21 deletions submitit_train.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import submitit
import datetime
import yaml
import os


if __name__ == "__main__":
executor = submitit.AutoExecutor(folder="~/slurm_jobs/titan/job_%j")
n_gpus = 8
n_gpus = 6
node = "h100"
executor.update_parameters(
name="titan", timeout_min=3 * 24 * 60,
name="titan",
timeout_min=24 * 24 * 60,
gpus_per_node=n_gpus,
nodes=1, mem_gb=80, cpus_per_task=n_gpus * 4,
slurm_additional_parameters={
"partition": "h100"
}
nodes=1,
mem_gb=80,
cpus_per_task=n_gpus * 6,
slurm_additional_parameters={"partition": node},
)

jobs = []
with executor.batch():
for _ in range(1):
# train_config = './train_configs/chemlactica_125m.toml'
# train_config = './train_configs/chemlactica_1.3b.toml'
train_config = './train_configs/llama3.2_1b.toml'
train_config = "./train_configs/llama3.2_1b.toml"
# train_config = './train_configs/debug_model.toml'
function = submitit.helpers.CommandFunction([
'python3', '-m', 'torch.distributed.run',
'--nproc_per_node', f'{n_gpus}',
'--rdzv_backend', 'c10d',
'--rdzv_endpoint', 'localhost:0',
'--local-ranks-filter', '0',
'--role', 'rank', '--tee', '3',
'train.py',
'--job.config_file', train_config,
])
print(' '.join(function.command))
function = submitit.helpers.CommandFunction(
[
"python3",
"-m",
"torch.distributed.run",
"--nproc_per_node",
f"{n_gpus}",
"--rdzv_backend",
"c10d",
"--rdzv_endpoint",
"localhost:0",
"--local-ranks-filter",
"0",
"--role",
"rank",
"--tee",
"3",
"train.py",
"--job.config_file",
train_config,
]
)
print(" ".join(function.command))
# subprocess.run(function.command)
job = executor.submit(function)
jobs.append(job)
49 changes: 38 additions & 11 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SaveDone:
pass


def checkpoint_mp(recv, send,log_level):
def checkpoint_mp(recv, send, log_level):
init_logger(log_level)
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
Expand Down Expand Up @@ -236,12 +236,21 @@ def __init__(
for idx, lr_scheduler in enumerate(lr_schedulers):
self.states[f"lr_scheduler_{idx}"] = lr_scheduler


if job_config.model_download_export.to_hf or job_config.model_download_export.to_titan:
self.save_folder = os.path.join(job_config.job.dump_folder, ckpt_config.save_folder)
if (
job_config.model_download_export.to_hf
or job_config.model_download_export.to_titan
):
self.save_folder = os.path.join(
job_config.job.dump_folder, ckpt_config.save_folder
)
else:
self.save_folder = os.path.join(job_config.job.dump_folder, os.path.join(ckpt_config.save_folder, experiment_hash))
self.load_folder = os.path.join(job_config.job.dump_folder, ckpt_config.load_folder)
self.save_folder = os.path.join(
job_config.job.dump_folder,
os.path.join(ckpt_config.save_folder, experiment_hash),
)
self.load_folder = os.path.join(
job_config.job.dump_folder, ckpt_config.load_folder
)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
Expand Down Expand Up @@ -273,7 +282,7 @@ def __init__(
args=(
self.mp_queue_send,
self.mp_queue_recv,
job_config.logging.log_level
job_config.logging.log_level,
),
daemon=True,
)
Expand Down Expand Up @@ -330,7 +339,10 @@ def _save_last_step(self, curr_step: int) -> None:
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")

dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step, self.save_folder))
dcp.save(
self.states,
checkpoint_id=self._create_checkpoint_id(curr_step, self.save_folder),
)
self.reset()

def _should_save(self, curr_step: int, force: bool = False) -> bool:
Expand Down Expand Up @@ -457,7 +469,9 @@ def load(self, step: int = -1) -> bool:
return False
if not os.path.isdir(self.load_folder):
return False
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step, self.load_folder)):
if step != -1 and not os.path.isdir(
self._create_checkpoint_id(step, self.load_folder)
):
return False

if step == -1:
Expand All @@ -473,15 +487,28 @@ def load(self, step: int = -1) -> bool:

# We won't have optimizer states to load, if we are loading a seed checkpoint
states = {"model": self.states["model"]} if step == 0 else self.states
# PyTorch bug: (pytorch/pytorch#138575)
# dcp.load() replaces the values of stateful elements in `states` with new objects
# from loading the checkpoint, in addition to updating the states of the original
# objects from `states` in-place. This is a problem because the state_dict no longer
# refers to the objects being used in the train loop, meaning any future checkpoints
# will not include updates to these objects (such as updated optimizer states, etc.)
original_stateful_states = {
k: v for k, v in states.items() if isinstance(v, Stateful)
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
dcp.load(
states,
checkpoint_id=self._create_checkpoint_id(step, self.load_folder)
original_stateful_states,
checkpoint_id=self._create_checkpoint_id(step, self.load_folder),
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
# bugfix from above: restore the original stateful objects,
# whose states were already updated in-place by dcp.load()
# for k, v in original_stateful_states.items():
# states[k].load_state_dict(v)
return True

def _purge_stale_checkpoints(self):
Expand Down
117 changes: 76 additions & 41 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Dict, List, Optional
import glob
import os
import pickle
from typing import Any, Dict, List, Optional

import numpy as np

Expand All @@ -23,8 +23,8 @@
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
) from e

from torchtitan.tokenizers.tokenizer import Tokenizer
from torchtitan.logging import logger
from torchtitan.tokenizers.tokenizer import Tokenizer
from torchtitan.utils.dataset_utils import chemlactica_style_data_processing

from datasets import load_dataset
Expand All @@ -38,10 +38,9 @@
"c4": "allenai/c4",
"chemlactica_train_mini": "test/assets/chemlactica_train_mini",
"chemlactica_train": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form/train_rdkit_computed_rel+form",

# valid
"chemlactica_valid": "/nfs/dgx/raid/chem/data/rdkit_computed_rel+form",
"chemlactica_valid_mini": "test/assets/chemlactica_valid_mini"
"chemlactica_valid_mini": "test/assets/chemlactica_valid_mini",
}

_supported_data_processing_styles = {
Expand All @@ -57,7 +56,7 @@ class HuggingFaceDataset(IterableDataset, Stateful):
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
data_processing_style (str): name of the data process style
data_processing_style (str): name of the data process style
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
Expand All @@ -79,7 +78,8 @@ class HuggingFaceDataset(IterableDataset, Stateful):
}
Example use (c4):
>>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None, data_processing_style="chemlactica_style", tokenizer=tokenizer)
>>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None,
data_processing_style="chemlactica_style", tokenizer=tokenizer)
>>> for batch in Dataloader(ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
Expand All @@ -96,7 +96,7 @@ def __init__(
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
special_mode = None,
special_mode=None,
) -> None:
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
Expand All @@ -123,13 +123,18 @@ def __init__(
ds = load_dataset(dataset_path, split="train")
else:
dataset_files = glob.glob(os.path.join(dataset_path, "*.jsonl"))
ds = load_dataset("text", data_files=dataset_files, split="train", streaming="valid" not in dataset_name)

ds = load_dataset(
"text",
data_files=dataset_files,
split="train",
streaming="valid" not in dataset_name,
)

# try:
data_processing_fn = _supported_data_processing_styles[data_processing_style]
# except KeyError as e:
# raise ValueError(f"Unsupported data processing style: {data_processing_style}")

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
Expand Down Expand Up @@ -159,35 +164,42 @@ def __iter__(self):

while True:
if self.special_mode == "yield_tensor":
logger.info("yielding tensor")
yield random_tensor, random_tensor
random_tensor = torch.randint(low=1, high=2, size=(self.seq_len,))
continue

for sample_json in self._get_data_iter():
sample_text = self.data_processing_fn(sample_json["text"], self.rng, self.representation_type)
if self.number_of_samples_to_log > 0:
logger.info(f"Sample: {sample_text}")
logger.info("yielding tensor")
self.number_of_samples_to_log -= 1
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
random_tensor = torch.randint(
low=1, high=2, size=(max_buffer_token_len,)
)
yield random_tensor[:-1], random_tensor[1:]
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
for sample_json in self._get_data_iter():
sample_text = self.data_processing_fn(
sample_json["text"], self.rng, self.representation_type
)
if self.number_of_samples_to_log > 0:
logger.info(f"Sample: {sample_text}")
self.number_of_samples_to_log -= 1
sample_tokens = self._tokenizer.encode(
sample_text, bos=True, eos=True
)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def _get_data_iter(self):
if self._sample_idx == 0:
Expand Down Expand Up @@ -218,7 +230,15 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""
def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, pin_memory: bool, num_workers: int):

def __init__(
self,
dp_rank: int,
hf_ds: IterableDataset,
batch_size: int,
pin_memory: bool,
num_workers: int,
):
super().__init__(hf_ds, batch_size, num_workers=num_workers)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"
Expand Down Expand Up @@ -253,10 +273,25 @@ def build_hf_data_loader(
infinite: bool = True,
pin_memory: bool = False,
num_workers: int = 2,
special_mode = None,
special_mode=None,
):
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, data_processing_style, tokenizer, representation_type, seq_len, world_size, rank, infinite, special_mode
dataset_name,
dataset_path,
data_processing_style,
tokenizer,
representation_type,
seq_len,
world_size,
rank,
infinite,
special_mode,
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)
return DPAwareDataLoader(
rank,
hf_ds,
batch_size=batch_size,
pin_memory=pin_memory,
num_workers=num_workers,
)
Loading

0 comments on commit 48155f2

Please sign in to comment.