Skip to content

Commit

Permalink
Local downloader w cache (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
mads-oestergaard authored Feb 26, 2024
1 parent 9f584d2 commit 918f7ba
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,16 @@ outputs = optimize(
)
```

## Network Drive On-Prem Support

On-prem compute nodes can mount and use network drive. In order to reduce their network overload, the `StreamingDataset` supports `caching` the chunks.

```python
from lightning.data import StreamingDataset

dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
```

# ⚡ Contributors

We welcome any contributions, pull requests, or issues. If you use the Streaming Dataset for your own project, please reach out to us on Slack or Discord.
8 changes: 7 additions & 1 deletion litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
shutil.copy(remote_filepath, local_filepath)


_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}
class LocalDownloaderWithCache(LocalDownloader):
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
remote_filepath = remote_filepath.replace("local:", "")
super().download_file(remote_filepath, local_filepath)


_DOWNLOADERS = {"s3://": S3Downloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}


def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
Expand Down
3 changes: 3 additions & 0 deletions litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
if dir_path.startswith("s3://"):
return Dir(path=None, url=dir_path)

if dir_path.startswith("local:"):
return Dir(path=None, url=dir_path)

dir_path = _resolve_time_template(dir_path)

dir_path_absolute = str(Path(dir_path).absolute().resolve())
Expand Down
17 changes: 16 additions & 1 deletion tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from unittest.mock import MagicMock

from litdata.streaming.downloader import S3Downloader, subprocess
from litdata.streaming.downloader import LocalDownloaderWithCache, S3Downloader, shutil, subprocess


def test_s3_downloader_fast(tmpdir, monkeypatch):
Expand All @@ -11,3 +11,18 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
downloader = S3Downloader(tmpdir, tmpdir, [])
downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
popen_mock.wait.assert_called()


def test_download_with_cache(tmpdir, monkeypatch):
# Create a file to download/cache
with open("a.txt", "w") as f:
f.write("hello")

try:
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
shutil_mock = MagicMock()
monkeypatch.setattr(shutil, "copy", shutil_mock)
local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
shutil_mock.assert_called()
finally:
os.remove("a.txt")

0 comments on commit 918f7ba

Please sign in to comment.