From 88c94191c7edfaf151142c347aa5ec87a9e3c517 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Thu, 1 Sep 2022 23:04:56 -0400 Subject: [PATCH] Merged FOBS from 2.1 to dev (#818) --- .../cifar10/pt/utils/cifar10_data_splitter.py | 2 +- .../df_stats/custom/df_statistics.py | 9 +- .../df_stats/data_utils.py | 2 +- .../app/config/config_fed_server.json | 2 +- .../app/custom/tf2_model_persistor.py | 24 +-- examples/hello-cyclic/app/custom/trainer.py | 7 +- .../app/custom/monai_trainer.py | 4 +- .../hello-monai/app/custom/monai_trainer.py | 4 +- .../app/config/config_fed_server.json | 2 +- examples/hello-tf2/app/custom/filter.py | 1 + .../app/custom/tf2_model_persistor.py | 24 +-- examples/hello-tf2/app/custom/trainer.py | 7 +- nvflare/apis/dxo.py | 8 +- nvflare/apis/fl_constant.py | 1 + nvflare/apis/impl/job_def_manager.py | 12 +- nvflare/apis/shareable.py | 9 +- nvflare/apis/utils/decomposers/__init__.py | 13 ++ .../utils/decomposers/flare_decomposers.py | 168 ++++++++++++++++ nvflare/apis/utils/fl_context_utils.py | 4 +- nvflare/app_common/abstract/learnable.py | 9 +- nvflare/app_common/decomposers/__init__.py | 13 ++ .../decomposers/common_decomposers.py | 151 +++++++++++++++ nvflare/app_common/np/np_model_locator.py | 2 +- .../storage_state_persistor.py | 10 +- nvflare/fuel/utils/fobs/README.rst | 178 +++++++++++++++++ nvflare/fuel/utils/fobs/__init__.py | 29 +++ nvflare/fuel/utils/fobs/decomposer.py | 60 ++++++ .../fuel/utils/fobs/decomposers/__init__.py | 13 ++ .../fobs/decomposers/core_decomposers.py | 67 +++++++ nvflare/fuel/utils/fobs/fobs.py | 182 ++++++++++++++++++ nvflare/lighter/impl/master_template.yml | 2 +- nvflare/lighter/impl/signature.py | 2 +- nvflare/private/fed/client/fed_client.py | 8 +- nvflare/private/fed/client/process_aux_cmd.py | 7 +- nvflare/private/fed/client/scheduler_cmds.py | 12 +- .../fed/server/collective_command_agent.py | 4 +- nvflare/private/fed/server/fed_server.py | 29 +-- .../fed/server/server_command_agent.py | 4 +- nvflare/private/fed/server/server_commands.py | 17 +- nvflare/private/fed/server/server_engine.py | 16 +- nvflare/private/fed/utils/fed_utils.py | 4 +- requirements-min.txt | 2 + setup.py | 2 +- .../apps/cyclic/config/config_fed_server.json | 2 +- .../apps/tf/config/config_fed_server.json | 2 +- tests/integration_test/tf2/model_persistor.py | 16 +- .../validators/tf_model_validator.py | 11 +- .../apis/utils/decomposers/__init__.py | 13 ++ .../decomposers/flare_decomposers_test.py | 43 +++++ tests/unit_test/fuel/utils/fobs/__init__.py | 13 ++ .../fuel/utils/fobs/decomposer_test.py | 30 +++ tests/unit_test/fuel/utils/fobs/fobs_test.py | 78 ++++++++ .../app/simulator/simulator_runner_test.py | 2 +- 53 files changed, 1197 insertions(+), 139 deletions(-) create mode 100644 nvflare/apis/utils/decomposers/__init__.py create mode 100644 nvflare/apis/utils/decomposers/flare_decomposers.py create mode 100644 nvflare/app_common/decomposers/__init__.py create mode 100644 nvflare/app_common/decomposers/common_decomposers.py create mode 100644 nvflare/fuel/utils/fobs/README.rst create mode 100644 nvflare/fuel/utils/fobs/__init__.py create mode 100644 nvflare/fuel/utils/fobs/decomposer.py create mode 100644 nvflare/fuel/utils/fobs/decomposers/__init__.py create mode 100644 nvflare/fuel/utils/fobs/decomposers/core_decomposers.py create mode 100644 nvflare/fuel/utils/fobs/fobs.py create mode 100644 tests/unit_test/apis/utils/decomposers/__init__.py create mode 100644 tests/unit_test/apis/utils/decomposers/flare_decomposers_test.py create mode 100644 tests/unit_test/fuel/utils/fobs/__init__.py create mode 100644 tests/unit_test/fuel/utils/fobs/decomposer_test.py create mode 100644 tests/unit_test/fuel/utils/fobs/fobs_test.py diff --git a/examples/cifar10/pt/utils/cifar10_data_splitter.py b/examples/cifar10/pt/utils/cifar10_data_splitter.py index 448d8b3950..c5f0322371 100644 --- a/examples/cifar10/pt/utils/cifar10_data_splitter.py +++ b/examples/cifar10/pt/utils/cifar10_data_splitter.py @@ -44,9 +44,9 @@ import numpy as np import torchvision.datasets as datasets +from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.apis.event_type import EventType CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments diff --git a/examples/federated_statistics/df_stats/custom/df_statistics.py b/examples/federated_statistics/df_stats/custom/df_statistics.py index d45f792fc4..dfe38f612c 100644 --- a/examples/federated_statistics/df_stats/custom/df_statistics.py +++ b/examples/federated_statistics/df_stats/custom/df_statistics.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, List +from typing import Dict, List, Optional import numpy as np import pandas as pd +from load_data_utils import get_app_paths, load_config from pandas.core.series import Series -from load_data_utils import get_app_paths, load_config from nvflare.apis.fl_constant import ReservedKey from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.statistics_spec import Statistics,BinRange, Histogram, HistogramType, Feature -from nvflare.app_common.statistics.numpy_utils import get_std_histogram_buckets -from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type +from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics +from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets class DFStatistics(Statistics): diff --git a/examples/federated_statistics/df_stats/data_utils.py b/examples/federated_statistics/df_stats/data_utils.py index 73c33502ed..4fb0d23e93 100644 --- a/examples/federated_statistics/df_stats/data_utils.py +++ b/examples/federated_statistics/df_stats/data_utils.py @@ -6,7 +6,7 @@ import wget from pyhocon import ConfigFactory -from nvflare.lighter.poc_commands import is_poc_ready, get_nvflare_home +from nvflare.lighter.poc_commands import get_nvflare_home, is_poc_ready def get_poc_workspace(): diff --git a/examples/hello-cyclic/app/config/config_fed_server.json b/examples/hello-cyclic/app/config/config_fed_server.json index bbb3985e69..2349d457e6 100755 --- a/examples/hello-cyclic/app/config/config_fed_server.json +++ b/examples/hello-cyclic/app/config/config_fed_server.json @@ -10,7 +10,7 @@ "id": "persistor", "path": "tf2_model_persistor.TF2ModelPersistor", "args": { - "save_name": "tf2weights.pickle" + "save_name": "tf2weights.fobs" } }, { diff --git a/examples/hello-cyclic/app/custom/tf2_model_persistor.py b/examples/hello-cyclic/app/custom/tf2_model_persistor.py index 0771932792..d9e49f73e1 100644 --- a/examples/hello-cyclic/app/custom/tf2_model_persistor.py +++ b/examples/hello-cyclic/app/custom/tf2_model_persistor.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import pickle import json +import os import tensorflow as tf +from tf2_net import Net + from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.model import ModelLearnable +from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable from nvflare.app_common.abstract.model_persistor import ModelPersistor -from tf2_net import Net from nvflare.app_common.app_constant import AppConstants -from nvflare.app_common.abstract.model import make_model_learnable +from nvflare.fuel.utils import fobs class TF2ModelPersistor(ModelPersistor): - def __init__(self, save_name="tf2_model.pkl"): + def __init__(self, save_name="tf2_model.fobs"): super().__init__() self.save_name = save_name @@ -65,7 +65,7 @@ def _initialize(self, fl_ctx: FLContext): self.log_dir = os.path.join(app_root, log_dir) else: self.log_dir = app_root - self._pkl_save_path = os.path.join(self.log_dir, self.save_name) + self._fobs_save_path = os.path.join(self.log_dir, self.save_name) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) @@ -82,10 +82,10 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable: Model object """ - if os.path.exists(self._pkl_save_path): + if os.path.exists(self._fobs_save_path): self.logger.info("Loading server weights") - with open(self._pkl_save_path, "rb") as f: - model_learnable = pickle.load(f) + with open(self._fobs_save_path, "rb") as f: + model_learnable = fobs.load(f) else: self.logger.info("Initializing server model") network = Net() @@ -109,5 +109,5 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext): """ model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()} self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}") - with open(self._pkl_save_path, "wb") as f: - pickle.dump(model_learnable, f) + with open(self._fobs_save_path, "wb") as f: + fobs.dump(model_learnable, f) diff --git a/examples/hello-cyclic/app/custom/trainer.py b/examples/hello-cyclic/app/custom/trainer.py index fa96b31c8a..06b4deb248 100644 --- a/examples/hello-cyclic/app/custom/trainer.py +++ b/examples/hello-cyclic/app/custom/trainer.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf import numpy as np +import tensorflow as tf +from tf2_net import Net from nvflare.apis.dxo import DXO, DataKind, from_shareable -from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal -from tf2_net import Net - class SimpleTrainer(Executor): def __init__(self, epochs_per_round): diff --git a/examples/hello-monai-bundle/app/custom/monai_trainer.py b/examples/hello-monai-bundle/app/custom/monai_trainer.py index acfb5cf9ac..7c1dff423d 100644 --- a/examples/hello-monai-bundle/app/custom/monai_trainer.py +++ b/examples/hello-monai-bundle/app/custom/monai_trainer.py @@ -16,6 +16,8 @@ import numpy as np import torch +from bundle_configer import BundleConfiger + from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor @@ -25,8 +27,6 @@ from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants -from bundle_configer import BundleConfiger - class MONAIBundleTrainer(Executor): """ diff --git a/examples/hello-monai/app/custom/monai_trainer.py b/examples/hello-monai/app/custom/monai_trainer.py index 5b01218cf0..a3da702c20 100644 --- a/examples/hello-monai/app/custom/monai_trainer.py +++ b/examples/hello-monai/app/custom/monai_trainer.py @@ -16,6 +16,8 @@ import numpy as np import torch +from train_configer import TrainConfiger + from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor @@ -25,8 +27,6 @@ from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants -from train_configer import TrainConfiger - class MONAITrainer(Executor): """ diff --git a/examples/hello-tf2/app/config/config_fed_server.json b/examples/hello-tf2/app/config/config_fed_server.json index 59886214bb..2fedc511e9 100644 --- a/examples/hello-tf2/app/config/config_fed_server.json +++ b/examples/hello-tf2/app/config/config_fed_server.json @@ -10,7 +10,7 @@ "id": "persistor", "path": "tf2_model_persistor.TF2ModelPersistor", "args": { - "save_name": "tf2weights.pickle" + "save_name": "tf2weights.fobs" } }, { diff --git a/examples/hello-tf2/app/custom/filter.py b/examples/hello-tf2/app/custom/filter.py index ddbaa8ee49..6332d27836 100644 --- a/examples/hello-tf2/app/custom/filter.py +++ b/examples/hello-tf2/app/custom/filter.py @@ -15,6 +15,7 @@ import re import numpy as np + from nvflare.apis.dxo import DXO, DataKind, from_shareable from nvflare.apis.filter import Filter from nvflare.apis.fl_context import FLContext diff --git a/examples/hello-tf2/app/custom/tf2_model_persistor.py b/examples/hello-tf2/app/custom/tf2_model_persistor.py index 37ac6b9afd..b94699be9c 100644 --- a/examples/hello-tf2/app/custom/tf2_model_persistor.py +++ b/examples/hello-tf2/app/custom/tf2_model_persistor.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import pickle import json +import os import tensorflow as tf +from tf2_net import Net + from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.model import ModelLearnable +from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable from nvflare.app_common.abstract.model_persistor import ModelPersistor -from tf2_net import Net from nvflare.app_common.app_constant import AppConstants -from nvflare.app_common.abstract.model import make_model_learnable +from nvflare.fuel.utils import fobs class TF2ModelPersistor(ModelPersistor): - def __init__(self, save_name="tf2_model.pkl"): + def __init__(self, save_name="tf2_model.fobs"): super().__init__() self.save_name = save_name @@ -65,7 +65,7 @@ def _initialize(self, fl_ctx: FLContext): self.log_dir = os.path.join(app_root, log_dir) else: self.log_dir = app_root - self._pkl_save_path = os.path.join(self.log_dir, self.save_name) + self._fobs_save_path = os.path.join(self.log_dir, self.save_name) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) @@ -81,10 +81,10 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable: Model object """ - if os.path.exists(self._pkl_save_path): + if os.path.exists(self._fobs_save_path): self.logger.info("Loading server weights") - with open(self._pkl_save_path, "rb") as f: - model_learnable = pickle.load(f) + with open(self._fobs_save_path, "rb") as f: + model_learnable = fobs.load(f) else: self.logger.info("Initializing server model") network = Net() @@ -108,5 +108,5 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext): """ model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()} self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}") - with open(self._pkl_save_path, "wb") as f: - pickle.dump(model_learnable, f) + with open(self._fobs_save_path, "wb") as f: + fobs.dump(model_learnable, f) diff --git a/examples/hello-tf2/app/custom/trainer.py b/examples/hello-tf2/app/custom/trainer.py index fa96b31c8a..06b4deb248 100644 --- a/examples/hello-tf2/app/custom/trainer.py +++ b/examples/hello-tf2/app/custom/trainer.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf import numpy as np +import tensorflow as tf +from tf2_net import Net from nvflare.apis.dxo import DXO, DataKind, from_shareable -from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal -from tf2_net import Net - class SimpleTrainer(Executor): def __init__(self, epochs_per_round): diff --git a/nvflare/apis/dxo.py b/nvflare/apis/dxo.py index b9cecf35f2..a7096d2856 100644 --- a/nvflare/apis/dxo.py +++ b/nvflare/apis/dxo.py @@ -13,10 +13,10 @@ # limitations under the License. import copy -import pickle from typing import List from nvflare.apis.shareable import ReservedHeaderKey, Shareable +from nvflare.fuel.utils import fobs class DataKind(object): @@ -115,7 +115,7 @@ def to_bytes(self) -> bytes: object serialized in bytes. """ - return pickle.dumps(self) + return fobs.dumps(self) def validate(self) -> str: if self.data is None: @@ -167,10 +167,10 @@ def from_bytes(data: bytes) -> DXO: data: a bytes object Returns: - an object loaded by pickle from data + an object loaded by FOBS from data """ - x = pickle.loads(data) + x = fobs.loads(data) if isinstance(x, DXO): return x else: diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index aa62430334..6f9fdb0226 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -83,6 +83,7 @@ class ReservedKey(object): EVENT_SCOPE = "__event_scope__" RUN_ABORT_SIGNAL = "__run_abort_signal__" SHAREABLE = "__shareable__" + SHARED_FL_CONTEXT = "__shared_fl_context__" ARGS = "__args__" WORKSPACE_OBJECT = "__workspace_object__" RANK_NUMBER = "__rank_number__" diff --git a/nvflare/apis/impl/job_def_manager.py b/nvflare/apis/impl/job_def_manager.py index 41b077e291..cf6e647d89 100644 --- a/nvflare/apis/impl/job_def_manager.py +++ b/nvflare/apis/impl/job_def_manager.py @@ -15,7 +15,6 @@ import datetime import os import pathlib -import pickle import shutil import tempfile import time @@ -29,6 +28,7 @@ from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.storage import StorageException, StorageSpec from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes +from nvflare.fuel.utils import fobs class _JobFilter(ABC): @@ -108,7 +108,7 @@ def create(self, meta: dict, uploaded_content: bytes, fl_ctx: FLContext) -> Dict # write it to the store stored_data = {JobDataKey.JOB_DATA.value: uploaded_content, JobDataKey.WORKSPACE_DATA.value: None} store = self._get_job_store(fl_ctx) - store.create_object(self.job_uri(jid), pickle.dumps(stored_data), meta, overwrite_existing=True) + store.create_object(self.job_uri(jid), fobs.dumps(stored_data), meta, overwrite_existing=True) return meta def delete(self, jid: str, fl_ctx: FLContext): @@ -185,12 +185,12 @@ def get_content(self, jid: str, fl_ctx: FLContext) -> Optional[bytes]: stored_data = store.get_data(self.job_uri(jid)) except StorageException: return None - return pickle.loads(stored_data).get(JobDataKey.JOB_DATA.value) + return fobs.loads(stored_data).get(JobDataKey.JOB_DATA.value) def get_job_data(self, jid: str, fl_ctx: FLContext) -> dict: store = self._get_job_store(fl_ctx) stored_data = store.get_data(self.job_uri(jid)) - return pickle.loads(stored_data) + return fobs.loads(stored_data) def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext): meta = {JobMetaKey.STATUS.value: status.value} @@ -259,6 +259,6 @@ def set_approval( def save_workspace(self, jid: str, data: bytes, fl_ctx: FLContext): store = self._get_job_store(fl_ctx) stored_data = store.get_data(self.job_uri(jid)) - job_data = pickle.loads(stored_data) + job_data = fobs.loads(stored_data) job_data[JobDataKey.WORKSPACE_DATA.value] = data - store.update_data(self.job_uri(jid), pickle.dumps(job_data)) + store.update_data(self.job_uri(jid), fobs.dumps(job_data)) diff --git a/nvflare/apis/shareable.py b/nvflare/apis/shareable.py index e21e872d85..903c3d4dfd 100644 --- a/nvflare/apis/shareable.py +++ b/nvflare/apis/shareable.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pickle - +from ..fuel.utils import fobs from .fl_constant import ReservedKey, ReturnCode @@ -113,7 +112,7 @@ def to_bytes(self) -> bytes: object serialized in bytes. """ - return pickle.dumps(self) + return fobs.dumps(self) @classmethod def from_bytes(cls, data: bytes): @@ -123,10 +122,10 @@ def from_bytes(cls, data: bytes): data: a bytes object Returns: - an object loaded by pickle from data + an object loaded by FOBS from data """ - return pickle.loads(data) + return fobs.loads(data) # some convenience functions diff --git a/nvflare/apis/utils/decomposers/__init__.py b/nvflare/apis/utils/decomposers/__init__.py new file mode 100644 index 0000000000..2b8f6c7e87 --- /dev/null +++ b/nvflare/apis/utils/decomposers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/apis/utils/decomposers/flare_decomposers.py b/nvflare/apis/utils/decomposers/flare_decomposers.py new file mode 100644 index 0000000000..e74484ed7f --- /dev/null +++ b/nvflare/apis/utils/decomposers/flare_decomposers.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decomposers for objects used by NVFlare itself + +This module contains all the decomposers used to run NVFlare. +The decomposers are registered at server/client startup. + +""" +import os +from argparse import Namespace +from typing import Any + +from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.client import Client +from nvflare.apis.dxo import DXO +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_snapshot import RunSnapshot +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.apis.workspace import Workspace +from nvflare.fuel.utils import fobs +from nvflare.fuel.utils.fobs.decomposer import Decomposer + + +class ShareableDecomposer(Decomposer): + @staticmethod + def supported_type(): + return Shareable + + def decompose(self, target: Shareable) -> Any: + return target.copy() + + def recompose(self, data: Any) -> Shareable: + obj = Shareable() + for k, v in data.items(): + obj[k] = v + return obj + + +class ContextDecomposer(Decomposer): + @staticmethod + def supported_type(): + return FLContext + + def decompose(self, target: FLContext) -> Any: + return [target.model, target.props] + + def recompose(self, data: Any) -> FLContext: + obj = FLContext() + obj.model = data[0] + obj.props = data[1] + return obj + + +class DxoDecomposer(Decomposer): + @staticmethod + def supported_type(): + return DXO + + def decompose(self, target: DXO) -> Any: + return [target.data_kind, target.data, target.meta] + + def recompose(self, data: Any) -> DXO: + return DXO(data[0], data[1], data[2]) + + +class ClientDecomposer(Decomposer): + @staticmethod + def supported_type(): + return Client + + def decompose(self, target: Client) -> Any: + return [target.name, target.token, target.last_connect_time, target.props] + + def recompose(self, data: Any) -> Client: + client = Client(data[0], data[1]) + client.last_connect_time = data[2] + client.props = data[3] + return client + + +class RunSnapshotDecomposer(Decomposer): + @staticmethod + def supported_type(): + return RunSnapshot + + def decompose(self, target: RunSnapshot) -> Any: + return [target.component_states, target.completed, target.job_id] + + def recompose(self, data: Any) -> RunSnapshot: + snapshot = RunSnapshot(data[2]) + snapshot.component_states = data[0] + snapshot.completed = data[1] + return snapshot + + +class WorkspaceDecomposer(Decomposer): + @staticmethod + def supported_type(): + return Workspace + + def decompose(self, target: Workspace) -> Any: + return [target.root_dir, target.name, target.config_folder] + + def recompose(self, data: Any) -> Workspace: + return Workspace(data[0], data[1], data[2]) + + +class SignalDecomposer(Decomposer): + @staticmethod + def supported_type(): + return Signal + + def decompose(self, target: Signal) -> Any: + return [target.value, target.trigger_time, target.triggered] + + def recompose(self, data: Any) -> Signal: + signal = Signal() + signal.value = data[0] + signal.trigger_time = data[1] + signal.triggered = data[2] + return signal + + +class AnalyticsDataTypeDecomposer(Decomposer): + @staticmethod + def supported_type(): + return AnalyticsDataType + + def decompose(self, target: AnalyticsDataType) -> Any: + return target.name + + def recompose(self, data: Any) -> AnalyticsDataType: + return AnalyticsDataType[data] + + +class NamespaceDecomposer(Decomposer): + @staticmethod + def supported_type(): + return Namespace + + def decompose(self, target: Namespace) -> Any: + return vars(target) + + def recompose(self, data: Any) -> Namespace: + return Namespace(**data) + + +def register(): + if register.registered: + return + + fobs.register_folder(os.path.dirname(__file__), __package__) + register.registered = True + + +register.registered = False diff --git a/nvflare/apis/utils/fl_context_utils.py b/nvflare/apis/utils/fl_context_utils.py index 105a4c68d9..74516bf0f7 100644 --- a/nvflare/apis/utils/fl_context_utils.py +++ b/nvflare/apis/utils/fl_context_utils.py @@ -13,11 +13,11 @@ # limitations under the License. import logging -import pickle from nvflare.apis.fl_constant import FLContextKey, NonSerializableKeys from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.fuel.utils import fobs def get_serializable_data(fl_ctx: FLContext): @@ -26,7 +26,7 @@ def get_serializable_data(fl_ctx: FLContext): for k, v in fl_ctx.props.items(): if k not in NonSerializableKeys.KEYS: try: - pickle.dumps(v) + fobs.dumps(v) new_fl_ctx.props[k] = v except: logger.warning(generate_log_message(fl_ctx, f"Object is not serializable (discarded): {k} - {v}")) diff --git a/nvflare/app_common/abstract/learnable.py b/nvflare/app_common/abstract/learnable.py index bc98943a51..d118c662bc 100644 --- a/nvflare/app_common/abstract/learnable.py +++ b/nvflare/app_common/abstract/learnable.py @@ -13,8 +13,7 @@ # limitations under the License. # from __future__ import annotations - -import pickle +from nvflare.fuel.utils import fobs class Learnable(dict): @@ -25,7 +24,7 @@ def to_bytes(self) -> bytes: object serialized in bytes. """ - return pickle.dumps(self) + return fobs.dumps(self) @classmethod def from_bytes(cls, data: bytes): @@ -35,7 +34,7 @@ def from_bytes(cls, data: bytes): data: a bytes object Returns: - an object loaded by pickle from data + an object loaded by FOBS from data """ - return pickle.loads(data) + return fobs.loads(data) diff --git a/nvflare/app_common/decomposers/__init__.py b/nvflare/app_common/decomposers/__init__.py new file mode 100644 index 0000000000..2b8f6c7e87 --- /dev/null +++ b/nvflare/app_common/decomposers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/decomposers/common_decomposers.py b/nvflare/app_common/decomposers/common_decomposers.py new file mode 100644 index 0000000000..dc99ec87d6 --- /dev/null +++ b/nvflare/app_common/decomposers/common_decomposers.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decomposers for types from app_common and Machine Learning libraries.""" +import os +from abc import ABC +from io import BytesIO +from typing import Any + +import numpy as np + +from nvflare.app_common.abstract.learnable import Learnable +from nvflare.app_common.widgets.event_recorder import _CtxPropReq, _EventReq, _EventStats +from nvflare.fuel.utils import fobs + + +class LearnableDecomposer(fobs.Decomposer): + @staticmethod + def supported_type(): + return Learnable + + def decompose(self, target: Learnable) -> Any: + return target.copy() + + def recompose(self, data: Any) -> Learnable: + obj = Learnable() + for k, v in data.items(): + obj[k] = v + return obj + + +class NumpyScalarDecomposer(fobs.Decomposer, ABC): + """Decomposer base class for all numpy types with item method.""" + + def decompose(self, target: Any) -> Any: + return target.item() + + def recompose(self, data: Any) -> np.ndarray: + return self.supported_type()(data) + + +class Float64ScalarDecomposer(NumpyScalarDecomposer): + @staticmethod + def supported_type(): + return np.float64 + + +class Float32ScalarDecomposer(NumpyScalarDecomposer): + @staticmethod + def supported_type(): + return np.float32 + + +class Int64ScalarDecomposer(NumpyScalarDecomposer): + @staticmethod + def supported_type(): + return np.int64 + + +class Int32ScalarDecomposer(NumpyScalarDecomposer): + @staticmethod + def supported_type(): + return np.int32 + + +class NumpyArrayDecomposer(fobs.Decomposer): + @staticmethod + def supported_type(): + return np.ndarray + + def decompose(self, target: np.ndarray) -> Any: + stream = BytesIO() + np.save(stream, target, allow_pickle=False) + return stream.getvalue() + + def recompose(self, data: Any) -> np.ndarray: + stream = BytesIO(data) + return np.load(stream, allow_pickle=False) + + +class CtxPropReqDecomposer(fobs.Decomposer): + @staticmethod + def supported_type(): + return _CtxPropReq + + def decompose(self, target: _CtxPropReq) -> Any: + return [target.dtype, target.is_private, target.is_sticky, target.allow_none] + + def recompose(self, data: Any) -> _CtxPropReq: + return _CtxPropReq(data[0], data[1], data[2], data[3]) + + +class EventReqDecomposer(fobs.Decomposer): + @staticmethod + def supported_type(): + return _EventReq + + def decompose(self, target: _EventReq) -> Any: + return [target.ctx_reqs, target.peer_ctx_reqs, target.ctx_block_list, target.peer_ctx_block_List] + + def recompose(self, data: Any) -> _EventReq: + return _EventReq(data[0], data[1], data[2], data[3]) + + +class EventStatsDecomposer(fobs.Decomposer): + @staticmethod + def supported_type(): + return _EventStats + + def decompose(self, target: _EventStats) -> Any: + return [ + target.call_count, + target.prop_missing, + target.prop_none_value, + target.prop_dtype_mismatch, + target.prop_attr_mismatch, + target.prop_block_list_violation, + target.peer_ctx_missing, + ] + + def recompose(self, data: Any) -> _EventStats: + stats = _EventStats() + stats.call_count = data[0] + stats.prop_missing = data[1] + stats.prop_none_value = data[2] + stats.prop_dtype_mismatch = data[3] + stats.prop_attr_mismatch = data[4] + stats.prop_block_list_violation = data[5] + stats.peer_ctx_missing = data[6] + return stats + + +def register(): + if register.registered: + return + + fobs.register_folder(os.path.dirname(__file__), __package__) + register.registered = True + + +register.registered = False diff --git a/nvflare/app_common/np/np_model_locator.py b/nvflare/app_common/np/np_model_locator.py index b8d1568f74..640791f3f0 100755 --- a/nvflare/app_common/np/np_model_locator.py +++ b/nvflare/app_common/np/np_model_locator.py @@ -65,7 +65,7 @@ def locate_model(self, model_name, fl_ctx: FLContext) -> DXO: model_load_path = os.path.join(model_path, self.model_file_name) np_data = None try: - np_data = np.load(model_load_path, allow_pickle=True) + np_data = np.load(model_load_path, allow_pickle=False) self.log_info(fl_ctx, f"Loaded {model_name} model from {model_load_path}.") except Exception as e: self.log_error(fl_ctx, f"Unable to load NP Model: {e}.") diff --git a/nvflare/app_common/state_persistors/storage_state_persistor.py b/nvflare/app_common/state_persistors/storage_state_persistor.py index f4510964b3..9f94b79b1d 100644 --- a/nvflare/app_common/state_persistors/storage_state_persistor.py +++ b/nvflare/app_common/state_persistors/storage_state_persistor.py @@ -13,11 +13,11 @@ # limitations under the License. import os -import pickle from nvflare.apis.fl_snapshot import FLSnapshot, RunSnapshot from nvflare.apis.state_persistor import StatePersistor from nvflare.apis.storage import StorageSpec +from nvflare.fuel.utils import fobs class StorageStatePersistor(StatePersistor): @@ -45,9 +45,7 @@ def save(self, snapshot: RunSnapshot) -> str: if snapshot.completed: full_uri = self.storage.delete_object(path) else: - full_uri = self.storage.create_object( - uri=path, data=pickle.dumps(snapshot), meta={}, overwrite_existing=True - ) + full_uri = self.storage.create_object(uri=path, data=fobs.dumps(snapshot), meta={}, overwrite_existing=True) return full_uri @@ -60,7 +58,7 @@ def retrieve(self) -> FLSnapshot: all_items = self.storage.list_objects(self.uri_root) fl_snapshot = FLSnapshot() for item in all_items: - snapshot = pickle.loads(self.storage.get_data(item)) + snapshot = fobs.loads(self.storage.get_data(item)) fl_snapshot.add_snapshot(snapshot.job_id, snapshot) return fl_snapshot @@ -75,7 +73,7 @@ def retrieve_run(self, job_id: str) -> RunSnapshot: """ path = os.path.join(self.uri_root, job_id) - snapshot = pickle.loads(self.storage.get_data(uri=path)) + snapshot = fobs.loads(self.storage.get_data(uri=path)) return snapshot def delete(self): diff --git a/nvflare/fuel/utils/fobs/README.rst b/nvflare/fuel/utils/fobs/README.rst new file mode 100644 index 0000000000..da980bc356 --- /dev/null +++ b/nvflare/fuel/utils/fobs/README.rst @@ -0,0 +1,178 @@ +Flare Object Serializer (FOBS) +============================== + + +Overview +-------- + +FOBS is a drop-in replacement for Pickle for security purposes. It uses **MessagePack** to +serialize objects. + +FOBS sacrifices convenience for security. With Pickle, most objects are supported +automatically using introspection. To serialize an object using FOBS, a **Decomposer** +must be registered for the class. A few decomposers for commonly used classes are +pre-registered with the module. + +FOBS throws :code:`TypeError` exception when it encounters an object with no decomposer +registered. For example, +:: + TypeError: can not serialize 'xxx' object + +Usage +----- + +FOBS defines following 4 functions, similar to Pickle, + +* :code:`dumps(obj)`: Serializes obj and returns bytes +* :code:`dump(obj, stream)`: Serializes obj and writes the result to stream +* :code:`loads(data)`: Deserializes the data and returns an object +* :code:`load(stream)`: Reads data from stream and deserializes it into an object + + +Examples, +:: + + from nvflare.fuel.utils import fobs + + data = fobs.dumps(dxo) + new_dxo = fobs.loads(data) + + # Pickle/json compatible functions can be used also + data = fobs.dumps(shareable) + new_shareable = fobs.loads(data) + +Decomposers +----------- + +Decomposers are classes that inherit abstract base class :code:`fobs.Decomposer`. FOBS +uses decomposers to break an object into **serializable objects** before serializing it +using MessagePack. + +Decomposers are very similar to serializers, except that they don't have to convert object +into bytes directly, they can just break the object into other objects that are serializable. + +An object is serializable if its type is supported by MessagePack or a decomposer is +registered for its class. + +FOBS recursively decomposes objects till all objects are of types supported by MessagePack. +Decomposing looping must be avoided, which causes stack overflow. Decomposers form a loop +when one class is decomposed into another class which is eventually decomposed into the +original class. For example, this scenario forms the simplest loop: X decomposes into Y +and Y decomposes back into X. + +MessagePack supports following types natively, + +* None +* bool +* int +* float +* str +* bytes +* bytearray +* memoryview +* list +* dict + +Decomposers for following classes are included with `fobs` module and auto-registered, + +* tuple +* set +* OrderedDict +* datetime +* Shareable +* FLContext +* DXO +* Client +* RunSnapshot +* Workspace +* Signal +* AnalyticsDataType +* argparse.Namespace +* Learnable +* _CtxPropReq +* _EventReq +* _EventStats +* numpy.float32 +* numpy.float64 +* numpy.int32 +* numpy.int64 +* numpy.ndarray + +All classes defined in :code:`fobs/decomposers` folder are automatically registered. +Other decomposers must be registered manually like this, + +:: + + fobs.register(FooDecomposer) + fobs.register(BarDecomposer()) + + +:code:`fobs.register` takes either a class or an instance as the argument. Decomposer whose +constructor takes arguments must be registered as instance. + +A decomposer can either serialize the class into bytes or decompose it into objects of +serializable types. In most cases, it only involves saving members as a list and reconstructing +the object from the list. + +Here is an example of a simple decomposer. Even though :code:`datetime` is not supported +by MessagePack, a decomposer is included in `fobs` module so no need to further decompose it. + +:: + + from nvflare.fuel.utils import fobs + + + class Simple: + + def __init__(self, num: int, name: str, timestamp: datetime): + self.num = num + self.name = name + self.timestamp = timestamp + + + class SimpleDecomposer(fobs.Decomposer): + + @staticmethod + def supported_type() -> Type[Any]: + return Simple + + def decompose(self, obj) -> Any: + return [obj.num, obj.name, obj.timestamp] + + def recompose(self, data: Any) -> Simple: + return Simple(data[0], data[1], data[2]) + + + fobs.register(SimpleDecomposer) + data = fobs.dumps(Simple(1, 'foo', datetime.now())) + obj = fobs.loads(data) + assert obj.num == 1 + assert obj.name == 'foo' + assert isinstance(obj.timestamp, datetime) + + +The same decomposer can be registered multiple times. Only first one takes effect, the others +are ignored with a warning message. + +Custom Types +------------ + +To support custom types with FOBS, the decomposers for the types must be included +with the custom code and registered. + +The decomposers must be registered in both server and client code before FOBS is used. +A good place for registration is the constructors for controllers and executors. It +can also be done in `START_RUN` event handler. + +Custom object cannot be put in `shareable` directly, +it must be serialized using FOBS first. Assuming `custom_data` contains custom type, +this is how data can be stored in shareable, +:: + shareable[CUSTOM_DATA] = fobs.dumps(custom_data) +On the receiving end, +:: + custom_data = fobs.loads(shareable[CUSTOM_DATA]) + +This doesn't work +:: + shareable[CUSTOM_DATA] = custom_data diff --git a/nvflare/fuel/utils/fobs/__init__.py b/nvflare/fuel/utils/fobs/__init__.py new file mode 100644 index 0000000000..8ca6d76fdb --- /dev/null +++ b/nvflare/fuel/utils/fobs/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.fuel.utils.fobs.decomposer import Decomposer +from nvflare.fuel.utils.fobs.fobs import ( + deserialize, + deserialize_stream, + num_decomposers, + register, + register_folder, + serialize, + serialize_stream, +) + +# aliases for compatibility to Pickle/json +load = deserialize_stream +loads = deserialize +dump = serialize_stream +dumps = serialize diff --git a/nvflare/fuel/utils/fobs/decomposer.py b/nvflare/fuel/utils/fobs/decomposer.py new file mode 100644 index 0000000000..877cd517bf --- /dev/null +++ b/nvflare/fuel/utils/fobs/decomposer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Type, TypeVar + +# Generic type supported by the decomposer. +T = TypeVar("T") + + +class Decomposer(ABC): + """Abstract base class for decomposers. + + Every class to be serialized by FOBS must register a decomposer which is + a concrete subclass of this class. + """ + + @staticmethod + @abstractmethod + def supported_type() -> Type[T]: + """Returns the type/class supported by this decomposer. + + Returns: + The class (not instance) of supported type + """ + pass + + @abstractmethod + def decompose(self, target: T) -> Any: + """Decompose the target into types supported by msgpack or classes with decomposers registered. + + Msgpack supports primitives, bytes, memoryview, lists, dicts. + + Args: + target: The instance to be serialized + Returns: + The decomposed serializable objects + """ + pass + + @abstractmethod + def recompose(self, data: Any) -> T: + """Reconstruct the object from decomposed components. + + Args: + data: The decomposed components + Returns: + The reconstructed object + """ + pass diff --git a/nvflare/fuel/utils/fobs/decomposers/__init__.py b/nvflare/fuel/utils/fobs/decomposers/__init__.py new file mode 100644 index 0000000000..2b8f6c7e87 --- /dev/null +++ b/nvflare/fuel/utils/fobs/decomposers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fuel/utils/fobs/decomposers/core_decomposers.py b/nvflare/fuel/utils/fobs/decomposers/core_decomposers.py new file mode 100644 index 0000000000..c2a1138c71 --- /dev/null +++ b/nvflare/fuel/utils/fobs/decomposers/core_decomposers.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Decomposers for Python builtin objects.""" +from collections import OrderedDict +from datetime import datetime +from typing import Any + +from nvflare.fuel.utils.fobs.decomposer import Decomposer + + +class TupleDecomposer(Decomposer): + @staticmethod + def supported_type(): + return tuple + + def decompose(self, target: tuple) -> Any: + return list(target) + + def recompose(self, data: Any) -> tuple: + return tuple(data) + + +class SetDecomposer(Decomposer): + @staticmethod + def supported_type(): + return set + + def decompose(self, target: set) -> Any: + return list(target) + + def recompose(self, data: Any) -> set: + return set(data) + + +class OrderedDictDecomposer(Decomposer): + @staticmethod + def supported_type(): + return OrderedDict + + def decompose(self, target: OrderedDict) -> Any: + return list(target.items()) + + def recompose(self, data: Any) -> set: + return OrderedDict(data) + + +class DatetimeDecomposer(Decomposer): + @staticmethod + def supported_type(): + return datetime + + def decompose(self, target: datetime) -> Any: + return target.isoformat() + + def recompose(self, data: Any) -> datetime: + return datetime.fromisoformat(data) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py new file mode 100644 index 0000000000..5ebccd4f47 --- /dev/null +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -0,0 +1,182 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import logging +import os +from os.path import dirname, join +from typing import Any, BinaryIO, Dict, Type, Union + +import msgpack + +from nvflare.fuel.utils.fobs.decomposer import Decomposer + +__all__ = [ + "register", + "register_folder", + "num_decomposers", + "serialize", + "serialize_stream", + "deserialize", + "deserialize_stream", +] + +FOBS_TYPE = "__fobs_type__" +FOBS_DATA = "__fobs_data__" +MSGPACK_TYPES = (None, bool, int, float, str, bytes, bytearray, memoryview, list, dict) + +log = logging.getLogger(__name__) +_decomposers: Dict[str, Decomposer] = {} +_decomposers_registered = False + + +def _get_type_name(cls: Type) -> str: + module = cls.__module__ + if module == "builtins": + return cls.__qualname__ + return module + "." + cls.__qualname__ + + +def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: + """Register a decomposer. It does nothing if decomposer is already registered for the type + + Args: + decomposer: The decomposer type or instance + """ + + name = _get_type_name(decomposer.supported_type()) + if name in _decomposers: + return + + if inspect.isclass(decomposer): + instance = decomposer() + else: + instance = decomposer + + if not isinstance(instance, Decomposer): + log.error(f"Class {instance.__class__} is not a decomposer") + return + + _decomposers[name] = instance + + +def register_folder(folder: str, package: str): + """Scan the folder and register all decomposers found. + + Args: + folder: The folder to scan + package: The package to import the decomposers from + """ + for module in os.listdir(folder): + if module != "__init__.py" and module[-3:] == ".py": + decomposers = package + "." + module[:-3] + imported = importlib.import_module(decomposers, __package__) + for _, cls_obj in inspect.getmembers(imported, inspect.isclass): + if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj): + register(cls_obj) + + +def _register_decomposers(): + global _decomposers_registered + + if _decomposers_registered: + return + + register_folder(join(dirname(__file__), "decomposers"), ".decomposers") + _decomposers_registered = True + + +def num_decomposers() -> int: + """Returns the number of decomposers registered. + + Returns: + The number of decomposers + """ + return len(_decomposers) + + +def _fobs_packer(obj: Any) -> dict: + + if type(obj) in MSGPACK_TYPES: + return obj + + type_name = _get_type_name(obj.__class__) + if type_name not in _decomposers: + return obj + + decomposed = _decomposers[type_name].decompose(obj) + return {FOBS_TYPE: type_name, FOBS_DATA: decomposed} + + +def _fobs_unpacker(obj: Any) -> Any: + + if type(obj) is not dict or FOBS_TYPE not in obj: + return obj + + type_name = obj[FOBS_TYPE] + if type_name not in _decomposers: + raise TypeError(f"Unknown type {type_name}, caused by mismatching decomposers") + decomposer = _decomposers[type_name] + return decomposer.recompose(obj[FOBS_DATA]) + + +def serialize(obj: Any, **kwargs) -> bytes: + """Serialize object into bytes. + + Args: + obj: Object to be serialized + kwargs: Arguments passed to msgpack.packb + Returns: + Serialized data + """ + _register_decomposers() + return msgpack.packb(obj, default=_fobs_packer, strict_types=True, **kwargs) + + +def serialize_stream(obj: Any, stream: BinaryIO, **kwargs): + """Serialize object and write the data to a stream. + + Args: + obj: Object to be serialized + stream: Stream to write the result to + kwargs: Arguments passed to msgpack.packb + """ + data = serialize(obj, **kwargs) + stream.write(data) + + +def deserialize(data: bytes, **kwargs) -> Any: + """Deserialize bytes into an object. + + Args: + data: Serialized data + kwargs: Arguments passed to msgpack.unpackb + Returns: + Deserialized object + """ + _register_decomposers() + return msgpack.unpackb(data, object_hook=_fobs_unpacker, **kwargs) + + +def deserialize_stream(stream: BinaryIO, **kwargs) -> Any: + """Deserialize bytes from stream into an object. + + Args: + stream: Stream to write serialized data to + kwargs: Arguments passed to msgpack.unpackb + Returns: + Deserialized object + """ + data = stream.read() + return deserialize(data, **kwargs) diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index dada5a74d4..ef4fea2c3e 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -54,7 +54,7 @@ readme_fs: | start.sh sub_start.sh stop_fl.sh - signature.pkl + signature.json Run start.sh to start the server. diff --git a/nvflare/lighter/impl/signature.py b/nvflare/lighter/impl/signature.py index 3c7c3b600d..aa3594922c 100644 --- a/nvflare/lighter/impl/signature.py +++ b/nvflare/lighter/impl/signature.py @@ -23,7 +23,7 @@ class SignatureBuilder(Builder): """Sign files with rootCA's private key. Creates signatures for all the files signed with the root CA for the startup kits so that they - can be cryptographically verified to ensure any tampering is detected. This builder writes the signature.pkl file. + can be cryptographically verified to ensure any tampering is detected. This builder writes the signature.json file. """ def _do_sign(self, root_pri_key, dest_dir): diff --git a/nvflare/private/fed/client/fed_client.py b/nvflare/private/fed/client/fed_client.py index beabeded96..c9522f4576 100644 --- a/nvflare/private/fed/client/fed_client.py +++ b/nvflare/private/fed/client/fed_client.py @@ -14,7 +14,6 @@ """The client of the federated training process.""" -import pickle from typing import List, Optional from nvflare.apis.event_type import EventType @@ -23,6 +22,9 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.apis.utils.decomposers import flare_decomposers +from nvflare.app_common.decomposers import common_decomposers +from nvflare.fuel.utils import fobs from nvflare.private.defs import SpecialTaskName from nvflare.private.event import fire_event from nvflare.private.fed.utils.numproto import proto_to_bytes @@ -80,6 +82,8 @@ def __init__( self.executors = executors self.enable_byoc = enable_byoc + flare_decomposers.register() + common_decomposers.register() def fetch_task(self, fl_ctx: FLContext): fire_event(EventType.BEFORE_PULL_TASK, self.handlers, fl_ctx) @@ -97,7 +101,7 @@ def extract_shareable(self, responses, fl_ctx: FLContext): peer_context = FLContext() for item in responses: shareable = shareable.from_bytes(proto_to_bytes(item.data.params["data"])) - peer_context = pickle.loads(proto_to_bytes(item.data.params["fl_context"])) + peer_context = fobs.loads(proto_to_bytes(item.data.params["fl_context"])) fl_ctx.set_peer_context(peer_context) shareable.set_peer_props(peer_context.get_all_public_props()) diff --git a/nvflare/private/fed/client/process_aux_cmd.py b/nvflare/private/fed/client/process_aux_cmd.py index 539b53060f..a98f18c025 100644 --- a/nvflare/private/fed/client/process_aux_cmd.py +++ b/nvflare/private/fed/client/process_aux_cmd.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pickle - from nvflare.apis.fl_constant import ReservedTopic, ReturnCode from nvflare.apis.shareable import make_reply +from nvflare.fuel.utils import fobs from nvflare.private.admin_defs import Message from nvflare.private.defs import RequestHeader from nvflare.private.fed.client.admin import RequestProcessor @@ -28,13 +27,13 @@ def get_topics(self) -> [str]: def process(self, req: Message, app_ctx) -> Message: engine = app_ctx - shareable = pickle.loads(req.body) + shareable = fobs.loads(req.body) job_id = req.get_header(RequestHeader.JOB_ID) result = engine.send_aux_command(shareable, job_id) if not result: result = make_reply(ReturnCode.EXECUTION_EXCEPTION) - result = pickle.dumps(result) + result = fobs.dumps(result) message = Message(topic="reply_" + req.topic, body=result) return message diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 1b8f52b43f..b808f82e1a 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import pickle from typing import List from nvflare.apis.fl_constant import ReturnCode, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec from nvflare.apis.shareable import Shareable +from nvflare.fuel.utils import fobs from nvflare.private.admin_defs import Message from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader, SysCommandTopic, TrainingTopic from nvflare.private.fed.client.admin import RequestProcessor @@ -42,7 +42,7 @@ def process(self, req: Message, app_ctx) -> Message: with engine.new_context() as fl_ctx: result = Shareable() try: - resource_spec = pickle.loads(req.body) + resource_spec = fobs.loads(req.body) check_result, token = resource_manager.check_resources( resource_requirement=resource_spec, fl_ctx=fl_ctx ) @@ -51,7 +51,7 @@ def process(self, req: Message, app_ctx) -> Message: except Exception: result.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return Message(topic="reply_" + req.topic, body=pickle.dumps(result)) + return Message(topic="reply_" + req.topic, body=fobs.dumps(result)) class StartJobProcessor(RequestProcessor): @@ -76,7 +76,7 @@ def process(self, req: Message, app_ctx) -> Message: allocated_resources = None try: - resource_spec = pickle.loads(req.body) + resource_spec = fobs.loads(req.body) job_id = req.get_header(RequestHeader.JOB_ID) token = req.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN) except Exception as e: @@ -123,13 +123,13 @@ def process(self, req: Message, app_ctx) -> Message: result = Shareable() try: # resource_spec = req.get_header(ShareableHeader.RESOURCE_SPEC) - resource_spec = pickle.loads(req.body) + resource_spec = fobs.loads(req.body) token = req.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN) resource_manager.cancel_resources(resource_requirement=resource_spec, token=token, fl_ctx=fl_ctx) except Exception: result.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return Message(topic="reply_" + req.topic, body=pickle.dumps(result)) + return Message(topic="reply_" + req.topic, body=fobs.dumps(result)) class ReportResourcesProcessor(RequestProcessor): diff --git a/nvflare/private/fed/server/collective_command_agent.py b/nvflare/private/fed/server/collective_command_agent.py index fc7c419b1e..7880ffaf0a 100644 --- a/nvflare/private/fed/server/collective_command_agent.py +++ b/nvflare/private/fed/server/collective_command_agent.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -import pickle import threading from nvflare.apis.collective_comm_constants import ( @@ -23,6 +22,7 @@ ) from nvflare.apis.fl_constant import ServerCommandKey from nvflare.apis.shareable import Shareable +from nvflare.fuel.utils import fobs from nvflare.private.fed.utils.fed_utils import listen_command from .server_commands import ServerCommands @@ -88,7 +88,7 @@ def _poll_command(self, conn, engine): try: if conn.poll(1.0): msg = conn.recv() - msg = pickle.loads(msg) + msg = fobs.loads(msg) command_name = msg.get(ServerCommandKey.COMMAND) command = ServerCommands.get_command(command_name) if not command: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 599076d633..ec96fda8e7 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -14,7 +14,6 @@ import logging import os -import pickle import shutil import threading import time @@ -37,7 +36,7 @@ from nvflare.apis.fl_constant import ( FLContextKey, MachineStatus, - RunProcessKey, + ReservedKey, ServerCommandKey, ServerCommandNames, SnapshotKey, @@ -45,8 +44,11 @@ ) from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReservedHeaderKey, ReturnCode, Shareable, make_reply +from nvflare.apis.utils.decomposers import flare_decomposers from nvflare.apis.workspace import Workspace +from nvflare.app_common.decomposers import common_decomposers from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes +from nvflare.fuel.utils import fobs from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.private.defs import SpecialTaskName from nvflare.private.fed.server.server_runner import ServerRunner @@ -129,8 +131,6 @@ def close(self): except RuntimeError: self.logger.info("canceling sync locks") try: - if self.admin_server: - self.admin_server.stop() if self.grpc_server: self.grpc_server.stop(0) finally: @@ -294,9 +294,11 @@ def __init__( self.overseer_agent = overseer_agent self.server_state: ServerState = ColdState() self.snapshot_persistor = snapshot_persistor - self._collective_comm_timeout = collective_command_timeout + flare_decomposers.register() + common_decomposers.register() + def _create_server_engine(self, args, snapshot_persistor): return ServerEngine( server=self, args=args, client_manager=self.client_manager, snapshot_persistor=snapshot_persistor @@ -412,7 +414,7 @@ def GetTask(self, request, context): token = client.get_token() # engine = fl_ctx.get_engine() - shared_fl_ctx = pickle.loads(proto_to_bytes(request.context["fl_context"])) + shared_fl_ctx = fobs.loads(proto_to_bytes(request.context["fl_context"])) job_id = str(shared_fl_ctx.get_prop(FLContextKey.CURRENT_RUN)) # fl_ctx.set_peer_context(shared_fl_ctx) @@ -471,7 +473,7 @@ def _process_task_request(self, client, fl_ctx, shared_fl_ctx): } command_conn.send(data) - return_data = pickle.loads(command_conn.recv()) + return_data = fobs.loads(command_conn.recv()) task_name = return_data.get(ServerCommandKey.TASK_NAME) task_id = return_data.get(ServerCommandKey.TASK_ID) shareable = return_data.get(ServerCommandKey.SHAREABLE) @@ -500,7 +502,7 @@ def SubmitUpdate(self, request, context): else: with self.lock: shareable = Shareable.from_bytes(proto_to_bytes(request.data.params["data"])) - shared_fl_context = pickle.loads(proto_to_bytes(request.data.params["fl_context"])) + shared_fl_context = fobs.loads(proto_to_bytes(request.data.params["fl_context"])) job_id = str(shared_fl_context.get_prop(FLContextKey.CURRENT_RUN)) if job_id not in self.engine.run_processes.keys(): @@ -525,12 +527,11 @@ def SubmitUpdate(self, request, context): time_seconds or "less than 1", ) - task_id = shareable.get_cookie(FLContextKey.TASK_ID) - shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_context) shareable.set_header(ServerCommandKey.FL_CLIENT, client) shareable.set_header(ServerCommandKey.TASK_NAME, contribution_task_name) + data = {ReservedKey.SHAREABLE: shareable, ReservedKey.SHARED_FL_CONTEXT: shared_fl_context} - self._submit_update(shareable, shared_fl_context) + self._submit_update(data, shared_fl_context) # self.server_runner.process_submission(client, contribution_task_name, task_id, shareable, fl_ctx) @@ -544,14 +545,14 @@ def SubmitUpdate(self, request, context): return summary_info - def _submit_update(self, shareable, shared_fl_context): + def _submit_update(self, submit_update_data, shared_fl_context): try: with self.engine.lock: job_id = shared_fl_context.get_prop(FLContextKey.CURRENT_RUN) self.engine.send_command_to_child_runner_process( job_id=job_id, command_name=ServerCommandNames.SUBMIT_UPDATE, - command_data=shareable, + command_data=submit_update_data, return_result=False, ) except BaseException: @@ -575,7 +576,7 @@ def AuxCommunicate(self, request, context): shareable = Shareable() shareable = shareable.from_bytes(proto_to_bytes(request.data["data"])) - shared_fl_context = pickle.loads(proto_to_bytes(request.data["fl_context"])) + shared_fl_context = fobs.loads(proto_to_bytes(request.data["fl_context"])) job_id = str(shared_fl_context.get_prop(FLContextKey.CURRENT_RUN)) if job_id not in self.engine.run_processes.keys(): diff --git a/nvflare/private/fed/server/server_command_agent.py b/nvflare/private/fed/server/server_command_agent.py index ba611ed7f6..589e592d7e 100644 --- a/nvflare/private/fed/server/server_command_agent.py +++ b/nvflare/private/fed/server/server_command_agent.py @@ -13,10 +13,10 @@ # limitations under the License. import logging -import pickle import threading from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.fuel.utils import fobs from ..utils.fed_utils import listen_command from .server_commands import ServerCommands @@ -46,7 +46,7 @@ def execute_command(self, conn, engine): try: if conn.poll(1.0): msg = conn.recv() - msg = pickle.loads(msg) + msg = fobs.loads(msg) command_name = msg.get(ServerCommandKey.COMMAND) data = msg.get(ServerCommandKey.DATA) command = ServerCommands.get_command(command_name) diff --git a/nvflare/private/fed/server/server_commands.py b/nvflare/private/fed/server/server_commands.py index bc991a1505..53fc698ff0 100644 --- a/nvflare/private/fed/server/server_commands.py +++ b/nvflare/private/fed/server/server_commands.py @@ -15,15 +15,15 @@ """FL Admin commands.""" import copy -import pickle import time from abc import ABC, abstractmethod from typing import List -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, ServerCommandKey, ServerCommandNames +from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, ReservedKey, ServerCommandKey, ServerCommandNames from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.apis.utils.fl_context_utils import get_serializable_data +from nvflare.fuel.utils import fobs from nvflare.widgets.widget import WidgetID @@ -139,7 +139,7 @@ def process(self, data: Shareable, fl_ctx: FLContext): ServerCommandKey.SHAREABLE: shareable, ServerCommandKey.FL_CONTEXT: copy.deepcopy(get_serializable_data(fl_ctx).props), } - return pickle.dumps(data) + return fobs.dumps(data) class SubmitUpdateCommand(CommandProcessor): @@ -164,13 +164,14 @@ def process(self, data: Shareable, fl_ctx: FLContext): """ - shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) - client = data.get_header(ServerCommandKey.FL_CLIENT) + shareable = data.get(ReservedKey.SHAREABLE) + shared_fl_ctx = data.get(ReservedKey.SHARED_FL_CONTEXT) + client = shareable.get_header(ServerCommandKey.FL_CLIENT) fl_ctx.set_peer_context(shared_fl_ctx) - contribution_task_name = data.get_header(ServerCommandKey.TASK_NAME) - task_id = data.get_cookie(FLContextKey.TASK_ID) + contribution_task_name = shareable.get_header(ServerCommandKey.TASK_NAME) + task_id = shareable.get_cookie(FLContextKey.TASK_ID) server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) - server_runner.process_submission(client, contribution_task_name, task_id, data, fl_ctx) + server_runner.process_submission(client, contribution_task_name, task_id, shareable, fl_ctx) return "" diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 346ee31268..f94bd996f3 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -15,7 +15,6 @@ import copy import logging import os -import pickle import re import shlex import shutil @@ -51,6 +50,7 @@ from nvflare.apis.utils.fl_context_utils import get_serializable_data from nvflare.apis.workspace import Workspace from nvflare.fuel.hci.zip_utils import zip_directory_to_bytes +from nvflare.fuel.utils import fobs from nvflare.private.admin_defs import Message from nvflare.private.defs import RequestHeader, TrainingTopic from nvflare.private.fed.server.server_json_config import ServerJsonConfigurator @@ -72,7 +72,7 @@ def __init__(self, client): self.client = client def send(self, data): - data = pickle.dumps(data) + data = fobs.dumps(data) self.client.send(data) def recv(self): @@ -574,7 +574,7 @@ def aux_send(self, targets: [], topic: str, request: Shareable, timeout: float, # Send the aux messages through admin_server request.set_peer_props(fl_ctx.get_all_public_props()) - message = Message(topic=ReservedTopic.AUX_COMMAND, body=pickle.dumps(request)) + message = Message(topic=ReservedTopic.AUX_COMMAND, body=fobs.dumps(request)) message.set_header(RequestHeader.JOB_ID, str(fl_ctx.get_prop(FLContextKey.CURRENT_RUN))) requests = {} for n in targets: @@ -586,7 +586,7 @@ def aux_send(self, targets: [], topic: str, request: Shareable, timeout: float, client_name = self.get_client_name_from_token(r.client_token) if r.reply: try: - results[client_name] = pickle.loads(r.reply.body) + results[client_name] = fobs.loads(r.reply.body) except BaseException: results[client_name] = make_reply(ReturnCode.COMMUNICATION_ERROR) self.logger.error( @@ -736,7 +736,7 @@ def check_client_resources(self, resource_reqs) -> Dict[str, Tuple[bool, str]]: # assume server resource is unlimited if site_name == "server": continue - request = Message(topic=TrainingTopic.CHECK_RESOURCE, body=pickle.dumps(resource_requirements)) + request = Message(topic=TrainingTopic.CHECK_RESOURCE, body=fobs.dumps(resource_requirements)) client = self.get_client_from_name(site_name) if client: requests.update({client.token: request}) @@ -747,7 +747,7 @@ def check_client_resources(self, resource_reqs) -> Dict[str, Tuple[bool, str]]: for r in replies: site_name = self.get_client_name_from_token(r.client_token) if r.reply: - resp = pickle.loads(r.reply.body) + resp = fobs.loads(r.reply.body) result[site_name] = ( resp.get_header(ShareableHeader.CHECK_RESOURCE_RESULT, False), resp.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, ""), @@ -764,7 +764,7 @@ def cancel_client_resources( check_result, token = result if check_result and token: resource_requirements = resource_reqs[site_name] - request = Message(topic=TrainingTopic.CANCEL_RESOURCE, body=pickle.dumps(resource_requirements)) + request = Message(topic=TrainingTopic.CANCEL_RESOURCE, body=fobs.dumps(resource_requirements)) request.set_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, token) client = self.get_client_from_name(site_name) if client: @@ -777,7 +777,7 @@ def start_client_job(self, job_id, client_sites): for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements token = dispatch_info.token - request = Message(topic=TrainingTopic.START_JOB, body=pickle.dumps(resource_requirement)) + request = Message(topic=TrainingTopic.START_JOB, body=fobs.dumps(resource_requirement)) request.set_header(RequestHeader.JOB_ID, job_id) request.set_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, token) client = self.get_client_from_name(site) diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index e38f1f71ac..998f95e536 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -14,7 +14,6 @@ import logging import os -import pickle import shutil from logging.handlers import RotatingFileHandler from multiprocessing.connection import Listener @@ -24,6 +23,7 @@ from nvflare.apis.fl_context import FLContext from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes from nvflare.fuel.sec.security_content_service import LoadResult, SecurityContentService +from nvflare.fuel.utils import fobs from nvflare.private.defs import SSLConstants from nvflare.private.fed.protos.federated_pb2 import ModelData from nvflare.private.fed.utils.numproto import bytes_to_proto @@ -47,7 +47,7 @@ def make_shareable_data(shareable): def make_context_data(fl_ctx): shared_fl_ctx = FLContext() shared_fl_ctx.set_public_props(fl_ctx.get_all_public_props()) - props = pickle.dumps(shared_fl_ctx) + props = fobs.dumps(shared_fl_ctx) context_data = bytes_to_proto(props) return context_data diff --git a/requirements-min.txt b/requirements-min.txt index 33062f4187..5f6c76dfa4 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -10,3 +10,5 @@ PyYAML==6.0 requests>=2.28.0 six>=1.15.0 tenseal==0.3.0 +msgpack==1.0.3 + diff --git a/setup.py b/setup.py index 8a5b530fe6..c8732f612b 100644 --- a/setup.py +++ b/setup.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os import shutil -import datetime from setuptools import find_packages, setup diff --git a/tests/integration_test/data/apps/cyclic/config/config_fed_server.json b/tests/integration_test/data/apps/cyclic/config/config_fed_server.json index f6fe809494..c0f30408e6 100755 --- a/tests/integration_test/data/apps/cyclic/config/config_fed_server.json +++ b/tests/integration_test/data/apps/cyclic/config/config_fed_server.json @@ -10,7 +10,7 @@ "id": "persistor", "path": "tests.integration_test.tf2.model_persistor.TF2ModelPersistor", "args": { - "save_name": "tf2weights.pickle" + "save_name": "tf2weights.fobs" } }, { diff --git a/tests/integration_test/data/apps/tf/config/config_fed_server.json b/tests/integration_test/data/apps/tf/config/config_fed_server.json index 68ec78930b..ec2238c2ac 100644 --- a/tests/integration_test/data/apps/tf/config/config_fed_server.json +++ b/tests/integration_test/data/apps/tf/config/config_fed_server.json @@ -10,7 +10,7 @@ "id": "persistor", "path": "tests.integration_test.tf2.model_persistor.TF2ModelPersistor", "args": { - "save_name": "tf2weights.pickle" + "save_name": "tf2weights.fobs" } }, { diff --git a/tests/integration_test/tf2/model_persistor.py b/tests/integration_test/tf2/model_persistor.py index a8484be506..a0174bcd28 100644 --- a/tests/integration_test/tf2/model_persistor.py +++ b/tests/integration_test/tf2/model_persistor.py @@ -14,7 +14,6 @@ import json import os -import pickle import tensorflow as tf @@ -24,12 +23,13 @@ from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable from nvflare.app_common.abstract.model_persistor import ModelPersistor from nvflare.app_common.app_constant import AppConstants +from nvflare.fuel.utils import fobs from .net import Net class TF2ModelPersistor(ModelPersistor): - def __init__(self, save_name="tf2_model.pkl"): + def __init__(self, save_name="tf2_model.fobs"): super().__init__() self.save_name = save_name @@ -66,7 +66,7 @@ def _initialize(self, fl_ctx: FLContext): self.log_dir = os.path.join(app_root, log_dir) else: self.log_dir = app_root - self._pkl_save_path = os.path.join(self.log_dir, self.save_name) + self._fobs_save_path = os.path.join(self.log_dir, self.save_name) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) @@ -82,10 +82,10 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable: Model object """ - if os.path.exists(self._pkl_save_path): + if os.path.exists(self._fobs_save_path): self.logger.info("Loading server weights") - with open(self._pkl_save_path, "rb") as f: - model_learnable = pickle.load(f) + with open(self._fobs_save_path, "rb") as f: + model_learnable = fobs.load(f) else: self.logger.info("Initializing server model") network = Net() @@ -109,5 +109,5 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext): """ model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()} self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}") - with open(self._pkl_save_path, "wb") as f: - pickle.dump(model_learnable, f) + with open(self._fobs_save_path, "wb") as f: + fobs.dump(model_learnable, f) diff --git a/tests/integration_test/validators/tf_model_validator.py b/tests/integration_test/validators/tf_model_validator.py index 49897dd8a6..ab1390bccd 100644 --- a/tests/integration_test/validators/tf_model_validator.py +++ b/tests/integration_test/validators/tf_model_validator.py @@ -13,9 +13,11 @@ # limitations under the License. import os -import pickle from nvflare.apis.fl_constant import WorkspaceConstants +from nvflare.apis.utils.decomposers import flare_decomposers +from nvflare.app_common.decomposers import common_decomposers +from nvflare.fuel.utils import fobs from .job_result_validator import FinishJobResultValidator @@ -28,13 +30,16 @@ def validate_finished_results(self, job_result, client_props) -> bool: self.logger.error(f"models dir {server_models_dir} doesn't exist.") return False - model_path = os.path.join(server_models_dir, "tf2weights.pickle") + model_path = os.path.join(server_models_dir, "tf2weights.fobs") if not os.path.isfile(model_path): self.logger.error(f"model_path {model_path} doesn't exist.") return False try: - data = pickle.load(open(model_path, "rb")) + flare_decomposers.register() + common_decomposers.register() + + data = fobs.load(open(model_path, "rb")) self.logger.info(f"Data loaded: {data}.") assert "weights" in data assert "meta" in data diff --git a/tests/unit_test/apis/utils/decomposers/__init__.py b/tests/unit_test/apis/utils/decomposers/__init__.py new file mode 100644 index 0000000000..2b8f6c7e87 --- /dev/null +++ b/tests/unit_test/apis/utils/decomposers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/apis/utils/decomposers/flare_decomposers_test.py b/tests/unit_test/apis/utils/decomposers/flare_decomposers_test.py new file mode 100644 index 0000000000..176f8eae63 --- /dev/null +++ b/tests/unit_test/apis/utils/decomposers/flare_decomposers_test.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ReservedKey, ServerCommandKey +from nvflare.apis.shareable import Shareable +from nvflare.apis.utils.decomposers import flare_decomposers +from nvflare.fuel.utils import fobs + + +class TestFlareDecomposers: + + ID1 = "abc" + ID2 = "xyz" + + @classmethod + def setup_class(cls): + flare_decomposers.register() + + def test_nested_shareable(self): + shareable = Shareable() + shareable[ReservedKey.TASK_ID] = TestFlareDecomposers.ID1 + + command_shareable = Shareable() + command_shareable[ReservedKey.TASK_ID] = TestFlareDecomposers.ID2 + command_shareable.set_header(ServerCommandKey.SHAREABLE, shareable) + + buf = fobs.dumps(command_shareable) + + new_command_shareable = fobs.loads(buf) + assert new_command_shareable[ReservedKey.TASK_ID] == TestFlareDecomposers.ID2 + + new_shareable = new_command_shareable.get_header(ServerCommandKey.SHAREABLE) + assert new_shareable[ReservedKey.TASK_ID] == TestFlareDecomposers.ID1 diff --git a/tests/unit_test/fuel/utils/fobs/__init__.py b/tests/unit_test/fuel/utils/fobs/__init__.py new file mode 100644 index 0000000000..2b8f6c7e87 --- /dev/null +++ b/tests/unit_test/fuel/utils/fobs/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/fuel/utils/fobs/decomposer_test.py b/tests/unit_test/fuel/utils/fobs/decomposer_test.py new file mode 100644 index 0000000000..50044b569f --- /dev/null +++ b/tests/unit_test/fuel/utils/fobs/decomposer_test.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict + +from nvflare.fuel.utils.fobs.decomposers.core_decomposers import OrderedDictDecomposer + + +class TestDecomposers: + def test_sorted_dict(self): + + test_list = [(3, "First"), (1, "Middle"), (2, "Last")] + test_data = OrderedDict(test_list) + + decomposer = OrderedDictDecomposer() + serializable = decomposer.decompose(test_data) + result = decomposer.recompose(serializable) + new_list = list(result.items()) + + assert test_list == new_list diff --git a/tests/unit_test/fuel/utils/fobs/fobs_test.py b/tests/unit_test/fuel/utils/fobs/fobs_test.py new file mode 100644 index 0000000000..c518970305 --- /dev/null +++ b/tests/unit_test/fuel/utils/fobs/fobs_test.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +from datetime import datetime +from typing import Any, Type + +import pytest + +from nvflare.fuel.utils import fobs +from nvflare.fuel.utils.fobs import Decomposer + + +class TestFobs: + + NUMBER = 123456 + FLOAT = 123.456 + NOW = datetime.now() + + test_data = { + "str": "Test string", + "number": NUMBER, + "float": FLOAT, + "list": [7, 8, 9], + "set": {4, 5, 6}, + "tuple": ("abc", "xyz"), + "time": NOW, + } + + def test_builtin(self): + buf = fobs.dumps(TestFobs.test_data) + data = fobs.loads(buf) + assert data["number"] == TestFobs.NUMBER + + def test_aliases(self): + buf = fobs.dumps(TestFobs.test_data) + data = fobs.loads(buf) + assert data["number"] == TestFobs.NUMBER + + def test_unsupported_classes(self): + with pytest.raises(TypeError): + # Queue is just a random built-in class not supported by FOBS + unsupported_class = queue.Queue() + fobs.dumps(unsupported_class) + + def test_decomposers(self): + test_class = TestClass(TestFobs.NUMBER) + fobs.register(TestClassDecomposer) + buf = fobs.dumps(test_class) + new_class = fobs.loads(buf) + assert new_class.number == TestFobs.NUMBER + + +class TestClass: + def __init__(self, number): + self.number = number + + +class TestClassDecomposer(Decomposer): + @staticmethod + def supported_type(): + return TestClass + + def decompose(self, target: TestClass) -> Any: + return target.number + + def recompose(self, data: Any) -> TestClass: + return TestClass(data) diff --git a/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py b/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py index b91f5144b8..780e804b2d 100644 --- a/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py +++ b/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py @@ -16,8 +16,8 @@ import os import shutil import tempfile -from unittest.mock import patch from argparse import Namespace +from unittest.mock import patch import pytest