diff --git a/c3s_eqc_automatic_quality_control/download.py b/c3s_eqc_automatic_quality_control/download.py index dd3aa3d..810cd25 100644 --- a/c3s_eqc_automatic_quality_control/download.py +++ b/c3s_eqc_automatic_quality_control/download.py @@ -18,11 +18,13 @@ # limitations under the License. import calendar +import contextlib import fnmatch import functools import itertools +import os import pathlib -from collections.abc import Callable +from collections.abc import Callable, Iterator from typing import Any import cacholote @@ -48,6 +50,19 @@ _SORTED_REQUEST_PARAMETERS = ("area", "grid") +@contextlib.contextmanager +def _set_env(**kwargs: Any) -> Iterator[None]: + old_environ = dict(os.environ) + try: + os.environ.update( + {k.upper(): str(v) for k, v in kwargs.items() if v is not None} + ) + yield + finally: + os.environ.clear() + os.environ.update(old_environ) + + def compute_stop_date(switch_month_day: int | None = None) -> pd.Period: today = pd.Timestamp.today() if switch_month_day is None: @@ -306,7 +321,8 @@ def get_sources( ) -> list[str]: source: set[str] = set() - for request in tqdm.tqdm(request_list, disable=len(request_list) <= 1): + disable = os.getenv("TQDM_DISABLE", "False") == "True" + for request in tqdm.tqdm(request_list, disable=disable): data = _cached_retrieve(collection_id, request) if content := getattr(data, "_content", None): source.update(map(str, content)) @@ -479,6 +495,7 @@ def download_and_transform( n_jobs: int | None = None, invalidate_cache: bool | None = None, cached_open_mfdataset_kwargs: bool | dict[str, Any] = {}, + quiet: bool = False, **open_mfdataset_kwargs: Any, ) -> xr.Dataset: """ @@ -513,6 +530,8 @@ def download_and_transform( cached_open_mfdataset_kwargs: bool | dict Kwargs to be passed on to xr.open_mfdataset for cached files. If True, use open_mfdataset_kwargs used for raw files. + quiet: bool + Whether to disable progress bars. **open_mfdataset_kwargs: Kwargs to be passed on to xr.open_mfdataset for raw files. @@ -520,6 +539,8 @@ def download_and_transform( ------- xr.Dataset """ + assert isinstance(quiet, bool) + if n_jobs is None: n_jobs = N_JOBS @@ -557,12 +578,15 @@ def download_and_transform( if use_cache and transform_chunks: # Cache each chunk transformed sources = [] - for request in tqdm.tqdm(request_list): + for request in tqdm.tqdm(request_list, disable=quiet): if invalidate_cache: cacholote.delete( func.func, *func.args, request_list=[request], **func.keywords ) - with cacholote.config.set(return_cache_entry=True): + with ( + cacholote.config.set(return_cache_entry=True), + _set_env(tqdm_disable=True), + ): sources.append(func(request_list=[request]).result["args"][0]["href"]) ds = xr.open_mfdataset(sources, **cached_open_mfdataset_kwargs) else: @@ -571,7 +595,8 @@ def download_and_transform( cacholote.delete( func.func, *func.args, request_list=request_list, **func.keywords ) - ds = func(request_list=request_list) + with _set_env(tqdm_disable=quiet): + ds = func(request_list=request_list) ds.attrs.pop("coordinates", None) # Previously added to guarantee roundtrip return ds