From 84db902f882377ea3fec9b189382ed84d6caf02b Mon Sep 17 00:00:00 2001 From: Paige Gulley Date: Tue, 9 Jul 2024 16:58:19 -0400 Subject: [PATCH 1/5] Big changes badum badum- extracted the elasticsearch interfaces into a wrapper and a query builder --- api.py | 358 +++------------------------------------- deploy.sh | 2 +- queries.py | 355 +++++++++++++++++++++++++++++++++++++++ test/__init__.py | 2 +- test/api_test.py | 6 +- test/create_fixtures.py | 7 +- utils.py | 15 ++ 7 files changed, 407 insertions(+), 338 deletions(-) create mode 100644 queries.py diff --git a/api.py b/api.py index 3f5442c..337b9da 100755 --- a/api.py +++ b/api.py @@ -17,34 +17,21 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration +from queries import EsClientWrapper from utils import ( assert_elasticsearch_connection, env_to_dict, + env_to_float, env_to_list, load_config, logger, ) -def getenv_float(name: str, defval: float | None) -> float | None: - """ - fetch environment variable with name `name` - if not set, return defval - if set to empty string, return None - else interpret as floating point number - """ - val = os.getenv(name) - if val is None: - return defval - if val == "": - return None - return float(val) - - if os.getenv("SENTRY_DSN"): sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), - traces_sample_rate=getenv_float("TRACING_SAMPLE_RATE", 1.0), - profiles_sample_rate=getenv_float("PROFILES_SAMPLE_RATE", 1.0), + traces_sample_rate=env_to_float("TRACING_SAMPLE_RATE", 1.0), + profiles_sample_rate=env_to_float("PROFILES_SAMPLE_RATE", 1.0), integrations=[ StarletteIntegration(transaction_style="url"), FastApiIntegration(transaction_style="url"), @@ -59,13 +46,10 @@ class ApiVersion(str, Enum): config = load_config() -config["termfields"] = env_to_list("TERMFIELDS") or config.get("termfields", []) -config["termaggrs"] = env_to_list("TERMAGGRS") or config.get("termaggrs", []) config["eshosts"] = env_to_list("ESHOSTS") or config.get( "eshosts", ["http://localhost:9200"] ) config["esopts"] = env_to_dict("ESOPTS") or config.get("esopts", {}) -config["maxpage"] = int(os.getenv("MAXPAGE", config.get("maxpage", 1000))) config["title"] = os.getenv("TITLE", config.get("title", "")) config["description"] = os.getenv("DESCRIPTION", config.get("description", "")) config["debug"] = str(os.getenv("DEBUG", config.get("debug", False))).lower() in ( @@ -74,46 +58,11 @@ class ApiVersion(str, Enum): "t", ) -ELASTICSEARCH_INDEX_NAME_PREFIX = os.getenv("ELASTICSEARCH_INDEX_NAME_PREFIX", "") - -ES = Elasticsearch(config["eshosts"], **config["esopts"]) +ES = EsClientWrapper(config["eshosts"], **config["esopts"]) -max_retries = 10 -retries = 0 -while not assert_elasticsearch_connection(ES): - retries += 1 - if retries < max_retries: - time.sleep(5) - logger.info(f"Connection to elasticsearch failed {retries} times, retrying") - else: - raise RuntimeError( - f"Elasticsearch connection failed {max_retries} times, giving up." - ) - - -def get_allowed_collections(): - # Only expose indexes with the correct prefix, and add a wildcard as well. - - all_indexes = [ - index - for index in ES.indices.get(index="*") - if index.startswith(ELASTICSEARCH_INDEX_NAME_PREFIX) - ] - for aliases in ES.indices.get_alias().values(): - # returns: {"index_name":{"aliases":{"alias_name":{"is_write_index":bool}}}} - for alias in aliases["aliases"].keys(): - if alias not in all_indexes: - all_indexes.append(alias) - all_indexes.append(f"{ELASTICSEARCH_INDEX_NAME_PREFIX}-*") - - logger.info(f"Exposed indices: {all_indexes}") - return all_indexes - - -# A little annoying- the list comprehension could be removed in py3.11 by using StrEnum, but this is backwards compatable. -Collection = Enum("Collection", [f"{kv}:{kv}".split(":")[:2] for kv in get_allowed_collections()]) # type: ignore [misc] -TermField = Enum("TermField", [f"{kv}:{kv}".split(":")[:2] for kv in config["termfields"]]) # type: ignore [misc] -TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in config["termaggrs"]]) # type: ignore [misc] +Collection = Enum("Collection", [f"{kv}:{kv}".split(":")[:2] for kv in ES.get_allowed_collections()]) # type: ignore [misc] +TermField = Enum("TermField", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMFIELDS")]) # type: ignore [misc] +TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMFIELDS")]) # type: ignore [misc] tags = [ @@ -134,7 +83,6 @@ def get_allowed_collections(): } ) - app = FastAPI( version=list(ApiVersion)[-1].value, docs_url=None, redoc_url=None, openapi_url=None ) @@ -163,10 +111,6 @@ async def add_api_version_header(req: Request, call_next): ) -VALID_SORT_ORDERS = ["asc", "desc"] -VALID_SORT_FIELDS = ["publication_date", "indexed_date"] - - class Query(BaseModel): q: str @@ -179,188 +123,6 @@ class PagedQuery(Query): page_size: Optional[int] = None -def encode(strng: str): - return base64.b64encode(strng.encode(), b"-_").decode().replace("=", "~") - - -def decode(strng: str): - return base64.b64decode(strng.replace("~", "=").encode(), b"-_").decode() - - -def cs_basic_query(q: str, expanded: bool = False) -> Dict: - default: dict = { - "_source": [ - "article_title", - "normalized_article_title", - "publication_date", - "indexed_date", - "language", - "full_language", - "canonical_domain", - "url", - "normalized_url", - "original_url", - ], - "query": { - "query_string": { - "default_field": "text_content", - "default_operator": "AND", - "query": q, - } - }, - } - if expanded: - default["_source"].extend(["text_content", "text_extraction"]) - return default - - -def cs_overview_query(q: str): - query = cs_basic_query(q) - query.update( - { - "aggregations": { - "daily": { - "date_histogram": { - "field": "publication_date", - "calendar_interval": "day", - "min_doc_count": 1, - } - }, - "lang": {"terms": {"field": "language.keyword", "size": 100}}, - "domain": {"terms": {"field": "canonical_domain", "size": 100}}, - "tld": {"terms": {"field": "tld", "size": 100}}, - }, - "track_total_hits": True, - } - ) - - return query - - -def cs_terms_query(q: str, field: TermField, aggr: TermAggr): - resct = 200 - aggr_map = { - "top": { - "terms": { - "field": field.name, - "size": resct, - "min_doc_count": 10, - "shard_min_doc_count": 5, - } - }, - "significant": { - "significant_terms": { - "field": field.name, - "size": resct, - "min_doc_count": 10, - "shard_min_doc_count": 5, - } - }, - "rare": {"rare_terms": {"field": field.name, "exclude": "[0-9].*"}}, - } - query = cs_basic_query(q) - query.update( - { - "track_total_hits": False, - "_source": False, - "aggregations": { - "sample": { - "sampler": {"shard_size": 10 if aggr.name == "rare" else 500}, - "aggregations": {"topterms": aggr_map[aggr.name]}, - } - }, - } - ) - return query - - -def _validate_sort_order(sort_order: Optional[str]): - if sort_order and sort_order not in VALID_SORT_ORDERS: - raise HTTPException( - status_code=400, - detail=f"Invalid sort order (must be on of {', '.join(VALID_SORT_ORDERS)})", - ) - return sort_order - - -def _validate_sort_field(sort_field: Optional[str]): - if sort_field and sort_field not in VALID_SORT_FIELDS: - raise HTTPException( - status_code=400, - detail=f"Invalid sort field (must be on of {', '.join(VALID_SORT_FIELDS)})", - ) - return sort_field - - -def _validate_page_size(page_size: Optional[int]): - if page_size and page_size < 1: - raise HTTPException( - status_code=400, detail="Invalid page size (must be greater than 0)" - ) - return page_size - - -def cs_paged_query( - q: str, - resume: Union[str, None], - expanded: bool, - sort_field=Optional[str], - sort_order=Optional[str], - page_size=Optional[int], -) -> Dict: - query = cs_basic_query(q, expanded) - final_sort_field = _validate_sort_field(sort_field or "publication_date") - final_sort_order = _validate_sort_order(sort_order or "desc") - query.update( - { - "size": _validate_page_size(page_size or config["maxpage"]), - "track_total_hits": False, - "sort": { - final_sort_field: { - "order": final_sort_order, - "format": "basic_date_time_no_millis", - } - }, - } - ) - if resume: - # important to use `search_after` instead of 'from' for memory reasons related to paging through more - # than 10k results - query["search_after"] = [decode(resume)] - return query - - -def format_match(hit: dict, base: str, collection: str, expanded: bool = False): - src = hit["_source"] - res = { - "article_title": src.get("article_title"), - "normalized_article_title": src.get("normalized_article_title"), - "publication_date": src.get("publication_date")[:10] - if src.get("publication_date") - else None, - "indexed_date": src.get("indexed_date"), - "language": src.get("language"), - "full_langauge": src.get("full_language"), - "url": src.get("url"), - "normalized_url": src.get("normalized_url"), - "original_url": src.get("original_url"), - "canonical_domain": src.get("canonical_domain"), - "id": urls.unique_url_hash(src.get("url")), - } - if expanded: - res["text_content"] = src.get("text_content") - res["text_extraction"] = src.get("text_extraction") - return res - - -def format_day_counts(bucket: list): - return {item["key_as_string"][:10]: item["doc_count"] for item in bucket} - - -def format_counts(bucket: list): - return {item["key"]: item["doc_count"] for item in bucket} - - def proxy_base_url(req: Request): return f'{str(os.getenv("PROXY_BASE", req.base_url)).rstrip("/")}/{req.scope.get("root_path").lstrip("/")}' # type: ignore [union-attr] @@ -451,34 +213,13 @@ def search_root(collection: Collection, req: Request): ) -def _search_overview(collection: Collection, q: str, req: Request): - res = ES.search(index=collection.name, body=cs_overview_query(q)) # type: ignore [call-arg] - - if not res["hits"]["hits"]: - raise HTTPException(status_code=404, detail="No results found!") - total = res["hits"]["total"]["value"] - tldsum = sum(item["doc_count"] for item in res["aggregations"]["tld"]["buckets"]) - base = proxy_base_url(req) - return { - "query": q, - "total": max(total, tldsum), - "topdomains": format_counts(res["aggregations"]["domain"]["buckets"]), - "toptlds": format_counts(res["aggregations"]["tld"]["buckets"]), - "toplangs": format_counts(res["aggregations"]["lang"]["buckets"]), - "dailycounts": format_day_counts(res["aggregations"]["daily"]["buckets"]), - "matches": [ - format_match(h, base, collection.name) for h in res["hits"]["hits"] - ], - } - - @v1.get("/{collection}/search/overview", tags=["data"]) @v1.head("/{collection}/search/overview", include_in_schema=False) def search_overview_via_query_params(collection: Collection, q: str, req: Request): """ Report overview summary of the search result """ - return _search_overview(collection, q, req) + return ES.search_overview(collection.name, q, req) @v1.post("/{collection}/search/overview", tags=["data"]) @@ -486,33 +227,7 @@ def search_overview_via_payload(collection: Collection, req: Request, payload: Q """ Report summary of the search result """ - return _search_overview(collection, payload.q, req) - - -def _search_result( - collection: Collection, - q: str, - req: Request, - resp: Response, - resume: Union[str, None] = None, - expanded: bool = False, - sort_field: Optional[str] = None, - sort_order: Optional[str] = None, - page_size: Optional[int] = None, -): - query = cs_paged_query(q, resume, expanded, sort_field, sort_order, page_size) - res = ES.search(index=collection.name, body=query) # type: ignore [call-arg] - if not res["hits"]["hits"]: - raise HTTPException(status_code=404, detail="No results found!") - base = proxy_base_url(req) - qurl = f"{base}/{collection.name}/search/result?q={quote_plus(q)}" - if len(res["hits"]["hits"]) == (page_size or config["maxpage"]): - resume_key = encode(str(res["hits"]["hits"][-1]["sort"][0])) - resp.headers["x-resume-token"] = resume_key - - return [ - format_match(h, base, collection.name, expanded) for h in res["hits"]["hits"] - ] + return ES.search_overview(collection.name, payload.q, req) @v1.get("/{collection}/search/result", tags=["data"]) @@ -531,8 +246,16 @@ def search_result_via_query_params( """ Paged response of search result """ - return _search_result( - collection, q, req, resp, resume, expanded, sort_field, sort_order, page_size + return ES.search_result( + collection.name, + q, + req, + resp, + resume, + expanded, + sort_field, + sort_order, + page_size, ) @@ -543,8 +266,8 @@ def search_result_via_payload( """ Paged response of search result """ - return _search_result( - collection, + return ES.search_result( + collection.name, payload.q, req, resp, @@ -563,17 +286,7 @@ def search_esdsl_via_payload(collection: Collection, payload: dict = Body(...)): """ Search using ES Query DSL as JSON payload """ - return ES.search(index=collection.name, body=payload) # type: ignore [call-arg] - - -def _get_terms(collection: Collection, q: str, field: TermField, aggr: TermAggr): - res = ES.search(index=collection.name, body=cs_terms_query(q, field, aggr)) # type: ignore [call-arg] - if ( - not res["hits"]["hits"] - or not res["aggregations"]["sample"]["topterms"]["buckets"] - ): - raise HTTPException(status_code=404, detail="No results found!") - return format_counts(res["aggregations"]["sample"]["topterms"]["buckets"]) + return ES.ES.search(index=collection.name, body=payload) # type: ignore [call-arg] @v1.get("/{collection}/terms", response_class=HTMLResponse, tags=["info"]) @@ -613,7 +326,7 @@ def get_terms_via_query_params( """ Top terms with frequencies in matching articles """ - return _get_terms(collection, q, field, aggr) + return ES.get_terms(collection.name, q, field.name, aggr.name) @v1.post("/{collection}/terms/{field}/{aggr}", tags=["data"]) @@ -626,7 +339,7 @@ def get_terms_via_payload( """ Top terms with frequencies in matching articles """ - return _get_terms(collection, payload.q, field, aggr) + return ES.get_terms(collection.name, payload.q, field.name, aggr.name) @v1.get("/{collection}/article/{id}", tags=["data"]) @@ -637,25 +350,8 @@ def get_article( """ Fetch an individual article record by ID. """ - source = {"includes": cs_basic_query(id, expanded=True)["_source"]} - query = {"match": {"_id": id}} - - try: - res = ES.search(index=collection.name, source=source, query=query) - hits = res["hits"]["hits"] - except (TransportError, TypeError, KeyError) as e: - raise HTTPException( - status_code=500, - detail=f"An error occured when searching for article with ID {id}", - ) from e - if len(hits) > 0: - hit = hits[0] - else: - raise HTTPException( - status_code=404, detail=f"An article with ID {id} not found!" - ) - base = proxy_base_url(req) - return format_match(hit, base, collection.name, True) # type: ignore[arg-type] + + return ES.get_article(collection.name, id, req) app.mount(f"/{ApiVersion.v1.name}", v1) diff --git a/deploy.sh b/deploy.sh index b9eef72..8620553 100755 --- a/deploy.sh +++ b/deploy.sh @@ -163,7 +163,7 @@ DOCKER_COMPOSE_FILE="docker-compose.yml" export ESOPTS='{"timeout": 60, "max_retries": 3}' # 'timeout' parameter is deprecated export TERMFIELDS="article_title,text_content" -export TERMAGGRS="top,significant,rare" +export TERMAGGRS="top" export ELASTICSEARCH_INDEX_NAME_PREFIX="mc_search" export API_PORT export API_REPLICAS diff --git a/queries.py b/queries.py new file mode 100644 index 0000000..5bc7c49 --- /dev/null +++ b/queries.py @@ -0,0 +1,355 @@ +import base64 +import os +import time +from enum import Enum +from typing import Dict, Optional, TypeAlias, Union + +import mcmetadata.urls as urls +from elasticsearch import Elasticsearch +from elasticsearch.exceptions import TransportError +from fastapi import HTTPException, Request, Response + +from utils import assert_elasticsearch_connection, env_to_list, logger + + +def decode(strng: str): + return base64.b64decode(strng.replace("~", "=").encode(), b"-_").decode() + + +def encode(strng: str): + return base64.b64encode(strng.encode(), b"-_").decode().replace("=", "~") + + +class QueryBuilder: + + """ + Utility Class to encapsulate the query construction logic for news-search-api + + """ + + def __init__(self, query_text): + self.query_text = query_text + self.VALID_SORT_ORDERS = ["asc", "desc"] + self.VALID_SORT_FIELDS = ["publication_date", "indexed_date"] + self._source = [ + "article_title", + "normalized_article_title", + "publication_date", + "indexed_date", + "language", + "full_language", + "canonical_domain", + "url", + "normalized_url", + "original_url", + ] + self._extended_source = [ + "article_title", + "normalized_article_title", + "publication_date", + "indexed_date", + "language", + "full_language", + "canonical_domain", + "url", + "normalized_url", + "original_url", + "text_content", + "text_extraction", + ] + + def _validate_sort_order(self, sort_order: Optional[str]): + if sort_order and sort_order not in self.VALID_SORT_ORDERS: + raise HTTPException( + status_code=400, + detail=f"Invalid sort order (must be on of {', '.join(self.VALID_SORT_ORDERS)})", + ) + return sort_order + + def _validate_sort_field(self, sort_field: Optional[str]): + if sort_field and sort_field not in self.VALID_SORT_FIELDS: + raise HTTPException( + status_code=400, + detail=f"Invalid sort field (must be on of {', '.join(self.VALID_SORT_FIELDS)})", + ) + return sort_field + + def _validate_page_size(self, page_size: Optional[int]): + if page_size and page_size < 1: + raise HTTPException( + status_code=400, detail="Invalid page size (must be greater than 0)" + ) + return page_size + + def basic_query(self, expanded: bool = False) -> Dict: + default: dict = { + "_source": self._extended_source if expanded else self._source, + "query": { + "query_string": { + "default_field": "text_content", + "default_operator": "AND", + "query": self.query_text, + } + }, + } + return default + + def overview_query(self): + query = self.basic_query() + query.update( + { + "aggregations": { + "daily": { + "date_histogram": { + "field": "publication_date", + "calendar_interval": "day", + "min_doc_count": 1, + } + }, + "lang": {"terms": {"field": "language.keyword", "size": 100}}, + "domain": {"terms": {"field": "canonical_domain", "size": 100}}, + "tld": {"terms": {"field": "tld", "size": 100}}, + }, + "track_total_hits": True, + } + ) + return query + + def terms_query(self, field): + resct = 200 + aggr_map = { + "terms": { + "field": field.name, + "size": resct, + "min_doc_count": 10, + "shard_min_doc_count": 5, + } + } + query = self.basic_query() + query.update( + { + "track_total_hits": False, + "_source": False, + "aggregations": { + "sample": { + "sampler": {"shard_size": 500}, + "aggregations": {"topterms": aggr_map}, + } + }, + } + ) + return query + + def paged_query( + self, + resume: Union[str, None], + expanded: bool, + sort_field=Optional[str], + sort_order=Optional[str], + page_size=Optional[int], + ) -> Dict: + query = self.basic_query(expanded) + final_sort_field = self._validate_sort_field(sort_field or "publication_date") + final_sort_order = self._validate_sort_order(sort_order or "desc") + query.update( + { + "size": self._validate_page_size(page_size or 1000), + "track_total_hits": False, + "sort": { + final_sort_field: { + "order": final_sort_order, + "format": "basic_date_time_no_millis", + } + }, + } + ) + if resume: + # important to use `search_after` instead of 'from' for memory reasons related to paging through more + # than 10k results + query["search_after"] = [decode(resume)] + return query + + def article_query(self): + default: dict = { + "_source": self._extended_source, + "query": {"match": {"_id": self.query_text}}, + } + + return default + + +class EsClientWrapper: + # A wrapper to actually make the calls to elasticsearch + def __init__(self, eshosts, **esopts): + self.ES = Elasticsearch(eshosts, **esopts) + self.maxpage = os.getenv("maxpage", 1000) + max_retries = 10 + retries = 0 + + while not assert_elasticsearch_connection(self.ES): + retries += 1 + if retries < max_retries: + time.sleep(5) + logger.info( + f"Connection to elasticsearch failed {retries} times, retrying" + ) + else: + raise RuntimeError( + f"Elasticsearch connection failed {max_retries} times, giving up." + ) + + self.index_name_prefix = os.getenv("ELASTICSEARCH_INDEX_NAME_PREFIX", "") + logger.info("Initialized ES client wrapper") + + def get_allowed_collections(self): + # Only expose indexes with the correct prefix, and add a wildcard as well. + + all_indexes = [ + index + for index in self.ES.indices.get(index="*") + if index.startswith(self.index_name_prefix) + ] + for aliases in self.ES.indices.get_alias().values(): + # returns: {"index_name":{"aliases":{"alias_name":{"is_write_index":bool}}}} + for alias in aliases["aliases"].keys(): + if alias not in all_indexes: + all_indexes.append(alias) + + all_indexes.append(f"{self.index_name_prefix}-*") + + logger.info(f"Exposed indices: {all_indexes}") + return all_indexes + + def format_match( + self, hit: dict, base: str, collection: str, expanded: bool = False + ): + src = hit["_source"] + res = { + "article_title": src.get("article_title"), + "normalized_article_title": src.get("normalized_article_title"), + "publication_date": src.get("publication_date")[:10] + if src.get("publication_date") + else None, + "indexed_date": src.get("indexed_date"), + "language": src.get("language"), + "full_langauge": src.get("full_language"), + "url": src.get("url"), + "normalized_url": src.get("normalized_url"), + "original_url": src.get("original_url"), + "canonical_domain": src.get("canonical_domain"), + "id": urls.unique_url_hash(src.get("url")), + } + if expanded: + res["text_content"] = src.get("text_content") + res["text_extraction"] = src.get("text_extraction") + return res + + def format_day_counts(self, bucket: list): + return {item["key_as_string"][:10]: item["doc_count"] for item in bucket} + + def format_counts(self, bucket: list): + return {item["key"]: item["doc_count"] for item in bucket} + + def proxy_base_url(self, req: Request): + return f'{str(os.getenv("PROXY_BASE", req.base_url)).rstrip("/")}/{req.scope.get("root_path").lstrip("/")}' # type: ignore [union-attr] + + def search_overview(self, collection: str, q: str, req: Request): + """ + Get overview statistics for a query + """ + res = self.ES.search(index=collection, body=QueryBuilder(q).overview_query()) # type: ignore [call-arg] + if not res["hits"]["hits"]: + raise HTTPException(status_code=404, detail="No results found!") + + total = res["hits"]["total"]["value"] + tldsum = sum( + item["doc_count"] for item in res["aggregations"]["tld"]["buckets"] + ) + base = self.proxy_base_url(req) + return { + "query": q, + "total": max(total, tldsum), + "topdomains": self.format_counts(res["aggregations"]["domain"]["buckets"]), + "toptlds": self.format_counts(res["aggregations"]["tld"]["buckets"]), + "toplangs": self.format_counts(res["aggregations"]["lang"]["buckets"]), + "dailycounts": self.format_day_counts( + res["aggregations"]["daily"]["buckets"] + ), + "matches": [ + self.format_match(h, base, collection) for h in res["hits"]["hits"] + ], + } + + def search_result( + self, + collection: str, + q: str, + req: Request, + resp: Response, + resume: Union[str, None] = None, + expanded: bool = False, + sort_field: Optional[str] = None, + sort_order: Optional[str] = None, + page_size: Optional[int] = None, + ): + """ + Get the search results for a query (including full text, if `expanded`) + """ + query = QueryBuilder(q).paged_query( + resume, expanded, sort_field, sort_order, page_size + ) + res = self.ES.search(index=collection, body=query) # type: ignore [call-arg] + base = self.proxy_base_url(req) + + if not res["hits"]["hits"]: + raise HTTPException(status_code=404, detail="No results found!") + + if len(res["hits"]["hits"]) == (page_size or self.maxpage): + resume_key = encode(str(res["hits"]["hits"][-1]["sort"][0])) + resp.headers["x-resume-token"] = resume_key + + return [ + self.format_match(h, base, collection, expanded) + for h in res["hits"]["hits"] + ] + + def get_terms( + self, + collection: str, + q: str, + field: str, + aggr: str, + ): + """ + Get top terms associated with a query + """ + res = self.ES.search(index=collection, body=QueryBuilder(q).terms_query(field)) # type: ignore [call-arg] + if ( + not res["hits"]["hits"] + or not res["aggregations"]["sample"]["topterms"]["buckets"] + ): + raise HTTPException(status_code=404, detail="No results found!") + return self.format_counts(res["aggregations"]["sample"]["topterms"]["buckets"]) + + def get_article(self, collection: str, id: str, req): + """ + Get an individual article by id. + """ + try: + res = self.ES.search( + index=collection, body=QueryBuilder(id).article_query() + ) + hits = res["hits"]["hits"] + except (TransportError, TypeError, KeyError) as e: + raise HTTPException( + status_code=500, + detail=f"An error occured when searching for article with ID {id}", + ) from e + if len(hits) > 0: + hit = hits[0] + else: + raise HTTPException( + status_code=404, detail=f"An article with ID {id} not found!" + ) + base = self.proxy_base_url(req) + return self.format_match(hit, base, collection, True) diff --git a/test/__init__.py b/test/__init__.py index a8f739f..415f5bd 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,6 +1,6 @@ import os -INDEX_NAME = "mediacloud_test" +INDEX_NAME = "mc_search" ELASTICSEARCH_URL = "http://localhost:9200" FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "fixtures") NUMBER_OF_TEST_STORIES = 5103 diff --git a/test/api_test.py b/test/api_test.py index e1e613f..e2b8557 100644 --- a/test/api_test.py +++ b/test/api_test.py @@ -12,10 +12,12 @@ # the `create_fixtures.py` script os.environ["INDEXES"] = INDEX_NAME os.environ["ESHOSTS"] = ELASTICSEARCH_URL -os.environ["ELASTICSEARCH_INDEX_NAME_PREFIX"] = "mediacloud" +os.environ["TERMFIELDS"] = "article_title,text_content" +os.environ["TERMAGGRS"] = "top" +os.environ["ELASTICSEARCH_INDEX_NAME_PREFIX"] = "mc_search" -TIMEOUT = 30 +TIMEOUT = 60 class ApiTest(TestCase): diff --git a/test/create_fixtures.py b/test/create_fixtures.py index 84dc9d8..69b546a 100644 --- a/test/create_fixtures.py +++ b/test/create_fixtures.py @@ -73,14 +73,15 @@ if (idx % 1000) != 0: fixture["publication_date"] = pub_date.isoformat() else: # make sure some have no publication date, and mark them for easy searching - fixture["publication_date"] = None + fixture["publication_date"] = None # type: ignore [assignment] fixture["article_title"] += " (no publication date)" fixture["text_content"] += " (no publication date)" fixture["normalized_article_title"] = titles.normalize_title( fixture["article_title"] ) - random_time_on_day = (dt.datetime(pub_date.year, pub_date.month, pub_date.day) + - dt.timedelta(minutes=randrange(1440))) + random_time_on_day = dt.datetime( + pub_date.year, pub_date.month, pub_date.day + ) + dt.timedelta(minutes=randrange(1440)) fixture["indexed_date"] = random_time_on_day.isoformat() unique_hash = urls.unique_url_hash(fixture["url"]) try: diff --git a/utils.py b/utils.py index adfda07..5d4de20 100644 --- a/utils.py +++ b/utils.py @@ -32,6 +32,21 @@ def list_to_enum(name: str, koptv: list): return Enum(name, [f"{kv}:{kv}".split(":")[:2] for kv in koptv]) +def env_to_float(name: str, defval: float | None) -> float | None: + """ + fetch environment variable with name `name` + if not set, return defval + if set to empty string, return None + else interpret as floating point number + """ + val = os.getenv(name) + if val is None: + return defval + if val == "": + return None + return float(val) + + def assert_elasticsearch_connection(es: Elasticsearch) -> bool: try: info = es.info() From fd21dc98ba1b1afe7aeb7ca27e1777c7f2562bd0 Mon Sep 17 00:00:00 2001 From: Paige Gulley Date: Thu, 11 Jul 2024 17:00:43 -0400 Subject: [PATCH 2/5] Addressing comments --- api.py | 9 +++------ queries.py => client.py | 35 ++++++++++++----------------------- 2 files changed, 15 insertions(+), 29 deletions(-) rename queries.py => client.py (93%) diff --git a/api.py b/api.py index 337b9da..c42a9b3 100755 --- a/api.py +++ b/api.py @@ -17,7 +17,7 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration -from queries import EsClientWrapper +from client import EsClientWrapper from utils import ( assert_elasticsearch_connection, env_to_dict, @@ -27,6 +27,7 @@ logger, ) +# Initialize our sentry integration if os.getenv("SENTRY_DSN"): sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), @@ -62,7 +63,7 @@ class ApiVersion(str, Enum): Collection = Enum("Collection", [f"{kv}:{kv}".split(":")[:2] for kv in ES.get_allowed_collections()]) # type: ignore [misc] TermField = Enum("TermField", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMFIELDS")]) # type: ignore [misc] -TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMFIELDS")]) # type: ignore [misc] +TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMAGGRS")]) # type: ignore [misc] tags = [ @@ -123,10 +124,6 @@ class PagedQuery(Query): page_size: Optional[int] = None -def proxy_base_url(req: Request): - return f'{str(os.getenv("PROXY_BASE", req.base_url)).rstrip("/")}/{req.scope.get("root_path").lstrip("/")}' # type: ignore [union-attr] - - @app.get("/", response_class=HTMLResponse) @app.head("/", response_class=HTMLResponse) def api_entrypoint(req: Request): diff --git a/queries.py b/client.py similarity index 93% rename from queries.py rename to client.py index 5bc7c49..ff5cda5 100644 --- a/queries.py +++ b/client.py @@ -12,11 +12,12 @@ from utils import assert_elasticsearch_connection, env_to_list, logger -def decode(strng: str): +# used to package paging keys for url transport +def decode_key(strng: str): return base64.b64decode(strng.replace("~", "=").encode(), b"-_").decode() -def encode(strng: str): +def encode_key(strng: str): return base64.b64encode(strng.encode(), b"-_").decode().replace("=", "~") @@ -27,10 +28,11 @@ class QueryBuilder: """ + VALID_SORT_ORDERS = ["asc", "desc"] + VALID_SORT_FIELDS = ["publication_date", "indexed_date"] + def __init__(self, query_text): self.query_text = query_text - self.VALID_SORT_ORDERS = ["asc", "desc"] - self.VALID_SORT_FIELDS = ["publication_date", "indexed_date"] self._source = [ "article_title", "normalized_article_title", @@ -43,20 +45,7 @@ def __init__(self, query_text): "normalized_url", "original_url", ] - self._extended_source = [ - "article_title", - "normalized_article_title", - "publication_date", - "indexed_date", - "language", - "full_language", - "canonical_domain", - "url", - "normalized_url", - "original_url", - "text_content", - "text_extraction", - ] + self._expanded_source = self._source.extend(["text_content", "text_extraction"]) def _validate_sort_order(self, sort_order: Optional[str]): if sort_order and sort_order not in self.VALID_SORT_ORDERS: @@ -83,7 +72,7 @@ def _validate_page_size(self, page_size: Optional[int]): def basic_query(self, expanded: bool = False) -> Dict: default: dict = { - "_source": self._extended_source if expanded else self._source, + "_source": self._expanded_source if expanded else self._source, "query": { "query_string": { "default_field": "text_content", @@ -166,12 +155,12 @@ def paged_query( if resume: # important to use `search_after` instead of 'from' for memory reasons related to paging through more # than 10k results - query["search_after"] = [decode(resume)] + query["search_after"] = [decode_key(resume)] return query def article_query(self): default: dict = { - "_source": self._extended_source, + "_source": self._expanded_source, "query": {"match": {"_id": self.query_text}}, } @@ -182,7 +171,7 @@ class EsClientWrapper: # A wrapper to actually make the calls to elasticsearch def __init__(self, eshosts, **esopts): self.ES = Elasticsearch(eshosts, **esopts) - self.maxpage = os.getenv("maxpage", 1000) + self.maxpage = os.getenv("MAXPAGE", 1000) max_retries = 10 retries = 0 @@ -305,7 +294,7 @@ def search_result( raise HTTPException(status_code=404, detail="No results found!") if len(res["hits"]["hits"]) == (page_size or self.maxpage): - resume_key = encode(str(res["hits"]["hits"][-1]["sort"][0])) + resume_key = encode_key(str(res["hits"]["hits"][-1]["sort"][0])) resp.headers["x-resume-token"] = resume_key return [ From af72455e70e4cf582b0488173dfd539358277e7b Mon Sep 17 00:00:00 2001 From: Paige Gulley Date: Thu, 11 Jul 2024 18:28:29 -0400 Subject: [PATCH 3/5] config overhaul- use pydantic_settings instead of manually shimming around os.get_env --- api.py | 83 +++++++++++++++++++++++++++--------------------- client.py | 38 +++++++++++----------- requirements.txt | 1 + utils.py | 41 ------------------------ 4 files changed, 65 insertions(+), 98 deletions(-) diff --git a/api.py b/api.py index c42a9b3..ae4671d 100755 --- a/api.py +++ b/api.py @@ -3,7 +3,7 @@ import os import time from enum import Enum -from typing import Dict, Optional, TypeAlias, Union +from typing import Dict, List, Optional, TypeAlias, Union from urllib.parse import quote_plus import mcmetadata.urls as urls @@ -13,26 +13,50 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, RedirectResponse -from pydantic import BaseModel +from pydantic import BaseModel, computed_field +from pydantic_settings import BaseSettings from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration from client import EsClientWrapper -from utils import ( - assert_elasticsearch_connection, - env_to_dict, - env_to_float, - env_to_list, - load_config, - logger, -) +from utils import assert_elasticsearch_connection, logger + + +class Config(BaseSettings): + eshosts: str = "http://localhost:9200" + termfields: str = "article_title,text_content" + termaggrs: str = "top" + esopts: Dict = {} + title: str = "Interactive API" + description: str = "A wrapper API for ES indexes." + debug: bool = False + sentry_dsn: str = "" + tracing_sample_rate: float = 1.0 + profiles_sample_rate: float = 1.0 + root_path: str = "" + + @computed_field() + def eshosts_list(self) -> List[str]: + return self.eshosts.split(",") + + @computed_field() + def termfields_list(self) -> List[str]: + return self.termfields.split(",") + + @computed_field() + def termaggrs_list(self) -> List[str]: + return self.termaggrs.split(",") + + +config = Config() +logger.info(f"Loaded config: {config}") # Initialize our sentry integration -if os.getenv("SENTRY_DSN"): +if config.sentry_dsn: sentry_sdk.init( - dsn=os.getenv("SENTRY_DSN"), - traces_sample_rate=env_to_float("TRACING_SAMPLE_RATE", 1.0), - profiles_sample_rate=env_to_float("PROFILES_SAMPLE_RATE", 1.0), + dsn=config.sentry_dsn, + traces_sample_rate=config.tracing_sample_rate, + profiles_sample_rate=config.profiles_sample_rate, integrations=[ StarletteIntegration(transaction_style="url"), FastApiIntegration(transaction_style="url"), @@ -46,24 +70,11 @@ class ApiVersion(str, Enum): v1 = "1.3.5" -config = load_config() -config["eshosts"] = env_to_list("ESHOSTS") or config.get( - "eshosts", ["http://localhost:9200"] -) -config["esopts"] = env_to_dict("ESOPTS") or config.get("esopts", {}) -config["title"] = os.getenv("TITLE", config.get("title", "")) -config["description"] = os.getenv("DESCRIPTION", config.get("description", "")) -config["debug"] = str(os.getenv("DEBUG", config.get("debug", False))).lower() in ( - "true", - "1", - "t", -) - -ES = EsClientWrapper(config["eshosts"], **config["esopts"]) +ES = EsClientWrapper(config.eshosts_list, **config.esopts) Collection = Enum("Collection", [f"{kv}:{kv}".split(":")[:2] for kv in ES.get_allowed_collections()]) # type: ignore [misc] -TermField = Enum("TermField", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMFIELDS")]) # type: ignore [misc] -TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in env_to_list("TERMAGGRS")]) # type: ignore [misc] +TermField = Enum("TermField", [f"{kv}:{kv}".split(":")[:2] for kv in config.termfields_list]) # type: ignore [misc] +TermAggr = Enum("TermAggr", [f"{kv}:{kv}".split(":")[:2] for kv in config.termaggrs_list]) # type: ignore [misc] tags = [ @@ -76,7 +87,7 @@ class ApiVersion(str, Enum): "description": "Data endpoints with machine-readable responses to interact with the collection indexes.", }, ] -if config["debug"]: +if config.debug: tags.append( { "name": "debug", @@ -105,8 +116,8 @@ async def add_api_version_header(req: Request, call_next): v1 = FastAPI( - title=config.get("title", "Interactive API") + " Docs", - description=config.get("description", "A wrapper API for ES indexes."), + title=config.title + " Docs", + description=config.description, version=ApiVersion.v1.value, openapi_tags=tags, ) @@ -276,7 +287,7 @@ def search_result_via_payload( ) -if config["debug"]: +if config.debug: @v1.post("/{collection}/search/esdsl", tags=["debug"]) def search_esdsl_via_payload(collection: Collection, payload: dict = Body(...)): @@ -357,6 +368,4 @@ def get_article( if __name__ == "__main__": import uvicorn - uvicorn.run( - "api:app", host="0.0.0.0", reload=True, root_path=os.getenv("ROOT_PATH", "") - ) + uvicorn.run("api:app", host="0.0.0.0", reload=True, root_path=config.root_path) diff --git a/client.py b/client.py index ff5cda5..8ecd26f 100644 --- a/client.py +++ b/client.py @@ -8,8 +8,18 @@ from elasticsearch import Elasticsearch from elasticsearch.exceptions import TransportError from fastapi import HTTPException, Request, Response +from pydantic import BaseModel, computed_field +from pydantic_settings import BaseSettings -from utils import assert_elasticsearch_connection, env_to_list, logger +from utils import assert_elasticsearch_connection, logger + + +class ClientConfig(BaseSettings): + maxpage: int = 1000 + elasticsearch_index_name_prefix: str = "" + + +client_config = ClientConfig() # used to package paging keys for url transport @@ -171,7 +181,7 @@ class EsClientWrapper: # A wrapper to actually make the calls to elasticsearch def __init__(self, eshosts, **esopts): self.ES = Elasticsearch(eshosts, **esopts) - self.maxpage = os.getenv("MAXPAGE", 1000) + self.maxpage = client_config.maxpage max_retries = 10 retries = 0 @@ -187,7 +197,7 @@ def __init__(self, eshosts, **esopts): f"Elasticsearch connection failed {max_retries} times, giving up." ) - self.index_name_prefix = os.getenv("ELASTICSEARCH_INDEX_NAME_PREFIX", "") + self.index_name_prefix = client_config.elasticsearch_index_name_prefix logger.info("Initialized ES client wrapper") def get_allowed_collections(self): @@ -209,9 +219,7 @@ def get_allowed_collections(self): logger.info(f"Exposed indices: {all_indexes}") return all_indexes - def format_match( - self, hit: dict, base: str, collection: str, expanded: bool = False - ): + def format_match(self, hit: dict, collection: str, expanded: bool = False): src = hit["_source"] res = { "article_title": src.get("article_title"), @@ -239,9 +247,6 @@ def format_day_counts(self, bucket: list): def format_counts(self, bucket: list): return {item["key"]: item["doc_count"] for item in bucket} - def proxy_base_url(self, req: Request): - return f'{str(os.getenv("PROXY_BASE", req.base_url)).rstrip("/")}/{req.scope.get("root_path").lstrip("/")}' # type: ignore [union-attr] - def search_overview(self, collection: str, q: str, req: Request): """ Get overview statistics for a query @@ -254,7 +259,6 @@ def search_overview(self, collection: str, q: str, req: Request): tldsum = sum( item["doc_count"] for item in res["aggregations"]["tld"]["buckets"] ) - base = self.proxy_base_url(req) return { "query": q, "total": max(total, tldsum), @@ -264,9 +268,7 @@ def search_overview(self, collection: str, q: str, req: Request): "dailycounts": self.format_day_counts( res["aggregations"]["daily"]["buckets"] ), - "matches": [ - self.format_match(h, base, collection) for h in res["hits"]["hits"] - ], + "matches": [self.format_match(h, collection) for h in res["hits"]["hits"]], } def search_result( @@ -288,7 +290,6 @@ def search_result( resume, expanded, sort_field, sort_order, page_size ) res = self.ES.search(index=collection, body=query) # type: ignore [call-arg] - base = self.proxy_base_url(req) if not res["hits"]["hits"]: raise HTTPException(status_code=404, detail="No results found!") @@ -297,10 +298,7 @@ def search_result( resume_key = encode_key(str(res["hits"]["hits"][-1]["sort"][0])) resp.headers["x-resume-token"] = resume_key - return [ - self.format_match(h, base, collection, expanded) - for h in res["hits"]["hits"] - ] + return [self.format_match(h, collection, expanded) for h in res["hits"]["hits"]] def get_terms( self, @@ -340,5 +338,5 @@ def get_article(self, collection: str, id: str, req): raise HTTPException( status_code=404, detail=f"An article with ID {id} not found!" ) - base = self.proxy_base_url(req) - return self.format_match(hit, base, collection, True) + + return self.format_match(hit, collection, True) diff --git a/requirements.txt b/requirements.txt index 4edba4c..31afc74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ matplotlib==3.8.* mediacloud-metadata==0.11.* pandas==2.2.* pydantic==2.5.* +pydantic_settings==2.* requests streamlit==1.30.* uvicorn[standard] diff --git a/utils.py b/utils.py index 5d4de20..32d32d3 100644 --- a/utils.py +++ b/utils.py @@ -1,52 +1,11 @@ -import json import logging -import os -from enum import Enum -from typing import TypeAlias -import yaml from elasticsearch import Elasticsearch logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def load_config(): - conf = os.getenv("CONFIG", "config.yml") - try: - return yaml.safe_load(open(conf, encoding="UTF-8")) - except OSError: - return {} - - -def env_to_list(name: str): - return " ".join(os.getenv(name, "").split(",")).split() - - -def env_to_dict(name: str): - return json.loads(os.getenv(name, "{}")) - - -def list_to_enum(name: str, koptv: list): - # Just use StrEnum? py3.11 feature- let's attempt. - return Enum(name, [f"{kv}:{kv}".split(":")[:2] for kv in koptv]) - - -def env_to_float(name: str, defval: float | None) -> float | None: - """ - fetch environment variable with name `name` - if not set, return defval - if set to empty string, return None - else interpret as floating point number - """ - val = os.getenv(name) - if val is None: - return defval - if val == "": - return None - return float(val) - - def assert_elasticsearch_connection(es: Elasticsearch) -> bool: try: info = es.info() From f5bddc9056a81f693060f44ec916daa8c67b0560 Mon Sep 17 00:00:00 2001 From: Paige Gulley Date: Thu, 11 Jul 2024 18:57:13 -0400 Subject: [PATCH 4/5] fully decoupled the api and the es client- no resp/res into the client --- api.py | 19 ++++++++++--------- client.py | 15 ++++++++------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/api.py b/api.py index ae4671d..b62a74e 100755 --- a/api.py +++ b/api.py @@ -227,7 +227,7 @@ def search_overview_via_query_params(collection: Collection, q: str, req: Reques """ Report overview summary of the search result """ - return ES.search_overview(collection.name, q, req) + return ES.search_overview(collection.name, q) @v1.post("/{collection}/search/overview", tags=["data"]) @@ -235,7 +235,7 @@ def search_overview_via_payload(collection: Collection, req: Request, payload: Q """ Report summary of the search result """ - return ES.search_overview(collection.name, payload.q, req) + return ES.search_overview(collection.name, payload.q) @v1.get("/{collection}/search/result", tags=["data"]) @@ -254,17 +254,18 @@ def search_result_via_query_params( """ Paged response of search result """ - return ES.search_result( + + result, resume_key = ES.search_result( collection.name, q, - req, - resp, resume, expanded, sort_field, sort_order, page_size, ) + resp.headers["x-resume-token"] = resume_key + return result @v1.post("/{collection}/search/result", tags=["data"]) @@ -274,17 +275,17 @@ def search_result_via_payload( """ Paged response of search result """ - return ES.search_result( + result, resume_key = ES.search_result( collection.name, payload.q, - req, - resp, payload.resume, payload.expanded, payload.sort_field, payload.sort_order, payload.page_size, ) + resp.headers["x-resume-token"] = resume_key + return result if config.debug: @@ -359,7 +360,7 @@ def get_article( Fetch an individual article record by ID. """ - return ES.get_article(collection.name, id, req) + return ES.get_article(collection.name, id) app.mount(f"/{ApiVersion.v1.name}", v1) diff --git a/client.py b/client.py index 8ecd26f..2bca16f 100644 --- a/client.py +++ b/client.py @@ -7,13 +7,14 @@ import mcmetadata.urls as urls from elasticsearch import Elasticsearch from elasticsearch.exceptions import TransportError -from fastapi import HTTPException, Request, Response +from fastapi import HTTPException from pydantic import BaseModel, computed_field from pydantic_settings import BaseSettings from utils import assert_elasticsearch_connection, logger +# Loads values from the environment class ClientConfig(BaseSettings): maxpage: int = 1000 elasticsearch_index_name_prefix: str = "" @@ -247,7 +248,7 @@ def format_day_counts(self, bucket: list): def format_counts(self, bucket: list): return {item["key"]: item["doc_count"] for item in bucket} - def search_overview(self, collection: str, q: str, req: Request): + def search_overview(self, collection: str, q: str): """ Get overview statistics for a query """ @@ -275,8 +276,6 @@ def search_result( self, collection: str, q: str, - req: Request, - resp: Response, resume: Union[str, None] = None, expanded: bool = False, sort_field: Optional[str] = None, @@ -294,11 +293,13 @@ def search_result( if not res["hits"]["hits"]: raise HTTPException(status_code=404, detail="No results found!") + resume_key = None if len(res["hits"]["hits"]) == (page_size or self.maxpage): resume_key = encode_key(str(res["hits"]["hits"][-1]["sort"][0])) - resp.headers["x-resume-token"] = resume_key - return [self.format_match(h, collection, expanded) for h in res["hits"]["hits"]] + return [ + self.format_match(h, collection, expanded) for h in res["hits"]["hits"] + ], resume_key def get_terms( self, @@ -318,7 +319,7 @@ def get_terms( raise HTTPException(status_code=404, detail="No results found!") return self.format_counts(res["aggregations"]["sample"]["topterms"]["buckets"]) - def get_article(self, collection: str, id: str, req): + def get_article(self, collection: str, id: str): """ Get an individual article by id. """ From 04c3a837abb0055cd3bf943212fc77045ffe08fc Mon Sep 17 00:00:00 2001 From: Paige Gulley Date: Thu, 11 Jul 2024 19:02:18 -0400 Subject: [PATCH 5/5] missed this condition locally- was failing when resume_key was none --- api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api.py b/api.py index b62a74e..6fed705 100755 --- a/api.py +++ b/api.py @@ -264,7 +264,8 @@ def search_result_via_query_params( sort_order, page_size, ) - resp.headers["x-resume-token"] = resume_key + if resume_key: + resp.headers["x-resume-token"] = resume_key return result @@ -284,7 +285,8 @@ def search_result_via_payload( payload.sort_order, payload.page_size, ) - resp.headers["x-resume-token"] = resume_key + if resume_key: + resp.headers["x-resume-token"] = resume_key return result