diff --git a/deepset_cloud_sdk/_service/files_service.py b/deepset_cloud_sdk/_service/files_service.py index 9b22d145..f05819d4 100644 --- a/deepset_cloud_sdk/_service/files_service.py +++ b/deepset_cloud_sdk/_service/files_service.py @@ -490,7 +490,15 @@ async def download( pbar: Optional[tqdm] = None if show_progress: - total = (await self._files.list_paginated(workspace_name, limit=1)).total + total = ( + await self._files.list_paginated( + workspace_name, + name=name, + content=content, + odata_filter=odata_filter, + limit=1, + ) + ).total pbar = tqdm(total=total, desc="Download Progress") after_value = None diff --git a/tests/unit/service/test_files_service.py b/tests/unit/service/test_files_service.py index 42d42231..cfe4a23f 100644 --- a/tests/unit/service/test_files_service.py +++ b/tests/unit/service/test_files_service.py @@ -681,6 +681,59 @@ async def test_download_files_with_filter(self, file_service: FilesService, monk after_value=None, ) + async def test_download_files_with_filter_and_progress_bar( + self, file_service: FilesService, monkeypatch: MonkeyPatch + ) -> None: + mocked_list_paginated = AsyncMock( + return_value=FileList( + total=1, + data=[ + File( + file_id=UUID("cd16435f-f6eb-423f-bf6f-994dc8a36a10"), + url="/api/v1/workspaces/search tests/files/cd16435f-f6eb-423f-bf6f-994dc8a36a10", + name="silly_things_2.txt", + size=611, + created_at=datetime.datetime.fromisoformat("2022-06-21T16:40:00.634653+00:00"), + meta={}, + ) + ], + has_more=False, + ), + ) + + monkeypatch.setattr(file_service._files, "list_paginated", mocked_list_paginated) + + mocked_download = AsyncMock(return_value=None) + monkeypatch.setattr(file_service._files, "download", mocked_download) + + await file_service.download( + workspace_name="test_workspace", + show_progress=True, # This requires a previous cal that checks the total number of files + odata_filter="category eq 'news'", + name="asdf", + content="bsdf", + batch_size=54, + ) + + mocked_list_paginated.mock_calls == [ + call( + workspace_name="test_workspace", + name="asdf", + content="bsdf", + odata_filter="category eq 'news'", + limit=54, + ), + call( + workspace_name="test_workspace", + name="asdf", + content="bsdf", + odata_filter="category eq 'news'", + limit=54, + after_file_id=None, + after_value=None, + ), + ] + async def test_download_all_files_with_file_not_found( self, file_service: FilesService, monkeypatch: MonkeyPatch ) -> None: