diff --git a/src/sewerrat/query.py b/src/sewerrat/query.py index 717ebc3..610e76e 100644 --- a/src/sewerrat/query.py +++ b/src/sewerrat/query.py @@ -1,10 +1,19 @@ -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Literal import requests +import warnings from . import _utils as ut -def query(url, text: Optional[str] = None, user: Optional[str] = None, path: Optional[str] = None, after: Optional[int] = None, before: Optional[int] = None, number: int = 100) -> List[Dict]: +def query( + url: str, + text: Optional[str] = None, + user: Optional[str] = None, + path: Optional[str] = None, + after: Optional[int] = None, + before: Optional[int] = None, + number: int = 100, + on_truncation: Literal["message", "warning", "none"] = "message") -> List[Dict]: """ Query the metadata in the SewerRat backend based on free text, the owner, creation time, etc. This function does not require filesystem access. @@ -38,6 +47,10 @@ def query(url, text: Optional[str] = None, user: Optional[str] = None, path: Opt number: Integer specifying the maximum number of results to return. + + on_truncation: + String specifying the action to take when the number of search + results is capped by ``number``. Returns: List of dictionaries where each inner dictionary corresponds to a @@ -72,7 +85,11 @@ def query(url, text: Optional[str] = None, user: Optional[str] = None, path: Opt else: raise ValueError("at least one search filter must be present") - stub = "/query?translate=true&limit=" + str(number) + if on_truncation != "none": + original_number = number + number += 1 + + stub = "/query?translate=true" collected = [] while len(collected) < number: @@ -86,4 +103,13 @@ def query(url, text: Optional[str] = None, user: Optional[str] = None, path: Opt break stub = payload["next"] + if on_truncation != "none": + if len(collected) > original_number: + msg = "truncated query results to the first " + str(original_number) + " matches" + if on_truncation == "warning": + warnings.warn(msg) + else: + print(msg) + collected = collected[:original_number] + return collected diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..20e825d --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,30 @@ +import sewerrat +import os +import tempfile +import time +import pytest + + +def test_query_truncation(capfd): + mydir = tempfile.mkdtemp() + with open(os.path.join(mydir, "metadata.json"), "w") as handle: + handle.write('{ "first": "Aaron", "last": "Lun" }') + + os.mkdir(os.path.join(mydir, "diet")) + with open(os.path.join(mydir, "diet", "metadata.json"), "w") as handle: + handle.write('{ "meal": "lunch", "ingredients": "water" }') + + _, url = sewerrat.start_sewerrat() + sewerrat.register(mydir, ["metadata.json"], url=url) + + res = sewerrat.query(url, "lun", number=0) + out, err = capfd.readouterr() + assert "truncated" in out + assert len(res) == 0 + + with pytest.warns(UserWarning, match="truncated"): + res = sewerrat.query(url, "lun", number=0, on_truncation="warning") + assert len(res) == 0 + + res = sewerrat.query(url, "lun", number=0, on_truncation="none") + assert len(res) == 0