Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable run cross-validation without training workflow and examples #2035

Merged
merged 30 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3e6967e
Enable re-run cross-validation without training workflow, added the e…
yhwen Sep 26, 2023
70a88e4
codestyle fix.
yhwen Sep 26, 2023
6f4eb66
Added README.md to explain how the example been built.
yhwen Sep 27, 2023
cdc3c34
updated Readme.md
yhwen Sep 27, 2023
c4e1226
Merge branch 'main' into rerun_cross_validation
chesterxgchen Sep 28, 2023
cc19063
re-engineer the re-run cross-validation, making use of the global_mod…
yhwen Nov 6, 2023
a2417e9
updated the README.
yhwen Nov 7, 2023
d4fa366
Updated README.
yhwen Nov 8, 2023
0a7de2e
Updated README.md.
yhwen Nov 8, 2023
ee62f5f
Added hello-numpy-cross-val for cross-validation only example.
yhwen Nov 17, 2023
187c9fc
Merged from main
yhwen Nov 18, 2023
6b1127e
added 2 cross-validation only examples.
yhwen Nov 18, 2023
04e9498
Added the README for the cross-validation examples.
yhwen Nov 20, 2023
779a2e0
Moved the README.md to a general place.
yhwen Nov 20, 2023
35246bf
Added a cross-validation only example which supports providing a list…
yhwen Nov 22, 2023
f712bc7
Added hello-numpy-cross-val only examples.
yhwen Nov 27, 2023
cf3ab9c
Added file_list_model_persistor.py to support list of models with loc…
yhwen Nov 27, 2023
bec53a3
Removed the examples/advanced/cross-validation-without-training.
yhwen Nov 28, 2023
8471396
Updated the REAADME.md.
yhwen Nov 28, 2023
52f6fe8
Removed the persistor ID from the config.
yhwen Nov 28, 2023
3d5874c
Maded the FileListModelLocator general purpose.
yhwen Nov 28, 2023
95ae7ae
Codestyle fix.
yhwen Nov 28, 2023
3e51534
Update the constructor.
yhwen Nov 29, 2023
85ba2a4
Moved the list_model_locator.py to app_common.
yhwen Nov 29, 2023
a8b473d
Changed back to model_name in NPModelLocator.
yhwen Nov 29, 2023
0fa82a0
Updated NPModelLocator type check.
yhwen Nov 29, 2023
af39df6
updated the README for hello-numpy-cross-val.
yhwen Dec 5, 2023
3218269
Use a script to generate the pre-trained models for cross-validation …
yhwen Dec 6, 2023
2a6241c
codestyle fix.
yhwen Dec 6, 2023
ead164d
Merge branch 'main' into rerun_cross_validation
YuanTingHsieh Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/hello-world/hello-numpy-cross-val/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,28 @@ $ ls /tmp/nvflare/simulate_job/
app_server app_site-1 app_site-2 log.txt

