Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to earthkit-data #45

Merged
merged 32 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 45 additions & 65 deletions c3s_eqc_automatic_quality_control/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,25 @@
# limitations under the License.

import calendar
import fnmatch
import functools
import itertools
import pathlib
from collections.abc import Callable
from typing import Any

import cacholote
import cads_toolbox
import cf_xarray # noqa: F401
import cgul
import emohawk.readers.directory
import emohawk.readers.shapefile
import earthkit.data
import fsspec
import fsspec.implementations.local
import joblib
import pandas as pd
import tqdm
import xarray as xr

cads_toolbox.config.USE_CACHE = True

N_JOBS = 1
INVALIDATE_CACHE = False
# TODO: This kwargs should somehow be handle upstream by the toolbox.
TO_XARRAY_KWARGS: dict[str, Any] = {
"pandas_read_csv_kwargs": {"comment": "#"},
}

_SORTED_REQUEST_PARAMETERS = ("area", "grid")


Expand Down Expand Up @@ -294,28 +286,42 @@ def ensure_request_gets_cached(request: dict[str, Any]) -> dict[str, Any]:
return cacheable_request


def _cached_retrieve(collection_id: str, request: dict[str, Any]) -> emohawk.Data:
with cacholote.config.set(return_cache_entry=False):
return cads_toolbox.catalogue.retrieve(collection_id, request).data
def get_paths(sources: list[Any]) -> list[str]:
paths = []
for source in sources:
if hasattr(source, "path"):
paths.append(source.path)
else:
paths.extend([index.path for index in source._indexes])
return paths


@cacholote.cacheable
def _cached_retrieve(
collection_id: str, request: dict[str, Any]
) -> list[fsspec.implementations.local.LocalFileOpener]:
ds = earthkit.data.from_source("cds", collection_id, request, prompt=False)
sources = ds.sources if hasattr(ds, "sources") else [ds]
fs = fsspec.filesystem("file")
return [fs.open(path) for path in get_paths(sources)]


def retrieve(collection_id: str, request: dict[str, Any]) -> list[str]:
with cacholote.config.set(
return_cache_entry=False,
io_delete_original=True,
):
return [file.path for file in _cached_retrieve(collection_id, request)]


def get_sources(
collection_id: str,
request_list: list[dict[str, Any]],
exclude: list[str] = ["*.png", "*.json"],
) -> list[str]:
source: set[str] = set()

sources: set[str] = set()
for request in tqdm.tqdm(request_list, disable=len(request_list) <= 1):
data = _cached_retrieve(collection_id, request)
if content := getattr(data, "_content", None):
source.update(map(str, content))
else:
source.add(str(data.source))

for pattern in exclude:
source -= set(fnmatch.filter(source, pattern))
return list(source)
sources.update(retrieve(collection_id, request))
return list(sources)


