Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve json serialization to accomodate numpy float32 #3028

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
18 changes: 15 additions & 3 deletions nvflare/app_common/widgets/validation_json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import json
import os.path
from functools import singledispatch

import numpy as np

from nvflare.apis.dxo import DataKind, from_shareable, get_leaf_dxos
from nvflare.apis.event_type import EventType
Expand All @@ -23,6 +26,17 @@
from nvflare.widgets.widget import Widget


@singledispatch
def to_serializable(val):
"""Default json serializable method."""
return str(val)


@to_serializable.register(np.float32)
def ts_float32(val):
return np.float64(val)


class ValidationJsonGenerator(Widget):
def __init__(self, results_dir=AppConstants.CROSS_VAL_DIR, json_file_name="cross_val_results.json"):
"""Catches VALIDATION_RESULT_RECEIVED event and generates a results.json containing accuracy of each
Expand Down Expand Up @@ -58,7 +72,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
if val_results:
try:
dxo = from_shareable(val_results)
dxo.validate()
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved

if dxo.data_kind == DataKind.METRICS:
if data_client not in self._val_results:
Expand All @@ -71,7 +84,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
for err in errors:
self.log_error(fl_ctx, f"Bad result from {data_client}: {err}")
for _sub_data_client, _dxo in leaf_dxos.items():
_dxo.validate()
if _sub_data_client not in self._val_results:
self._val_results[_sub_data_client] = {}
self._val_results[_sub_data_client][model_owner] = _dxo.data
Expand All @@ -93,4 +105,4 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):

res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
with open(res_file_path, "w") as f:
json.dump(self._val_results, f)
json.dump(self._val_results, f, default=to_serializable)
Loading