```

# Run cross site validation using the previous trained results

## Introduction

The "hello-numpy-cross-val-only" and "hello-numpy-cross-val-only-list-models" jobs show how to run the NVFlare cross-site validation without the training workflow, making use of the previous run results. The first one uses the default single server model. The second enables a list of server models. You can provide / use your own previous trained models for the cross-validation.

### Generate the previous run best global model and local best model

Run the following command to generate the pre-trained models:
yhwen marked this conversation as resolved.
Show resolved Hide resolved

```
python pre_train_models.py
yhwen marked this conversation as resolved.
Show resolved Hide resolved
```

### How to run the Job

Define two OS system variable "SERVER_MODEL_DIR" and "CLIENT_MODEL_DIR" to point to the absolute path of the server best model and local best model location respectively. Then use the NVFlare admin command "submit_job" to submit and run the cross-validation job.
yhwen marked this conversation as resolved.
Show resolved Hide resolved

For example, define the system variable "SERVER_MODEL_DIR" like this:

```
export SERVER_MODEL_DIR="/path/to/model/location/at/server-side"
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"format_version": 2,
"model_dir": "{$CLIENT_MODEL_DIR}",
"executors": [
{
"tasks": [
"train",
"submit_model"
],
"executor": {
"path": "nvflare.app_common.np.np_trainer.NPTrainer",
"args": {
"model_dir": "{model_dir}"
}
}
},
{
"tasks": [
"validate"
],
"executor": {
"path": "nvflare.app_common.np.np_validator.NPValidator"
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": []
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"format_version": 2,
"model_dir": "{$SERVER_MODEL_DIR}",
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "model_locator",
"path": "nvflare.app_common.np.np_model_locator.NPModelLocator",
"args": {
"model_dir": "{model_dir}",
"model_names": {
"server_model_1": "server_1.npy",
"server_model_2": "server_2.npy"
}
}
},
{
"id": "json_generator",
"path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
{
"id": "cross_site_model_eval",
"path": "nvflare.app_common.workflows.cross_site_model_eval.CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": false
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "hello-numpy-cross-val",
"resource_spec": {},
"min_clients" : 2,
"deploy_map": {
"app": [
"@ALL"
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

from nvflare.app_common.abstract.model import ModelLearnableKey, make_model_learnable
from nvflare.app_common.np.constants import NPConstants

SERVER_MODEL_DIR = "models/server"
CLIENT_MODEL_DIR = "models/client"

if __name__ == "__main__":
"""
This is the tool to generate the pre-trained models for demonstrating the cross-validation without training.
"""

model_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
model_learnable = make_model_learnable(weights={NPConstants.NUMPY_KEY: model_data}, meta_props={})

working_dir = os.getcwd()
model_dir = os.path.join(working_dir, SERVER_MODEL_DIR)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_path = os.path.join(model_dir, "server_1.npy")
np.save(model_path, model_learnable[ModelLearnableKey.WEIGHTS][NPConstants.NUMPY_KEY])
model_path = os.path.join(model_dir, "server_2.npy")
np.save(model_path, model_learnable[ModelLearnableKey.WEIGHTS][NPConstants.NUMPY_KEY])

model_dir = os.path.join(working_dir, CLIENT_MODEL_DIR)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_save_path = os.path.join(model_dir, "best_numpy.npy")
np.save(model_save_path, model_data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"format_version": 2,
"model_dir": "{$CLIENT_MODEL_DIR}",
"executors": [
{
"tasks": [
"train",
"submit_model"
],
"executor": {
"path": "nvflare.app_common.np.np_trainer.NPTrainer",
"args": {
"model_dir": "{model_dir}"
}
}
},
{
"tasks": [
"validate"
],
"executor": {
"path": "nvflare.app_common.np.np_validator.NPValidator"
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": []
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"format_version": 2,
"model_dir": "{$SERVER_MODEL_DIR}",
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "model_locator",
"path": "nvflare.app_common.np.np_model_locator.NPModelLocator",
"args": {
"model_dir": "{model_dir}"
}
},
{
"id": "json_generator",
"path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
{
"id": "cross_site_model_eval",
"path": "nvflare.app_common.workflows.cross_site_model_eval.CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": false
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "hello-numpy-cross-val",
"resource_spec": {},
"min_clients" : 2,
"deploy_map": {
"app": [
"@ALL"
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

from nvflare.app_common.abstract.model import ModelLearnableKey, make_model_learnable
from nvflare.app_common.np.constants import NPConstants

SERVER_MODEL_DIR = "models/server"
CLIENT_MODEL_DIR = "models/client"

if __name__ == "__main__":
"""
This is the tool to generate the pre-trained models for demonstrating the cross-validation without training.
"""

model_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
model_learnable = make_model_learnable(weights={NPConstants.NUMPY_KEY: model_data}, meta_props={})

working_dir = os.getcwd()
model_dir = os.path.join(working_dir, SERVER_MODEL_DIR)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_path = os.path.join(model_dir, "server.npy")
np.save(model_path, model_learnable[ModelLearnableKey.WEIGHTS][NPConstants.NUMPY_KEY])

model_dir = os.path.join(working_dir, CLIENT_MODEL_DIR)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_save_path = os.path.join(model_dir, "best_numpy.npy")
np.save(model_save_path, model_data)
13 changes: 13 additions & 0 deletions nvflare/app_common/model_locator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
80 changes: 80 additions & 0 deletions nvflare/app_common/model_locator/list_model_locator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from nvflare.apis.dxo import DXO
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import model_learnable_to_dxo
from nvflare.app_common.abstract.model_locator import ModelLocator
from nvflare.app_common.abstract.model_persistor import ModelPersistor


class ListModelLocator(ModelLocator):
def __init__(self, persistor_id: str, model_list={}):
"""The ModelLocator's job is to find and locate the models inventory saved during training.

Args:
persistor_id (str): ModelPersistor component ID
model_list: a list of model_names and locations
"""
super().__init__()

self.persistor_id = persistor_id

self.model_persistor = None
self.model_list = model_list

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self._initialize(fl_ctx)

def _initialize(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self.model_persistor: ModelPersistor = engine.get_component(self.persistor_id)
if self.model_persistor is None or not isinstance(self.model_persistor, ModelPersistor):
raise ValueError(
f"persistor_id component must be PTFileModelPersistor. " f"But got: {type(self.model_persistor)}"
)

def get_model_names(self, fl_ctx: FLContext) -> List[str]:
"""Returns the list of model names that should be included from server in cross site validation.add().

Args:
fl_ctx (FLContext): FL Context object.

Returns:
List[str]: List of model names.
"""
return list(self.model_list.keys())

def locate_model(self, model_name, fl_ctx: FLContext) -> DXO:
"""Call to locate and load the model weights of model_name.

Args:
model_name: name of the model
fl_ctx: FLContext

Returns: model_weight DXO

"""
if model_name not in list(self.model_list.keys()):
raise ValueError(f"model inventory does not contain: {model_name}")

location = self.model_list[model_name]
model_learnable = self.model_persistor.get_model_from_location(location, fl_ctx)
dxo = model_learnable_to_dxo(model_learnable)

return dxo
Loading