-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #195 from EMMC-ASBL/filter-strategy
Added new filter strategy
- Loading branch information
Showing
6 changed files
with
259 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,155 @@ | ||
"""Trivial filter that adds an empty collection to the session.""" | ||
"""Filter that removes all but specified instances in the collection.""" | ||
# pylint: disable=unused-argument | ||
from typing import TYPE_CHECKING, Any, Dict | ||
import re | ||
from typing import TYPE_CHECKING | ||
|
||
import dlite | ||
from oteapi.datacache import DataCache | ||
from oteapi.models import FilterConfig | ||
from dlite.utils import get_referred_instances | ||
from oteapi.models import AttrDict, FilterConfig | ||
from pydantic import Field | ||
from pydantic.dataclasses import dataclass | ||
|
||
from oteapi_dlite.models import DLiteSessionUpdate | ||
from oteapi_dlite.utils import get_collection, update_collection | ||
|
||
if TYPE_CHECKING: | ||
from typing import Optional | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
@dataclass | ||
class CreateCollectionStrategy: | ||
"""Trivial filter that adds an empty collection to the session. | ||
class DLiteQueryConfig(AttrDict): | ||
"""Configuration for the DLite filter strategy. | ||
**Registers strategies**: | ||
First the `remove_label` and `remove_datamodel` configurations are | ||
used to mark matching instances for removal. If neither | ||
`remove_label` or `remove_datamodel` are given, all instances are | ||
marked for removal. | ||
Then instances matching `keep_label` and `keep_datamodel` are unmarked | ||
for removal. | ||
- `("filterType", "dlite/create-collection")` | ||
If `keep_referred` is true, any instance that is referred to by | ||
an instance not marked for removal is also unmarked for removal. | ||
Finally, the instances that are still marked for removal are removed | ||
from the collection. | ||
""" | ||
|
||
filter_config: FilterConfig | ||
remove_label: str = Field( | ||
None, description="Regular expression matching labels to remove." | ||
) | ||
remove_datamodel: str = Field( | ||
None, | ||
description="Regular expression matching datamodel URIs to remove.", | ||
) | ||
keep_label: str = Field( | ||
None, | ||
description="Regular expression matching labels to keep." | ||
"This configuration overrides `remove_label` and `remove_datamodel`. " | ||
"Alias for the FilterStrategy `query` configuration, that is " | ||
"inherited from the oteapi-core Filter data model.", | ||
) | ||
keep_datamodel: str = Field( | ||
None, | ||
description="Regular expression matching datamodel URIs to keep in " | ||
"collection. " | ||
"This configuration overrides `remove_label` and `remove_datamodel`.", | ||
) | ||
keep_referred: bool = Field( | ||
True, | ||
description="Whether to keep all instances in the collection that are " | ||
"directly or indirectly referred to (via ref-types or collections) " | ||
"by kept instances.", | ||
) | ||
|
||
|
||
# Find a better way to keep collections alive!!! | ||
# Need to be `Any`, because otherwise `pydantic` complains. | ||
collection_refs: Dict[str, Any] = Field( | ||
{}, | ||
description="A dictionary of DLite Collections.", | ||
class DLiteFilterConfig(FilterConfig): | ||
"""DLite generate strategy config.""" | ||
|
||
configuration: DLiteQueryConfig = Field( | ||
..., description="DLite filter strategy-specific configuration." | ||
) | ||
|
||
def initialize( | ||
self, session: "Optional[Dict[str, Any]]" = None | ||
) -> DLiteSessionUpdate: | ||
"""Initialize.""" | ||
if session is None: | ||
raise ValueError("Missing session") | ||
if "collection_id" in session: | ||
raise KeyError("`collection_id` already exists in session.") | ||
|
||
coll = dlite.Collection() | ||
@dataclass | ||
class DLiteFilterStrategy: | ||
"""Filter that removes all but specified instances in the collection. | ||
The `query` configuration should be a regular expression matching labels | ||
to keep in the collection. All other labels will be removed. | ||
# Make sure that collection stays alive | ||
# It will never be deallocated... | ||
coll._incref() # pylint: disable=protected-access | ||
**Registers strategies**: | ||
# Store the collection in the data cache | ||
cache = DataCache() | ||
cache.add(value=coll.asjson(), key=coll.uuid) | ||
- `("filterType", "dlite/filter")` | ||
return DLiteSessionUpdate(collection_id=coll.uuid) | ||
""" | ||
|
||
filter_config: DLiteFilterConfig | ||
|
||
def initialize( | ||
self, | ||
session: "Optional[Dict[str, Any]]" = None, | ||
) -> DLiteSessionUpdate: | ||
"""Initialize.""" | ||
return DLiteSessionUpdate(collection_id=get_collection(session).uuid) | ||
|
||
def get( | ||
self, session: "Optional[Dict[str, Any]]" = None | ||
) -> DLiteSessionUpdate: | ||
"""Execute the strategy.""" | ||
if session is None: | ||
raise ValueError("Missing session") | ||
return DLiteSessionUpdate(collection_id=session["collection_id"]) | ||
# pylint: disable=too-many-branches | ||
config = self.filter_config.configuration | ||
|
||
# Alias for query configuration | ||
keep_label = ( | ||
config.keep_label if config.keep_label else self.filter_config.query | ||
) | ||
|
||
instdict = {} # Map instance labels to [uuid, metaURI] | ||
coll = get_collection(session) | ||
for s, _, o in coll.get_relations(p="_has-uuid"): | ||
instdict[s] = [o] | ||
for s, _, o in coll.get_relations(p="_has-meta"): | ||
instdict[s].append(o) | ||
|
||
removal = set() # Labels marked for removal | ||
|
||
# 1: remove_label, remove_datamodel | ||
if config.remove_label or config.remove_datamodel: | ||
for label, (uuid, metauri) in instdict.items(): | ||
if config.remove_label and re.match(config.remove_label, label): | ||
removal.add(label) | ||
|
||
if config.remove_datamodel and re.match( | ||
config.remove_datamodel, metauri | ||
): | ||
removal.add(label) | ||
else: | ||
removal.update(instdict.keys()) | ||
|
||
# 2: keep_label, keep_datamodel | ||
for label in set(removal): | ||
if keep_label and re.match(keep_label, label): | ||
removal.remove(label) | ||
|
||
uuid, metauri = instdict[label] | ||
if config.keep_datamodel and re.match( | ||
config.keep_datamodel, metauri | ||
): | ||
removal.remove(label) | ||
|
||
# 3: keep_referred | ||
if config.keep_referred: | ||
labels = {uuid: label for label, (uuid, _) in instdict.items()} | ||
kept = set(instdict.keys()).difference(removal) | ||
for label in kept: | ||
removal.difference_update( | ||
labels[inst.uuid] | ||
for inst in get_referred_instances(coll.get(label)) | ||
if inst.uuid in labels | ||
) | ||
|
||
# 4: remove from collection | ||
for label in removal: | ||
coll.remove(label) | ||
|
||
update_collection(coll) | ||
return DLiteSessionUpdate(collection_id=get_collection(session).uuid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
DLite-Python>=0.3.3,<1.0 | ||
DLite-Python>=0.4.5,<1.0 | ||
numpy>=1.21,<2 | ||
oteapi-core>=0.1.2,<0.6.0 | ||
Pillow>=9.0.1,<11 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
invoke~=2.2 | ||
otelib>=0.3,<0.4.0 | ||
pre-commit~=3.5 | ||
pylint~=3.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,137 @@ | ||
"""Tests filter strategies.""" | ||
from typing import TYPE_CHECKING | ||
from pathlib import Path | ||
|
||
if TYPE_CHECKING: | ||
from oteapi.interfaces import IFilterStrategy | ||
import dlite | ||
|
||
from oteapi_dlite.strategies.filter import ( | ||
DLiteFilterConfig, | ||
DLiteFilterStrategy, | ||
) | ||
from oteapi_dlite.utils import get_meta | ||
|
||
def test_create_collection() -> None: | ||
"""Test the create_collection filter.""" | ||
import dlite | ||
thisdir = Path(__file__).resolve().parent | ||
entitydir = thisdir / ".." / "entities" | ||
outdir = thisdir / ".." / "output" | ||
|
||
from oteapi_dlite.strategies.filter import CreateCollectionStrategy | ||
Image = get_meta("http://onto-ns.com/meta/1.0/Image") | ||
image1 = Image([2, 2, 1]) | ||
image2 = Image([2, 2, 1]) | ||
image3 = Image([2, 2, 1]) | ||
image4 = Image([2, 2, 1]) | ||
innercoll = dlite.Collection() | ||
innercoll.add("im1", image1) | ||
innercoll.add("im2", image2) | ||
|
||
config = {"filterType": "dlite/create_collection"} | ||
coll = dlite.Collection() | ||
coll.add("innercoll", innercoll) | ||
coll.add("image1", image1) | ||
coll.add("image2", image2) | ||
coll.add("image3", image3) | ||
coll.add("image4", image4) | ||
|
||
session = {} | ||
|
||
collfilter: "IFilterStrategy" = CreateCollectionStrategy(config) | ||
session.update(collfilter.initialize(session)) | ||
# Test simple use of query | ||
# Here keeping all instances with label containing "im" in the collection | ||
config = DLiteFilterConfig( | ||
filterType="dlite/filter", | ||
query="^im", | ||
configuration={}, | ||
) | ||
coll0 = coll.copy() | ||
session = {"collection_id": coll0.uuid} | ||
|
||
assert "collection_id" in session | ||
coll_id = session["collection_id"] | ||
coll = dlite.get_instance(coll_id) | ||
assert isinstance(coll, dlite.Collection) | ||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.initialize(session)) | ||
|
||
collfilter = CreateCollectionStrategy(config) | ||
session.update(collfilter.get(session)) | ||
assert "collection_id" in session | ||
assert session["collection_id"] == coll_id | ||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.get(session)) | ||
|
||
assert set(coll0.get_labels()) == set( | ||
[ | ||
"image1", | ||
"image2", | ||
"image3", | ||
"image4", | ||
] | ||
) | ||
|
||
|
||
# Same test as above, but use use `keep_label` instead of `query` | ||
config = DLiteFilterConfig( | ||
filterType="dlite/filter", | ||
configuration={ | ||
"keep_label": "^im", | ||
}, | ||
) | ||
coll1 = coll.copy() | ||
session = {"collection_id": coll1.uuid} | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.initialize(session)) | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.get(session)) | ||
|
||
assert set(coll1.get_labels()) == set( | ||
[ | ||
"image1", | ||
"image2", | ||
"image3", | ||
"image4", | ||
] | ||
) | ||
|
||
|
||
# Test combining remove and keep | ||
config = DLiteFilterConfig( | ||
filterType="dlite/filter", | ||
configuration={ | ||
"remove_datamodel": Image.uri, | ||
"keep_label": "(image2)|(image4)", | ||
"keep_referred": False, | ||
}, | ||
) | ||
coll2 = coll.copy() | ||
session = {"collection_id": coll2.uuid} | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.initialize(session)) | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.get(session)) | ||
|
||
assert set(coll2.get_labels()) == set( | ||
[ | ||
"innercoll", | ||
"image2", | ||
"image4", | ||
] | ||
) | ||
|
||
|
||
# Test with keep_referred=True | ||
config = DLiteFilterConfig( | ||
filterType="dlite/filter", | ||
configuration={ | ||
"remove_datamodel": Image.uri, | ||
"keep_label": "(image2)|(image4)", | ||
"keep_referred": True, | ||
}, | ||
) | ||
coll3 = coll.copy() | ||
session = {"collection_id": coll3.uuid} | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.initialize(session)) | ||
|
||
strategy = DLiteFilterStrategy(config) | ||
session.update(strategy.get(session)) | ||
|
||
assert set(coll3.get_labels()) == set( | ||
[ | ||
"innercoll", | ||
"image1", | ||
"image2", | ||
"image4", | ||
] | ||
) |