Skip to content

Commit

Permalink
Use registry when creating Stream in StreamingDataset (#858)
Browse files Browse the repository at this point in the history
Co-authored-by: Saaketh Narayan <[email protected]>
  • Loading branch information
es94129 and snarayan21 authored Jan 7, 2025
1 parent 9165c9e commit 0b2227f
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ cd docs && make clean && make doctest # run doctests

<!--pytest.mark.skip-->
```bash
cd docs
pip install -e '.[docs]'
cd docs
make clean && make html
make host # open the output link in a browser.
```
Expand Down
27 changes: 27 additions & 0 deletions docs/source/dataset_configuration/mixing_data_sources.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from matplotlib.artist import kwdoc

# Mixing Datasets

Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams.
Expand All @@ -8,6 +10,31 @@ A stream is a data source, as a collection of shard files (or set of subdirector

It is possible, though not recommended, for streams to have different schemas.

### Registering a custom Stream implementation
You can also customize the implementation of a `Stream`. To modify the behavior of a `Stream` that is used in a `StreamingDataset`, you can subclass `Stream`, and register the subclass as shown in the below example without forking the library.

<!--pytest.mark.skip-->
```python
from streaming.base.stream import streams_registry

class MyStream(Stream):
# your implementation goes here
pass

# Register your custom stream class as 'my_stream'
streams_registry.register('my_stream', func=MyStream)

# StreamingDataset creates a MyStream object when 'my_stream' is passed as a stream_name
dataset = StreamingDataset(
remote='s3://some/path',
local='/local/path',
stream_name='my_stream',
stream_config={'arg1': 'val1'},
)
```

See more methods for registering custom Stream classes in [this README section of LLM Foundry](https://github.com/mosaicml/llm-foundry/tree/3269c7399add8ca30842edbeb83d0c82f7906726?tab=readme-ov-file#how-to-register).

## Configuring the data mix
The `proportion`, `repeat`, or `choose` arguments to `Stream` are used to configure different dataset mixing schemes. Only one of them may be set at a time, and all streams must use the same mixing scheme (e.g., Stream A with `proportion` and Stream B with `choose` are incompatible).
- **`proportion`**: Specifies how to sample this Stream relative to other Streams.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
'azure-storage-blob>=12.0.0,<13',
'azure-storage-file-datalake>=12.11.0,<13',
'azure-identity>=1.13.0',
'catalogue>=2,<3',
]

extra_deps = {}
Expand Down
38 changes: 29 additions & 9 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
SHARD_ACCESS_TIMES, SHARD_STATES, TICK)
from streaming.base.distributed import maybe_init_dist
from streaming.base.format import get_index_basename
from streaming.base.registry_utils import construct_from_registry
from streaming.base.sampling import get_sampling
from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar,
_get_path, get_shm_prefix)
from streaming.base.spanner import Spanner
from streaming.base.stream import Stream
from streaming.base.stream import Stream, streams_registry
from streaming.base.util import bytes_to_int, number_abbrev_to_int
from streaming.base.world import World

Expand Down Expand Up @@ -308,6 +309,9 @@ class StreamingDataset(Array, IterableDataset):
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
stream_name (str): The name of the Stream to use which is registered in streams_registry.
Defaults to ``stream``.
stream_config (dict[str, Any]): Additional arguments to pass to the Stream constructor.
"""

def __init__(self,
Expand All @@ -334,7 +338,9 @@ def __init__(self,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random',
allow_unsafe_types: bool = False,
replication: Optional[int] = None) -> None:
replication: Optional[int] = None,
stream_name: str = 'stream',
stream_config: Optional[dict[str, Any]] = None) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
Expand Down Expand Up @@ -438,13 +444,27 @@ def __init__(self,
for stream in streams:
stream.apply_default(default)
else:
default = Stream(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip)
stream_config = stream_config or {}
stream_config.update({
'remote': remote,
'local': local,
'split': split,
'download_retry': download_retry,
'download_timeout': download_timeout,
'validate_hash': validate_hash,
'keep_zip': keep_zip,
})

# Construct a Stream instance using registry-based construction
default = construct_from_registry(
name=stream_name,
registry=streams_registry,
partial_function=False,
pre_validation_function=None,
post_validation_function=None,
kwargs=stream_config,
)

streams = [default]

# Validate the stream weighting scheme (relative or absolute) to catch errors before we go
Expand Down
199 changes: 199 additions & 0 deletions streaming/base/registry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Wrapper of catalogue.Registry, copied from llm-foundry."""

import copy
import functools
import importlib.util
import os
from contextlib import contextmanager
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union

import catalogue

__all__ = [
'TypedRegistry',
'create_registry',
'construct_from_registry',
'import_file',
'save_registry',
]

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=type)
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])


class TypedRegistry(catalogue.Registry, Generic[T]):
"""A thin wrapper around catalogue.Registry to add static typing and descriptions."""

def __init__(
self,
namespace: Sequence[str],
entry_points: bool = False,
description: str = '',
) -> None:
super().__init__(namespace, entry_points=entry_points)

self.description = description

def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]:
return super().__call__(name, func)

def register(self, name: str, *, func: Optional[T] = None) -> T:
return super().register(name, func=func)

def register_class(
self,
name: str,
*,
func: Optional[TypeBoundT] = None,
) -> TypeBoundT:
return super().register(name, func=func)

def get(self, name: str) -> T:
return super().get(name)

def get_all(self) -> dict[str, T]:
return super().get_all()

def get_entry_point(self, name: str, default: Optional[T] = None) -> T:
return super().get_entry_point(name, default=default)

def get_entry_points(self) -> dict[str, T]:
return super().get_entry_points()


S = TypeVar('S')


def create_registry(
*namespace: str,
generic_type: type[S],
entry_points: bool = False,
description: str = '',
) -> 'TypedRegistry[S]':
"""Create a new registry.
Args:
namespace (str): The namespace, e.g. "streaming.streams_registry"
generic_type (Type[S]): The type of the registry.
entry_points (bool): Accept registered functions from entry points.
description (str): A description of the registry.
Returns:
The TypedRegistry object.
"""
if catalogue.check_exists(*namespace):
raise catalogue.RegistryError(f'Namespace already exists: {namespace}')

return TypedRegistry[generic_type](
namespace,
entry_points=entry_points,
description=description,
)


def construct_from_registry(
name: str,
registry: TypedRegistry,
partial_function: bool = True,
pre_validation_function: Optional[Union[Callable[[Any], None], type]] = None,
post_validation_function: Optional[Callable[[Any], None]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> Any:
"""Helper function to build an item from the registry.
Args:
name (str): The name of the registered item
registry (catalogue.Registry): The registry to fetch the item from
partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True.
pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called
before constructing the item to return. This should throw an exception if validation fails. Defaults to None.
post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after
constructing the item to return. This should throw an exception if validation fails. Defaults to None.
kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments.
Raises:
ValueError: If the validation functions failed or the registered item is invalid
Returns:
Any: The constructed item from the registry
"""
if kwargs is None:
kwargs = {}

registered_constructor = registry.get(name)

if pre_validation_function is not None:
if isinstance(pre_validation_function, type):
if not issubclass(registered_constructor, pre_validation_function):
raise ValueError(
f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}',
)
elif isinstance(pre_validation_function, Callable):
pre_validation_function(registered_constructor)
else:
raise ValueError(
f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}',
)

# If it is a class, or a builder function, construct the class with kwargs
# If it is a function, create a partial with kwargs
if isinstance(
registered_constructor,
type,
) or callable(registered_constructor) and not partial_function:
constructed_item = registered_constructor(**kwargs)
elif callable(registered_constructor):
constructed_item = functools.partial(registered_constructor, **kwargs)
else:
raise ValueError(
f'Expected {name} to be a class or function, but got {type(registered_constructor)}',)

if post_validation_function is not None:
post_validation_function(constructed_item)

return constructed_item


def import_file(loc: Union[str, Path]) -> ModuleType:
"""Import module from a file.
Used to run arbitrary python code.
Args:
name (str): Name of module to load.
loc (str / Path): Path to the file.
Returns:
ModuleType: The module object.
"""
if not os.path.exists(loc):
raise FileNotFoundError(f'File {loc} does not exist.')

spec = importlib.util.spec_from_file_location('python_code', str(loc))

assert spec is not None
assert spec.loader is not None

module = importlib.util.module_from_spec(spec)

try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f'Error executing {loc}') from e
return module


@contextmanager
def save_registry():
"""Save the registry state and restore after the context manager exits."""
saved_registry_state = copy.deepcopy(catalogue.REGISTRY)

yield

catalogue.REGISTRY = saved_registry_state
11 changes: 11 additions & 0 deletions streaming/base/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from streaming.base.distributed import barrier, get_local_rank
from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json
from streaming.base.hashing import get_hash
from streaming.base.registry_utils import create_registry
from streaming.base.storage import CloudDownloader
from streaming.base.util import retry, wait_for_file_to_exist
from streaming.base.world import World
Expand Down Expand Up @@ -507,3 +508,13 @@ def get_index_size(self) -> int:
"""
filename = os.path.join(self.local, self.split, get_index_basename())
return os.stat(filename).st_size


streams_registry = create_registry(
'streaming',
'streams_registry',
generic_type=type[Stream],
entry_points=True,
description='The streams registry is used for registering Stream classes.')

streams_registry.register('stream', func=Stream)
Loading

0 comments on commit 0b2227f

Please sign in to comment.