Skip to content

Commit

Permalink
Refactor query class inheritance (#410)
Browse files Browse the repository at this point in the history
* Fix examples

* Do not subclass Query directly

* Fix version issue

* Remove unneeded code

* No need for overriding deepcopy any more

* Explicit dependence for typing_extensions

* Eliminate now invalid union

---------

Co-authored-by: KyungEon Choi <[email protected]>
  • Loading branch information
ponyisi and kyungeonchoi authored Jul 18, 2024
1 parent 8817a5a commit 4ec0929
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 314 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"PyYAML>=6.0",
"types-PyYAML>=6.0",
"importlib_metadata; python_version <= '3.9'",
"typing_extensions; python_version <= '3.10'", # compatible versions controlled through pydantic
"rich>=13.0.0", # databinder
"aiofile", # compatible versions controlled through miniopy-async
"make-it-sync", # compatible versions controlled through func_adl
Expand Down Expand Up @@ -98,7 +99,7 @@ FuncADL_ATLASr21 = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASr21"
FuncADL_ATLASr22 = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASr22"
FuncADL_ATLASxAOD = "servicex.func_adl.func_adl_dataset:FuncADLQuery_ATLASxAOD"
FuncADL_CMS = "servicex.func_adl.func_adl_dataset:FuncADLQuery_CMS"
PythonFunction = "servicex.python_dataset:PythonQuery"
PythonFunction = "servicex.python_dataset:PythonFunction"
UprootRaw = "servicex.uproot_raw.uproot_raw:UprootRawQuery"

[project.entry-points.'servicex.dataset']
Expand Down
4 changes: 2 additions & 2 deletions servicex/databinder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from servicex.dataset_identifier import (DataSetIdentifier, RucioDatasetIdentifier,
FileListDataset)
from servicex.query_core import Query as SXQuery, QueryStringGenerator
from servicex.query_core import QueryStringGenerator
from servicex.models import ResultFormat


Expand All @@ -46,7 +46,7 @@ class Sample(BaseModel):
XRootDFiles: Optional[Union[str, List[str]]] = None
Dataset: Optional[DataSetIdentifier] = None
NFiles: Optional[int] = Field(default=None)
Query: Optional[Union[str, SXQuery, QueryStringGenerator]] = Field(default=None)
Query: Optional[Union[str, QueryStringGenerator]] = Field(default=None)
IgnoreLocalCache: bool = False

model_config = {"arbitrary_types_allowed": True}
Expand Down
6 changes: 2 additions & 4 deletions servicex/dataset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
from typing import List, Optional, Union

from typing import List, Optional
from rich.progress import Progress

from servicex.query_core import Query
from servicex.expandable_progress import ExpandableProgress
from servicex.func_adl.func_adl_dataset import FuncADLQuery
from servicex.models import TransformedResults, ResultFormat
from make_it_sync import make_sync


DatasetGroupMember = Union[Query, FuncADLQuery]
DatasetGroupMember = Query


class DatasetGroup:
Expand Down
48 changes: 2 additions & 46 deletions servicex/func_adl/func_adl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,14 @@

from func_adl import EventDataset, find_EventDataset
from func_adl.object_stream import S
from servicex.configuration import Configuration
from servicex.query_core import Query
from servicex.query_core import QueryStringGenerator
from servicex.func_adl.util import has_tuple
from servicex.models import ResultFormat
from servicex.query_cache import QueryCache
from servicex.servicex_adapter import ServiceXAdapter
from servicex.types import DID
from abc import ABC

T = TypeVar("T")


class FuncADLQuery(Query, EventDataset[T], ABC):
class FuncADLQuery(QueryStringGenerator, EventDataset[T], ABC):
r"""
ServiceX Dataset class that uses func_adl query syntax.
"""
Expand All @@ -75,53 +70,14 @@ def check_data_format_request(self, f_name: str):

def __init__(
self,
dataset_identifier: DID = None,
sx_adapter: ServiceXAdapter = None,
title: str = "ServiceX Client",
codegen: Optional[str] = None,
config: Configuration = None,
query_cache: QueryCache = None,
result_format: Optional[ResultFormat] = None,
item_type: Type = Any,
ignore_cache: bool = False,
):
Query.__init__(
self,
dataset_identifier=dataset_identifier,
title=title,
codegen=codegen if codegen is not None else self.default_codegen,
sx_adapter=sx_adapter,
config=config,
query_cache=query_cache,
result_format=result_format,
ignore_cache=ignore_cache,
)
EventDataset.__init__(self, item_type=item_type)
self.provided_qastle = None

def set_provided_qastle(self, qastle: str):
self.provided_qastle = qastle

def __deepcopy__(self, memo):
"""
Customize deepcopy behavior for this class.
We need to be careful because the query cache is a tinyDB database that holds an
open file pointer. We are not allowed to clone an open file handle, so for this
property we will copy by reference and share it between the clones
"""
cls = self.__class__
obj = cls.__new__(cls)

memo[id(self)] = obj

for attr, value in vars(self).items():
if type(value) is QueryCache:
setattr(obj, attr, value)
else:
setattr(obj, attr, copy.deepcopy(value, memo))

return obj

def SelectMany(
self, func: Union[str, ast.Lambda, Callable[[T], Iterable[S]]]
) -> FuncADLQuery[S]:
Expand Down
42 changes: 12 additions & 30 deletions servicex/python_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,22 @@
from typing import Optional, Union, Callable
from base64 import b64encode
from textwrap import dedent
from servicex.query_core import QueryStringGenerator
import sys
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self

from servicex.configuration import Configuration
from servicex.query_core import Query
from servicex.models import ResultFormat
from servicex.query_cache import QueryCache
from servicex.servicex_adapter import ServiceXAdapter
from servicex.types import DID


class PythonQuery(Query):
class PythonFunction(QueryStringGenerator):
yaml_tag = '!PythonFunction'
default_codegen = 'python'

def __init__(self, dataset_identifier: DID = None,
sx_adapter: Optional[ServiceXAdapter] = None,
title: str = "ServiceX Client",
codegen: str = 'python',
config: Optional[Configuration] = None,
query_cache: Optional[QueryCache] = None,
result_format: Optional[ResultFormat] = None,
ignore_cache: bool = False
):
super().__init__(dataset_identifier=dataset_identifier,
title=title,
codegen=codegen,
sx_adapter=sx_adapter,
config=config,
query_cache=query_cache,
result_format=result_format,
ignore_cache=ignore_cache)

self.python_function = None
def __init__(self, python_function: Optional[Union[str, Callable]] = None):
self.python_function: Optional[Union[str, Callable]] = python_function

def with_uproot_function(self, f: Union[str, Callable]) -> Query:
def with_uproot_function(self, f: Union[str, Callable]) -> Self:
self.python_function = f
return self

Expand All @@ -83,6 +66,5 @@ def from_yaml(cls, _, node):
exec(code)
except SyntaxError as e:
raise SyntaxError(f"Syntax error {e} interpreting\n{code}")
q = PythonQuery()
q.with_uproot_function(code)
q = PythonFunction(code)
return q
27 changes: 8 additions & 19 deletions servicex/query_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, IRIS-HEP
# Copyright (c) 2024, IRIS-HEP
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -72,19 +72,20 @@ class ServiceXException(Exception):
""" Something happened while trying to carry out a ServiceX request """


class Query(ABC):
class Query:
def __init__(
self,
dataset_identifier: DID,
title: str,
codegen: str,
sx_adapter: ServiceXAdapter,
config: Configuration,
query_cache: QueryCache,
query_cache: Optional[QueryCache],
servicex_polling_interval: int = 5,
minio_polling_interval: int = 5,
result_format: ResultFormat = ResultFormat.parquet,
ignore_cache: bool = False,
query_string_generator: Optional[QueryStringGenerator] = None,
):
r"""
This is the main class for constructing transform requests and receiving the
Expand All @@ -103,7 +104,6 @@ def __init__(
:param result_format:
:param ignore_cache: If true, ignore the cache and always submit a new transform
"""
super(Query, self).__init__()
self.servicex = sx_adapter
self.configuration = config
self.cache = query_cache
Expand All @@ -123,14 +123,16 @@ def __init__(

self.request_id = None
self.ignore_cache = ignore_cache
self.query_string_generator = query_string_generator

# Number of seconds in between ServiceX status polls
self.servicex_polling_interval = servicex_polling_interval
self.minio_polling_interval = minio_polling_interval

@abc.abstractmethod
def generate_selection_string(self) -> str:
pass
if self.query_string_generator is None:
raise RuntimeError('query string generator not set')
return self.query_string_generator.generate_selection_string()

@property
def transform_request(self):
Expand Down Expand Up @@ -627,16 +629,3 @@ def __init__(self, query: str, codegen: str):

def generate_selection_string(self) -> str:
return self.query


class GenericQuery(Query):
'''
This class gives a "generic" Query object which doesn't require
overloading the constructor
'''
query_string_generator: Optional[QueryStringGenerator] = None

def generate_selection_string(self) -> str:
if self.query_string_generator is None:
raise RuntimeError('query string generator not set')
return self.query_string_generator.generate_selection_string()
Loading

0 comments on commit 4ec0929

Please sign in to comment.