Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding upsampling for StreamingDataset #451

Closed
JackUrb opened this issue Jan 16, 2025 · 2 comments · Fixed by #453
Closed

Adding upsampling for StreamingDataset #451

JackUrb opened this issue Jan 16, 2025 · 2 comments · Fixed by #453
Labels
enhancement New feature or request

Comments

@JackUrb
Copy link
Contributor

JackUrb commented Jan 16, 2025

🚀 Feature

The goal is to allow the StreamingDataset(..., subsample=...) argument to accept float values greater than 1, which would then act as if duplicate shuffles of the dataset were included.

Motivation

The main use case I see in this is being able to augment the CombinedStreamingDataset's current limited behavior when it comes to mixing datasets of very different sizes. Presently it can either return when any of the source datasets run out, or continue running (but reweigh to drop the dataset that ran out). Allowing upsampling can ensure that the larger dataset doesn't need to be reset just so that the mixing proportions can remain the same after a smaller dataset runs out.

In an exaggerated case, imagine you have one StreamingDataset A of 2 elements, and another B of 50. If you try to combine these with the CombinedStreamingDataset with a 50-50 split, there's no option presently that would allow the observed dataset to contain all 50 elems of B and 25 copies of the elements of A.

Pitch

From what I can see, this is enabled with three primary changes:
litdata.utilities.dataset_utilities's subsample_streaming_dataset:
Needs to be updated to append shuffled copies of a dataset prior to a final downsampling for whatever's left

def subsample_streaming_dataset(...) -> ...:
   ...
   
   if math.isclose(subsample, 1.0):
        subsampled_files = [chnk["filename"] for chnk in original_chunks]

        return subsampled_files, roi

    final_files: List[str] = []
    final_roi: List[Tuple[int, int]] = []

    random_seed_sampler = None
    if shuffle:
        random_seed_sampler = np.random.RandomState([seed])
    
    while subsample >= 1.0: 
        # shuffle lists together
        if random_seed_sampler is not None:
            original_chunks, roi = shuffle_lists_together(original_chunks, roi, random_seed_sampler)
        subsampled_files = [chnk["filename"] for chnk in original_chunks]
        final_files.extend(subsampled_files)
        final_roi.extend(roi)
        subsample -= 1.0

    if subsample > 0:
        # shuffle lists together
        if random_seed_sampler is not None:
            original_chunks, roi = shuffle_lists_together(original_chunks, roi, random_seed_sampler)

        num_items_to_subsample = int(sum([roi[1] - roi[0] for roi in roi]) * subsample)

        subsampled_files, roi, _, _ = subsample_filenames_and_roi(original_chunks, roi, num_items_to_subsample)
        final_files.extend(subsampled_files)
        final_roi.extend(roi)

    return final_files, final_roi

litdata.streaming.config's load_subsampled_chunks method:
Needs to be updated to handle a many-to-one relationship between _subsampled_chunks and original_chunks.

def load_subsampled_chunks(subsampled_files: List[str], original_chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Loads Chunks based on subsample provided."""
    _subsampled_chunks: List[Dict[str, Any]] = [{} for _ in range(len(subsampled_files))]

    assert len(_subsampled_chunks) == len(subsampled_files)

    filename_dict = defaultdict(list)

    # Populate the dictionary with filenames and their indices
    for index, filename in enumerate(subsampled_files):
        filename_dict[filename].append(index)

    for curr_chunk in original_chunks:
        if curr_chunk["filename"] in filename_dict:
            for idx in filename_dict[curr_chunk["filename"]]:
                _subsampled_chunks[idx] = curr_chunk

    # if any idx of _subsampled_chunks is None, means,
    # some elements in subsampled_files were not actually part of chunks
    # raise error
    if any(not _subsampled_chunk for _subsampled_chunk in _subsampled_chunks):
        raise ValueError(
            "Mismatch in subsampled files and the chunks loaded",
            "Make sure subsampled chunks are actually part of the original chunk",
        )

    return _subsampled_chunks

StreamingDataset's argument validation:
Simply remove the limit.

Alternatives

I've considered updating CombinedStreamingDataset's state to handle resetting individual datasets when they run out, but this becomes incredibly hairy really fast.

@JackUrb JackUrb added the enhancement New feature or request label Jan 16, 2025
Copy link

Hi! thanks for your contribution!, great first issue!

@tchaton
Copy link
Collaborator

tchaton commented Jan 21, 2025

Hey @JackUrb Feel free to make a contribution ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants