Skip to content

Commit

Permalink
Compress Progress bars to a single bar for dataset groups
Browse files Browse the repository at this point in the history
  • Loading branch information
Prajwal Kiran Kumar authored and BenGalewsky committed Dec 1, 2023
1 parent b789933 commit e765f07
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 26 deletions.
44 changes: 28 additions & 16 deletions servicex/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
sx_adapter: ServiceXAdapter,
config: Configuration,
query_cache: QueryCache,
servicex_polling_interval: int = 10,
servicex_polling_interval: int = 5,
minio_polling_interval: int = 5,
result_format: ResultFormat = ResultFormat.parquet,
ignore_cache: bool = False,
Expand Down Expand Up @@ -158,7 +158,8 @@ def set_result_format(self, result_format: ResultFormat):

async def submit_and_download(
self, signed_urls_only: bool,
expandable_progress: ExpandableProgress
expandable_progress: ExpandableProgress,
dataset_group: Optional[bool] = False,
) -> Optional[TransformedResults]:
"""
Submit the transform request to ServiceX. Poll the transform status to see when
Expand Down Expand Up @@ -222,9 +223,10 @@ def transform_complete(task: Task):
# If we get here with a cached record, then we know that the transform
# has been run, but we just didn't get the files from object store in the way
# requested by user
transform_bar_title = "Transform"
if not cached_record:
transform_progress = expandable_progress.add_task(
"Transform", start=False, total=None
transform_bar_title, start=False, total=None
) if expandable_progress else None
else:
self.request_id = cached_record.request_id
Expand All @@ -244,7 +246,8 @@ def transform_complete(task: Task):

monitor_task = loop.create_task(
self.transform_status_listener(
expandable_progress.progress, transform_progress, download_progress
expandable_progress, transform_progress, transform_bar_title,
download_progress, minio_progress_bar_title
)
)
monitor_task.add_done_callback(transform_complete)
Expand All @@ -253,7 +256,7 @@ def transform_complete(task: Task):

download_files_task = loop.create_task(
self.download_files(
signed_urls_only, expandable_progress.progress, download_progress, cached_record
signed_urls_only, expandable_progress, download_progress, cached_record
)
)

Expand Down Expand Up @@ -288,7 +291,8 @@ def transform_complete(task: Task):
rich.print("Aborted file downloads due to transform failure")

async def transform_status_listener(
self, progress: Progress, progress_task: TaskID, download_task: TaskID
self, progress: ExpandableProgress, progress_task: TaskID,
progress_bar_title: str, download_task: TaskID, download_bar_title: str
):
"""
Poll ServiceX for the status of a transform. Update progress bars and keep track
Expand All @@ -309,15 +313,16 @@ async def transform_status_listener(
if not final_count and self.current_status.files:
final_count = self.current_status.files
if progress:
progress.update(progress_task, total=final_count)
progress.start_task(progress_task)
progress.update(progress_task, progress_bar_title, total=final_count)
progress.start_task(task_id=progress_task, task_type="Transform")

progress.update(download_task, total=final_count)
progress.start_task(download_task)
progress.update(download_task, download_bar_title, total=final_count)
progress.start_task(task_id=download_task, task_type="Download")

if progress:
progress.update(
progress_task, completed=self.current_status.files_completed
progress_task, progress_bar_title,
completed=self.current_status.files_completed
)

if self.current_status.status == Status.complete:
Expand Down Expand Up @@ -347,7 +352,7 @@ async def retrieve_current_transform_status(self):
async def download_files(
self,
signed_urls_only: bool,
progress: Progress,
progress: ExpandableProgress,
download_progress: TaskID,
cached_record: Optional[TransformedResults],
) -> List[str]:
Expand All @@ -371,7 +376,7 @@ async def download_file(
filename, self.download_path, shorten_filename=shorten_filename
)
result_uris.append(downloaded_filename.as_posix())
progress.advance(download_progress)
progress.advance(task_id=download_progress, task_type="Download")

async def get_signed_url(
minio: MinioAdapter,
Expand All @@ -382,7 +387,7 @@ async def get_signed_url(
url = await minio.get_signed_url(filename)
result_uris.append(url)
if progress:
progress.advance(download_progress)
progress.advance(task_id=download_progress, task_type="Download")

while True:
if not cached_record:
Expand Down Expand Up @@ -467,16 +472,23 @@ async def as_pandas_async(self,
pass

async def as_signed_urls_async(self, display_progress: bool = True,
provided_progress: Optional[ProgressIndicators] = None) \
provided_progress: Optional[ProgressIndicators] = None,
dataset_group: bool = False) \
-> TransformedResults:
r"""
Presign URLs for each of the transformed files
:return: TransformedResults object with the presigned_urls list populated
"""
if dataset_group:
return await self.submit_and_download(signed_urls_only=True,
expandable_progress=provided_progress,
dataset_group=dataset_group)

with ExpandableProgress(display_progress=display_progress,
provided_progress=provided_progress) as progress:
return await self.submit_and_download(signed_urls_only=True,
expandable_progress=progress)
expandable_progress=progress,
dataset_group=dataset_group)

as_signed_urls = make_sync(as_signed_urls_async)
5 changes: 3 additions & 2 deletions servicex/dataset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ async def as_signed_urls_async(
display_progress: bool = True,
provided_progress: Optional[Progress] = None,
) -> List[TransformedResults]:
with ExpandableProgress(display_progress, provided_progress) as progress:
with ExpandableProgress(display_progress, provided_progress,
overall_progress=True) as progress:
self.tasks = [
d.as_signed_urls_async(provided_progress=progress)
d.as_signed_urls_async(provided_progress=progress, dataset_group=True)
for d in self.datasets
]
return await asyncio.gather(*self.tasks)
Expand Down
107 changes: 100 additions & 7 deletions servicex/expandable_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,29 @@
from typing import Optional

from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, \
TimeRemainingColumn
TimeRemainingColumn, TaskID


class ProgressCounts:
def __init__(self,
description: str,
task_id: TaskID,
start: Optional[int] = None,
total: Optional[int] = None,
completed: Optional[int] = None):

self.description = description
self.taskId = task_id
self.start = start
self.total = total
self.completed = completed


class ExpandableProgress:
def __init__(self,
display_progress: bool = True,
provided_progress: Optional[Progress | ExpandableProgress] = None):
provided_progress: Optional[Progress | ExpandableProgress] = None,
overall_progress: bool = False):
"""
We want to be able to use rich progress bars in the async code, but there are
some situtations where the user doesn't want them. Also we might be running
Expand All @@ -52,17 +68,22 @@ def __init__(self,
"""
self.display_progress = display_progress
self.provided_progress = provided_progress
self.overall_progress = overall_progress
self.overall_progress_transform_task = None
self.overall_progress_download_task = None
self.progress_counts = {}
if display_progress:
if provided_progress:
self.progress = provided_progress if isinstance(provided_progress, Progress) \
else provided_progress.progress
else:
if self.overall_progress or not provided_progress:
self.progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(compact=True, elapsed_when_finished=True)
)

if provided_progress:
self.progress = provided_progress if isinstance(provided_progress, Progress) \
else provided_progress.progress
else:
self.progress = None

Expand All @@ -88,5 +109,77 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.progress.stop()

def add_task(self, param, start, total):
if self.display_progress:
if self.display_progress and self.overall_progress:
if (
not self.overall_progress_download_task
and not self.overall_progress_transform_task
):
self.overall_progress_transform_task = self.progress.add_task("Transform",
start=False,
total=None)
self.overall_progress_download_task = self.progress.add_task("Download/URLs",
start=False,
total=None)

task_id = self.progress.add_task(param, start=start, total=total, visible=False)
new_task = ProgressCounts(param, task_id, start=start, total=total)
self.progress_counts[task_id] = new_task
return task_id
if self.display_progress and not self.overall_progress:
return self.progress.add_task(param, start=start, total=total)

def update(self, task_id, task_type, total=None, completed=None):

if self.display_progress and self.overall_progress:
# Calculate and update
overall_completed = 0
overall_total = 0
if completed:
self.progress_counts[task_id].completed = completed

elif total:
self.progress_counts[task_id].total = total

for task in self.progress_counts:
if (
self.progress_counts[task].description == task_type
and self.progress_counts[task].completed
):
overall_completed += self.progress_counts[task].completed

for task in self.progress_counts:
if (
self.progress_counts[task].description == task_type
and self.progress_counts[task].total
):
overall_total += self.progress_counts[task].total

if task_type == "Transform":
return self.progress.update(self.overall_progress_transform_task,
completed=overall_completed,
total=overall_total)
else:
return self.progress.update(self.overall_progress_download_task,
completed=overall_completed,
total=overall_total)

if self.display_progress and not self.overall_progress:
return self.progress.update(task_id, completed=completed, total=total)