def _set_bound_coords(ds: xr.Dataset) -> xr.Dataset:
Expand Down Expand Up @@ -394,57 +400,24 @@ def _preprocess(
return harmonise(ds, collection_id)


def get_data(source: list[str]) -> Any:
if len(source) == 1:
return emohawk.open(source[0])

# TODO: emohawk not able to open a list of files
emohwak_dir = emohawk.readers.directory.DirectoryReader("")
emohwak_dir._content = source
return emohwak_dir


def _download_and_transform_requests(
collection_id: str,
request_list: list[dict[str, Any]],
transform_func: Callable[..., xr.Dataset] | None,
transform_func_kwargs: dict[str, Any],
**open_mfdataset_kwargs: Any,
) -> xr.Dataset:
# TODO: Ideally, we would always use emohawk.
# However, there is not a consistent behavior across backends.
# For example, GRIB silently ignore open_mfdataset_kwargs
sources = get_sources(collection_id, request_list)
try:
engine = open_mfdataset_kwargs.get(
"engine",
{xr.backends.plugins.guess_engine(source) for source in sources},
)
use_emohawk = len(engine) != 1
except ValueError:
use_emohawk = True

open_mfdataset_kwargs["preprocess"] = functools.partial(
_preprocess,
collection_id=collection_id,
preprocess=open_mfdataset_kwargs.get("preprocess", None),
)

if use_emohawk:
data = get_data(sources)
if isinstance(data, emohawk.readers.shapefile.ShapefileReader):
# FIXME: emohawk NotImplementedError
ds: xr.Dataset = data.to_pandas().to_xarray()
else:
ds = data.to_xarray(
xarray_open_mfdataset_kwargs=open_mfdataset_kwargs,
**TO_XARRAY_KWARGS,
)
if not isinstance(ds, xr.Dataset):
# When emohawk fails to concat, it silently return a list
raise TypeError(f"`emohawk` returned {type(ds)} instead of a xr.Dataset")
else:
ds = xr.open_mfdataset(sources, **open_mfdataset_kwargs)
ds = earthkit.data.from_source("file", sources).to_xarray(
xarray_open_mfdataset_kwargs=open_mfdataset_kwargs
)
if not isinstance(ds, xr.Dataset):
raise TypeError(f"`earthkit.data` returned {type(ds)} instead of a xr.Dataset")

if transform_func is not None:
with cacholote.config.set(return_cache_entry=False):
Expand All @@ -465,7 +438,7 @@ def _delayed_download(
collection_id: str, request: dict[str, Any], config: cacholote.config.Settings
) -> None:
with cacholote.config.set(**dict(config)):
_cached_retrieve(collection_id, request)
retrieve(collection_id, request)


def download_and_transform(
Expand Down Expand Up @@ -547,6 +520,13 @@ def download_and_transform(
for request in ensure_list(requests):
request_list.extend(split_request(request, chunks, split_all))

if invalidate_cache and not use_cache:
# Delete raw data
for request in request_list:
cacholote.delete(
_cached_retrieve, collection_id=collection_id, request=request
)

if n_jobs != 1:
# Download all data in parallel
joblib.Parallel(n_jobs=n_jobs)(
Expand Down
5 changes: 2 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ channels:
# DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW AS SHOWN IN THE EXAMPLE
dependencies:
- cartopy
- cdsapi
- cfgrib
- cf-units
- cf_xarray
Expand Down Expand Up @@ -40,8 +39,8 @@ dependencies:
- xesmf
- xskillscore
- pip:
- git+https://github.com/ecmwf/cdsapi.git
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
- cacholote
- cads-toolbox
- cgul
- emohawk
- earthkit-data>=0.7.0
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
- kaleido
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ ignore_missing_imports = true
module = [
"cads_toolbox",
"cartopy.*",
"cdsapi",
"cgul",
"emohawk.*",
"fsspec",
"earthkit.*",
"fsspec.*",
"joblib",
"plotly.*",
"shapely",
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
import pathlib
import tempfile
from collections.abc import Generator
from typing import Any

import cacholote
import cdsapi
import fsspec
import pytest
import xarray as xr


class MockResult:
def __init__(self, name: str, request: dict[str, Any]) -> None:
self.name = name
self.request = request

@property
def location(self) -> str:
return tempfile.NamedTemporaryFile(suffix=".nc", delete=False).name

def download(self, target: str | pathlib.Path | None = None) -> str | pathlib.Path:
ds = xr.tutorial.open_dataset(self.name).sel(**self.request)
ds.to_netcdf(path := target or self.location)
return path


def mock_retrieve(
self: cdsapi.Client,
name: str,
request: dict[str, Any],
target: str | pathlib.Path | None = None,
) -> fsspec.spec.AbstractBufferedFile:
result = MockResult(name, request)
if target is None:
return result
return result.download(target)


@pytest.fixture(autouse=True)
def mock_download(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CDSAPI_URL", "")
monkeypatch.setenv("CDSAPI_KEY", "123456:1123e4567-e89b-12d3-a456-42665544000")
monkeypatch.setattr(cdsapi.Client, "retrieve", mock_retrieve)


@pytest.fixture(autouse=True)
Expand Down
14 changes: 1 addition & 13 deletions tests/test_10_download.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import datetime
from typing import Any

import cads_toolbox
import pandas as pd
import pytest
import xarray as xr
from utils import mock_download

from c3s_eqc_automatic_quality_control import download

Expand Down Expand Up @@ -271,12 +269,9 @@ def test_ensure_request_gets_cached() -> None:
],
)
def test_download_no_transform(
monkeypatch: pytest.MonkeyPatch,
chunks: dict[str, int],
dask_chunks: dict[str, tuple[int, ...]],
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

ds = download.download_and_transform(*AIR_TEMPERATURE_REQUEST, chunks=chunks)
assert dict(ds.chunks) == dask_chunks

Expand All @@ -289,12 +284,9 @@ def test_download_no_transform(
],
)
def test_download_and_transform(
monkeypatch: pytest.MonkeyPatch,
transform_chunks: bool,
dask_chunks: dict[str, tuple[int, ...]],
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

def transform_func(ds: xr.Dataset) -> xr.Dataset:
return ds.round().mean(("longitude", "latitude"))

Expand All @@ -310,11 +302,7 @@ def transform_func(ds: xr.Dataset) -> xr.Dataset:

@pytest.mark.parametrize("transform_chunks", [True, False])
@pytest.mark.parametrize("invalidate_cache", [True, False])
def test_invalidate_cache(
monkeypatch: pytest.MonkeyPatch, transform_chunks: bool, invalidate_cache: bool
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

def test_invalidate_cache(transform_chunks: bool, invalidate_cache: bool) -> None:
def transform_func(ds: xr.Dataset) -> xr.Dataset:
return ds * 0

Expand Down
19 changes: 0 additions & 19 deletions tests/utils.py

This file was deleted.