From fa7845a6bbee107832f84e7e5d4c98bd9fbec637 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:15:01 +0800 Subject: [PATCH] add custom parquet and interleave datasets (#130) * add custom parquet and interleave datasets * update datalaoding * take multi process dataloader in consideration * remove log * remove lin * use parquet file direclty * use parquet file direclty * fix read tabl * fix tests * raise error * remove probabilites from state dict * allow to process math web * skip tests * should fix open web math * should fix open web math * remove unused data key * change default path * conversion script * mass conversion scripts --------- Co-authored-by: Jackmin801 --- Untitled.ipynb | 162 +++++++++++++++++++++++ configs/test.toml | 6 +- pyproject.toml | 5 +- scripts/convert_dl_ckpt.sh | 35 +++++ scripts/convert_dl_state.py | 139 ++++++++++++++++++++ src/zeroband/data.py | 247 +++++++++++++++++++++++++++++------- tests/test_data.py | 166 +++++++++++++++++++++++- uv.lock | 17 +++ 8 files changed, 721 insertions(+), 56 deletions(-) create mode 100644 Untitled.ipynb create mode 100755 scripts/convert_dl_ckpt.sh create mode 100755 scripts/convert_dl_state.py diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 00000000..cdcc7c02 --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "495e4800-4cbc-4975-a915-f740ece8eddc", + "metadata": {}, + "outputs": [], + "source": [ + "import pyarrow.parquet as pq\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8449146c-c171-4d5f-8b53-32684536e415", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Method 1: Read entire file into a Table\n", + "table = pq.read_table(\"data0/train-00000.parquet\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "273411f2-6714-4b20-95b3-9a876c096556", + "metadata": {}, + "outputs": [], + "source": [ + "ds = pq.ParquetDataset([os.path.join(\"data0\", f) for f in os.listdir(\"data0\")], memory_map=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c4ef983d-86a2-4fff-bcad-7501a260a5b1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "text: string\n", + "-- schema metadata --\n", + "huggingface: '{\"info\": {\"features\": {\"text\": {\"dtype\": \"string\", \"_type\":' + 12" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.schema" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "412f2c4a-feca-4e98-a9e0-0d65d06020a3", + "metadata": {}, + "outputs": [], + "source": [ + "table = ds.read()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "56dd0338-a139-465e-b8be-7536ccba2978", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "400000" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(table)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "feaab25a-36c5-40db-a0f3-c2264fd024cd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "400000" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table.num_rows" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "79b36ada-8a63-4f0c-b532-bd232add7ff3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'MIDI: The Musical Language of the Digital Age\\nMIDI (Musical Instrument Digital Interface) is a standard that allows electronic musical instruments, computers, and other devices to communicate with each other. It was developed in the early 1980s, and has since become the de facto standard for digital music production.\\nMIDI is not a recording format, like WAV or MP3. Instead, it is a way of sending instructions from one device to another. These instructions can include things like note pitches, velocity, and duration. MIDI data can be recorded and played back on a computer, or it can be sent directly to a synthesizer or other MIDI-compatible device.\\nMIDI has revolutionized the way music is made. It allows musicians to create complex arrangements and sounds without having to be skilled in traditional music notation. It also makes it possible to collaborate on music projects with musicians from all over the world.\\n6 benefits of using MIDI\\n- MIDI is a universal language. MIDI is supported by virtually all electronic musical instruments and computers, making it easy to connect different devices and create music.\\n- MIDI is versatile. MIDI can be used to control a wide variety of synthesizers, samplers, and other electronic instruments. This makes it possible to create a wide range of sounds and effects.\\n- MIDI is easy to use. MIDI is a simple protocol that can be learned quickly. This makes it a great option for beginners who are just starting to learn about music production.\\n- MIDI is efficient. MIDI data is small and compact, making it easy to share and store. This makes it a great option for collaborating with other musicians or sharing your music online.\\n- MIDI is future-proof. MIDI is a standard that has been around for decades and is still widely used today. This means that your MIDI files will be compatible with future music production software and hardware.\\n- MIDI is affordable. MIDI is a very affordable way to create and produce music. You can get started with MIDI for just a few dollars.\\nVirtual Instruments: The Power of Creation\\nVirtual instruments are software synthesizers that can be used to create a wide variety of sounds. They are typically used in conjunction with MIDI controllers to allow musicians to play and control them in real time.\\nVirtual instruments offer a number of advantages over traditional hardware synthesizers. They are typically more affordable, easier to use, and more versatile. They also allow musicians to create sounds that would be impossible to produce with a physical synthesizer.\\nFrom highly accessible options to a more virtuoso approach MIDI instruments are the future of music Production Integration with Music Software As discussed earlier MIDI controllers can seamlessly integrate with popular music Production software and all DAWs remember those With some tools connecting via bluetooth students are free to This guide will give you everything you need to know for making MIDI a powerful part of your processfrom basic MIDI connections to using MIDI effectively in your music Production workflow If youre already using MIDI Ill also cover some useful tips to help you get the most out of your current setupVSTis are separate to VSTfx in that they do no alter sounds they generate sounds Generally using MIDI input data to recognize and recreate melodies and musical tones\\nthese instruments often work both as the primary plugin on a DAW Digital Audio Workstation software track or as a standalone program on your computerStep 1 Install the software on your computer and connect any hardware controllers you plan to use Step 2 Open your DAW and navigate to the plugins section Locate and open the VIP plugin Step 3 Once the VIP plugin is open you will see a browser window that displays your plugin and virtual instrument libraryWhat is a VST The basics of Virtual Studio Technology amp how it integrates with your DAW The significance of VST plugins and their myriad types The beauty of VST instruments and the different versions available An indepth look at VST effects Pro tips to help you make the most out of your VST pluginsWhat is a DAW Now that you have your\\nBeyerdynamic M90 PRO X microphone youll need to send the signal to your DAW Because digital audio workstation DAW is about seven syllables too many most people just say DAWMIDI is a method of sending data that allows you to create music digitally It doesnt matter if you are a bedroom producer or a Grammy Award winner MIDI will play a big part in your working life In this article we will go into more detail about what MIDI is how it works and how to use itNo matter your skill level find your creative outlet with our free song maker and beatmaking app Enter our multitrack Studio an intuitive Digital Audio Workstation DAW to record edit and remix your music Easily record music on the go or build a beat with loops and samples from our royaltyfree sound packs\\n6 benefits of using virtual instruments\\n- Virtual instruments are affordable. Virtual instruments are typically much more affordable than hardware synthesizers. This makes them a great option for budget-minded musicians.\\n- Virtual instruments are easy to use. Virtual instruments are typically very easy to use, even for beginners. This makes them a great way to get started with music production.\\n- Virtual instruments are versatile. Virtual instruments can be used to create a wide variety of sounds. This makes them a great option for musicians of all genres.\\n- Virtual instruments are portable. Virtual instruments can be used on any computer with a sound card. This makes them a great option for musicians who need to be able to create music on the go.\\n- Virtual instruments are constantly evolving. New virtual instruments are being released all the time. This means that musicians always have access to the latest and greatest sounds.\\n- Virtual instruments are compatible with any DAW. Virtual instruments can be used with any digital audio workstation (DAW). This makes them a great option for musicians who already have a DAW that they like.\\nMIDI and virtual instruments have revolutionized the way music is made. They have made it possible for anyone to create music, regardless of their skill level or budget. They have also opened up new possibilities for collaboration and creativity. If you are interested in getting started with music production, MIDI and virtual instruments are a great place to start.'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table[\"text\"][4000]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68723e22-4c0e-4592-bfe4-4e1d8e43610b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/configs/test.toml b/configs/test.toml index 99b6c832..46abc536 100644 --- a/configs/test.toml +++ b/configs/test.toml @@ -7,9 +7,9 @@ micro_bs = 4 # change this base on the gpu [data] seq_length = 8192 -dataset_name_or_paths = "PrimeIntellect/fineweb-edu,PrimeIntellect/fineweb,PrimeIntellect/StackV1-popular,mlfoundations/dclm-baseline-1.0-parquet,open-web-math/open-web-math" -dataset_ratio = "55:10:20:10:5" -num_workers = 8 +dataset_name_or_paths = "/data/datasets/open-web-math" +dataset_ratio = "100" +num_workers = 1 [optim] batch_size = 128 diff --git a/pyproject.toml b/pyproject.toml index f7003637..e15087c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "torchdata>=0.8.0", "fsspec[gcs]>=2024.3.1", "ninja", - "zstandard" + "zstandard", + "pyarrow", ] [project.optional-dependencies] @@ -37,4 +38,4 @@ allow-direct-references = true # allow direct references to git repos in depende line-length = 120 [tool.uv] -dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0"] +dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0", "faker"] diff --git a/scripts/convert_dl_ckpt.sh b/scripts/convert_dl_ckpt.sh new file mode 100755 index 00000000..efc1defc --- /dev/null +++ b/scripts/convert_dl_ckpt.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -e + +# Wrapper script to run the Python command on 8 checkpoints in parallel +# Usage: ./convert_all.sh /data/10b/step_50800/diloco_0/data + +# Input path prefix +INPUT_PATH=$1 + +# Run the commands for each checkpoint in parallel +for i in {0..7}; do + CHECKPOINT_PATH="${INPUT_PATH}/_${i}.pt" + BACKUP_PATH="${INPUT_PATH}/_${i}_old.pt" + TMP_PATH="${INPUT_PATH}/_${i}_tmp.pt" + + if [ -f "$BACKUP_PATH" ]; then + echo "Checkpoint ${CHECKPOINT_PATH} has already been processed, skipping." & + else + ( + uv run python scripts/convert_dl_state.py @configs/10B/H100.toml \ + --input_path "$CHECKPOINT_PATH" \ + --output_path "$TMP_PATH" \ + --rank "$i" \ + --world_size 8 && \ + mv "$CHECKPOINT_PATH" "$BACKUP_PATH" && \ + mv "$TMP_PATH" "$CHECKPOINT_PATH" && \ + echo "Processed ${CHECKPOINT_PATH} and moved to ${BACKUP_PATH}" + ) & + fi +done + +# Wait for all background jobs to complete +wait + +echo "All checkpoints processed" diff --git a/scripts/convert_dl_state.py b/scripts/convert_dl_state.py new file mode 100755 index 00000000..3fe8d004 --- /dev/null +++ b/scripts/convert_dl_state.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# coding: utf-8 +# Example Usage: +# python scripts/convert_dl_state.py @configs/10B/H100.toml --input_path /workspace/step_49200/diloco_0/data/_3.pt --output_path ./meow.pt --rank 3 --world_size 8 + +import torch +from zeroband.data import get_dataloader +from transformers import AutoTokenizer +from zeroband.train import Config +from zeroband.utils.logging import get_logger +from pydantic_config import parse_argv + +COMMON_KEYS = [ + "_snapshot._main_snapshot._sampler_iter_yielded", + "_snapshot._snapshot_step", + "_snapshot._main_snapshot._index_sampler_state.samples_yielded", + "_snapshot._main_snapshot._num_workers", + "_snapshot._main_snapshot._sampler_iter_state", + "_snapshot._main_snapshot._shared_seed", + "_snapshot._last_yielded_worker_id", + "_snapshot._main_snapshot._base_seed", +] + + +def traverse_dict(d: dict, key: str): + _k = key.split(".") + for k in _k: + d = d[k] + return d + + +def transfer_states(old_state_dict: dict, new_state_dict: dict): + for k in COMMON_KEYS: + parent, _, child = k.rpartition(".") + if parent: + traverse_dict(new_state_dict, parent)[child] = traverse_dict(old_state_dict, parent)[child] + for worker_id in range(4): + ex_iterables = [ + ds_state["ex_iterable"] + for ds_state in traverse_dict( + old_state_dict, f"_snapshot._worker_snapshots.worker_{worker_id}.dataset_state.ex_iterable.ex_iterables" + ) + ] + num_ds = len(ex_iterables) + new_ds_state = traverse_dict( + new_state_dict, f"_snapshot._worker_snapshots.worker_{worker_id}.dataset_state.dataset" + ) + # HACK: dataset_4 is openwebmath which is not always present + if "dataset_4" not in new_ds_state.keys(): + num_ds -= 1 + new_ds_state = [ + traverse_dict( + new_state_dict, f"_snapshot._worker_snapshots.worker_{worker_id}.dataset_state.dataset.dataset_{i}" + ) + for i in range(num_ds) + ] + + for new_state, old_state in zip(new_ds_state, ex_iterables): + # HACK: We might index error because of skipping into a different sized shard for dclm + new_state["file_index"] = (old_state["shard_idx"] + 1) % len(new_state["files"]) + new_state["row_index"] = 0 # old_state["shard_example_idx"] + + +class ExportConfig(Config): + input_path: str + output_path: str + rank: int + world_size: int + + +def main(config: ExportConfig): + old_state_dict = torch.load(config.input_path)["data_loader"] + + if config.type_model == "llama2": + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) + elif config.type_model == "llama3": + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) + else: + raise ValueError(f"Model type {config.type_model} not supported") + + dl = get_dataloader( + tokenizer=tokenizer, + world_size=config.world_size, + rank=config.rank, + batch_size=config.train.micro_bs, + data_config=config.data, + ) + + iter_dl = iter(dl) + + # Needed to init the states because they are lazy + while True: + try: + _ = next(iter_dl) + new_state_dict = dl.state_dict() + transfer_states(old_state_dict, new_state_dict) + break + except KeyError: + print("Not inited, sampling again") + pass + + print(f"Saving to {config.output_path}") + torch.save({"data_loader": new_state_dict}, config.output_path) + + del dl + + +def test_dl(config: ExportConfig): + if config.type_model == "llama2": + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) + elif config.type_model == "llama3": + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) + else: + raise ValueError(f"Model type {config.type_model} not supported") + + dl = get_dataloader( + tokenizer=tokenizer, + world_size=config.world_size, + rank=config.rank, + batch_size=config.train.micro_bs, + data_config=config.data, + ) + dl.load_state_dict(torch.load(config.output_path, weights_only=True)["data_loader"]) + + iter_dl = iter(dl) + + # Needed to init the states because they are lazy + for i in range(10): + batch = next(iter_dl) + print(batch.keys(), batch["input_ids"].shape) + + +if __name__ == "__main__": + logger = get_logger() + config = ExportConfig(**parse_argv()) + logger.debug(f"config: {config.model_dump()}") + + main(config) + test_dl(config) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index b4689482..90794796 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -1,5 +1,8 @@ +from dataclasses import dataclass, asdict import random from typing import Any, Generator, Optional, List, Dict, TypedDict, Union +import functools + from pydantic_config import BaseConfig from zeroband.utils.logging import get_logger @@ -9,9 +12,10 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torch.distributed.checkpoint.stateful import Stateful -from datasets import load_dataset, interleave_datasets, load_dataset_builder, BuilderConfig -from datasets.distributed import split_dataset_by_node -import functools +from datasets import load_dataset_builder, BuilderConfig +from pyarrow import parquet as pq +from transformers import PreTrainedTokenizer + TEST_VOCAB_SIZE = 1024 @@ -21,12 +25,11 @@ class DataConfig(BaseConfig): - dataset_name_or_paths: str = "allenai/c4:en" + dataset_name_or_paths: str = "/data/datasets/fineweb-edu" val_dataset_name_or_paths: Optional[str] = None seq_length: int = 1024 fake: bool = False num_workers: int = 4 - streaming: bool = True max_train_samples: Optional[int] = None max_eval_samples: Optional[int] = None dataset_ratio: Optional[str] = None @@ -62,6 +65,13 @@ class BatchOutput(TypedDict): seqlens: list[int] +@dataclass +class SequencePackingDataSetState: + inputs_ids: list[int] + labels: list[int] + seqlens: list[int] + + class SequencePackingDataSet(IterableDataset, Stateful): """ This class wrap a dataset and wrap it into an iterable that return sequence of max_seq_length @@ -73,11 +83,9 @@ def __init__(self, dataset: Dataset, max_seq_length: int, eos_token: int): self.max_seq_length = max_seq_length self.eos_token = eos_token - def __iter__(self) -> Generator[BatchOutput, Any, None]: - inputs_ids = [] - labels = [] - seqlens = [] + self.state = SequencePackingDataSetState(inputs_ids=[], labels=[], seqlens=[]) + def __iter__(self) -> Generator[BatchOutput, Any, None]: for og_sample in self.dataset: og_sample: list[int] = og_sample["input_ids"] @@ -85,32 +93,35 @@ def __iter__(self) -> Generator[BatchOutput, Any, None]: sample_inputs_ids = og_sample[:-1] sample_labels = og_sample[1:] - token_remaining = self.max_seq_length - len(inputs_ids) + token_remaining = self.max_seq_length - len(self.state.inputs_ids) if len(sample_inputs_ids) < token_remaining: - inputs_ids.extend(sample_inputs_ids) - labels.extend(sample_labels) - seqlens.append(len(sample_inputs_ids)) + self.state.inputs_ids.extend(sample_inputs_ids) + self.state.labels.extend(sample_labels) + self.state.seqlens.append(len(sample_inputs_ids)) else: - inputs_ids.extend(sample_inputs_ids[:token_remaining]) - labels.extend(sample_labels[:token_remaining]) - seqlens.append(token_remaining) - - yield { - "input_ids": torch.Tensor(inputs_ids).to(dtype=torch.long), - "labels": torch.Tensor(labels).to(dtype=torch.long), - "seqlens": seqlens, + self.state.inputs_ids.extend(sample_inputs_ids[:token_remaining]) + self.state.labels.extend(sample_labels[:token_remaining]) + self.state.seqlens.append(token_remaining) + + data = { + "input_ids": torch.Tensor(self.state.inputs_ids).to(dtype=torch.long), + "labels": torch.Tensor(self.state.labels).to(dtype=torch.long), + "seqlens": self.state.seqlens, } - inputs_ids = [] - labels = [] - seqlens = [] + self.state.inputs_ids = [] + self.state.labels = [] + self.state.seqlens = [] + + yield data def state_dict(self): - return self.dataset.state_dict() + return {"dataset": self.dataset.state_dict(), "state": asdict(self.state)} def load_state_dict(self, state_dict): - self.dataset.load_state_dict(state_dict) + self.dataset.load_state_dict(state_dict["dataset"]) + self.state = SequencePackingDataSetState(**state_dict["state"]) def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor]: @@ -133,6 +144,150 @@ def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.Lo } +@dataclass +class PQDatasetState: + files: List[str] + file_index: int + row_index: int + increment: int + init_row_index: int + + +class ParquetDataset(IterableDataset, Stateful): + """ + this class is a wrapper around a parquet dataset compatible with datasets and statefull compatible. The dataset is infinite and will restart from the last state if the iterator is exhausted. + TODO: + * [ ] handle mutli proc dataloader pytorch + """ + + def __init__(self, files: List[str], tokenizer: PreTrainedTokenizer): + self.arg_files = files + self.tokenizer = tokenizer + + self.state = None + + def _lazy_init(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + if worker_info.num_workers > len(self.arg_files): + logger.warning( + f"dataloader rank {worker_info.id} Number of workers {worker_info.num_workers} is greater than the number of files {len(self.arg_files)}" + ) + self.state = PQDatasetState( + files=self.arg_files, + file_index=0, + row_index=worker_info.id, + increment=worker_info.num_workers, + init_row_index=worker_info.id, + ) + return + + files = self.arg_files[worker_info.id :: worker_info.num_workers] + else: + files = self.arg_files + + self.state = PQDatasetState(files=files, file_index=0, row_index=0, increment=1, init_row_index=0) + + def __iter__(self): + # we lazy init the parquet dataset to get the worker info from dataloader multi process + if self.state is None: + self._lazy_init() + + while True: + file = self.state.files[self.state.file_index] + + parquet_file = pq.ParquetFile(file) + table = parquet_file.read()["text"] + + while True: + row = table[self.state.row_index] + + self.state.row_index += self.state.increment + if self.state.row_index >= len(table): + self.state.row_index = self.state.init_row_index + self.state.file_index += 1 + if self.state.file_index >= len(self.state.files): # infinite datasets + self.state.file_index = 0 + + yield {"input_ids": self.tokenizer.encode(str(row))} + + @property + def is_empty(self): + return len(self.arg_files) == 0 + + def state_dict(self) -> dict[str, Any]: + return asdict(self.state) if self.state is not None else {} + + def load_state_dict(self, state_dict): + self.state = PQDatasetState(**state_dict) + + +@dataclass +class InterleaveDatasetState: + current_index: int + seed: int + + +class InterleaveDataset(IterableDataset, Stateful): + """This class take a list of datasets and interleave them. It is stateful and can be used with pytorch dataloader. + + It draw a sample from each dataset with a probability given by the probabilities list. + + The state can be saved and restored. Under the hood we just fast forward the random generator to the current position. + """ + + def __init__(self, datasets: List[ParquetDataset], probabilities: Optional[List[float]] = None, seed: int = 42): + assert len(datasets) > 0, "At least one dataset is required" + assert len(datasets) == len(probabilities), "The number of datasets and probabilities must be the same" + + self.probabilities = [] + self.datasets = [] + + for dataset, prob in zip(datasets, probabilities): + if not dataset.is_empty: + self.datasets.append(dataset) + self.probabilities.append(prob) + else: + logger.warning(f"Dataset {dataset} is empty. Skipping.") + + self.state = InterleaveDatasetState(current_index=0, seed=seed) + self._init_random_state() + + def _init_random_state(self): + """Initialize random generator and advance to current position""" + ... + self.random_generator = random.Random(self.state.seed) + # Advance the RNG to the current position + for _ in range(self.state.current_index): + self._get_dataset_to_yield_from() + + def _get_dataset_to_yield_from(self) -> int: + return self.random_generator.choices(range(len(self.datasets)), weights=self.probabilities, k=1)[0] + + def __iter__(self): + data_iters = [iter(dataset) for dataset in self.datasets] + while True: + dataset_to_yield_from = self._get_dataset_to_yield_from() + + sample = next(data_iters[dataset_to_yield_from]) + self.state.current_index += 1 + + yield sample + + def state_dict(self): + state = {"interleave_state": asdict(self.state)} + + for i, dataset in enumerate(self.datasets): + state[f"dataset_{i}"] = dataset.state_dict() + return state + + def load_state_dict(self, state_dict): + self.state = InterleaveDatasetState(**state_dict["interleave_state"]) + for i, dataset in enumerate(self.datasets): + dataset.load_state_dict(state_dict[f"dataset_{i}"]) + self._init_random_state() + + def get_dataloader( tokenizer, world_size: int, @@ -143,14 +298,9 @@ def get_dataloader( if data_config.fake: train_dataset = FakeTokenizedDataset(data_config.seq_length, TEST_VOCAB_SIZE) else: - ds = load_all_datasets(data_config=data_config, split="train") - - def tokenize_function(data): - outputs = tokenizer(data["text"], truncation=True, max_length=data_config.seq_length) - return outputs - - tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "attention_mask"]) - train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) + train_dataset = load_all_datasets( + data_config=data_config, split="train", tokenizer=tokenizer, rank=rank, world_size=world_size + ) dataset = SequencePackingDataSet(train_dataset, data_config.seq_length, eos_token=tokenizer.eos_token_id) @@ -191,12 +341,13 @@ def _foo(a): def _load_datasets( dataset_names: str, split: str, + tokenizer: PreTrainedTokenizer, data_rank: Optional[int] = None, data_world_size: Optional[int] = None, streaming: bool = True, probabilities: Optional[List[float]] = None, reverse_data_files: bool = False, -) -> Dataset: +) -> InterleaveDataset: logger.debug(dataset_names) ds_args = [] for _ds in dataset_names.split(","): @@ -210,6 +361,7 @@ def _load_datasets( _ds_args["data_files"] = _data_files if data_rank is not None and data_world_size is not None: _ds_args["data_files"] = _data_files[data_rank::data_world_size] + ds_args.append(_ds_args) # logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args))) @@ -217,13 +369,12 @@ def _load_datasets( logger.debug(f"Loading datasets{' in streaming mode' if streaming else ''}") datasets = [] for ds_arg in ds_args: - # logger.debug(f"Loading dataset: {ds_arg}") - _ds = load_dataset(**ds_arg, split=split, streaming=streaming) - _ds = _ds.remove_columns([i for i in _ds.column_names if i not in ["text"]]) + logger.debug(f"Loading dataset: {ds_arg['data_files']}") + _ds = ParquetDataset(files=ds_arg["data_files"], tokenizer=tokenizer) datasets.append(_ds) - # logger.debug(f"Loaded dataset: {ds_arg}") - ds = interleave_datasets(datasets=datasets, probabilities=probabilities, stopping_strategy="all_exhausted") + ds = InterleaveDataset(datasets=datasets, probabilities=probabilities) + logger.info(f"Loaded datasets ({split})") return ds @@ -238,22 +389,20 @@ def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]: return [i / denom for i in nums] -def load_all_datasets(data_config: DataConfig, split: str, max_samples: Optional[int] = None) -> IterableDataset: +def load_all_datasets( + data_config: DataConfig, split: str, tokenizer: PreTrainedTokenizer, rank: int, world_size: int +) -> InterleaveDataset: """Load all datasets and interleave them""" - if max_samples is not None and not data_config.streaming: - split = f"{split}[:{max_samples}]" ds = _load_datasets( dataset_names=data_config.dataset_name_or_paths, split=split, - data_rank=data_config.data_rank, - data_world_size=data_config.data_world_size, - streaming=data_config.streaming, + data_rank=rank, + data_world_size=world_size, probabilities=_get_probabilities(data_config), reverse_data_files=data_config.reverse_data_files, + tokenizer=tokenizer, ) - if max_samples is not None and data_config.streaming: - if data_config.max_train_samples is not None: - ds = ds.take(data_config.max_train_samples) + logger.info(f"Train dataset:\n{ds}") return ds diff --git a/tests/test_data.py b/tests/test_data.py index 162974aa..85a3aa68 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,13 +1,21 @@ +import copy import torch -from zeroband.data import SequencePackingDataSet +from zeroband.data import InterleaveDataset, ParquetDataset, SequencePackingDataSet, collate_fn from torch.utils.data import DataLoader from zeroband.data import load_all_datasets, DataConfig, logger as data_logger from collections import Counter from itertools import chain import pytest import logging +import pyarrow as pa +import pyarrow.parquet as pq +from faker import Faker +from typing import List +import string +from torchdata.stateful_dataloader import StatefulDataLoader +@pytest.mark.skip(reason="not using hf for now") @pytest.mark.parametrize( "ratio, lower, upper", [ @@ -37,6 +45,7 @@ def test_load_all_datasets_vanilla(ratio: str, lower: float, upper: float): assert letter_count["A"] / letter_count["C"] > lower +@pytest.mark.skip(reason="not using hf for now") @pytest.mark.parametrize( "ratio, lower, upper, data_rank, data_world_size", [ @@ -89,7 +98,7 @@ def __len__(self): return len(self.data) def __getitem__(self, index): - return {'input_ids': self.data[index]} + return {"input_ids": self.data[index]} MAX_SEQ_LEN = 8 dataset = SequencePackingDataSet(FakeDataset(), max_seq_length=MAX_SEQ_LEN, eos_token=0) @@ -106,3 +115,156 @@ def __getitem__(self, index): assert input_ids == [[6, 1, 2, 3, 4, 6, 3, 3], [3, 2, 1, 2, 1, 4, 5, 3]] assert labels == [[1, 2, 3, 4, 0, 3, 3, 4], [2, 0, 2, 0, 4, 5, 3, 4]] + + +class SimpleTokenizer: + def __init__(self): + # Create vocabulary: a-z (0-25) and unknown token (26) + self.char_to_id = {char: idx for idx, char in enumerate(string.ascii_lowercase)} + self.unknown_token = 26 + + def encode(self, text: str) -> List[int]: + """Convert text to list of token ids""" + return [self.char_to_id.get(char.lower(), self.unknown_token) for char in text] + + +@pytest.fixture +def fake_sentences(): + """Generate 500 fake sentences (100 per file * 5 files)""" + fake = Faker() + return [fake.sentence() for _ in range(10_000)] + + +@pytest.fixture +def parquet_files(tmp_path, fake_sentences): + """Create 10 parquet files with 100 sentences each""" + files = [] + for i in range(10): + # Create data for this file + start_idx = i * 100 + sentences = fake_sentences[start_idx : start_idx + 100] + + # Create arrow table + table = pa.Table.from_arrays([pa.array(sentences)], names=["text"]) + + # Write to parquet file + file_path = tmp_path / f"data_{i}.parquet" + pq.write_table(table, file_path) + files.append(str(file_path)) + + return files + + +@pytest.fixture +def tokenizer(): + """Get a simple character-based tokenizer""" + return SimpleTokenizer() + + +def test_parquet_dataset_ckpt(parquet_files, tokenizer): + # Create first dataset and iterate halfway + dataset1 = ParquetDataset(parquet_files, tokenizer) + halfway_point = 100 + + for _, data in zip(range(halfway_point), dataset1): + pass + # Save state + state_dict = dataset1.state_dict() + + # Create new dataset and load state + dataset2 = ParquetDataset(parquet_files, tokenizer) + dataset2.load_state_dict(state_dict) + + max_to_yield = 200 + # Continue first dataset + + for _, data1, data2 in zip(range(max_to_yield), dataset1, dataset2): + assert data1["input_ids"] == data2["input_ids"] + + +def test_sequence_packing_dataset_ckpt(parquet_files, tokenizer): + dataset1 = SequencePackingDataSet(ParquetDataset(parquet_files, tokenizer), max_seq_length=16, eos_token=0) + + halfway_point = 100 + + for _, data in zip(range(halfway_point), dataset1): + pass + # Save state + state_dict = dataset1.state_dict() + + # Create new dataset and load state + dataset2 = SequencePackingDataSet(ParquetDataset(parquet_files, tokenizer), max_seq_length=16, eos_token=0) + dataset2.load_state_dict(state_dict) + + assert dataset1.state_dict() == dataset2.state_dict() + + max_to_yield = 199 + # Continue first dataset + + for _, data1, data2 in zip(range(max_to_yield), dataset1, dataset2): + assert (data1["input_ids"] == data2["input_ids"]).all() + assert (data1["labels"] == data2["labels"]).all() + assert data1["seqlens"] == data2["seqlens"] + + +def test_interleave_dataset_ckpt(parquet_files, tokenizer): + # Split parquet files into two groups to create two datasets + files1 = parquet_files[:2] # First two files + files2 = parquet_files[2:4] # Next two files + + # Create first dataset and iterate halfway + dataset1 = InterleaveDataset( + [ParquetDataset(files1, tokenizer), ParquetDataset(files2, tokenizer)], probabilities=[0.5, 0.5] + ) + + halfway_point = 100 + + for _, data in zip(range(halfway_point), dataset1): + pass + # Save state + state_dict = dataset1.state_dict() + + # Create new dataset and load state + dataset2 = InterleaveDataset( + [ParquetDataset(files1, tokenizer), ParquetDataset(files2, tokenizer)], probabilities=[0.5, 0.5] + ) + dataset2.load_state_dict(state_dict=copy.deepcopy(state_dict)) + + assert dataset1.state_dict() == dataset2.state_dict() + + max_to_yield = 250 + + for _, data1, data2 in zip(range(max_to_yield), dataset1, dataset2): + assert data1["input_ids"] == data2["input_ids"] + + +@pytest.mark.parametrize("num_workers", [0, 2, 16]) +def test_dataloader_parquet_dataset(parquet_files, tokenizer, num_workers): + dataset = SequencePackingDataSet(ParquetDataset(parquet_files, tokenizer), max_seq_length=8, eos_token=0) + + loader = StatefulDataLoader(dataset, batch_size=8, num_workers=num_workers, collate_fn=collate_fn) + + total_samples = 100 + + for _, _batch in zip(range(total_samples), loader): + ... + + # Save state + state_dict = loader.state_dict() + + # Create new loader and load state + dataset2 = SequencePackingDataSet(ParquetDataset(parquet_files, tokenizer), max_seq_length=8, eos_token=0) + + loader2 = StatefulDataLoader(dataset2, batch_size=8, num_workers=num_workers, collate_fn=collate_fn) + + print(state_dict) + + loader2.load_state_dict(state_dict) + + warmup = 10 + + for i, batch1, batch2 in zip(range(total_samples), loader, loader2): + if i > warmup: + assert (batch1["input_ids"] == batch2["input_ids"]).all() + assert (batch1["labels"] == batch2["labels"]).all() + assert (batch1["seqlens"] == batch2["seqlens"]).all() diff --git a/uv.lock b/uv.lock index 08f66c2e..0d24dd02 100644 --- a/uv.lock +++ b/uv.lock @@ -390,6 +390,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, ] +[[package]] +name = "faker" +version = "30.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/45/528a560000078f33166300acc8b60e17129ce540962b573e8a28aa8bf4d9/faker-30.8.0.tar.gz", hash = "sha256:3608c7fcac2acde0eaa6da28dae97628f18f14d54eaa2a92b96ae006f1621bd7", size = 1808343 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/ca/f7a812229555391b0630982c0522b643cc530e7a7feabef8f8caec10ba9a/Faker-30.8.0-py3-none-any.whl", hash = "sha256:4cd0c5ea4bc1e4c902967f6e662f5f5da69f1674d9a94f54e516d27f3c2a6a16", size = 1846806 }, +] + [[package]] name = "filelock" version = "3.16.0" @@ -2132,6 +2145,7 @@ dependencies = [ { name = "fsspec", extra = ["gcs"] }, { name = "ninja" }, { name = "numpy" }, + { name = "pyarrow" }, { name = "pydantic-config" }, { name = "setuptools" }, { name = "torch" }, @@ -2150,6 +2164,7 @@ all = [ [package.dev-dependencies] dev = [ + { name = "faker" }, { name = "pre-commit" }, { name = "pytest" }, { name = "ruff" }, @@ -2164,6 +2179,7 @@ requires-dist = [ { name = "fsspec", extras = ["gcs"], specifier = ">=2024.3.1" }, { name = "ninja" }, { name = "numpy" }, + { name = "pyarrow" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, { name = "requests", marker = "extra == 'all'", specifier = ">=2.32.3" }, { name = "setuptools" }, @@ -2176,6 +2192,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "faker" }, { name = "pre-commit", specifier = ">=3.0.0" }, { name = "pytest", specifier = ">=7.0.0" }, { name = "ruff", specifier = ">=0.5.0" },