From 0b2227f552d76cf359e1364876167f46a021f293 Mon Sep 17 00:00:00 2001 From: Ying Chen Date: Tue, 7 Jan 2025 11:44:40 -0800 Subject: [PATCH] Use registry when creating Stream in StreamingDataset (#858) Co-authored-by: Saaketh Narayan --- CONTRIBUTING.md | 2 +- .../mixing_data_sources.md | 27 ++ setup.py | 1 + streaming/base/dataset.py | 38 ++- streaming/base/registry_utils.py | 199 ++++++++++++++ streaming/base/stream.py | 11 + tests/test_registry.py | 245 ++++++++++++++++++ tests/test_stream.py | 30 +++ tests/test_streaming.py | 59 ++++- 9 files changed, 601 insertions(+), 11 deletions(-) create mode 100644 streaming/base/registry_utils.py create mode 100644 tests/test_registry.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3ed168953..2b60b79ce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,8 +79,8 @@ cd docs && make clean && make doctest # run doctests ```bash -cd docs pip install -e '.[docs]' +cd docs make clean && make html make host # open the output link in a browser. ``` diff --git a/docs/source/dataset_configuration/mixing_data_sources.md b/docs/source/dataset_configuration/mixing_data_sources.md index 2949fc31c..c8a073a44 100644 --- a/docs/source/dataset_configuration/mixing_data_sources.md +++ b/docs/source/dataset_configuration/mixing_data_sources.md @@ -1,3 +1,5 @@ +from matplotlib.artist import kwdoc + # Mixing Datasets Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams. @@ -8,6 +10,31 @@ A stream is a data source, as a collection of shard files (or set of subdirector It is possible, though not recommended, for streams to have different schemas. +### Registering a custom Stream implementation +You can also customize the implementation of a `Stream`. To modify the behavior of a `Stream` that is used in a `StreamingDataset`, you can subclass `Stream`, and register the subclass as shown in the below example without forking the library. + + +```python +from streaming.base.stream import streams_registry + +class MyStream(Stream): + # your implementation goes here + pass + +# Register your custom stream class as 'my_stream' +streams_registry.register('my_stream', func=MyStream) + +# StreamingDataset creates a MyStream object when 'my_stream' is passed as a stream_name +dataset = StreamingDataset( + remote='s3://some/path', + local='/local/path', + stream_name='my_stream', + stream_config={'arg1': 'val1'}, +) +``` + +See more methods for registering custom Stream classes in [this README section of LLM Foundry](https://github.com/mosaicml/llm-foundry/tree/3269c7399add8ca30842edbeb83d0c82f7906726?tab=readme-ov-file#how-to-register). + ## Configuring the data mix The `proportion`, `repeat`, or `choose` arguments to `Stream` are used to configure different dataset mixing schemes. Only one of them may be set at a time, and all streams must use the same mixing scheme (e.g., Stream A with `proportion` and Stream B with `choose` are incompatible). - **`proportion`**: Specifies how to sample this Stream relative to other Streams. diff --git a/setup.py b/setup.py index d7a3aa727..7714b8a9c 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'catalogue>=2,<3', ] extra_deps = {} diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index cb5c32ba9..d463d6f0e 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -30,11 +30,12 @@ SHARD_ACCESS_TIMES, SHARD_STATES, TICK) from streaming.base.distributed import maybe_init_dist from streaming.base.format import get_index_basename +from streaming.base.registry_utils import construct_from_registry from streaming.base.sampling import get_sampling from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream +from streaming.base.stream import Stream, streams_registry from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -308,6 +309,9 @@ class StreamingDataset(Array, IterableDataset): replication (int, optional): Determines how many consecutive devices will receive the same samples. Useful for training with tensor or sequence parallelism, where multiple devices need to see the same partition of the dataset. Defaults to ``None``. + stream_name (str): The name of the Stream to use which is registered in streams_registry. + Defaults to ``stream``. + stream_config (dict[str, Any]): Additional arguments to pass to the Stream constructor. """ def __init__(self, @@ -334,7 +338,9 @@ def __init__(self, shuffle_block_size: Optional[int] = None, batching_method: str = 'random', allow_unsafe_types: bool = False, - replication: Optional[int] = None) -> None: + replication: Optional[int] = None, + stream_name: str = 'stream', + stream_config: Optional[dict[str, Any]] = None) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -438,13 +444,27 @@ def __init__(self, for stream in streams: stream.apply_default(default) else: - default = Stream(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip) + stream_config = stream_config or {} + stream_config.update({ + 'remote': remote, + 'local': local, + 'split': split, + 'download_retry': download_retry, + 'download_timeout': download_timeout, + 'validate_hash': validate_hash, + 'keep_zip': keep_zip, + }) + + # Construct a Stream instance using registry-based construction + default = construct_from_registry( + name=stream_name, + registry=streams_registry, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=stream_config, + ) + streams = [default] # Validate the stream weighting scheme (relative or absolute) to catch errors before we go diff --git a/streaming/base/registry_utils.py b/streaming/base/registry_utils.py new file mode 100644 index 000000000..422727a2a --- /dev/null +++ b/streaming/base/registry_utils.py @@ -0,0 +1,199 @@ +# Copyright 2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Wrapper of catalogue.Registry, copied from llm-foundry.""" + +import copy +import functools +import importlib.util +import os +from contextlib import contextmanager +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union + +import catalogue + +__all__ = [ + 'TypedRegistry', + 'create_registry', + 'construct_from_registry', + 'import_file', + 'save_registry', +] + +T = TypeVar('T') +TypeBoundT = TypeVar('TypeBoundT', bound=type) +CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) + + +class TypedRegistry(catalogue.Registry, Generic[T]): + """A thin wrapper around catalogue.Registry to add static typing and descriptions.""" + + def __init__( + self, + namespace: Sequence[str], + entry_points: bool = False, + description: str = '', + ) -> None: + super().__init__(namespace, entry_points=entry_points) + + self.description = description + + def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: + return super().__call__(name, func) + + def register(self, name: str, *, func: Optional[T] = None) -> T: + return super().register(name, func=func) + + def register_class( + self, + name: str, + *, + func: Optional[TypeBoundT] = None, + ) -> TypeBoundT: + return super().register(name, func=func) + + def get(self, name: str) -> T: + return super().get(name) + + def get_all(self) -> dict[str, T]: + return super().get_all() + + def get_entry_point(self, name: str, default: Optional[T] = None) -> T: + return super().get_entry_point(name, default=default) + + def get_entry_points(self) -> dict[str, T]: + return super().get_entry_points() + + +S = TypeVar('S') + + +def create_registry( + *namespace: str, + generic_type: type[S], + entry_points: bool = False, + description: str = '', +) -> 'TypedRegistry[S]': + """Create a new registry. + + Args: + namespace (str): The namespace, e.g. "streaming.streams_registry" + generic_type (Type[S]): The type of the registry. + entry_points (bool): Accept registered functions from entry points. + description (str): A description of the registry. + + Returns: + The TypedRegistry object. + """ + if catalogue.check_exists(*namespace): + raise catalogue.RegistryError(f'Namespace already exists: {namespace}') + + return TypedRegistry[generic_type]( + namespace, + entry_points=entry_points, + description=description, + ) + + +def construct_from_registry( + name: str, + registry: TypedRegistry, + partial_function: bool = True, + pre_validation_function: Optional[Union[Callable[[Any], None], type]] = None, + post_validation_function: Optional[Callable[[Any], None]] = None, + kwargs: Optional[dict[str, Any]] = None, +) -> Any: + """Helper function to build an item from the registry. + + Args: + name (str): The name of the registered item + registry (catalogue.Registry): The registry to fetch the item from + partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True. + pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called + before constructing the item to return. This should throw an exception if validation fails. Defaults to None. + post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after + constructing the item to return. This should throw an exception if validation fails. Defaults to None. + kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. + + Raises: + ValueError: If the validation functions failed or the registered item is invalid + + Returns: + Any: The constructed item from the registry + """ + if kwargs is None: + kwargs = {} + + registered_constructor = registry.get(name) + + if pre_validation_function is not None: + if isinstance(pre_validation_function, type): + if not issubclass(registered_constructor, pre_validation_function): + raise ValueError( + f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}', + ) + elif isinstance(pre_validation_function, Callable): + pre_validation_function(registered_constructor) + else: + raise ValueError( + f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}', + ) + + # If it is a class, or a builder function, construct the class with kwargs + # If it is a function, create a partial with kwargs + if isinstance( + registered_constructor, + type, + ) or callable(registered_constructor) and not partial_function: + constructed_item = registered_constructor(**kwargs) + elif callable(registered_constructor): + constructed_item = functools.partial(registered_constructor, **kwargs) + else: + raise ValueError( + f'Expected {name} to be a class or function, but got {type(registered_constructor)}',) + + if post_validation_function is not None: + post_validation_function(constructed_item) + + return constructed_item + + +def import_file(loc: Union[str, Path]) -> ModuleType: + """Import module from a file. + + Used to run arbitrary python code. + + Args: + name (str): Name of module to load. + loc (str / Path): Path to the file. + + Returns: + ModuleType: The module object. + """ + if not os.path.exists(loc): + raise FileNotFoundError(f'File {loc} does not exist.') + + spec = importlib.util.spec_from_file_location('python_code', str(loc)) + + assert spec is not None + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f'Error executing {loc}') from e + return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 133938e12..357df71f2 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -18,6 +18,7 @@ from streaming.base.distributed import barrier, get_local_rank from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash +from streaming.base.registry_utils import create_registry from streaming.base.storage import CloudDownloader from streaming.base.util import retry, wait_for_file_to_exist from streaming.base.world import World @@ -507,3 +508,13 @@ def get_index_size(self) -> int: """ filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size + + +streams_registry = create_registry( + 'streaming', + 'streams_registry', + generic_type=type[Stream], + entry_points=True, + description='The streams registry is used for registering Stream classes.') + +streams_registry.register('stream', func=Stream) diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 000000000..556fca654 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,245 @@ +# Copyright 2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +import importlib.metadata +import pathlib +from importlib.metadata import EntryPoint +from typing import Any, Callable, Union + +import catalogue +import pytest + +from streaming.base import registry_utils +from streaming.base.stream import Stream, streams_registry + + +def test_streams_registry_setup(): + assert isinstance(streams_registry, registry_utils.TypedRegistry) + assert streams_registry.namespace == ('streaming', 'streams_registry') + + stream = streams_registry.get('stream') + assert stream == Stream + + +# The tests below are adapted with minimal changes from llm-foundry +# to guarantee registry_utils works as expected + + +def test_registry_create(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + generic_type=str, + entry_points=False, + ) + + assert new_registry.namespace == ('streaming', 'test_registry') + assert isinstance(new_registry, registry_utils.TypedRegistry) + + +def test_registry_typing(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + generic_type=str, + entry_points=False, + ) + new_registry.register('test_name', func='test') + + # This would fail type checking without the type ignore + # It is here to show that the TypedRegistry is working (gives a type error without the ignore), + # although this would not catch a regression in this regard + new_registry.register('test_name', func=1) # type: ignore + + +def test_registry_add(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + generic_type=str, + entry_points=False, + ) + new_registry.register('test_name', func='test') + + assert new_registry.get('test_name') == 'test' + + +def test_registry_overwrite(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + generic_type=str, + entry_points=False, + ) + new_registry.register('test_name', func='test') + new_registry.register('test_name', func='test2') + + assert new_registry.get('test_name') == 'test2' + + +def test_registry_init_code(tmp_path: pathlib.Path): + register_code = """ +from streaming.base.stream import Stream, streams_registry + +@streams_registry.register('test_stream') +class TestStream(Stream): + pass +""" + + with open(tmp_path / 'init_code.py', 'w') as _f: + _f.write(register_code) + + registry_utils.import_file(tmp_path / 'init_code.py') + + assert issubclass(streams_registry.get('test_stream'), Stream) + + del catalogue.REGISTRY[('streaming', 'streams_registry', 'test_stream')] + + assert 'test_stream' not in streams_registry + + +def test_registry_entrypoint(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + + monkeypatch.setattr( + importlib.metadata, + 'entry_points', + lambda: { + 'streaming_test_registry': [ + EntryPoint( + name='test_entry', + value='streaming.base.stream:Stream', + group='streaming_test_registry', + ), + ], + }, + ) + + monkeypatch.setattr( + catalogue, + 'AVAILABLE_ENTRY_POINTS', + importlib.metadata.entry_points(), + ) + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + generic_type=str, + entry_points=True, + ) + assert new_registry.get('test_entry') == Stream + + +def test_registry_builder(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(catalogue, 'Registry', {}) + + new_registry = registry_utils.create_registry( + 'streaming', + 'test_registry', + entry_points=False, + generic_type=Union[type[Stream], Callable[..., Stream]], + ) + + class TestStream(Stream): + + def __init__(self): + pass + + new_registry.register('test_stream', func=TestStream) + + # Valid, no validation + valid_class = registry_utils.construct_from_registry( + 'test_stream', + new_registry, + pre_validation_function=TestStream, + ) + assert isinstance(valid_class, TestStream) + + class NotStream: + pass + + # Invalid, class validation + with pytest.raises( + ValueError, + match='Expected test_stream to be of type', + ): + registry_utils.construct_from_registry( + 'test_stream', + new_registry, + pre_validation_function=NotStream, + ) + + # Invalid, function pre-validation + with pytest.raises(ValueError, match='Invalid'): + + def pre_validation_function(x: Any): + raise ValueError('Invalid') + + registry_utils.construct_from_registry( + 'test_stream', + new_registry, + pre_validation_function=pre_validation_function, + ) + + # Invalid, function post-validation + with pytest.raises(ValueError, match='Invalid'): + + def post_validation_function(x: Any): + raise ValueError('Invalid') + + registry_utils.construct_from_registry( + 'test_stream', + new_registry, + post_validation_function=post_validation_function, + ) + + # Invalid, not a class or function + new_registry.register('non_callable', func=1) # type: ignore + with pytest.raises( + ValueError, + match='Expected non_callable to be a class or function', + ): + registry_utils.construct_from_registry('non_callable', new_registry) + + # Valid, partial function + new_registry.register( + 'partial_func', + func=lambda x, y: x * y, + ) # type: ignore + partial_func = registry_utils.construct_from_registry( + 'partial_func', + new_registry, + partial_function=True, + kwargs={'x': 2}, + ) + assert partial_func(y=3) == 6 + + # Valid, builder function + new_registry.register('builder_func', func=lambda: TestStream()) + valid_built_class = registry_utils.construct_from_registry( + 'builder_func', + new_registry, + partial_function=False, + ) + assert isinstance(valid_built_class, TestStream) + + +def test_registry_init_code_fails(tmp_path: pathlib.Path): + register_code = """ +asdf +""" + + with open(tmp_path / 'init_code.py', 'w') as _f: + _f.write(register_code) + + with pytest.raises(RuntimeError, match='Error executing .*init_code.py'): + registry_utils.import_file(tmp_path / 'init_code.py') + + +def test_registry_init_code_dne(tmp_path: pathlib.Path): + with pytest.raises(FileNotFoundError, match='File .* does not exist'): + registry_utils.import_file(tmp_path / 'init_code.py') diff --git a/tests/test_stream.py b/tests/test_stream.py index cd7a11784..f21bb3c3e 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -12,6 +12,8 @@ from streaming import Stream, StreamingDataset from streaming.base.distributed import barrier +from streaming.base.registry_utils import construct_from_registry +from streaming.base.stream import streams_registry from tests.common.utils import convert_to_mds @@ -69,3 +71,31 @@ def test_missing_index_json_local(local_remote_dir: Any): stream = Stream(remote=None, local=remote_dir) with pytest.raises(RuntimeError, match='No `remote` provided, but local file.*'): _ = StreamingDataset(streams=[stream], batch_size=1) + + +@pytest.mark.parametrize('remote, local', [('remote_dir', tempfile.mkdtemp()), + ('remote_dir', None), (None, tempfile.mkdtemp())]) +def test_construct_stream_from_registry(remote: Any, local: Any): + kwargs = { + 'remote': remote, + 'local': local, + } + + if local is None: + remote_hash = hashlib.blake2s(remote.encode('utf-8'), digest_size=16).hexdigest() + local = os.path.join(tempfile.gettempdir(), remote_hash) + '/' + shutil.rmtree(local, ignore_errors=True) + barrier() + + stream_instance = construct_from_registry( + 'stream', + streams_registry, + partial_function=False, + kwargs=kwargs, + ) + + assert isinstance(stream_instance, Stream) + assert remote == stream_instance.remote + assert local == stream_instance.local + + shutil.rmtree(local, ignore_errors=True) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cd113c6e8..1a2d45f16 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,13 +5,14 @@ import os import shutil from multiprocessing import Process -from typing import Any +from typing import Any, Optional import pytest from torch.utils.data import DataLoader from streaming.base import Stream, StreamingDataLoader, StreamingDataset from streaming.base.batching import generate_work +from streaming.base.stream import streams_registry from streaming.base.util import clean_stale_shared_memory from streaming.base.world import World from tests.common.utils import convert_to_mds @@ -1053,3 +1054,59 @@ def test_same_local_diff_remote(local_remote_dir: tuple[str, str]): # Build StreamingDataset with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'): _ = StreamingDataset(local=local_0, remote=remote_1, batch_size=2, num_canonical_nodes=1) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_custom_stream_name_and_kwargs(local_remote_dir: tuple[str, str]): + remote_dir, local_dir = local_remote_dir + convert_to_mds(out_root=remote_dir, + dataset_name='sequencedataset', + num_samples=117, + size_limit=1 << 8) + + class CustomStream(Stream): + + def __init__( + self, + *, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any, + ): + super().__init__( + remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + ) + + self.custom_arg = kwargs['custom_arg'] + + streams_registry.register('custom_stream', func=CustomStream) + + dataset = StreamingDataset( + local=local_dir, + remote=remote_dir, + stream_name='custom_stream', + stream_config={ + 'custom_arg': 100, + }, + ) + + assert len(dataset.streams) == 1 + assert isinstance(dataset.streams[0], CustomStream) + assert dataset.streams[0].custom_arg == 100