def start_task(self, task_id, task_type):
if self.display_progress and self.overall_progress:
if task_type == "Transform":
self.progress.start_task(task_id=self.overall_progress_transform_task)
else:
self.progress.start_task(task_id=self.overall_progress_download_task)
elif self.display_progress and not self.overall_progress:
self.progress.start_task(task_id=task_id)

def advance(self, task_id, task_type):
if self.display_progress and self.overall_progress:
if task_type == "Transform":
self.progress.advance(task_id=self.overall_progress_transform_task)
else:
self.progress.advance(task_id=self.overall_progress_download_task)
elif self.display_progress and not self.overall_progress:
self.progress.advance(task_id=task_id)
47 changes: 46 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ async def test_as_signed_urls_happy(transformed_result):
assert result == transformed_result


@pytest.mark.asyncio
async def test_as_signed_urls_happy_dataset_group(transformed_result):
# Test when display_progress is True and provided_progress is None
did = FileListDataset("/foo/bar/baz.root")
dataset = PythonDataset(dataset_identifier=did, codegen="uproot",
sx_adapter=None, query_cache=None)
dataset.submit_and_download = AsyncMock()
dataset.submit_and_download.return_value = transformed_result

result = dataset.as_signed_urls(display_progress=True, provided_progress=None,
dataset_group=True)
assert result == transformed_result


@pytest.mark.asyncio
async def test_as_files_happy(transformed_result):
did = FileListDataset("/foo/bar/baz.root")
Expand Down Expand Up @@ -138,7 +152,8 @@ async def test_transform_status_listener_happy(python_dataset):
status = Mock(files=10, files_completed=5, files_failed=1, status=Status.complete)
python_dataset.current_status = status
python_dataset.retrieve_current_transform_status = AsyncMock(return_value=status)
await python_dataset.transform_status_listener(progress, progress_task, download_task)
await python_dataset.transform_status_listener(progress, progress_task, "mock_title",
download_task, "mock_title")

python_dataset.retrieve_current_transform_status.assert_awaited_once()
# progress.update.assert_called_with(progress_task, total=10)
Expand Down Expand Up @@ -217,6 +232,36 @@ async def test_submit_and_download_cache_miss(python_dataset, completed_status):
cache.close()


@pytest.mark.asyncio
async def test_submit_and_download_cache_miss_overall_progress(python_dataset, completed_status):
with tempfile.TemporaryDirectory() as temp_dir:
python_dataset.current_status = None
python_dataset.servicex = AsyncMock()
config = Configuration(cache_path=temp_dir, api_endpoints=[])
cache = QueryCache(config)
python_dataset.cache = cache
python_dataset.configuration = config

python_dataset.servicex = AsyncMock()
python_dataset.cache.get_transform_by_hash = Mock()
python_dataset.cache.get_transform_by_hash.return_value = None
python_dataset.servicex.get_transform_status = AsyncMock(id="12345")
python_dataset.servicex.get_transform_status.return_value = completed_status
python_dataset.servicex.submit_transform = AsyncMock()
python_dataset.download_files = AsyncMock()
python_dataset.download_files.return_value = []

signed_urls_only = False
expandable_progress = ExpandableProgress(overall_progress=True)
dataset_group = True

result = await python_dataset.submit_and_download(signed_urls_only, expandable_progress,
dataset_group)
assert result is not None
assert result.request_id == "b8c508d0-ccf2-4deb-a1f7-65c839eebabf"
cache.close()


@pytest.mark.asyncio
async def test_submit_and_download_no_result_format(python_dataset, completed_status):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_expandable_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ def test_progress(mock_progress):
assert mock_progress.return_value.stop.call_count == 1


@patch("servicex.expandable_progress.Progress", return_value=MagicMock(Progress))
def test_overall_progress(mock_progress):
with ExpandableProgress(overall_progress=True) as progress:
assert progress.progress == mock_progress.return_value
mock_progress.return_value.start.assert_called_once()
assert progress.display_progress
assert mock_progress.return_value.stop.call_count == 1


@patch("servicex.expandable_progress.Progress", return_value=MagicMock(Progress))
def test_overall_progress_mock(mock_progress):
with ExpandableProgress(overall_progress=True) as progress:
assert progress.progress == mock_progress.return_value
mock_progress.return_value.start.assert_called_once()
assert progress.display_progress
assert mock_progress.return_value.stop.call_count == 1


def test_provided_progress(mocker):
class MockedProgress(Progress):
def __init__(self):
Expand Down

0 comments on commit e765f07

Please sign in to comment.