Skip to content

Commit

Permalink
Merged FOBS from 2.1 to dev (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz authored Sep 2, 2022
1 parent cd4a135 commit 88c9419
Show file tree
Hide file tree
Showing 53 changed files with 1,197 additions and 139 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10/pt/utils/cifar10_data_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
import numpy as np
import torchvision.datasets as datasets

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.event_type import EventType

CIFAR10_ROOT = "/tmp/cifar10" # will be used for all CIFAR-10 experiments

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, List
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from load_data_utils import get_app_paths, load_config
from pandas.core.series import Series

from load_data_utils import get_app_paths, load_config
from nvflare.apis.fl_constant import ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.statistics_spec import Statistics,BinRange, Histogram, HistogramType, Feature
from nvflare.app_common.statistics.numpy_utils import get_std_histogram_buckets
from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type
from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics
from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets


class DFStatistics(Statistics):
Expand Down
2 changes: 1 addition & 1 deletion examples/federated_statistics/df_stats/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import wget
from pyhocon import ConfigFactory

from nvflare.lighter.poc_commands import is_poc_ready, get_nvflare_home
from nvflare.lighter.poc_commands import get_nvflare_home, is_poc_ready


def get_poc_workspace():
Expand Down
2 changes: 1 addition & 1 deletion examples/hello-cyclic/app/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"id": "persistor",
"path": "tf2_model_persistor.TF2ModelPersistor",
"args": {
"save_name": "tf2weights.pickle"
"save_name": "tf2weights.fobs"
}
},
{
Expand Down
24 changes: 12 additions & 12 deletions examples/hello-cyclic/app/custom/tf2_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import json
import os

import tensorflow as tf
from tf2_net import Net

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable
from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from tf2_net import Net
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.fuel.utils import fobs


class TF2ModelPersistor(ModelPersistor):
def __init__(self, save_name="tf2_model.pkl"):
def __init__(self, save_name="tf2_model.fobs"):
super().__init__()
self.save_name = save_name

Expand Down Expand Up @@ -65,7 +65,7 @@ def _initialize(self, fl_ctx: FLContext):
self.log_dir = os.path.join(app_root, log_dir)
else:
self.log_dir = app_root
self._pkl_save_path = os.path.join(self.log_dir, self.save_name)
self._fobs_save_path = os.path.join(self.log_dir, self.save_name)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

Expand All @@ -82,10 +82,10 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
Model object
"""

if os.path.exists(self._pkl_save_path):
if os.path.exists(self._fobs_save_path):
self.logger.info("Loading server weights")
with open(self._pkl_save_path, "rb") as f:
model_learnable = pickle.load(f)
with open(self._fobs_save_path, "rb") as f:
model_learnable = fobs.load(f)
else:
self.logger.info("Initializing server model")
network = Net()
Expand All @@ -109,5 +109,5 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
"""
model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
with open(self._pkl_save_path, "wb") as f:
pickle.dump(model_learnable, f)
with open(self._fobs_save_path, "wb") as f:
fobs.dump(model_learnable, f)
7 changes: 3 additions & 4 deletions examples/hello-cyclic/app/custom/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
import tensorflow as tf
from tf2_net import Net

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal

from tf2_net import Net


class SimpleTrainer(Executor):
def __init__(self, epochs_per_round):
Expand Down
4 changes: 2 additions & 2 deletions examples/hello-monai-bundle/app/custom/monai_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np
import torch
from bundle_configer import BundleConfiger

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
Expand All @@ -25,8 +27,6 @@
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants

from bundle_configer import BundleConfiger


class MONAIBundleTrainer(Executor):
"""
Expand Down
4 changes: 2 additions & 2 deletions examples/hello-monai/app/custom/monai_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np
import torch
from train_configer import TrainConfiger

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
Expand All @@ -25,8 +27,6 @@
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants

from train_configer import TrainConfiger


class MONAITrainer(Executor):
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/hello-tf2/app/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"id": "persistor",
"path": "tf2_model_persistor.TF2ModelPersistor",
"args": {
"save_name": "tf2weights.pickle"
"save_name": "tf2weights.fobs"
}
},
{
Expand Down
1 change: 1 addition & 0 deletions examples/hello-tf2/app/custom/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re

import numpy as np

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.filter import Filter
from nvflare.apis.fl_context import FLContext
Expand Down
24 changes: 12 additions & 12 deletions examples/hello-tf2/app/custom/tf2_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import json
import os

import tensorflow as tf
from tf2_net import Net

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable
from nvflare.app_common.abstract.model import ModelLearnable, make_model_learnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from tf2_net import Net
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.fuel.utils import fobs


class TF2ModelPersistor(ModelPersistor):
def __init__(self, save_name="tf2_model.pkl"):
def __init__(self, save_name="tf2_model.fobs"):
super().__init__()
self.save_name = save_name

Expand Down Expand Up @@ -65,7 +65,7 @@ def _initialize(self, fl_ctx: FLContext):
self.log_dir = os.path.join(app_root, log_dir)
else:
self.log_dir = app_root
self._pkl_save_path = os.path.join(self.log_dir, self.save_name)
self._fobs_save_path = os.path.join(self.log_dir, self.save_name)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

Expand All @@ -81,10 +81,10 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
Model object
"""

if os.path.exists(self._pkl_save_path):
if os.path.exists(self._fobs_save_path):
self.logger.info("Loading server weights")
with open(self._pkl_save_path, "rb") as f:
model_learnable = pickle.load(f)
with open(self._fobs_save_path, "rb") as f:
model_learnable = fobs.load(f)
else:
self.logger.info("Initializing server model")
network = Net()
Expand All @@ -108,5 +108,5 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
"""
model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
with open(self._pkl_save_path, "wb") as f:
pickle.dump(model_learnable, f)
with open(self._fobs_save_path, "wb") as f:
fobs.dump(model_learnable, f)
7 changes: 3 additions & 4 deletions examples/hello-tf2/app/custom/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
import tensorflow as tf
from tf2_net import Net

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal

from tf2_net import Net


class SimpleTrainer(Executor):
def __init__(self, epochs_per_round):
Expand Down
8 changes: 4 additions & 4 deletions nvflare/apis/dxo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

import copy
import pickle
from typing import List

from nvflare.apis.shareable import ReservedHeaderKey, Shareable
from nvflare.fuel.utils import fobs


class DataKind(object):
Expand Down Expand Up @@ -115,7 +115,7 @@ def to_bytes(self) -> bytes:
object serialized in bytes.
"""
return pickle.dumps(self)
return fobs.dumps(self)

def validate(self) -> str:
if self.data is None:
Expand Down Expand Up @@ -167,10 +167,10 @@ def from_bytes(data: bytes) -> DXO:
data: a bytes object
Returns:
an object loaded by pickle from data
an object loaded by FOBS from data
"""
x = pickle.loads(data)
x = fobs.loads(data)
if isinstance(x, DXO):
return x
else:
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ReservedKey(object):
EVENT_SCOPE = "__event_scope__"
RUN_ABORT_SIGNAL = "__run_abort_signal__"
SHAREABLE = "__shareable__"
SHARED_FL_CONTEXT = "__shared_fl_context__"
ARGS = "__args__"
WORKSPACE_OBJECT = "__workspace_object__"
RANK_NUMBER = "__rank_number__"
Expand Down
12 changes: 6 additions & 6 deletions nvflare/apis/impl/job_def_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import datetime
import os
import pathlib
import pickle
import shutil
import tempfile
import time
Expand All @@ -29,6 +28,7 @@
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.storage import StorageException, StorageSpec
from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.fuel.utils import fobs


class _JobFilter(ABC):
Expand Down Expand Up @@ -108,7 +108,7 @@ def create(self, meta: dict, uploaded_content: bytes, fl_ctx: FLContext) -> Dict
# write it to the store
stored_data = {JobDataKey.JOB_DATA.value: uploaded_content, JobDataKey.WORKSPACE_DATA.value: None}
store = self._get_job_store(fl_ctx)
store.create_object(self.job_uri(jid), pickle.dumps(stored_data), meta, overwrite_existing=True)
store.create_object(self.job_uri(jid), fobs.dumps(stored_data), meta, overwrite_existing=True)
return meta

def delete(self, jid: str, fl_ctx: FLContext):
Expand Down Expand Up @@ -185,12 +185,12 @@ def get_content(self, jid: str, fl_ctx: FLContext) -> Optional[bytes]:
stored_data = store.get_data(self.job_uri(jid))
except StorageException:
return None
return pickle.loads(stored_data).get(JobDataKey.JOB_DATA.value)
return fobs.loads(stored_data).get(JobDataKey.JOB_DATA.value)

def get_job_data(self, jid: str, fl_ctx: FLContext) -> dict:
store = self._get_job_store(fl_ctx)
stored_data = store.get_data(self.job_uri(jid))
return pickle.loads(stored_data)
return fobs.loads(stored_data)

def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext):
meta = {JobMetaKey.STATUS.value: status.value}
Expand Down Expand Up @@ -259,6 +259,6 @@ def set_approval(
def save_workspace(self, jid: str, data: bytes, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
stored_data = store.get_data(self.job_uri(jid))
job_data = pickle.loads(stored_data)
job_data = fobs.loads(stored_data)
job_data[JobDataKey.WORKSPACE_DATA.value] = data
store.update_data(self.job_uri(jid), pickle.dumps(job_data))
store.update_data(self.job_uri(jid), fobs.dumps(job_data))
9 changes: 4 additions & 5 deletions nvflare/apis/shareable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle

from ..fuel.utils import fobs
from .fl_constant import ReservedKey, ReturnCode


Expand Down Expand Up @@ -113,7 +112,7 @@ def to_bytes(self) -> bytes:
object serialized in bytes.
"""
return pickle.dumps(self)
return fobs.dumps(self)

@classmethod
def from_bytes(cls, data: bytes):
Expand All @@ -123,10 +122,10 @@ def from_bytes(cls, data: bytes):
data: a bytes object
Returns:
an object loaded by pickle from data
an object loaded by FOBS from data
"""
return pickle.loads(data)
return fobs.loads(data)


# some convenience functions
Expand Down
Loading

0 comments on commit 88c9419

Please sign in to comment.