Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SystemModel: update_overall_statistics() cache invalidation and refactoring #491

Merged
merged 2 commits into from
Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions backend/src/impl/internal_models/system_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from explainaboard import get_processor
from explainaboard.loaders.file_loader import FileLoaderReturn
from explainaboard.metrics.metric import MetricConfig
from explainaboard.serialization.legacy import general_to_dict
from explainaboard_web.impl.db_utils.dataset_db_utils import DatasetDBUtils
from explainaboard_web.impl.db_utils.db_utils import DBUtils
Expand Down Expand Up @@ -63,14 +64,28 @@ def from_dict(cls, dikt: dict) -> SystemModel:

return super().from_dict(document)

def _get_private_properties(self) -> dict:
def _get_private_properties(self, session: ClientSession | None = None) -> dict:
"""Retrieves privates properties of the system. These properties are meant
for internal use only.
TODO(lyuyang): store this in memory

Args:
session: A mongodb session. Private properties are stored in the DB so
we need to query the DB to retrieve this data. If multiple DB operations
needs to be performed in one session, the same session should be used
to query private properties.
TODO(lyuyang): cache this data in memory. Even if it is cached in
memory, session is still required in situations where we need to refresh
the cache.

Raises:
ValueError: The system cannot be found in the DB. This method should not be
called on a system that hasn't been created or has been deleted.
"""
sys_doc = DBUtils.find_one_by_id(DBUtils.DEV_SYSTEM_METADATA, self.system_id)
sys_doc = DBUtils.find_one_by_id(
DBUtils.DEV_SYSTEM_METADATA, self.system_id, session=session
)
if not sys_doc:
abort_with_error_message(404, f"system id: {self.system_id} not found")
raise ValueError(f"system {self.system_id} does not exist in the DB")
return sys_doc

def get_system_info(self) -> SystemInfo:
Expand Down Expand Up @@ -113,7 +128,12 @@ def save_to_db(self, session: ClientSession | None = None) -> None:
def save_system_output(
self, system_output: FileLoaderReturn, session: ClientSession | None = None
):
"""TODO(lyuyang): should delete stale data from storage"""
"""Saves `system_output` to storage. If `system_output` has been saved
previously, it is replaced with the new one."""
properties = self._get_private_properties(session=session)
if properties.get("system_output"):
# delete previously saved system_output
get_storage().delete([properties["system_output"]])
sample_list = [general_to_dict(v) for v in system_output.samples]
blob_name = f"{self.system_id}/{self._SYSTEM_OUTPUT_CONST}"
get_storage().compress_and_upload(
Expand All @@ -134,31 +154,20 @@ def update_overall_statistics(
session: ClientSession | None = None,
force_update=False,
) -> None:
"""regenerates overall statistics and updates cache
TODO(lyuyang) This method is not complete. It should only be called once because
it does not remove system cases from cloud storage properly."""
"""regenerates overall statistics and updates cache"""
properties = self._get_private_properties(session=session)
if not force_update:
sys_doc = DBUtils.find_one_by_id(
DBUtils.DEV_SYSTEM_METADATA, self.system_id, session=session
)
if not sys_doc:
raise ValueError(f"system {self.system_id} hasn't been created")
if "system_info" in sys_doc and "metric_stats" in sys_doc:
if "system_info" in properties and "metric_stats" in properties:
# cache hit
return
sys_doc = DBUtils.find_one_by_id(
DBUtils.DEV_SYSTEM_METADATA, self.system_id, session=session
)
if sys_doc.get("system_info"):
raise ValueError("update_overall_statistics can only be called once")

def _process():
processor = get_processor(self.task)
metrics_lookup = {
metric.name: metric
for metric in get_processor(self.task).full_metric_list()
}
metric_configs = []
metric_configs: list[MetricConfig] = []
for metric_name in metadata.metric_names:
if metric_name not in metrics_lookup:
abort_with_error_message(
Expand All @@ -168,11 +177,16 @@ def _process():
custom_features = system_output_data.metadata.custom_features or {}
custom_features.update(self.get_dataset_custom_features())
processor_metadata = {
**metadata.to_dict(),
# system properties
"system_name": self.system_name,
"source_language": self.source_language,
"target_language": self.target_language,
"dataset_name": self.dataset.dataset_name if self.dataset else None,
"sub_dataset_name": self.dataset.sub_dataset if self.dataset else None,
"dataset_split": metadata.dataset_split,
"task_name": metadata.task,
"dataset_split": self.dataset.split if self.dataset else None,
"task_name": self.task,
"system_details": self.system_details,
# processor parameters
"metric_configs": metric_configs,
"custom_features": custom_features,
"custom_analyses": system_output_data.metadata.custom_analyses or [],
Expand Down Expand Up @@ -226,7 +240,9 @@ def update_analysis_cases():
analysis_cases_lookup[analysis_level.name] = blob_name
return analysis_cases_lookup

# Insert system output and analysis cases
if properties.get("analysis_cases"):
# invalidate cache
get_storage().delete(properties["analysis_cases"].values())

DBUtils.update_one_by_id(
DBUtils.DEV_SYSTEM_METADATA,
Expand Down
3 changes: 2 additions & 1 deletion backend/src/impl/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import zlib
from collections.abc import Iterable

from flask import current_app, g
from google.cloud import storage as cloud_storage
Expand Down Expand Up @@ -49,7 +50,7 @@ def download(self, blob_name: str) -> bytes:
def download_and_decompress(self, blob_name: str) -> str:
return zlib.decompress(self.download(blob_name)).decode()

def delete(self, blob_names: list[str]) -> None:
def delete(self, blob_names: Iterable[str]) -> None:
self._bucket.delete_blobs([self._bucket.blob(name) for name in blob_names])


Expand Down