From 4ec092902716b2e68821178d83526e75af77ec53 Mon Sep 17 00:00:00 2001 From: ponyisi Date: Thu, 18 Jul 2024 01:47:11 -0500 Subject: [PATCH] Refactor query class inheritance (#410) * Fix examples * Do not subclass Query directly * Fix version issue * Remove unneeded code * No need for overriding deepcopy any more * Explicit dependence for typing_extensions * Eliminate now invalid union --------- Co-authored-by: KyungEon Choi <54450665+kyungeonchoi@users.noreply.github.com> --- pyproject.toml | 3 +- servicex/databinder_models.py | 4 +- servicex/dataset_group.py | 6 +- servicex/func_adl/func_adl_dataset.py | 48 +------- servicex/python_dataset.py | 42 ++----- servicex/query_core.py | 27 ++--- servicex/servicex_client.py | 154 ++++---------------------- tests/conftest.py | 14 ++- tests/test_databinder.py | 4 +- tests/test_dataset.py | 49 ++++++-- tests/test_func_adl_dataset.py | 79 ++++++++++++- tests/test_python_dataset.py | 21 +--- tests/test_servicex_dataset.py | 69 +++++------- 13 files changed, 206 insertions(+), 314 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4449627a..1a1482d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "PyYAML>=6.0", "types-PyYAML>=6.0", "importlib_metadata; python_version <= '3.9'", + "typing_extensions; python_version <= '3.10'", # compatible versions controlled through pydantic "rich>=13.0.0", # databinder "aiofile", # compatible versions controlled through miniopy-async "make-it-sync", # compatible versions controlled through func_adl @@ -98,7 +99,7 @@ FuncADL_ATLASr21 = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASr21" FuncADL_ATLASr22 = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASr22" FuncADL_ATLASxAOD = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASxAOD" FuncADL_CMS = "servicex.func_adl.func_adl_dataset:FuncADLQuery_CMS" -PythonFunction = "servicex.python_dataset:PythonQuery" +PythonFunction = "servicex.python_dataset:PythonFunction" UprootRaw = "servicex.uproot_raw.uproot_raw:UprootRawQuery" [project.entry-points.'servicex.dataset'] diff --git a/servicex/databinder_models.py b/servicex/databinder_models.py index f030b9ad..1109bd96 100644 --- a/servicex/databinder_models.py +++ b/servicex/databinder_models.py @@ -35,7 +35,7 @@ from servicex.dataset_identifier import (DataSetIdentifier, RucioDatasetIdentifier, FileListDataset) -from servicex.query_core import Query as SXQuery, QueryStringGenerator +from servicex.query_core import QueryStringGenerator from servicex.models import ResultFormat @@ -46,7 +46,7 @@ class Sample(BaseModel): XRootDFiles: Optional[Union[str, List[str]]] = None Dataset: Optional[DataSetIdentifier] = None NFiles: Optional[int] = Field(default=None) - Query: Optional[Union[str, SXQuery, QueryStringGenerator]] = Field(default=None) + Query: Optional[Union[str, QueryStringGenerator]] = Field(default=None) IgnoreLocalCache: bool = False model_config = {"arbitrary_types_allowed": True} diff --git a/servicex/dataset_group.py b/servicex/dataset_group.py index d7d94f2a..7d459834 100644 --- a/servicex/dataset_group.py +++ b/servicex/dataset_group.py @@ -26,18 +26,16 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio -from typing import List, Optional, Union - +from typing import List, Optional from rich.progress import Progress from servicex.query_core import Query from servicex.expandable_progress import ExpandableProgress -from servicex.func_adl.func_adl_dataset import FuncADLQuery from servicex.models import TransformedResults, ResultFormat from make_it_sync import make_sync -DatasetGroupMember = Union[Query, FuncADLQuery] +DatasetGroupMember = Query class DatasetGroup: diff --git a/servicex/func_adl/func_adl_dataset.py b/servicex/func_adl/func_adl_dataset.py index 372dd102..785fbdf5 100644 --- a/servicex/func_adl/func_adl_dataset.py +++ b/servicex/func_adl/func_adl_dataset.py @@ -45,19 +45,14 @@ from func_adl import EventDataset, find_EventDataset from func_adl.object_stream import S -from servicex.configuration import Configuration -from servicex.query_core import Query +from servicex.query_core import QueryStringGenerator from servicex.func_adl.util import has_tuple -from servicex.models import ResultFormat -from servicex.query_cache import QueryCache -from servicex.servicex_adapter import ServiceXAdapter -from servicex.types import DID from abc import ABC T = TypeVar("T") -class FuncADLQuery(Query, EventDataset[T], ABC): +class FuncADLQuery(QueryStringGenerator, EventDataset[T], ABC): r""" ServiceX Dataset class that uses func_adl query syntax. """ @@ -75,53 +70,14 @@ def check_data_format_request(self, f_name: str): def __init__( self, - dataset_identifier: DID = None, - sx_adapter: ServiceXAdapter = None, - title: str = "ServiceX Client", - codegen: Optional[str] = None, - config: Configuration = None, - query_cache: QueryCache = None, - result_format: Optional[ResultFormat] = None, item_type: Type = Any, - ignore_cache: bool = False, ): - Query.__init__( - self, - dataset_identifier=dataset_identifier, - title=title, - codegen=codegen if codegen is not None else self.default_codegen, - sx_adapter=sx_adapter, - config=config, - query_cache=query_cache, - result_format=result_format, - ignore_cache=ignore_cache, - ) EventDataset.__init__(self, item_type=item_type) self.provided_qastle = None def set_provided_qastle(self, qastle: str): self.provided_qastle = qastle - def __deepcopy__(self, memo): - """ - Customize deepcopy behavior for this class. - We need to be careful because the query cache is a tinyDB database that holds an - open file pointer. We are not allowed to clone an open file handle, so for this - property we will copy by reference and share it between the clones - """ - cls = self.__class__ - obj = cls.__new__(cls) - - memo[id(self)] = obj - - for attr, value in vars(self).items(): - if type(value) is QueryCache: - setattr(obj, attr, value) - else: - setattr(obj, attr, copy.deepcopy(value, memo)) - - return obj - def SelectMany( self, func: Union[str, ast.Lambda, Callable[[T], Iterable[S]]] ) -> FuncADLQuery[S]: diff --git a/servicex/python_dataset.py b/servicex/python_dataset.py index 38cd8774..3d1f4573 100644 --- a/servicex/python_dataset.py +++ b/servicex/python_dataset.py @@ -29,39 +29,22 @@ from typing import Optional, Union, Callable from base64 import b64encode from textwrap import dedent +from servicex.query_core import QueryStringGenerator +import sys +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self -from servicex.configuration import Configuration -from servicex.query_core import Query -from servicex.models import ResultFormat -from servicex.query_cache import QueryCache -from servicex.servicex_adapter import ServiceXAdapter -from servicex.types import DID - -class PythonQuery(Query): +class PythonFunction(QueryStringGenerator): yaml_tag = '!PythonFunction' + default_codegen = 'python' - def __init__(self, dataset_identifier: DID = None, - sx_adapter: Optional[ServiceXAdapter] = None, - title: str = "ServiceX Client", - codegen: str = 'python', - config: Optional[Configuration] = None, - query_cache: Optional[QueryCache] = None, - result_format: Optional[ResultFormat] = None, - ignore_cache: bool = False - ): - super().__init__(dataset_identifier=dataset_identifier, - title=title, - codegen=codegen, - sx_adapter=sx_adapter, - config=config, - query_cache=query_cache, - result_format=result_format, - ignore_cache=ignore_cache) - - self.python_function = None + def __init__(self, python_function: Optional[Union[str, Callable]] = None): + self.python_function: Optional[Union[str, Callable]] = python_function - def with_uproot_function(self, f: Union[str, Callable]) -> Query: + def with_uproot_function(self, f: Union[str, Callable]) -> Self: self.python_function = f return self @@ -83,6 +66,5 @@ def from_yaml(cls, _, node): exec(code) except SyntaxError as e: raise SyntaxError(f"Syntax error {e} interpreting\n{code}") - q = PythonQuery() - q.with_uproot_function(code) + q = PythonFunction(code) return q diff --git a/servicex/query_core.py b/servicex/query_core.py index c68d375b..bae2ab27 100644 --- a/servicex/query_core.py +++ b/servicex/query_core.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, IRIS-HEP +# Copyright (c) 2024, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -72,7 +72,7 @@ class ServiceXException(Exception): """ Something happened while trying to carry out a ServiceX request """ -class Query(ABC): +class Query: def __init__( self, dataset_identifier: DID, @@ -80,11 +80,12 @@ def __init__( codegen: str, sx_adapter: ServiceXAdapter, config: Configuration, - query_cache: QueryCache, + query_cache: Optional[QueryCache], servicex_polling_interval: int = 5, minio_polling_interval: int = 5, result_format: ResultFormat = ResultFormat.parquet, ignore_cache: bool = False, + query_string_generator: Optional[QueryStringGenerator] = None, ): r""" This is the main class for constructing transform requests and receiving the @@ -103,7 +104,6 @@ def __init__( :param result_format: :param ignore_cache: If true, ignore the cache and always submit a new transform """ - super(Query, self).__init__() self.servicex = sx_adapter self.configuration = config self.cache = query_cache @@ -123,14 +123,16 @@ def __init__( self.request_id = None self.ignore_cache = ignore_cache + self.query_string_generator = query_string_generator # Number of seconds in between ServiceX status polls self.servicex_polling_interval = servicex_polling_interval self.minio_polling_interval = minio_polling_interval - @abc.abstractmethod def generate_selection_string(self) -> str: - pass + if self.query_string_generator is None: + raise RuntimeError('query string generator not set') + return self.query_string_generator.generate_selection_string() @property def transform_request(self): @@ -627,16 +629,3 @@ def __init__(self, query: str, codegen: str): def generate_selection_string(self) -> str: return self.query - - -class GenericQuery(Query): - ''' - This class gives a "generic" Query object which doesn't require - overloading the constructor - ''' - query_string_generator: Optional[QueryStringGenerator] = None - - def generate_selection_string(self) -> str: - if self.query_string_generator is None: - raise RuntimeError('query string generator not set') - return self.query_string_generator.generate_selection_string() diff --git a/servicex/servicex_client.py b/servicex/servicex_client.py index af61fc00..352a659b 100644 --- a/servicex/servicex_client.py +++ b/servicex/servicex_client.py @@ -26,22 +26,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import logging -from typing import Optional, List, TypeVar, Any, Type, Mapping, Union +from typing import Optional, List, TypeVar, Any, Mapping, Union from pathlib import Path from servicex.configuration import Configuration -from servicex.func_adl.func_adl_dataset import FuncADLQuery from servicex.models import ResultFormat, TransformStatus, TransformedResults from servicex.query_cache import QueryCache from servicex.servicex_adapter import ServiceXAdapter from servicex.query_core import ( - GenericQuery, + Query, QueryStringGenerator, GenericQueryStringGenerator, - Query, ) from servicex.types import DID -from servicex.python_dataset import PythonQuery from servicex.dataset_group import DatasetGroup from make_it_sync import make_sync @@ -105,54 +102,18 @@ def get_codegen(_sample: Sample, _general: General): sx = ServiceXClient(backend=servicex_name, config_path=config_path) datasets = [] for sample in config.Sample: - # if string or QueryStringGenerator, turn into a Query - if isinstance(sample.Query, str) or isinstance( - sample.Query, QueryStringGenerator - ): - logger.debug("sample.Query from string or QueryStringGenerator") - sample.Query = sx.generic_query( - dataset_identifier=sample.dataset_identifier, - title=sample.Name, - codegen=get_codegen(sample, config.General), - result_format=config.General.OutputFormat, - ignore_cache=sample.IgnoreLocalCache, - query=sample.Query, - ) - elif isinstance(sample.Query, FuncADLQuery): - logger.debug("sample.Query from FuncADLQuery") - logger.debug( - f"qastle_query from ServiceXSpec: {sample.Query.generate_selection_string()}" - ) - query = sx.func_adl_dataset( - dataset_identifier=sample.dataset_identifier, - title=sample.Name, - codegen=get_codegen(sample, config.General), - result_format=config.General.OutputFormat, - ignore_cache=sample.IgnoreLocalCache, - ) - query.set_provided_qastle(sample.Query.generate_selection_string()) - - sample.Query = query - - logger.debug( - f"final qastle_query: {sample.Query.generate_selection_string()}" - ) - elif isinstance(sample.Query, PythonQuery): - logger.debug("sample.Query from PythonQuery") - query = sx.python_dataset( - dataset_identifier=sample.dataset_identifier, - title=sample.Name, - codegen=get_codegen(sample, config.General), - result_format=config.General.OutputFormat, - ignore_cache=sample.IgnoreLocalCache, - ) - query.python_function = sample.Query.python_function - sample.Query = query - else: - logger.debug(f"Unknown Query type: {sample.Query}") - sample.Query.ignore_cache = sample.IgnoreLocalCache + query = sx.generic_query( + dataset_identifier=sample.dataset_identifier, + title=sample.Name, + codegen=get_codegen(sample, config.General), + result_format=config.General.OutputFormat, + ignore_cache=sample.IgnoreLocalCache, + query=sample.Query, + ) + logger.debug(f"Query string: {query.generate_selection_string()}") + query.ignore_cache = sample.IgnoreLocalCache - datasets.append(sample.Query) + datasets.append(query) return datasets @@ -275,94 +236,15 @@ def get_code_generators(self, backend=None): self.query_cache.update_codegen_by_backend(backend, code_generators) return code_generators - def func_adl_dataset( - self, - dataset_identifier: DID, - title: str = "ServiceX Client", - codegen: str = "uproot", - result_format: Optional[ResultFormat] = None, - item_type: Type[T] = Any, - ignore_cache: bool = False, - ) -> FuncADLQuery[T]: - r""" - Generate a dataset that can use func_adl query language - - :param dataset_identifier: The dataset identifier or filelist to be the source of files - :param title: Title to be applied to the transform. This is also useful for - relating transform results. - :param codegen: Name of the code generator to use with this transform - :param result_format: Do you want Paqrquet or Root? This can be set later with - the set_result_format method - :param item_type: The type of the items that will be returned from the query - :param ignore_cache: Ignore the query cache and always run the query - :return: A func_adl dataset ready to accept query statements. - """ - if codegen not in self.code_generators: - raise NameError( - f"{codegen} code generator not supported by serviceX " - f"deployment at {self.servicex.url}" - ) - - return FuncADLQuery( - dataset_identifier, - sx_adapter=self.servicex, - title=title, - codegen=codegen, - config=self.config, - query_cache=self.query_cache, - result_format=result_format, - item_type=item_type, - ignore_cache=ignore_cache, - ) - - def python_dataset( - self, - dataset_identifier: DID, - title: str = "ServiceX Client", - codegen: str = "uproot", - result_format: Optional[ResultFormat] = None, - ignore_cache: bool = False, - ) -> PythonQuery: - r""" - Generate a dataset that can use accept a python function for the query - - :param dataset_identifier: The dataset identifier or filelist to be the source of files - :param title: Title to be applied to the transform. This is also useful for - relating transform results. - :param codegen: Name of the code generator to use with this transform - :param result_format: Do you want Paqrquet or Root? This can be set later with - the set_result_format method - :param ignore_cache: Ignore the query cache and always run the query - :return: A func_adl dataset ready to accept a python function statements. - - """ - - if codegen not in self.code_generators: - raise NameError( - f"{codegen} code generator not supported by serviceX " - f"deployment at {self.servicex.url}" - ) - - return PythonQuery( - dataset_identifier, - sx_adapter=self.servicex, - title=title, - codegen=codegen, - config=self.config, - query_cache=self.query_cache, - result_format=result_format, - ignore_cache=ignore_cache, - ) - def generic_query( self, dataset_identifier: DID, query: Union[str, QueryStringGenerator], - codegen: str = None, + codegen: Optional[str] = None, title: str = "ServiceX Client", result_format: ResultFormat = ResultFormat.parquet, ignore_cache: bool = False, - ) -> GenericQuery: + ) -> Query: r""" Generate a Query object for a generic codegen specification @@ -378,6 +260,8 @@ def generic_query( """ if isinstance(query, str): + if codegen is None: + raise RuntimeError("A pure string query requires a codegen argument as well") query = GenericQueryStringGenerator(query, codegen) if not isinstance(query, QueryStringGenerator): raise ValueError("query argument must be string or QueryStringGenerator") @@ -394,7 +278,7 @@ def generic_query( f"deployment at {self.servicex.url}" ) - qobj = GenericQuery( + qobj = Query( dataset_identifier=dataset_identifier, sx_adapter=self.servicex, title=title, @@ -403,6 +287,6 @@ def generic_query( query_cache=self.query_cache, result_format=result_format, ignore_cache=ignore_cache, + query_string_generator=query ) - qobj.query_string_generator = query return qobj diff --git a/tests/conftest.py b/tests/conftest.py index 05b5a627..4e0d1e99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, IRIS-HEP +# Copyright (c) 2024, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -28,7 +28,8 @@ from datetime import datetime from pytest_asyncio import fixture -from servicex.python_dataset import PythonQuery +from servicex.python_dataset import PythonFunction +from servicex.query_core import Query from servicex.models import ( TransformRequest, ResultDestination, @@ -64,17 +65,20 @@ def minio_adapter() -> MinioAdapter: @fixture def python_dataset(dummy_parquet_file): did = FileListDataset(dummy_parquet_file) - dataset = PythonQuery( + dataset = Query( title="Test submission", dataset_identifier=did, codegen="uproot", - result_format=ResultFormat.parquet, # type: ignore + result_format=ResultFormat.parquet, + sx_adapter=None, # type: ignore + config=None, # type: ignore + query_cache=None # type: ignore ) # type: ignore def foo(): return - dataset.with_uproot_function(foo) + dataset.query_string_generator = PythonFunction(foo) return dataset diff --git a/tests/test_databinder.py b/tests/test_databinder.py index 698dee29..119e98da 100644 --- a/tests/test_databinder.py +++ b/tests/test_databinder.py @@ -218,7 +218,7 @@ def run_query(input_filenames=None): """) f.flush() result = _load_ServiceXSpec(path) - assert type(result.Sample[0].Query).__name__ == 'PythonQuery' + assert type(result.Sample[0].Query).__name__ == 'PythonFunction' assert type(result.Sample[1].Query).__name__ == 'FuncADLQuery_Uproot' assert type(result.Sample[2].Query).__name__ == 'UprootRawQuery' assert isinstance(result.Sample[3].dataset_identifier, Rucio) @@ -232,7 +232,7 @@ def run_query(input_filenames=None): # Path from string result2 = _load_ServiceXSpec(str(path)) - assert type(result2.Sample[0].Query).__name__ == 'PythonQuery' + assert type(result2.Sample[0].Query).__name__ == 'PythonFunction' # Python syntax error with open(path := (tmp_path / "python.yaml"), "w") as f: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 3b38eaad..b0631db2 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,3 +1,30 @@ +# Copyright (c) 2024, IRIS-HEP +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest import pandas as pd import tempfile @@ -7,7 +34,7 @@ from servicex.dataset_identifier import FileListDataset from servicex.configuration import Configuration from servicex.minio_adapter import MinioAdapter -from servicex.python_dataset import PythonQuery +from servicex.query_core import Query from servicex.query_cache import QueryCache from servicex.expandable_progress import ExpandableProgress from servicex.query_core import ServiceXException @@ -24,8 +51,9 @@ async def test_as_signed_urls_happy(transformed_result): # Test when display_progress is True and provided_progress is None did = FileListDataset("/foo/bar/baz.root") - dataset = PythonQuery(dataset_identifier=did, codegen="uproot", - sx_adapter=None, query_cache=None) + dataset = Query(dataset_identifier=did, codegen="uproot", + title="", config=None, + sx_adapter=None, query_cache=None) dataset.submit_and_download = AsyncMock() dataset.submit_and_download.return_value = transformed_result @@ -37,8 +65,9 @@ async def test_as_signed_urls_happy(transformed_result): async def test_as_signed_urls_happy_dataset_group(transformed_result): # Test when display_progress is True and provided_progress is None did = FileListDataset("/foo/bar/baz.root") - dataset = PythonQuery(dataset_identifier=did, codegen="uproot", - sx_adapter=None, query_cache=None) + dataset = Query(dataset_identifier=did, codegen="uproot", + title="", config=None, + sx_adapter=None, query_cache=None) dataset.submit_and_download = AsyncMock() dataset.submit_and_download.return_value = transformed_result @@ -50,8 +79,9 @@ async def test_as_signed_urls_happy_dataset_group(transformed_result): @pytest.mark.asyncio async def test_as_files_happy(transformed_result): did = FileListDataset("/foo/bar/baz.root") - dataset = PythonQuery(dataset_identifier=did, codegen="uproot", - sx_adapter=None, query_cache=None) + dataset = Query(dataset_identifier=did, codegen="uproot", + title="", config=None, + sx_adapter=None, query_cache=None) dataset.submit_and_download = AsyncMock() dataset.submit_and_download.return_value = transformed_result @@ -66,8 +96,9 @@ async def test_as_pandas_happy(transformed_result): with tempfile.TemporaryDirectory() as temp_dir: config = Configuration(cache_path=temp_dir, api_endpoints=[]) cache = QueryCache(config) - dataset = PythonQuery(dataset_identifier=did, codegen="uproot", sx_adapter=servicex, - query_cache=cache) + dataset = Query(dataset_identifier=did, codegen="uproot", sx_adapter=servicex, + title="", config=None, + query_cache=cache) dataset.submit_and_download = AsyncMock() dataset.submit_and_download.return_value = transformed_result result = dataset.as_pandas(display_progress=False) diff --git a/tests/test_func_adl_dataset.py b/tests/test_func_adl_dataset.py index 09c2a2a1..c3e3d153 100644 --- a/tests/test_func_adl_dataset.py +++ b/tests/test_func_adl_dataset.py @@ -1,11 +1,80 @@ -from servicex.func_adl.func_adl_dataset import FuncADLQuery_Uproot -from servicex.dataset_identifier import FileListDataset +# Copyright (c) 2024, IRIS-HEP +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from servicex.func_adl.func_adl_dataset import FuncADLQuery_Uproot, FuncADLQuery +from typing import Any def test_set_from_tree(): - did = FileListDataset("/foo/bar/baz.root") - query = FuncADLQuery_Uproot(dataset_identifier=did, - codegen="uproot") + query = FuncADLQuery_Uproot() query = query.FromTree("TREE_NAME") assert "TREE_NAME" in query.generate_selection_string() + + +def test_a_query(): + query = FuncADLQuery_Uproot() + query = query.FromTree("nominal") \ + .Select(lambda e: {"lep_pt": e["lep_pt"]}) + + assert (query.generate_selection_string() + == "(call Select (call EventDataset 'bogus.root' 'nominal') " + "(lambda (list e) (dict (list 'lep_pt') " + "(list (subscript e 'lep_pt')))))" + ) + + +def test_set_query(): + qastle = "(call Select (call EventDataset 'bogus.root' 'nominal') " \ + "(lambda (list e) (dict (list 'lep_pt') " \ + "(list (subscript e 'lep_pt')))))" + query = FuncADLQuery_Uproot() + query.set_provided_qastle(qastle) + + assert (query.generate_selection_string() == qastle) + + +def test_type(): + "Test that the item type for a dataset is correctly propagated" + + class my_type_info: + "typespec for possible event type" + + def fork_it_over(self) -> int: + ... + + datasource = FuncADLQuery[my_type_info]( + item_type=my_type_info + ) + + assert datasource.item_type == my_type_info + + +def test_type_any(): + "Test the type is any if no type is given" + datasource = FuncADLQuery() + assert datasource.item_type == Any diff --git a/tests/test_python_dataset.py b/tests/test_python_dataset.py index 1c2c541e..05181514 100644 --- a/tests/test_python_dataset.py +++ b/tests/test_python_dataset.py @@ -29,27 +29,18 @@ import pytest -from servicex.dataset_identifier import FileListDataset -from servicex.python_dataset import PythonQuery +from servicex.python_dataset import PythonFunction def test_no_provided_function(): - did = FileListDataset("/foo/bar/baz.root") - datasource = PythonQuery(dataset_identifier=did, - codegen="uproot", - sx_adapter=None, - query_cache=None) + datasource = PythonFunction() with pytest.raises(ValueError): print(datasource.generate_selection_string()) def test_generate_transform(): - did = FileListDataset("/foo/bar/baz.root") - datasource = PythonQuery(dataset_identifier=did, - codegen="uproot", - sx_adapter=None, - query_cache=None) + datasource = PythonFunction() def run_query(input_filenames=None): print("Greetings from your query") @@ -62,11 +53,7 @@ def run_query(input_filenames=None): def test_function_as_string(): - did = FileListDataset("/foo/bar/baz.root") - datasource = PythonQuery(dataset_identifier=did, - codegen="uproot", - sx_adapter=None, - query_cache=None) + datasource = PythonFunction() string_function = """ def run_query(input_filenames=None): diff --git a/tests/test_servicex_dataset.py b/tests/test_servicex_dataset.py index 8b8a3bd7..3b533e70 100644 --- a/tests/test_servicex_dataset.py +++ b/tests/test_servicex_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, IRIS-HEP +# Copyright (c) 2024, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import tempfile -from typing import Any, List +from typing import List from unittest.mock import AsyncMock from pathlib import PurePath @@ -35,11 +35,11 @@ from servicex.configuration import Configuration from servicex.dataset_identifier import FileListDataset from servicex.expandable_progress import ExpandableProgress -from servicex.func_adl.func_adl_dataset import FuncADLQuery +from servicex.func_adl.func_adl_dataset import FuncADLQuery_Uproot from servicex.models import (TransformStatus, Status, ResultFile, ResultFormat, TransformRequest, TransformedResults) from servicex.query_cache import QueryCache -from servicex.query_core import ServiceXException +from servicex.query_core import ServiceXException, Query from servicex.servicex_client import ServiceXClient from servicex.uproot_raw.uproot_raw import UprootRawQuery @@ -201,13 +201,15 @@ async def test_submit(mocker): mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath('.')) mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, query_cache=mock_cache, config=Configuration(api_endpoints=[]), ) + datasource.query_string_generator = FuncADLQuery_Uproot() with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet result = await datasource.submit_and_download(signed_urls_only=False, @@ -240,13 +242,15 @@ async def test_submit_partial_success(mocker): mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath('.')) mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, query_cache=mock_cache, config=Configuration(api_endpoints=[]), ) + datasource.query_string_generator = FuncADLQuery_Uproot() with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet result = await datasource.submit_and_download(signed_urls_only=False, @@ -279,13 +283,15 @@ async def test_use_of_cache(mocker): with tempfile.TemporaryDirectory() as temp_dir: config = Configuration(cache_path=temp_dir, api_endpoints=[]) cache = QueryCache(config) - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, query_cache=cache, config=config, ) + datasource.query_string_generator = FuncADLQuery_Uproot() datasource.result_format = ResultFormat.parquet upd = mocker.patch.object(cache, 'update_record', side_effect=cache.update_record) with ExpandableProgress(display_progress=False) as progress: @@ -332,13 +338,15 @@ async def test_submit_cancel(mocker): mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath('.')) mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, query_cache=mock_cache, config=Configuration(api_endpoints=[]), ) + datasource.query_string_generator = FuncADLQuery_Uproot() with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet with pytest.raises(ServiceXException): @@ -368,13 +376,15 @@ async def test_submit_fatal(mocker): mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath('.')) mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, query_cache=mock_cache, config=Configuration(api_endpoints=[]), ) + datasource.query_string_generator = FuncADLQuery_Uproot() with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet with pytest.raises(ServiceXException): @@ -469,42 +479,23 @@ def test_transform_request(): with tempfile.TemporaryDirectory() as temp_dir: config = Configuration(cache_path=temp_dir, api_endpoints=[]) cache = QueryCache(config) - datasource = FuncADLQuery( + datasource = Query( dataset_identifier=did, + title="ServiceX Client", codegen="uproot", sx_adapter=servicex, - config=config, - query_cache=cache, + query_cache=None, + config=Configuration(api_endpoints=[]), ) + datasource.query_string_generator = (FuncADLQuery_Uproot() + .FromTree("nominal") + .Select(lambda e: {"lep_pt": e["lep_pt"]})) q = ( - datasource.Select(lambda e: {"lep_pt": e["lep_pt"]}) - .set_result_format(ResultFormat.parquet) + datasource.set_result_format(ResultFormat.parquet) .transform_request ) - print("Qastle is ", q) + assert q.selection == "(call Select (call EventDataset 'bogus.root' 'nominal') " \ + "(lambda (list e) (dict (list 'lep_pt') " \ + "(list (subscript e 'lep_pt')))))" cache.close() - - -def test_type(): - "Test that the item type for a dataset is correctly propagated" - - class my_type_info: - "typespec for possible event type" - - def fork_it_over(self) -> int: - ... - - did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery[my_type_info]( - dataset_identifier=did, codegen="uproot", item_type=my_type_info - ) - - assert datasource.item_type == my_type_info - - -def test_type_any(): - "Test the type is any if no type is given" - did = FileListDataset("/foo/bar/baz.root") - datasource = FuncADLQuery(dataset_identifier=did, codegen="uproot") - assert datasource.item_type == Any