diff --git a/3rdParty/tdigest.LICENSE.txt b/3rdParty/tdigest.LICENSE.txt new file mode 100644 index 0000000000..c4da63fac0 --- /dev/null +++ b/3rdParty/tdigest.LICENSE.txt @@ -0,0 +1,23 @@ +https://github.com/CamDavidsonPilon/tdigest/blob/master/LICENSE.txt + +The MIT License (MIT) + +Copyright (c) 2015 Cameron Davidson-Pilon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32843e03e6..07e05e2264 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,6 +44,10 @@ To collaborate efficiently, please read through this section and follow them. * [Building documentation](#building-the-documentation) * [Signing your work](#signing-your-work) +> Note: + > some package dependencies requires python-dev in local development such as + > python3.12-dev. + #### Checking the coding style We check code style using flake8 and isort. A bash script (`runtest.sh`) is provided to run all tests locally. diff --git a/README.md b/README.md index 8e7cfa3404..f2de47861e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # NVIDIA FLARE -[Website](https://nvidia.github.io/NVFlare) | [Paper](https://arxiv.org/abs/2210.13291) | [Talks & Blogs](https://nvflare.readthedocs.io/en/main/publications_and_talks.html) | [Research](./research/README.md) | [Documentation](https://nvflare.readthedocs.io/en/main) +[Website](https://nvidia.github.io/NVFlare) | [Paper](https://arxiv.org/abs/2210.13291) | [Blogs](https://developer.nvidia.com/blog/tag/federated-learning) | [Talks & Papers](https://nvflare.readthedocs.io/en/main/publications_and_talks.html) | [Research](./research/README.md) | [Documentation](https://nvflare.readthedocs.io/en/main) [![Blossom-CI](https://github.com/NVIDIA/nvflare/workflows/Blossom-CI/badge.svg?branch=main)](https://github.com/NVIDIA/nvflare/actions) [![documentation](https://readthedocs.org/projects/nvflare/badge/?version=main)](https://nvflare.readthedocs.io/en/main/?badge=main) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 9cb8d7dca7..ed0bfad8e2 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -4,4 +4,9 @@ Getting Started ############### -See :ref:`installation`. +### To get started, take a look at the instructions and code examples on our [**website**](https://nvidia.github.io/NVFlare/). + +* For more details, see our getting started `tutorials `__. +* For more advanced examples and `step-by-step `__ walk-troughs, see our `examples `__. +* There are also detailed instructions on how to convert your standalone/centralized training code to `federated learning code. `__. +* If you'd like to write your own NVIDIA FLARE components, a detailed programming guide can be found `here `__. diff --git a/docs/publications_and_talks.rst b/docs/publications_and_talks.rst index 422512da72..cf0bef59dd 100644 --- a/docs/publications_and_talks.rst +++ b/docs/publications_and_talks.rst @@ -9,7 +9,15 @@ including papers using NVIDIA FLARE's predecessor libraries included in the `Cla Publications: 2024 ------------------ -* **2024-02** `Empowering Federated Learning for Massive Models with NVIDIA FLARE `__ (Accepted to `FL@FM-TheWebConf'24 `__)) +* **2024-12** `C-FedRAG: A Confidential Federated Retrieval-Augmented Generation System `__ (preprint) +* **2024-11** `Toward a tipping point in federated learning in healthcare and life sciences `__ (`Patterns, Volume 5, Issue 11, 2024, `__) +* **2024-07** `FedBPT: Efficient Federated Black-box Prompt Tuning for Large Language Models `__ (`ICML 2024 `__) +* **2024-07** `Fair evaluation of federated learning algorithms for automated breast density classification: The results of the 2022 ACR-NCI-NVIDIA federated learning challenge `__ (`Medical Image Analysis, Volume 95, July 2024 `__) +* **2024-07** `Easy and Scalable Federated Learning in the Age of Large Language Models with NVIDIA FLARE `__ (`FL@FM-ICME'24 `__) +* **2024-05** `Federated Learning Privacy: Attacks, Defenses, Applications, and Policy Landscape - A Survey `__ (preprint) +* **2024-05** `Supercharging Federated Learning with Flower and NVIDIA FLARE `__ (Presented at `FL@FM-IJCAI `__ In preparation for Lecture Notes in AI) +* **2024-05** `An in-depth evaluation of federated learning on biomedical natural language processing for information extraction `__ (`Nature Digital Medicine 7, 127, 2024 `__) +* **2024-02** `Empowering Federated Learning for Massive Models with NVIDIA FLARE `__ (Presented at `FL@FM-TheWebConf'24 `__, `Springer Book Chapter `__)) Publications: 2023 ------------------ @@ -21,7 +29,6 @@ Publications: 2023 Publications: 2022 ------------------ -* **2022-11** `Federated Learning with Azure Machine Learning `__ (Video) * **2022-10** `Auto-FedRL: Federated Hyperparameter Optimization for Multi-institutional Medical Image Segmentation `__ (`ECCV 2022 `__) * **2022-10** `Joint Multi Organ and Tumor Segmentation from Partial Labels Using Federated Learning `__ (`DeCaF @ MICCAI 2022 `__) * **2022-10** `Split-U-Net: Preventing Data Leakage in Split Learning for Collaborative Multi-modal Brain Tumor Segmentation `__ (`DeCaF @ MICCAI 2022 `__) @@ -52,6 +59,8 @@ NVIDIA FLARE related blogs and other media. Blogs & Videos: 2024 -------------------- +* **2024-04** `Differential Privacy and Federated Learning for Medical Data `__ (Roche Technical Blog) +* **2024-03** `Announcing NVIDIA and Flower Collaboration `__ (Flower Blog) * **2024-03** `Turning Machine Learning to Federated Learning in Minutes with NVIDIA FLARE 2.4 `__ (NVIDIA Technical Blog) * **2024-02** `Scalable Federated Learning with NVIDIA FLARE for Enhanced LLM Performance `__ (NVIDIA Technical Blog) @@ -66,7 +75,6 @@ Blogs & Videos: 2023 Blogs & Videos: 2022 -------------------- - * **2022-10** `Federated Learning from Simulation to Production with NVIDIA FLARE `__ (NVIDIA Technical Blog) * **2022-08** `Using Federated Learning to Bridge Data Silos in Financial Services `__ (NVIDIA Technical Blog) * **2022-06** `Experimenting with Novel Distributed Applications Using NVIDIA Flare 2.1 `__ (NVIDIA Technical Blog) @@ -94,7 +102,9 @@ Recent talks and Webinars covering federated learning research and NVIDIA FLARE. Talks: 2024 ----------- -* **2024-03** `Empowering Federated Learning for Massive Models with NVIDIA FLARE `__ (`SFBigAnalytics Meetup `__) +* **2024-12** `Real-world Federated Learning with NVIDIA FLARE `__ [Passcode: !Ms8Tw.u8H] (`UCSF Biostatistics and Bioinformatics Seminar `__) +* **2024-04** `Federated Learning: Towards Real-world Studies `__ (`SFBigAnalytics Meetup `__) +* **2024-03** `Empowering Federated Learning for Massive Models with NVIDIA FLARE `__ (`SFBigAnalytics Meetup `__) Talks: 2023 ----------- @@ -103,6 +113,7 @@ Talks: 2023 Talks: 2022 ----------- +* **2022-11** `Federated Learning with Azure Machine Learning `__ (Microsoft Developer Video) * **2022-10** `Modern Tools for Collaborative Medical Image Analysis `__ (`Keynote - DART @ MICCAI 2022 `__) * **2022-07** `NVIDIA FLARE Tutorial for Beginners `__ (United Imaging Meetup) * **2022-07** `Techniques and Tools for Collaborative Development of AI Models across Institutes `__ (`VALSE Webinar `__) diff --git a/docs/resources/log.config b/docs/resources/log.config index 07c2963686..a245312cf7 100644 --- a/docs/resources/log.config +++ b/docs/resources/log.config @@ -1,28 +1,77 @@ -[loggers] -keys=root - -[handlers] -keys=consoleHandler,errorFileHandler - -[formatters] -keys=baseFormatter - -[logger_root] -level=INFO -handlers=consoleHandler,errorFileHandler - -[handler_consoleHandler] -class=StreamHandler -level=DEBUG -formatter=baseFormatter -args=(sys.stdout,) - -[handler_errorFileHandler] -class=FileHandler -level=ERROR -formatter=baseFormatter -args=('error_log.txt', 'a') - -[formatter_baseFormatter] -class=nvflare.fuel.utils.log_utils.BaseFormatter -format=%(asctime)s - %(name)s - %(levelname)s - %(message)s \ No newline at end of file +{ + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "baseFormatter": { + "()": "nvflare.fuel.utils.log_utils.BaseFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "colorFormatter": { + "()": "nvflare.fuel.utils.log_utils.ColorFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "jsonFormatter": { + "()": "nvflare.fuel.utils.log_utils.JsonFormatter", + "fmt": "%(asctime)s - %(name)s - %(fullName)s - %(levelname)s - %(message)s" + } + }, + "filters": { + "FLFilter": { + "()": "nvflare.fuel.utils.log_utils.LoggerNameFilter", + "logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"] + } + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "colorFormatter", + "filters": [], + "stream": "ext://sys.stdout" + }, + "logFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filename": "log.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "errorFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "baseFormatter", + "filename": "log_error.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "jsonFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "jsonFormatter", + "filename": "log.json", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "FLFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filters": ["FLFilter"], + "filename": "log_fl.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10, + "delay": true + } + }, + "loggers": { + "root": { + "level": "INFO", + "handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"] + } + } +} \ No newline at end of file diff --git a/docs/resources/researcher.svg b/docs/resources/researcher.svg new file mode 100644 index 0000000000..c1d8d9a502 --- /dev/null +++ b/docs/resources/researcher.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/examples/advanced/federated-statistics/df_stats.ipynb b/examples/advanced/federated-statistics/df_stats.ipynb index 93d3e86140..944544e4e5 100644 --- a/examples/advanced/federated-statistics/df_stats.ipynb +++ b/examples/advanced/federated-statistics/df_stats.ipynb @@ -144,15 +144,27 @@ { "cell_type": "code", "execution_count": null, - "id": "0d5041aa-c2e0-4af6-a2c8-bae76e4512d0", + "id": "6361a85e-4187-433c-976c-0dc4021908ac", + "metadata": {}, + "outputs": [], + "source": [ + "! nvflare simulator df_stats/jobs/df_stats -w /tmp/nvflare/df/workdir -n 2 -t 2" + ] + }, + { + "cell_type": "markdown", + "id": "4fdbfb95-90c9-4d45-b727-dab6f5a8bc41", "metadata": { "tags": [] }, - "outputs": [], "source": [ + "Or python code\n", + "```\n", "from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner\n", - "runner = SimulatorRunner(job_folder=\"df_stats/jobs/df_stats\", workspace=\"/tmp/nvflare/df_stats/workdir\", n_clients = 2, threads=2)\n", - "runner.run()" + "runner = SimulatorRunner(job_folder=\"df_stats/jobs/df_stats\", workspace=\"/tmp/nvflare/df/workdir\", n_clients = 2, threads=2)\n", + "runner.run()\n", + "\n", + "```" ] }, { @@ -167,7 +179,7 @@ "From a **terminal** one can also the following equivalent CLI\n", "\n", "```\n", - "nvflare simulator df_stats/jobs/df_stats -w /tmp/nvflare/df_stats -n 2 -t 2\n", + "nvflare simulator df_stats/jobs/df_stats -w /tmp/nvflare/df/workdir -n 2 -t 2\n", "\n", "```\n", "\n", @@ -184,9 +196,9 @@ "metadata": {}, "source": [ "\n", - "The results are stored in workspace \"/tmp/nvflare/df_stats/workdir/\"\n", + "The results are stored in workspace \"/tmp/nvflare/df/workdir/\"\n", "```\n", - "/tmp/nvflare/df_stats/workdir/server/simulate_job/statistics/adults_stats.json\n", + "/tmp/nvflare/df/workdir/server/simulate_job/statistics/adults_stats.json\n", "```" ] }, @@ -199,7 +211,7 @@ }, "outputs": [], "source": [ - "cat /tmp/nvflare/df_stats/workdir/server/simulate_job/statistics/adults_stats.json" + "cat /tmp/nvflare/df/workdir/server/simulate_job/statistics/adults_stats.json" ] }, { @@ -222,7 +234,7 @@ }, "outputs": [], "source": [ - "! cp /tmp/nvflare/df_stats/workdir/server/simulate_job/statistics/adults_stats.json df_stats/demo/." + "! cp /tmp/nvflare/df/workdir/server/simulate_job/statistics/adults_stats.json df_stats/demo/." ] }, { @@ -271,7 +283,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.19" + "version": "3.10.2" } }, "nbformat": 4, diff --git a/examples/advanced/federated-statistics/df_stats/demo/visualization.ipynb b/examples/advanced/federated-statistics/df_stats/demo/visualization.ipynb index 99eb5c4c91..283f5279b2 100644 --- a/examples/advanced/federated-statistics/df_stats/demo/visualization.ipynb +++ b/examples/advanced/federated-statistics/df_stats/demo/visualization.ipynb @@ -285,7 +285,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.17" + "version": "3.10.2" } }, "nbformat": 4, diff --git a/examples/advanced/federated-statistics/df_stats/job_api/df_statistics.py b/examples/advanced/federated-statistics/df_stats/job_api/df_statistics.py new file mode 100644 index 0000000000..5078c2f0a9 --- /dev/null +++ b/examples/advanced/federated-statistics/df_stats/job_api/df_statistics.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import pandas as pd + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.statistics.df.df_core_statistics import DFStatisticsCore + + +class DFStatistics(DFStatisticsCore): + def __init__(self, data_path): + super().__init__() + self.data_root_dir = "/tmp/nvflare/df_stats/data" + self.data_path = data_path + self.data: Optional[Dict[str, pd.DataFrame]] = None + self.data_features = [ + "Age", + "Workclass", + "fnlwgt", + "Education", + "Education-Num", + "Marital Status", + "Occupation", + "Relationship", + "Race", + "Sex", + "Capital Gain", + "Capital Loss", + "Hours per week", + "Country", + "Target", + ] + + # the original dataset has no header, + # we will use the adult.train dataset for site-1, the adult.test dataset for site-2 + # the adult.test dataset has incorrect formatted row at 1st line, we will skip it. + self.skip_rows = { + "site-1": [], + "site-2": [0], + } + + def load_data(self, fl_ctx: FLContext) -> Dict[str, pd.DataFrame]: + client_name = fl_ctx.get_identity_name() + self.log_info(fl_ctx, f"load data for client {client_name}") + try: + skip_rows = self.skip_rows[client_name] + data_path = f"{self.data_root_dir}/{fl_ctx.get_identity_name()}/{self.data_path}" + # example of load data from CSV + df: pd.DataFrame = pd.read_csv( + data_path, names=self.data_features, sep=r"\s*,\s*", skiprows=skip_rows, engine="python", na_values="?" + ) + train = df.sample(frac=0.8, random_state=200) # random state is a seed value + test = df.drop(train.index).sample(frac=1.0) + + self.log_info(fl_ctx, f"load data done for client {client_name}") + return {"train": train, "test": test} + + except Exception as e: + raise Exception(f"Load data for client {client_name} failed! {e}") + + def initialize(self, fl_ctx: FLContext): + self.data = self.load_data(fl_ctx) diff --git a/examples/advanced/federated-statistics/df_stats/job_api/df_stats_job.py b/examples/advanced/federated-statistics/df_stats/job_api/df_stats_job.py new file mode 100644 index 0000000000..696d57170c --- /dev/null +++ b/examples/advanced/federated-statistics/df_stats/job_api/df_stats_job.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from df_statistics import DFStatistics + +from nvflare.job_config.stats_job import StatsJob + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--n_clients", type=int, default=3) + parser.add_argument("-d", "--data_root_dir", type=str, nargs="?", default="/tmp/nvflare/dataset/output") + parser.add_argument("-o", "--stats_output_path", type=str, nargs="?", default="statistics/stats.json") + parser.add_argument("-j", "--job_dir", type=str, nargs="?", default="/tmp/nvflare/jobs/stats_df") + parser.add_argument("-w", "--work_dir", type=str, nargs="?", default="/tmp/nvflare/jobs/stats_df/work_dir") + parser.add_argument("-co", "--export_config", action="store_true", help="config only mode, export config") + + return parser.parse_args() + + +def main(): + args = define_parser() + + n_clients = args.n_clients + data_root_dir = args.data_root_dir + output_path = args.stats_output_path + job_dir = args.job_dir + work_dir = args.work_dir + export_config = args.export_config + + statistic_configs = { + "count": {}, + "mean": {}, + "sum": {}, + "stddev": {}, + "histogram": {"*": {"bins": 20}}, + "Age": {"bins": 20, "range": [0, 10]}, + "percentile": {"*": [25, 50, 75], "Age": [50, 95]}, + } + # define local stats generator + df_stats_generator = DFStatistics(data_root_dir=data_root_dir) + + job = StatsJob( + job_name="stats_df", + statistic_configs=statistic_configs, + stats_generator=df_stats_generator, + output_path=output_path, + ) + + sites = [f"site-{i + 1}" for i in range(n_clients)] + job.setup_clients(sites) + + if export_config: + job.export_job(job_dir) + else: + job.simulator_run(work_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_client.json b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_client.json index 858431b138..3d3ade8735 100644 --- a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_client.json +++ b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_client.json @@ -14,23 +14,7 @@ } } ], - "task_result_filters": [ - { - "tasks": ["fed_stats"], - "filters":[ - { - "path": "nvflare.app_common.filters.statistics_privacy_filter.StatisticsPrivacyFilter", - "args": { - "result_cleanser_ids": [ - "min_count_cleanser", - "min_max_noise_cleanser", - "hist_bins_cleanser" - ] - } - } - ] - } - ], + "task_data_filters": [], "components": [ { diff --git a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_server.json b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_server.json index 344ed08e4d..58e4b861fb 100644 --- a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_server.json +++ b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/config/config_fed_server.json @@ -18,10 +18,14 @@ "bins": 10, "range": [0,120] } + }, + "percentile": { + "*": [25, 50, 75] } }, "writer_id": "stats_writer", - "enable_pre_run_task": false + "enable_pre_run_task": false, + "precision" : 2 } } ], diff --git a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/custom/df_statistics.py b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/custom/df_statistics.py index 4a87b6c4b9..4277f269e7 100644 --- a/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/custom/df_statistics.py +++ b/examples/advanced/federated-statistics/df_stats/jobs/df_stats/app/custom/df_statistics.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, Optional -import numpy as np import pandas as pd -from pandas.core.series import Series from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics -from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets +from nvflare.app_opt.statistics.df.df_core_statistics import DFStatisticsCore -class DFStatistics(Statistics): +class DFStatistics(DFStatisticsCore): def __init__(self, data_path): super().__init__() self.data_root_dir = "/tmp/nvflare/df_stats/data" @@ -76,67 +73,3 @@ def load_data(self, fl_ctx: FLContext) -> Dict[str, pd.DataFrame]: def initialize(self, fl_ctx: FLContext): self.data = self.load_data(fl_ctx) - if self.data is None: - raise ValueError("data is not loaded. make sure the data is loaded") - - def features(self) -> Dict[str, List[Feature]]: - results: Dict[str, List[Feature]] = {} - for ds_name in self.data: - df = self.data[ds_name] - results[ds_name] = [] - for feature_name in df: - data_type = dtype_to_data_type(df[feature_name].dtype) - results[ds_name].append(Feature(feature_name, data_type)) - - return results - - def count(self, dataset_name: str, feature_name: str) -> int: - df: pd.DataFrame = self.data[dataset_name] - return df[feature_name].count() - - def sum(self, dataset_name: str, feature_name: str) -> float: - df: pd.DataFrame = self.data[dataset_name] - return df[feature_name].sum().item() - - def mean(self, dataset_name: str, feature_name: str) -> float: - - count: int = self.count(dataset_name, feature_name) - sum_value: float = self.sum(dataset_name, feature_name) - return sum_value / count - - def stddev(self, dataset_name: str, feature_name: str) -> float: - df = self.data[dataset_name] - return df[feature_name].std().item() - - def variance_with_mean( - self, dataset_name: str, feature_name: str, global_mean: float, global_count: float - ) -> float: - df = self.data[dataset_name] - tmp = (df[feature_name] - global_mean) * (df[feature_name] - global_mean) - variance = tmp.sum() / (global_count - 1) - return variance.item() - - def histogram( - self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float - ) -> Histogram: - - num_of_bins: int = num_of_bins - - df = self.data[dataset_name] - feature: Series = df[feature_name] - flattened = feature.ravel() - flattened = flattened[flattened != np.array(None)] - buckets = get_std_histogram_buckets(flattened, num_of_bins, BinRange(global_min_value, global_max_value)) - return Histogram(HistogramType.STANDARD, buckets) - - def max_value(self, dataset_name: str, feature_name: str) -> float: - """this is needed for histogram calculation, not used for reporting""" - - df = self.data[dataset_name] - return df[feature_name].max() - - def min_value(self, dataset_name: str, feature_name: str) -> float: - """this is needed for histogram calculation, not used for reporting""" - - df = self.data[dataset_name] - return df[feature_name].min() diff --git a/examples/advanced/federated-statistics/df_stats/requirements.txt b/examples/advanced/federated-statistics/df_stats/requirements.txt index aef6212b4c..c766bd7827 100644 --- a/examples/advanced/federated-statistics/df_stats/requirements.txt +++ b/examples/advanced/federated-statistics/df_stats/requirements.txt @@ -1,5 +1,5 @@ -nvflare~=2.5.0rc numpy pandas matplotlib jupyterlab +tdigest diff --git a/examples/advanced/federated-statistics/image_stats/job_api/image_statistics.py b/examples/advanced/federated-statistics/image_stats/job_api/image_statistics.py new file mode 100644 index 0000000000..3bfe2ea61c --- /dev/null +++ b/examples/advanced/federated-statistics/image_stats/job_api/image_statistics.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +from typing import Dict, List, Optional + +import numpy as np +from monai.data import ITKReader, load_decathlon_datalist +from monai.transforms import LoadImage + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.abstract.statistics_spec import Bin, DataType, Feature, Histogram, HistogramType, Statistics +from nvflare.security.logging import secure_log_traceback + + +class ImageStatistics(Statistics): + def __init__(self, data_root: str = "/tmp/nvflare/image_stats/data", data_list_key: str = "data"): + """local image statistics generator . + + Args: + data_root: directory with local image data. + data_list_key: data list key to use. + Returns: + a Shareable with the computed local statistics` + """ + super().__init__() + self.data_list_key = data_list_key + self.data_root = data_root + self.data_list = None + self.client_name = None + + self.loader = None + self.failure_images = 0 + self.fl_ctx = None + + def initialize(self, fl_ctx: FLContext): + self.fl_ctx = fl_ctx + self.client_name = fl_ctx.get_identity_name() + self.loader = LoadImage(image_only=True) + self.loader.register(ITKReader()) + self._load_data_list(self.client_name, fl_ctx) + + if self.data_list is None: + raise ValueError("data is not loaded. make sure the data is loaded") + + def _load_data_list(self, client_name, fl_ctx: FLContext) -> bool: + dataset_json = glob.glob(os.path.join(self.data_root, client_name + "*.json")) + if len(dataset_json) != 1: + self.log_error( + fl_ctx, f"No unique matching dataset list found in {self.data_root} for client {client_name}" + ) + return False + dataset_json = dataset_json[0] + self.log_info(fl_ctx, f"Reading data from {dataset_json}") + + data_list = load_decathlon_datalist( + data_list_file_path=dataset_json, data_list_key=self.data_list_key, base_dir=self.data_root + ) + self.data_list = {"train": data_list} + + self.log_info(fl_ctx, f"Client {client_name} has {len(self.data_list)} images") + return True + + def pre_run( + self, + statistics: List[str], + num_of_bins: Optional[Dict[str, Optional[int]]], + bin_ranges: Optional[Dict[str, Optional[List[float]]]], + ): + return {} + + def features(self) -> Dict[str, List[Feature]]: + return {"train": [Feature("intensity", DataType.FLOAT)]} + + def count(self, dataset_name: str, feature_name: str) -> int: + image_paths = self.data_list[dataset_name] + return len(image_paths) + + def failure_count(self, dataset_name: str, feature_name: str) -> int: + + return self.failure_images + + def histogram( + self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float + ) -> Histogram: + histogram_bins: List[Bin] = [] + histogram = np.zeros((num_of_bins,), dtype=np.int64) + bin_edges = [] + for i, entry in enumerate(self.data_list[dataset_name]): + file = entry.get("image") + try: + img = self.loader(file) + curr_histogram, bin_edges = np.histogram( + img, bins=num_of_bins, range=(global_min_value, global_max_value) + ) + histogram += curr_histogram + bin_edges = bin_edges.tolist() + + if i % 100 == 0: + self.logger.info( + f"{self.client_name}, adding {i + 1} of {len(self.data_list[dataset_name])}: {file}" + ) + except Exception as e: + self.failure_images += 1 + self.logger.critical( + f"Failed to load file {file} with exception: {e.__str__()}. " f"Skipping this image..." + ) + + if num_of_bins + 1 != len(bin_edges): + secure_log_traceback() + raise ValueError( + f"bin_edges size: {len(bin_edges)} is not matching with number of bins + 1: {num_of_bins + 1}" + ) + + for j in range(num_of_bins): + low_value = bin_edges[j] + high_value = bin_edges[j + 1] + bin_sample_count = histogram[j] + histogram_bins.append(Bin(low_value=low_value, high_value=high_value, sample_count=bin_sample_count)) + + return Histogram(HistogramType.STANDARD, histogram_bins) diff --git a/examples/advanced/federated-statistics/image_stats/job_api/image_stats_job.py b/examples/advanced/federated-statistics/image_stats/job_api/image_stats_job.py new file mode 100644 index 0000000000..d7c5cf753c --- /dev/null +++ b/examples/advanced/federated-statistics/image_stats/job_api/image_stats_job.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +from image_statistics import ImageStatistics + +from nvflare.job_config.stats_job import StatsJob + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--n_clients", type=int, default=3) + parser.add_argument("-d", "--data_root_dir", type=str, nargs="?", default="/tmp/nvflare/dataset/output") + parser.add_argument("-o", "--stats_output_path", type=str, nargs="?", default="statistics/stats.json") + parser.add_argument("-j", "--job_dir", type=str, nargs="?", default="/tmp/nvflare/jobs/stats_df") + parser.add_argument("-w", "--work_dir", type=str, nargs="?", default="/tmp/nvflare/jobs/stats_df/work_dir") + parser.add_argument("-co", "--export_config", action="store_true", help="config only mode, export config") + + return parser.parse_args() + + +def main(): + args = define_parser() + + n_clients = args.n_clients + data_root_dir = args.data_root_dir + output_path = args.stats_output_path + job_dir = args.job_dir + work_dir = args.work_dir + export_config = args.export_config + + statistic_configs = {"count": {}, "histogram": {"*": {"bins": 20, "range": [0, 256]}}} + # define local stats generator + stats_generator = ImageStatistics(data_root_dir) + + job = StatsJob( + job_name="stats_image", + statistic_configs=statistic_configs, + stats_generator=stats_generator, + output_path=output_path, + ) + + sites = [f"site-{i + 1}" for i in range(n_clients)] + job.setup_clients(sites) + + if export_config: + job.export_job(job_dir) + else: + job.simulator_run(work_dir, gpu="0") + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/federated-statistics/image_stats/requirements.txt b/examples/advanced/federated-statistics/image_stats/requirements.txt index 1a77cf1524..227880ff29 100644 --- a/examples/advanced/federated-statistics/image_stats/requirements.txt +++ b/examples/advanced/federated-statistics/image_stats/requirements.txt @@ -1,4 +1,3 @@ -nvflare~=2.5.0rc numpy monai[itk] pandas diff --git a/examples/hello-world/step-by-step/cifar10/stats/image_stats_job.py b/examples/hello-world/step-by-step/cifar10/stats/image_stats_job.py index 35601e839e..6c50493833 100644 --- a/examples/hello-world/step-by-step/cifar10/stats/image_stats_job.py +++ b/examples/hello-world/step-by-step/cifar10/stats/image_stats_job.py @@ -52,7 +52,7 @@ def main(): ) sites = [f"site-{i + 1}" for i in range(n_clients)] - job.setup_client(sites) + job.setup_clients(sites) if export_config: job.export_job(job_dir) diff --git a/examples/hello-world/step-by-step/higgs/stats/code/df_stats_job.py b/examples/hello-world/step-by-step/higgs/stats/code/df_stats_job.py index 151728a54c..39dbed5a61 100644 --- a/examples/hello-world/step-by-step/higgs/stats/code/df_stats_job.py +++ b/examples/hello-world/step-by-step/higgs/stats/code/df_stats_job.py @@ -59,7 +59,7 @@ def main(): ) sites = [f"site-{i + 1}" for i in range(n_clients)] - job.setup_client(sites) + job.setup_clients(sites) if export_config: job.export_job(job_dir) diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index a84a89669c..d03e456ad3 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -382,7 +382,7 @@ class WorkspaceConstants: SITE_FOLDER_NAME = "local" CUSTOM_FOLDER_NAME = "custom" - LOGGING_CONFIG = "log.config" + LOGGING_CONFIG = "log_config.json" DEFAULT_LOGGING_CONFIG = LOGGING_CONFIG + ".default" AUDIT_LOG = "audit.log" LOG_FILE_NAME = "log.txt" diff --git a/nvflare/app_common/abstract/statistics_spec.py b/nvflare/app_common/abstract/statistics_spec.py index 3fedc0ead1..b34dd93f68 100644 --- a/nvflare/app_common/abstract/statistics_spec.py +++ b/nvflare/app_common/abstract/statistics_spec.py @@ -318,6 +318,19 @@ def failure_count(self, dataset_name: str, feature_name: str) -> int: """ return 0 + def percentiles(self, dataset_name: str, feature_name: str, percentiles: List) -> Dict: + """Return failed count for given dataset and feature. + + To perform data privacy min_count check, failure_count is always required. + + Args: + dataset_name: + feature_name: + percentiles: List[Int] ex [25,50, 75] corresponding to p25, p50, p75 + Returns: dict + """ + raise NotImplementedError + def finalize(self, fl_ctx: FLContext): """Called to finalize the Statistic calculator (close/release resources gracefully). diff --git a/nvflare/app_common/app_constant.py b/nvflare/app_common/app_constant.py index 98b8bc21d6..a7c34b8444 100644 --- a/nvflare/app_common/app_constant.py +++ b/nvflare/app_common/app_constant.py @@ -163,6 +163,7 @@ class StatisticsConstants(AppConstants): STATS_VAR = "var" STATS_STDDEV = "stddev" STATS_HISTOGRAM = "histogram" + STATS_PERCENTILE = "percentile" STATS_MAX = "max" STATS_MIN = "min" STATS_FEATURES = "stats_features" @@ -173,6 +174,9 @@ class StatisticsConstants(AppConstants): STATS_BIN_RANGE = "range" STATS_TARGET_STATISTICS = "statistics" + STATS_PERCENTILES_KEY = "percentiles" + STATS_CENTROIDS_KEY = "centroids" + FED_STATS_PRE_RUN = "fed_stats_pre_run" FED_STATS_TASK = "fed_stats" STATISTICS_TASK_KEY = "fed_stats_task_key" @@ -184,7 +188,16 @@ class StatisticsConstants(AppConstants): NAME = "Name" ordered_statistics = { - STATS_1st_STATISTICS: [STATS_COUNT, STATS_FAILURE_COUNT, STATS_SUM, STATS_MEAN, STATS_MIN, STATS_MAX], + # statistics can only require one/two-round of calculations + STATS_1st_STATISTICS: [ + STATS_COUNT, + STATS_FAILURE_COUNT, + STATS_SUM, + STATS_MEAN, + STATS_MIN, + STATS_MAX, + STATS_PERCENTILE, + ], STATS_2nd_STATISTICS: [STATS_HISTOGRAM, STATS_VAR, STATS_STDDEV], } diff --git a/nvflare/app_common/executors/statistics/statistics_task_handler.py b/nvflare/app_common/executors/statistics/statistics_task_handler.py index 1e282faac0..e3b1bb45a7 100644 --- a/nvflare/app_common/executors/statistics/statistics_task_handler.py +++ b/nvflare/app_common/executors/statistics/statistics_task_handler.py @@ -23,7 +23,7 @@ from nvflare.app_common.app_constant import StatisticsConstants as StC from nvflare.app_common.statistics.numeric_stats import filter_numeric_features from nvflare.app_common.statistics.statisitcs_objects_decomposer import fobs_registration -from nvflare.app_common.statistics.statistics_config_utils import get_feature_bin_range +from nvflare.app_common.statistics.statistics_config_utils import get_feature_bin_range, get_target_percents from nvflare.fuel.utils import fobs from nvflare.security.logging import secure_format_exception @@ -96,6 +96,7 @@ def statistic_functions(self) -> dict: StC.STATS_HISTOGRAM: self.get_histogram, StC.STATS_MAX: self.get_max_value, StC.STATS_MIN: self.get_min_value, + StC.STATS_PERCENTILE: self.get_percentiles_and_centroids, } def _populate_result_statistics(self, statistics_result, ds_features, tm: StatisticConfig, shareable, fl_ctx, fn): @@ -318,6 +319,19 @@ def get_bin_range( return bin_range + def get_percentiles_and_centroids( + self, + dataset_name: str, + feature_name: str, + statistic_configs: StatisticConfig, + inputs: Shareable, + fl_ctx: FLContext, + ) -> dict: + percentile_config = statistic_configs.config + target_percents = get_target_percents(percentile_config, feature_name) + result = self.stats_generator.percentiles(dataset_name, feature_name, target_percents) + return result + def _get_global_value_from_input(self, statistic_key: str, dataset_name: str, feature_name: str, inputs): global_value = None if dataset_name in inputs[statistic_key]: diff --git a/nvflare/app_common/statistics/numeric_stats.py b/nvflare/app_common/statistics/numeric_stats.py index d9e7988775..046f0ce0b4 100644 --- a/nvflare/app_common/statistics/numeric_stats.py +++ b/nvflare/app_common/statistics/numeric_stats.py @@ -15,8 +15,11 @@ from math import sqrt from typing import Dict, List, TypeVar +from tdigest import TDigest + from nvflare.app_common.abstract.statistics_spec import Bin, BinRange, DataType, Feature, Histogram, HistogramType from nvflare.app_common.app_constant import StatisticsConstants as StC +from nvflare.app_common.statistics.statistics_config_utils import get_target_percents T = TypeVar("T") @@ -37,7 +40,9 @@ def get_global_feature_data_types( return global_feature_data_types -def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: str) -> dict: +def get_global_stats( + global_metrics: dict, client_metrics: dict, metric_task: str, statistic_configs: Dict[str, dict], precision: int = 4 +) -> dict: # we need to calculate the metrics in specified order ordered_target_metrics = StC.ordered_statistics[metric_task] ordered_metrics = [metric for metric in ordered_target_metrics if metric in client_metrics] @@ -49,21 +54,27 @@ def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: st stats = client_metrics[metric] if metric == StC.STATS_COUNT or metric == StC.STATS_FAILURE_COUNT or metric == StC.STATS_SUM: for client_name in stats: - global_metrics[metric] = accumulate_metrics(stats[client_name], global_metrics[metric]) + global_metrics[metric] = accumulate_metrics(stats[client_name], global_metrics[metric], precision) elif metric == StC.STATS_MEAN: - global_metrics[metric] = get_means(global_metrics[StC.STATS_SUM], global_metrics[StC.STATS_COUNT]) + global_metrics[metric] = get_means( + global_metrics[StC.STATS_SUM], global_metrics[StC.STATS_COUNT], precision + ) elif metric == StC.STATS_MAX: for client_name in stats: - global_metrics[metric] = get_min_or_max_values(stats[client_name], global_metrics[metric], max) + global_metrics[metric] = get_min_or_max_values( + stats[client_name], global_metrics[metric], max, precision + ) elif metric == StC.STATS_MIN: for client_name in stats: - global_metrics[metric] = get_min_or_max_values(stats[client_name], global_metrics[metric], min) + global_metrics[metric] = get_min_or_max_values( + stats[client_name], global_metrics[metric], min, precision + ) elif metric == StC.STATS_HISTOGRAM: for client_name in stats: global_metrics[metric] = accumulate_hists(stats[client_name], global_metrics[metric]) elif metric == StC.STATS_VAR: for client_name in stats: - global_metrics[metric] = accumulate_metrics(stats[client_name], global_metrics[metric]) + global_metrics[metric] = accumulate_metrics(stats[client_name], global_metrics[metric], precision) elif metric == StC.STATS_STDDEV: ds_vars = global_metrics[StC.STATS_VAR] ds_stddev = {} @@ -71,14 +82,22 @@ def get_global_stats(global_metrics: dict, client_metrics: dict, metric_task: st ds_stddev[ds_name] = {} feature_vars = ds_vars[ds_name] for feature in feature_vars: - ds_stddev[ds_name][feature] = sqrt(feature_vars[feature]) + ds_stddev[ds_name][feature] = round(sqrt(feature_vars[feature]), precision) global_metrics[StC.STATS_STDDEV] = ds_stddev + elif metric == StC.STATS_PERCENTILE: + global_digest = {} + for client_name in stats: + + global_digest = aggregate_centroids(stats[client_name], global_digest) + + percent_config = statistic_configs.get(StC.STATS_PERCENTILE) + global_metrics[metric] = compute_percentiles(global_digest, percent_config, precision) return global_metrics -def accumulate_metrics(metrics: dict, global_metrics: dict) -> dict: +def accumulate_metrics(metrics: dict, global_metrics: dict, precision: int) -> dict: for ds_name in metrics: if ds_name not in global_metrics: global_metrics[ds_name] = {} @@ -87,14 +106,16 @@ def accumulate_metrics(metrics: dict, global_metrics: dict) -> dict: for feature_name in feature_metrics: if feature_metrics[feature_name] is not None: if feature_name not in global_metrics[ds_name]: - global_metrics[ds_name][feature_name] = feature_metrics[feature_name] + global_metrics[ds_name][feature_name] = round(feature_metrics[feature_name], precision) else: - global_metrics[ds_name][feature_name] += feature_metrics[feature_name] + global_metrics[ds_name][feature_name] = round( + global_metrics[ds_name][feature_name] + feature_metrics[feature_name], precision + ) return global_metrics -def get_min_or_max_values(metrics: dict, global_metrics: dict, fn2) -> dict: +def get_min_or_max_values(metrics: dict, global_metrics: dict, fn2, precision: int = 4) -> dict: """Use 2 argument function to calculate fn2(global, client), for example, min or max. .. note:: @@ -105,6 +126,7 @@ def get_min_or_max_values(metrics: dict, global_metrics: dict, fn2) -> dict: metrics: client's metric global_metrics: global metrics fn2: two-argument function such as min or max + precision: decimal number precision Returns: Dict[dataset, Dict[feature, int]] @@ -116,19 +138,21 @@ def get_min_or_max_values(metrics: dict, global_metrics: dict, fn2) -> dict: feature_metrics = metrics[ds_name] for feature_name in feature_metrics: if feature_name not in global_metrics[ds_name]: - global_metrics[ds_name][feature_name] = feature_metrics[feature_name] + global_metrics[ds_name][feature_name] = round(feature_metrics[feature_name], precision) else: - global_metrics[ds_name][feature_name] = fn2( - global_metrics[ds_name][feature_name], feature_metrics[feature_name] + global_metrics[ds_name][feature_name] = round( + fn2(global_metrics[ds_name][feature_name], feature_metrics[feature_name]), precision ) results = {} for ds_name in global_metrics: for feature_name in global_metrics[ds_name]: if feature_name not in results: - results[feature_name] = global_metrics[ds_name][feature_name] + results[feature_name] = round(global_metrics[ds_name][feature_name], precision) else: - results[feature_name] = fn2(results[feature_name], global_metrics[ds_name][feature_name]) + results[feature_name] = round( + fn2(results[feature_name], global_metrics[ds_name][feature_name]), precision + ) for ds_name in global_metrics: for feature_name in global_metrics[ds_name]: @@ -146,7 +170,7 @@ def bins_to_dict(bins: List[Bin]) -> Dict[BinRange, float]: def accumulate_hists( - metrics: Dict[str, Dict[str, Histogram]], global_hists: Dict[str, Dict[str, Histogram]] + metrics: Dict[str, Dict[str, Histogram]], global_hists: Dict[str, Dict[str, Histogram]], precision: int = 4 ) -> Dict[str, Dict[str, Histogram]]: for ds_name in metrics: feature_hists = metrics[ds_name] @@ -158,14 +182,18 @@ def accumulate_hists( if feature not in global_hists[ds_name]: g_bins = [] for bucket in hist.bins: - g_bins.append(Bin(bucket.low_value, bucket.high_value, bucket.sample_count)) + g_bins.append( + Bin( + round(bucket.low_value, precision), round(bucket.high_value, precision), bucket.sample_count + ) + ) g_hist = Histogram(HistogramType.STANDARD, g_bins) global_hists[ds_name][feature] = g_hist else: g_hist = global_hists[ds_name][feature] g_buckets = bins_to_dict(g_hist.bins) for bucket in hist.bins: - bin_range = BinRange(bucket.low_value, bucket.high_value) + bin_range = BinRange(round(bucket.low_value, precision), round(bucket.high_value, precision)) if bin_range in g_buckets: g_buckets[bin_range] += bucket.sample_count else: @@ -174,22 +202,24 @@ def accumulate_hists( # update ordered bins updated_bins = [] for gb in g_hist.bins: - bin_range = BinRange(gb.low_value, gb.high_value) - updated_bins.append(Bin(gb.low_value, gb.high_value, g_buckets[bin_range])) + bin_range = BinRange(round(gb.low_value, precision), round(gb.high_value, precision)) + updated_bins.append( + Bin(round(gb.low_value, precision), round(gb.high_value, precision), g_buckets[bin_range]) + ) global_hists[ds_name][feature] = Histogram(g_hist.hist_type, updated_bins) return global_hists -def get_means(sums: dict, counts: dict) -> dict: +def get_means(sums: dict, counts: dict, precision: int = 4) -> dict: means = {} for ds_name in sums: means[ds_name] = {} feature_sums = sums[ds_name] feature_counts = counts[ds_name] for feature in feature_sums: - means[ds_name][feature] = feature_sums[feature] / feature_counts[feature] + means[ds_name][feature] = round(feature_sums[feature] / feature_counts[feature], precision) return means @@ -201,3 +231,42 @@ def filter_numeric_features(ds_features: Dict[str, List[Feature]]) -> Dict[str, numeric_ds_features[ds_name] = n_features return numeric_ds_features + + +def aggregate_centroids(metrics: Dict[str, Dict[str, Dict]], g_digest: dict) -> dict: + for ds_name in metrics: + if ds_name not in g_digest: + g_digest[ds_name] = {} + + feature_metrics = metrics[ds_name] + for feature_name in feature_metrics: + if feature_metrics[feature_name] is not None: + centroids: List = feature_metrics[feature_name].get(StC.STATS_CENTROIDS_KEY) + if feature_name not in g_digest[ds_name]: + g_digest[ds_name][feature_name] = TDigest() + + for centroid in centroids: + mean = centroid.get("m") + count = centroid.get("c") + g_digest[ds_name][feature_name].update(mean, count) + + return g_digest + + +def compute_percentiles(g_digest: Dict[str, Dict[str, TDigest]], quantile_config: Dict, precision: int = 4) -> dict: + g_ds_metrics = {} + for ds_name in g_digest: + if ds_name not in g_ds_metrics: + g_ds_metrics[ds_name] = {} + + feature_metrics = g_digest[ds_name] + for feature_name in feature_metrics: + digest = feature_metrics[feature_name] + percentiles = get_target_percents(quantile_config, feature_name) + percentile_values = {} + for percentile in percentiles: + percentile_values[percentile] = round(digest.percentile(percentile), precision) + + g_ds_metrics[ds_name][feature_name] = percentile_values + + return g_ds_metrics diff --git a/nvflare/app_common/statistics/statistics_config_utils.py b/nvflare/app_common/statistics/statistics_config_utils.py index 2c2efb76c0..b128df2394 100644 --- a/nvflare/app_common/statistics/statistics_config_utils.py +++ b/nvflare/app_common/statistics/statistics_config_utils.py @@ -28,3 +28,14 @@ def get_feature_bin_range(feature_name: str, hist_config: dict) -> Optional[List bin_range = default_config[StC.STATS_BIN_RANGE] return bin_range + + +def get_target_percents(percentile_config: dict, feature_name: str): + if feature_name in percentile_config: + percents = percentile_config.get(feature_name) + elif "*" in percentile_config: + percents = percentile_config.get("*") + else: + raise ValueError(f"feature: {feature_name} target percents are not defined.") + + return percents diff --git a/nvflare/app_common/tie/py_applet.py b/nvflare/app_common/tie/py_applet.py index 148b415871..406bb61307 100644 --- a/nvflare/app_common/tie/py_applet.py +++ b/nvflare/app_common/tie/py_applet.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing -import os import sys import threading import time from abc import ABC, abstractmethod from nvflare.apis.workspace import Workspace -from nvflare.fuel.utils.log_utils import add_log_file_handler, configure_logging +from nvflare.fuel.utils.log_utils import configure_logging from nvflare.security.logging import secure_format_exception, secure_log_traceback from .applet import Applet @@ -94,9 +93,7 @@ def start(self, app_ctx: dict): if not self.in_process: # enable logging run_dir = self.workspace.get_run_dir(self.job_id) - log_file_name = os.path.join(run_dir, "applet_log.txt") - configure_logging(self.workspace) - add_log_file_handler(log_file_name) + configure_logging(self.workspace, dir_path=run_dir, file_prefix="applet") self.runner.start(app_ctx) # Note: run_func does not return until it runs to completion! diff --git a/nvflare/app_common/workflows/statistics_controller.py b/nvflare/app_common/workflows/statistics_controller.py index 53fdfbf552..a835c95900 100644 --- a/nvflare/app_common/workflows/statistics_controller.py +++ b/nvflare/app_common/workflows/statistics_controller.py @@ -69,6 +69,10 @@ def __init__( "histogram": { "*": {"bins": 20}, "Age": {"bins": 10, "range": [0, 120]} + }, + percentile: { + "*": [25, 50, 75, 90], + "Age": [50, 75, 95] } }, @@ -207,6 +211,7 @@ def _get_all_statistic_configs(self) -> List[StatisticConfig]: StC.STATS_MEAN: StatisticConfig(StC.STATS_MEAN, {}), StC.STATS_VAR: StatisticConfig(StC.STATS_VAR, {}), StC.STATS_STDDEV: StatisticConfig(StC.STATS_STDDEV, {}), + StC.STATS_PERCENTILE: StatisticConfig(StC.STATS_PERCENTILE, {}), } if StC.STATS_HISTOGRAM in self.statistic_configs: @@ -264,7 +269,9 @@ def statistics_task_flow(self, abort_signal: Signal, fl_ctx: FLContext, statisti abort_signal=abort_signal, ) - self.global_statistics = get_global_stats(self.global_statistics, self.client_statistics, statistic_task) + self.global_statistics = get_global_stats( + self.global_statistics, self.client_statistics, statistic_task, self.statistic_configs, self.precision + ) self.log_info(fl_ctx, f"task {self.task_name} statistics_flow for {statistic_task} flow end.") @@ -402,12 +409,19 @@ def _combine_all_statistics(self): hist: Histogram = self.client_statistics[statistic][client][ds][feature_name] buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) result[feature_name][statistic][client][ds] = buckets + elif statistic == StC.STATS_PERCENTILE: + percentiles = self.client_statistics[statistic][client][ds][feature_name][ + StC.STATS_PERCENTILES_KEY + ] + formatted_percentiles = {} + for p in percentiles: + formatted_percentiles[p] = round(percentiles.get(p), self.precision) + result[feature_name][statistic][client][ds] = formatted_percentiles else: result[feature_name][statistic][client][ds] = round( self.client_statistics[statistic][client][ds][feature_name], self.precision ) - precision = self.precision for statistic in filtered_global_statistics: for ds in self.global_statistics[statistic]: for feature_name in self.global_statistics[statistic][ds]: @@ -419,11 +433,13 @@ def _combine_all_statistics(self): if statistic == StC.STATS_HISTOGRAM: hist: Histogram = self.global_statistics[statistic][ds][feature_name] - buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision) - result[feature_name][statistic][StC.GLOBAL][ds] = buckets + result[feature_name][statistic][StC.GLOBAL][ds] = hist.bins + elif statistic == StC.STATS_PERCENTILE: + percentiles = self.global_statistics[statistic][ds][feature_name] + result[feature_name][statistic][StC.GLOBAL][ds] = percentiles else: result[feature_name][statistic][StC.GLOBAL].update( - {ds: round(self.global_statistics[statistic][ds][feature_name], precision)} + {ds: self.global_statistics[statistic][ds][feature_name]} ) return result @@ -444,9 +460,10 @@ def _apply_histogram_precision(bins: List[Bin], precision) -> List[Bin]: @staticmethod def _get_target_statistics(statistic_configs: dict, ordered_statistics: list) -> List[StatisticConfig]: # make sure the execution order of the statistics calculation + targets = [] if statistic_configs: - for statistic in statistic_configs: + for metric in statistic_configs: # if target statistic has histogram, we are not in 2nd statistic task # we only need to estimate the global min/max if we have histogram statistic, # If the user provided the global min/max for a specified feature, then we do nothing @@ -457,16 +474,16 @@ def _get_target_statistics(statistic_configs: dict, ordered_statistics: list) -> # in all cases, we will still send the STATS_MIN/MAX tasks, but client executor may or may not # delegate to stats generator to calculate the local min/max depends on if the global bin ranges # are specified. to do this, we send over the histogram configuration when calculate the local min/max - if statistic == StC.STATS_HISTOGRAM and statistic not in ordered_statistics: + if metric == StC.STATS_HISTOGRAM and metric not in ordered_statistics: targets.append(StatisticConfig(StC.STATS_MIN, statistic_configs[StC.STATS_HISTOGRAM])) targets.append(StatisticConfig(StC.STATS_MAX, statistic_configs[StC.STATS_HISTOGRAM])) - if statistic == StC.STATS_STDDEV and statistic in ordered_statistics: + if metric == StC.STATS_STDDEV and metric in ordered_statistics: targets.append(StatisticConfig(StC.STATS_VAR, {})) for rm in ordered_statistics: - if rm == statistic: - targets.append(StatisticConfig(statistic, statistic_configs[statistic])) + if rm == metric: + targets.append(StatisticConfig(metric, statistic_configs[metric])) return targets def _prepare_inputs(self, statistic_task: str) -> Shareable: diff --git a/nvflare/app_opt/statistics/df/__init__.py b/nvflare/app_opt/statistics/df/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/app_opt/statistics/df/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_opt/statistics/df/df_core_statistics.py b/nvflare/app_opt/statistics/df/df_core_statistics.py new file mode 100644 index 0000000000..5f509f0fe8 --- /dev/null +++ b/nvflare/app_opt/statistics/df/df_core_statistics.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +from pandas.core.series import Series +from tdigest import TDigest + +from nvflare.app_common.abstract.statistics_spec import BinRange, Feature, Histogram, HistogramType, Statistics +from nvflare.app_common.app_constant import StatisticsConstants +from nvflare.app_common.statistics.numpy_utils import dtype_to_data_type, get_std_histogram_buckets + + +class DFStatisticsCore(Statistics, ABC): + + def __init__(self): + # assumption: the data can be loaded and cached in the memory + self.data: Optional[Dict[str, pd.DataFrame]] = None + super(DFStatisticsCore, self).__init__() + + def features(self) -> Dict[str, List[Feature]]: + results: Dict[str, List[Feature]] = {} + for ds_name in self.data: + df = self.data[ds_name] + results[ds_name] = [] + for feature_name in df: + data_type = dtype_to_data_type(df[feature_name].dtype) + results[ds_name].append(Feature(feature_name, data_type)) + + return results + + def count(self, dataset_name: str, feature_name: str) -> int: + df: pd.DataFrame = self.data[dataset_name] + return df[feature_name].count() + + def sum(self, dataset_name: str, feature_name: str) -> float: + df: pd.DataFrame = self.data[dataset_name] + return df[feature_name].sum().item() + + def mean(self, dataset_name: str, feature_name: str) -> float: + + count: int = self.count(dataset_name, feature_name) + sum_value: float = self.sum(dataset_name, feature_name) + return sum_value / count + + def stddev(self, dataset_name: str, feature_name: str) -> float: + df = self.data[dataset_name] + return df[feature_name].std().item() + + def variance_with_mean( + self, dataset_name: str, feature_name: str, global_mean: float, global_count: float + ) -> float: + df = self.data[dataset_name] + tmp = (df[feature_name] - global_mean) * (df[feature_name] - global_mean) + variance = tmp.sum() / (global_count - 1) + return variance.item() + + def histogram( + self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float + ) -> Histogram: + + num_of_bins: int = num_of_bins + + df = self.data[dataset_name] + feature: Series = df[feature_name] + flattened = feature.ravel() + flattened = flattened[flattened != np.array(None)] + buckets = get_std_histogram_buckets(flattened, num_of_bins, BinRange(global_min_value, global_max_value)) + return Histogram(HistogramType.STANDARD, buckets) + + def max_value(self, dataset_name: str, feature_name: str) -> float: + """this is needed for histogram calculation, not used for reporting""" + + df = self.data[dataset_name] + return df[feature_name].max() + + def min_value(self, dataset_name: str, feature_name: str) -> float: + """this is needed for histogram calculation, not used for reporting""" + + df = self.data[dataset_name] + return df[feature_name].min() + + def percentiles(self, dataset_name: str, feature_name: str, percents: List) -> Dict: + digest = self._prepare_t_digest(dataset_name, feature_name) + results = {} + p_results = {} + for p in percents: + v = round(digest.percentile(p), 4) + p_results[p] = v + results[StatisticsConstants.STATS_PERCENTILES_KEY] = p_results + + # Extract centroids (mean, count) from the digest to used for merge for the global + x = digest.centroids_to_list() + results[StatisticsConstants.STATS_CENTROIDS_KEY] = x + return results + + def _prepare_t_digest(self, dataset_name: str, feature_name: str) -> TDigest: + df = self.data[dataset_name] + data = df[feature_name] + digest = TDigest() + for value in data: + digest.update(value) + return digest diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/adaptor.py index f252ffdb92..64c9be5f22 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/adaptor.py @@ -27,7 +27,7 @@ from nvflare.apis.workspace import Workspace from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner -from nvflare.fuel.utils.log_utils import add_log_file_handler, configure_logging, get_obj_logger +from nvflare.fuel.utils.log_utils import configure_logging, get_obj_logger from nvflare.fuel.utils.validation_utils import check_object_type from nvflare.security.logging import secure_format_exception, secure_log_traceback @@ -65,8 +65,7 @@ def start(self, ctx: dict): run_dir = self.workspace.get_run_dir(self.job_id) log_file_name = os.path.join(run_dir, f"{self.app_name}_log.txt") print(f"XGB Log: {log_file_name}") - configure_logging(self.workspace) - add_log_file_handler(log_file_name) + configure_logging(self.workspace, dir_path=run_dir, file_prefix=self.app_name) self.runner.run(ctx) self.stopped = True except Exception as e: diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index d3688c0bfe..a8c9542663 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -26,7 +26,6 @@ from nvflare.client.flare_agent import FlareAgentException from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel from nvflare.client.model_registry import ModelRegistry -from nvflare.fuel.utils import fobs from nvflare.fuel.utils.import_utils import optional_import from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.pipe.pipe import Pipe @@ -52,14 +51,6 @@ def _create_pipe_using_config(client_config: ClientConfig, section: str) -> Tupl return pipe, pipe_channel_name -def _register_tensor_decomposer(): - tensor_decomposer, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") - if ok: - fobs.register(tensor_decomposer) - else: - raise RuntimeError(f"Can't import TensorDecomposer for format: {ExchangeFormat.PYTORCH}") - - class ExProcessClientAPI(APISpec): def __init__(self): self.process_model_registry = None @@ -93,8 +84,12 @@ def init(self, rank: Optional[str] = None): flare_agent = None try: if rank == "0": - if client_config.get_exchange_format() == ExchangeFormat.PYTORCH: - _register_tensor_decomposer() + if client_config.get_exchange_format() in [ExchangeFormat.PYTORCH, ExchangeFormat.NUMPY]: + # both numpy and pytorch exchange format can need tensor decomposer + # import here, and register later when needed + _, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") + if not ok: + raise RuntimeError("Can't import TensorDecomposer") pipe, task_channel_name = None, "" if ConfigKey.TASK_EXCHANGE in client_config.config: diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index 7d0bb00a70..7a69ea17e7 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import builtins import importlib import inspect import logging @@ -66,15 +67,12 @@ def _get_type_name(cls: Type) -> str: def _load_class(type_name: str): try: - parts = type_name.split(".") - if len(parts) == 1: - parts = ["builtins", type_name] - - mod = __import__(parts[0]) - for comp in parts[1:]: - mod = getattr(mod, comp) - - return mod + if "." in type_name: + module_name, class_name = type_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + else: + return getattr(builtins, type_name) except Exception as ex: raise TypeError(f"Can't load class {type_name}: {ex}") diff --git a/nvflare/fuel/utils/log_utils.py b/nvflare/fuel/utils/log_utils.py index 77a855e126..e889d22759 100644 --- a/nvflare/fuel/utils/log_utils.py +++ b/nvflare/fuel/utils/log_utils.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import logging import logging.config import os +import re from logging import Logger from logging.handlers import RotatingFileHandler @@ -22,35 +24,61 @@ class ANSIColor: - GREY = "38" - YELLOW = "33" - RED = "31" - BOLD_RED = "31;1" - CYAN = "36" - RESET = "0" + # Basic ANSI color codes + COLORS = { + "black": "30", + "red": "31", + "bold_red": "31;1", + "green": "32", + "yellow": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", + "grey": "38", + "reset": "0", + } + + # Default logger level:color mappings + DEFAULT_LEVEL_COLORS = { + "NOTSET": COLORS["grey"], + "DEBUG": COLORS["grey"], + "INFO": COLORS["grey"], + "WARNING": COLORS["yellow"], + "ERROR": COLORS["red"], + "CRITICAL": COLORS["bold_red"], + } + + @classmethod + def colorize(cls, text: str, color: str) -> str: + """Wrap text with the given ANSI SGR color. + Args: + text (str): text to colorize. + color (str): ANSI SGR color code or color name defined in ANSIColor.COLORS. + + Returns: + colorized text + """ + if not any(c.isdigit() for c in color): + color = cls.COLORS.get(color.lower(), cls.COLORS["reset"]) -DEFAULT_LEVEL_COLORS = { - "DEBUG": ANSIColor.GREY, - "INFO": ANSIColor.GREY, - "WARNING": ANSIColor.YELLOW, - "ERROR": ANSIColor.RED, - "CRITICAL": ANSIColor.BOLD_RED, -} + return f"\x1b[{color}m{text}\x1b[{cls.COLORS['reset']}m" class BaseFormatter(logging.Formatter): def __init__(self, fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt=None, style="%"): - """BaseFormatter is the default formatter for log records. + """Default formatter for log records. - Shortens logger %(name)s to the suffix, full name can be accessed with %(fullName)s + Shortens logger %(name)s to the basenames. Full name can be accessed with %(fullName)s Args: - fmt: format string which uses LogRecord attributes. - datefmt: date/time format string. Defaults to '%Y-%m-%d %H:%M:%S'. - style: style character '%' '{' or '$' for format string. + fmt (str): format string which uses LogRecord attributes. + datefmt (str): date/time format string. Defaults to '%Y-%m-%d %H:%M:%S'. + style (str): style character '%' '{' or '$' for format string. """ + self.fmt = fmt super().__init__(fmt=fmt, datefmt=datefmt, style=style) def format(self, record): @@ -61,9 +89,127 @@ def format(self, record): return super().format(record) -def ansi_sgr(code): - # ANSI Select Graphics Rendition - return "\x1b[" + code + "m" +class ColorFormatter(BaseFormatter): + def __init__( + self, + fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt=None, + style="%", + level_colors=ANSIColor.DEFAULT_LEVEL_COLORS, + logger_colors={}, + ): + """Format colors based on log levels. Optionally can provide mapping based on logger namess. + + Args: + fmt (str): format string which uses LogRecord attributes. + datefmt (str): date/time format string. Defaults to '%Y-%m-%d %H:%M:%S'. + style (str): style character '%' '{' or '$' for format string. + level_colors (Dict[str, str]): dict of levelname: ANSI color. Defaults to ANSIColor.DEFAULT_LEVEL_COLORS. + logger_colors (Dict[str, str]): dict of loggername: ANSI color. Defaults to {}. + + """ + super().__init__(fmt=fmt, datefmt=datefmt, style=style) + self.level_colors = level_colors + self.logger_colors = logger_colors + + def format(self, record): + super().format(record) + + # Apply level_colors based on record levelname + log_color = self.level_colors.get(record.levelname, "reset") + + # Apply logger_color to logger_names if INFO or below + if record.levelno <= logging.INFO: + log_color = self.logger_colors.get(record.name, log_color) + + log_fmt = ANSIColor.colorize(self.fmt, log_color) + + formatter = logging.Formatter(log_fmt) + return formatter.format(record) + + +class JsonFormatter(BaseFormatter): + def __init__( + self, + fmt="%(asctime)s - %(name)s - %(fullName)s - %(levelname)s - %(message)s", + datefmt=None, + style="%", + extract_brackets=True, + ): + """Format log records into JSON. + + Args: + fmt (str): format string which uses LogRecord attributes. Attributes are used for JSON keys. + datefmt (str): date/time format string. Defaults to '%Y-%m-%d %H:%M:%S'. + style (str): style character '%' '{' or '$' for format string. + extract_bracket_fields (bool): whether to extract bracket fields of message into sub-dictionary. Defaults to True. + + """ + super().__init__(fmt=fmt, datefmt=datefmt, style=style) + self.fmt_dict = self.generate_fmt_dict(self.fmt) + self.extract_brackets = extract_brackets + + def generate_fmt_dict(self, fmt: str) -> dict: + # Parse the `fmt` string and create a mapping of keys to LogRecord attributes + matches = re.findall(r"%\((.*?)\)([sd])", fmt) + + fmt_dict = {} + for key, _ in matches: + if key == "shortname": + fmt_dict["name"] = "shortname" + else: + fmt_dict[key] = key + + return fmt_dict + + def extract_bracket_fields(self, message: str) -> dict: + # Extract bracketed fl_ctx_fields eg. [k1=v1, k2=v2...] into sub-dictionary + bracket_fields = {} + match = re.search(r"\[(.*?)\]:", message) + if match: + pairs = match.group(1).split(", ") + for pair in pairs: + if "=" in pair: + key, value = pair.split("=", 1) + bracket_fields[key] = value + return bracket_fields + + def formatMessage(self, record) -> dict: + return {fmt_key: record.__dict__.get(fmt_val, "") for fmt_key, fmt_val in self.fmt_dict.items()} + + def format(self, record) -> str: + super().format(record) + + record.message = record.getMessage() + bracket_fields = self.extract_bracket_fields(record.message) if self.extract_brackets else None + record.asctime = self.formatTime(record) + + formatted_message_dict = self.formatMessage(record) + message_dict = {k: v for k, v in formatted_message_dict.items() if k != "message"} + + if bracket_fields: + message_dict["fl_ctx_fields"] = bracket_fields + record.message = re.sub(r"\[.*?\]:", "", record.message).strip() + + message_dict[self.fmt_dict.get("message", "message")] = record.message + + return json.dumps(message_dict, default=str) + + +class LoggerNameFilter(logging.Filter): + def __init__(self, logger_names=["nvflare"]): + """Filter log records based on logger names. + + Args: + logger_names (List[str]): list of logger names to allow through filter (inclusive) + + """ + super().__init__() + self.logger_names = logger_names + + def filter(self, record): + name = record.fullName if hasattr(record, "fullName") else record.name + return any(name.startswith(logger_name) for logger_name in self.logger_names) def get_module_logger(module=None, name=None): @@ -79,19 +225,41 @@ def get_obj_logger(obj): def get_script_logger(): + # Get script logger name based on filename and package. If not in a package, default to custom. caller_frame = inspect.stack()[1] package = caller_frame.frame.f_globals.get("__package__", "") file = caller_frame.frame.f_globals.get("__file__", "") return logging.getLogger( - f"{package + '.' if package else ''}{os.path.splitext(os.path.basename(file))[0] if file else ''}" + f"{package if package else 'custom'}{'.' + os.path.splitext(os.path.basename(file))[0] if file else ''}" ) -def configure_logging(workspace: Workspace): +def configure_logging(workspace: Workspace, dir_path: str = "", file_prefix: str = ""): + # Read log_config.json from workspace, update with file_prefix, and apply to dir_path log_config_file_path = workspace.get_log_config_file_path() assert os.path.isfile(log_config_file_path), f"missing log config file {log_config_file_path}" - logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) + + with open(log_config_file_path, "r") as f: + dict_config = json.load(f) + + apply_log_config(dict_config, dir_path, file_prefix) + + +def apply_log_config(dict_config, dir_path: str = "", file_prefix: str = ""): + # Update log config dictionary with file_prefix, and apply to dir_path + stack = [dict_config] + while stack: + current_dict = stack.pop() + for key, value in current_dict.items(): + if isinstance(value, dict): + stack.append(value) + elif key == "filename": + if file_prefix: + value = os.path.join(os.path.dirname(value), file_prefix + "_" + os.path.basename(value)) + current_dict[key] = os.path.join(dir_path, value) + + logging.config.dictConfig(dict_config) def add_log_file_handler(log_file_name): @@ -103,7 +271,7 @@ def add_log_file_handler(log_file_name): root_logger.addHandler(file_handler) -def print_logger_hierarchy(package_name="nvflare", level_colors=DEFAULT_LEVEL_COLORS): +def print_logger_hierarchy(package_name="nvflare", level_colors=ANSIColor.DEFAULT_LEVEL_COLORS): all_loggers = logging.root.manager.loggerDict # Filter for package loggers based on package_name @@ -134,8 +302,8 @@ def print_hierarchy(logger_name, indent_level=0): level_display = f"{level_name} (SET)" if not is_unset else level_name # Print the logger with color and indentation - color = level_colors.get(level_name, ANSIColor.RESET) - print(" " * indent_level + f"{ansi_sgr(color)}{logger_name} [{level_display}]{ansi_sgr(ANSIColor.RESET)}") + color = level_colors.get(level_name, ANSIColor.COLORS["reset"]) + print(" " * indent_level + ANSIColor.colorize(f"{logger_name} [{level_display}]", color)) # Find child loggers based on the current hierarchy level for name in sorted_package_loggers: diff --git a/nvflare/job_config/stats_job.py b/nvflare/job_config/stats_job.py index 63cdf19070..1a1320c123 100644 --- a/nvflare/job_config/stats_job.py +++ b/nvflare/job_config/stats_job.py @@ -14,6 +14,7 @@ from typing import List from nvflare import FedJob, FilterType +from nvflare.apis.job_def import SERVER_SITE_NAME from nvflare.app_common.abstract.statistics_spec import Statistics from nvflare.app_common.executors.statistics.statistics_executor import StatisticsExecutor from nvflare.app_common.filters.statistics_privacy_filter import StatisticsPrivacyFilter @@ -50,15 +51,15 @@ def __init__( self.setup_server() - def setup_server(self): + def setup_server(self, server_name: str = SERVER_SITE_NAME): # define stats controller ctr = self.get_stats_controller() self.to(ctr, "server") # define stat writer to output Json file stats_writer = self.get_stats_output_writer() - self.to(stats_writer, "server", id=self.writer_id) + self.to(stats_writer, server_name, id=self.writer_id) - def setup_client(self, sites: List[str]): + def setup_clients(self, sites: List[str]): # Client side job config # Add client site for site_id in sites: diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index 97b18d0595..a6341d3abb 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -115,7 +115,7 @@ class ProvFileName: FED_SERVER_JSON = "fed_server.json" FED_CLIENT_JSON = "fed_client.json" STOP_FL_SH = "stop_fl.sh" - LOG_CONFIG_DEFAULT = "log.config.default" + LOG_CONFIG_DEFAULT = "log_config.json.default" RESOURCES_JSON_DEFAULT = "resources.json.default" PRIVACY_JSON_SAMPLE = "privacy.json.sample" AUTHORIZATION_JSON_DEFAULT = "authorization.json.default" diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index 09ee86230e..68961077d2 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -303,34 +303,83 @@ fl_admin_sh: | python3 -m nvflare.fuel.hci.tools.admin -m $DIR/.. -s fed_admin.json log_config: | - [loggers] - keys=root - - [handlers] - keys=consoleHandler,errorFileHandler - - [formatters] - keys=baseFormatter - - [logger_root] - level=INFO - handlers=consoleHandler,errorFileHandler - - [handler_consoleHandler] - class=StreamHandler - level=DEBUG - formatter=baseFormatter - args=(sys.stdout,) - - [handler_errorFileHandler] - class=FileHandler - level=ERROR - formatter=baseFormatter - args=('error_log.txt', 'a') - - [formatter_baseFormatter] - class=nvflare.fuel.utils.log_utils.BaseFormatter - format=%(asctime)s - %(name)s - %(levelname)s - %(message)s + { + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "baseFormatter": { + "()": "nvflare.fuel.utils.log_utils.BaseFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "colorFormatter": { + "()": "nvflare.fuel.utils.log_utils.ColorFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "jsonFormatter": { + "()": "nvflare.fuel.utils.log_utils.JsonFormatter", + "fmt": "%(asctime)s - %(name)s - %(fullName)s - %(levelname)s - %(message)s" + } + }, + "filters": { + "FLFilter": { + "()": "nvflare.fuel.utils.log_utils.LoggerNameFilter", + "logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"] + } + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "colorFormatter", + "filters": [], + "stream": "ext://sys.stdout" + }, + "logFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filename": "log.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "errorFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "baseFormatter", + "filename": "log_error.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "jsonFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "jsonFormatter", + "filename": "log.json", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "FLFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filters": ["FLFilter"], + "filename": "log_fl.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10, + "delay": true + } + }, + "loggers": { + "root": { + "level": "INFO", + "handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"] + } + } + } start_ovsr_sh: | #!/usr/bin/env bash diff --git a/nvflare/lighter/templates/master_template.yml b/nvflare/lighter/templates/master_template.yml index 7b1e51af88..68961077d2 100644 --- a/nvflare/lighter/templates/master_template.yml +++ b/nvflare/lighter/templates/master_template.yml @@ -303,33 +303,83 @@ fl_admin_sh: | python3 -m nvflare.fuel.hci.tools.admin -m $DIR/.. -s fed_admin.json log_config: | - [loggers] - keys=root - - [handlers] - keys=consoleHandler,errorFileHandler - - [formatters] - keys=fullFormatter - - [logger_root] - level=INFO - handlers=consoleHandler,errorFileHandler - - [handler_consoleHandler] - class=StreamHandler - level=DEBUG - formatter=fullFormatter - args=(sys.stdout,) - - [handler_errorFileHandler] - class=FileHandler - level=ERROR - formatter=fullFormatter - args=('error.log', 'a') - - [formatter_fullFormatter] - format=%(asctime)s - %(name)s - %(levelname)s - %(message)s + { + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "baseFormatter": { + "()": "nvflare.fuel.utils.log_utils.BaseFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "colorFormatter": { + "()": "nvflare.fuel.utils.log_utils.ColorFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "jsonFormatter": { + "()": "nvflare.fuel.utils.log_utils.JsonFormatter", + "fmt": "%(asctime)s - %(name)s - %(fullName)s - %(levelname)s - %(message)s" + } + }, + "filters": { + "FLFilter": { + "()": "nvflare.fuel.utils.log_utils.LoggerNameFilter", + "logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"] + } + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "colorFormatter", + "filters": [], + "stream": "ext://sys.stdout" + }, + "logFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filename": "log.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "errorFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "baseFormatter", + "filename": "log_error.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "jsonFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "jsonFormatter", + "filename": "log.json", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "FLFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filters": ["FLFilter"], + "filename": "log_fl.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10, + "delay": true + } + }, + "loggers": { + "root": { + "level": "INFO", + "handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"] + } + } + } start_ovsr_sh: | #!/usr/bin/env bash diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index 2c26d9f016..ae71e409fa 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -322,7 +322,7 @@ def _write_common(type, dest_dir, template, tplt, replacement_dict, config): def _write_local(type, dest_dir, template, capacity=""): write( - os.path.join(dest_dir, "log.config.default"), + os.path.join(dest_dir, "log_config.json.default"), template["log_config"], "t", ) diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 544567e0c5..235631c978 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -25,6 +25,7 @@ from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.log_utils import configure_logging from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager from nvflare.private.fed.app.utils import component_security_check, version_check @@ -32,7 +33,7 @@ from nvflare.private.fed.client.client_engine import ClientEngine from nvflare.private.fed.client.client_status import ClientStatus from nvflare.private.fed.client.fed_client import FederatedClient -from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, security_init +from nvflare.private.fed.utils.fed_utils import fobs_initialize, security_init from nvflare.private.privacy_manager import PrivacyService from nvflare.security.logging import secure_format_exception @@ -75,8 +76,7 @@ def main(args): ) conf.configure() - log_file = workspace.get_log_file_path() - add_logfile_handler(log_file) + configure_logging(workspace, workspace.get_root_dir()) deployer = conf.base_deployer security_init( diff --git a/nvflare/private/fed/app/client/sub_worker_process.py b/nvflare/private/fed/app/client/sub_worker_process.py index bd2ce4cf8d..321fe3e71b 100644 --- a/nvflare/private/fed/app/client/sub_worker_process.py +++ b/nvflare/private/fed/app/client/sub_worker_process.py @@ -43,7 +43,7 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.sec.audit import AuditService from nvflare.fuel.sec.security_content_service import SecurityContentService -from nvflare.fuel.utils.log_utils import get_obj_logger, get_script_logger +from nvflare.fuel.utils.log_utils import configure_logging, get_obj_logger, get_script_logger from nvflare.private.defs import CellChannel, CellChannelTopic, new_cell_message from nvflare.private.fed.app.fl_conf import create_privacy_manager from nvflare.private.fed.app.utils import monitor_parent_process @@ -51,8 +51,6 @@ from nvflare.private.fed.runner import Runner from nvflare.private.fed.simulator.simulator_app_runner import SimulatorClientRunManager from nvflare.private.fed.utils.fed_utils import ( - add_logfile_handler, - configure_logging, create_stats_pool_files_for_job, fobs_initialize, register_ext_decomposers, @@ -309,7 +307,7 @@ def stop(self): def main(args): workspace = Workspace(args.workspace, args.client_name) - configure_logging(workspace) + configure_logging(workspace, workspace.get_run_dir(args.job_id)) fobs_initialize(workspace=workspace, job_id=args.job_id) register_ext_decomposers(args.decomposer_module) @@ -339,8 +337,6 @@ def main(args): thread.start() job_id = args.job_id - log_file = workspace.get_app_log_file_path(job_id) - add_logfile_handler(log_file) logger = get_script_logger() sub_executor.run() diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index e438e8e8f1..2f4da3db9e 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -25,14 +25,13 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService -from nvflare.fuel.utils.log_utils import get_script_logger +from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger from nvflare.private.defs import EngineConstant from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.client.client_app_runner import ClientAppRunner from nvflare.private.fed.client.client_status import ClientStatus from nvflare.private.fed.utils.fed_utils import ( - add_logfile_handler, create_stats_pool_files_for_job, fobs_initialize, register_ext_decomposers, @@ -98,8 +97,7 @@ def main(args): ) register_ext_decomposers(decomposer_module) - log_file = workspace.get_app_log_file_path(args.job_id) - add_logfile_handler(log_file) + configure_logging(workspace, workspace.get_run_dir(args.job_id)) logger = get_script_logger() logger.info("Worker_process started.") diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index 561dc718ed..6c3df2e296 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -16,7 +16,8 @@ import threading from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, SiteType, SystemComponents +from nvflare.apis.fl_constant import FLContextKey, ReservedKey, SiteType, SystemComponents +from nvflare.apis.signal import Signal from nvflare.apis.workspace import Workspace from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.fed.app.utils import component_security_check @@ -25,6 +26,7 @@ from nvflare.private.fed.server.run_manager import RunManager from nvflare.private.fed.server.server_cmd_modules import ServerCommandModules from nvflare.private.fed.server.server_status import ServerStatus +from nvflare.widgets.fed_event import ServerFedEventRunner class ServerDeployer: @@ -119,10 +121,14 @@ def deploy(self, args): services.engine.set_run_manager(run_manager) services.engine.set_job_runner(job_runner, job_manager) + fed_event_runner = ServerFedEventRunner() + run_manager.add_handler(fed_event_runner) + run_manager.add_handler(job_runner) run_manager.add_component(SystemComponents.JOB_RUNNER, job_runner) with services.engine.new_context() as fl_ctx: + fl_ctx.set_prop(ReservedKey.RUN_ABORT_SIGNAL, Signal(), private=True, sticky=True) fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) fl_ctx.set_prop(FLContextKey.ARGS, args, private=True, sticky=True) fl_ctx.set_prop(FLContextKey.SITE_OBJ, services, private=True, sticky=True) diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 084bb90248..7a60876637 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -26,7 +26,6 @@ from nvflare.fuel.utils.json_scanner import Node from nvflare.fuel.utils.wfconf import ConfigContext, ConfigError from nvflare.private.defs import SSLConstants -from nvflare.private.fed.utils.fed_utils import configure_logging from nvflare.private.json_configer import JsonConfigurator from nvflare.private.privacy_manager import PrivacyManager, Scope @@ -63,8 +62,6 @@ def __init__(self, workspace: Workspace, args, kv_list=None): else: self.cmd_vars = {} - configure_logging(workspace) - config_files = workspace.get_config_files_for_startup(is_server=True, for_job=True if args.job_id else False) JsonConfigurator.__init__( @@ -226,8 +223,6 @@ def __init__(self, workspace: Workspace, args, kv_list=None): else: self.cmd_vars = {} - configure_logging(workspace) - config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False) JsonConfigurator.__init__( diff --git a/nvflare/private/fed/app/server/runner_process.py b/nvflare/private/fed/app/server/runner_process.py index 76dbf73f01..0777152be2 100644 --- a/nvflare/private/fed/app/server/runner_process.py +++ b/nvflare/private/fed/app/server/runner_process.py @@ -26,14 +26,13 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService -from nvflare.fuel.utils.log_utils import get_script_logger +from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.server.server_app_runner import ServerAppRunner from nvflare.private.fed.server.server_state import HotState from nvflare.private.fed.utils.fed_utils import ( - add_logfile_handler, create_stats_pool_files_for_job, fobs_initialize, register_ext_decomposers, @@ -78,8 +77,7 @@ def main(args): args=args, kv_list=args.set, ) - log_file = workspace.get_app_log_file_path(args.job_id) - add_logfile_handler(log_file) + configure_logging(workspace, workspace.get_run_dir(args.job_id)) logger = get_script_logger() logger.info("Runner_process started.") diff --git a/nvflare/private/fed/app/server/server_train.py b/nvflare/private/fed/app/server/server_train.py index 00cf7e121a..2fa79d1513 100644 --- a/nvflare/private/fed/app/server/server_train.py +++ b/nvflare/private/fed/app/server/server_train.py @@ -25,11 +25,12 @@ from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.log_utils import configure_logging from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger, create_privacy_manager from nvflare.private.fed.app.utils import create_admin_server, version_check from nvflare.private.fed.server.server_status import ServerStatus -from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, security_init +from nvflare.private.fed.utils.fed_utils import fobs_initialize, security_init from nvflare.private.privacy_manager import PrivacyService from nvflare.security.logging import secure_format_exception @@ -83,8 +84,7 @@ def main(args): logger.critical("loglevel critical enabled") conf.configure() - log_file = workspace.get_log_file_path() - add_logfile_handler(log_file) + configure_logging(workspace, workspace.get_root_dir()) deployer = conf.deployer secure_train = conf.cmd_vars.get("secure_train", False) diff --git a/nvflare/private/fed/app/simulator/log.config b/nvflare/private/fed/app/simulator/log.config deleted file mode 100644 index 07c2963686..0000000000 --- a/nvflare/private/fed/app/simulator/log.config +++ /dev/null @@ -1,28 +0,0 @@ -[loggers] -keys=root - -[handlers] -keys=consoleHandler,errorFileHandler - -[formatters] -keys=baseFormatter - -[logger_root] -level=INFO -handlers=consoleHandler,errorFileHandler - -[handler_consoleHandler] -class=StreamHandler -level=DEBUG -formatter=baseFormatter -args=(sys.stdout,) - -[handler_errorFileHandler] -class=FileHandler -level=ERROR -formatter=baseFormatter -args=('error_log.txt', 'a') - -[formatter_baseFormatter] -class=nvflare.fuel.utils.log_utils.BaseFormatter -format=%(asctime)s - %(name)s - %(levelname)s - %(message)s \ No newline at end of file diff --git a/nvflare/private/fed/app/simulator/log_config.json b/nvflare/private/fed/app/simulator/log_config.json new file mode 100644 index 0000000000..a245312cf7 --- /dev/null +++ b/nvflare/private/fed/app/simulator/log_config.json @@ -0,0 +1,77 @@ +{ + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "baseFormatter": { + "()": "nvflare.fuel.utils.log_utils.BaseFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "colorFormatter": { + "()": "nvflare.fuel.utils.log_utils.ColorFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + "jsonFormatter": { + "()": "nvflare.fuel.utils.log_utils.JsonFormatter", + "fmt": "%(asctime)s - %(name)s - %(fullName)s - %(levelname)s - %(message)s" + } + }, + "filters": { + "FLFilter": { + "()": "nvflare.fuel.utils.log_utils.LoggerNameFilter", + "logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"] + } + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "colorFormatter", + "filters": [], + "stream": "ext://sys.stdout" + }, + "logFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filename": "log.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "errorFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "baseFormatter", + "filename": "log_error.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "jsonFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "jsonFormatter", + "filename": "log.json", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "FLFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filters": ["FLFilter"], + "filename": "log_fl.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10, + "delay": true + } + }, + "loggers": { + "root": { + "level": "INFO", + "handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"] + } + } +} \ No newline at end of file diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index c22e45327c..4896647265 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -13,7 +13,6 @@ # limitations under the License. import copy import json -import logging.config import os import shlex import shutil @@ -53,6 +52,7 @@ from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.gpu_utils import get_host_gpu_ids +from nvflare.fuel.utils.log_utils import apply_log_config from nvflare.fuel.utils.network_utils import get_open_ports from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, zip_directory_to_bytes from nvflare.private.defs import AppFolderConstants @@ -64,7 +64,6 @@ from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor from nvflare.private.fed.simulator.simulator_const import SimulatorConstants from nvflare.private.fed.utils.fed_utils import ( - add_logfile_handler, custom_fobs_initialize, get_simulator_app_root, nvflare_fobs_initialize, @@ -156,7 +155,9 @@ def setup(self): log_config_file_path = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG) if not os.path.isfile(log_config_file_path): log_config_file_path = os.path.join(os.path.dirname(__file__), WorkspaceConstants.LOGGING_CONFIG) - logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) + + with open(log_config_file_path, "r") as f: + dict_config = json.load(f) self.args.log_config = None self.args.config_folder = "config" @@ -179,8 +180,8 @@ def setup(self): init_security_content_service(self.args.workspace) os.makedirs(os.path.join(self.simulator_root, SiteType.SERVER)) - log_file = os.path.join(self.simulator_root, SiteType.SERVER, WorkspaceConstants.LOG_FILE_NAME) - add_logfile_handler(log_file) + + apply_log_config(dict_config, os.path.join(self.simulator_root, SiteType.SERVER)) try: data_bytes, job_name, meta = self.validate_job_data() diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index 262f78026f..baae171ce9 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -13,7 +13,7 @@ # limitations under the License. import argparse -import logging.config +import json import os import sys import threading @@ -30,6 +30,7 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.hci.server.authz import AuthorizationService from nvflare.fuel.sec.audit import AuditService +from nvflare.fuel.utils.log_utils import apply_log_config from nvflare.private.fed.app.deployer.base_client_deployer import BaseClientDeployer from nvflare.private.fed.app.utils import check_parent_alive, init_security_content_service from nvflare.private.fed.client.client_engine import ClientEngine @@ -38,12 +39,7 @@ from nvflare.private.fed.simulator.simulator_app_runner import SimulatorClientAppRunner from nvflare.private.fed.simulator.simulator_audit import SimulatorAuditor from nvflare.private.fed.simulator.simulator_const import SimulatorConstants -from nvflare.private.fed.utils.fed_utils import ( - add_logfile_handler, - fobs_initialize, - get_simulator_app_root, - register_ext_decomposers, -) +from nvflare.private.fed.utils.fed_utils import fobs_initialize, get_simulator_app_root, register_ext_decomposers from nvflare.security.logging import secure_format_exception, secure_log_traceback from nvflare.security.security import EmptyAuthorizer @@ -238,9 +234,10 @@ def main(args): thread = threading.Thread(target=check_parent_alive, args=(parent_pid, stop_event)) thread.start() - logging.config.fileConfig(fname=args.logging_config, disable_existing_loggers=False) - log_file = os.path.join(args.workspace, WorkspaceConstants.LOG_FILE_NAME) - add_logfile_handler(log_file) + with open(args.logging_config, "r") as f: + dict_config = json.load(f) + + apply_log_config(dict_config, args.workspace) os.chdir(args.workspace) startup = os.path.join(args.workspace, WorkspaceConstants.STARTUP_FOLDER_NAME) diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 4799ec91a5..a3417e280f 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -39,6 +39,7 @@ from nvflare.private.fed.utils.fed_utils import security_close from nvflare.private.stream_runner import ObjectStreamer from nvflare.security.logging import secure_format_exception, secure_log_traceback +from nvflare.widgets.fed_event import ClientFedEventRunner from .client_engine_internal_spec import ClientEngineInternalSpec from .client_executor import JobExecutor @@ -99,6 +100,8 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): self.logger = get_obj_logger(self) self.fl_components = [x for x in self.client.components.values() if isinstance(x, FLComponent)] + self.fl_components.append(ClientFedEventRunner()) + def fire_event(self, event_type: str, fl_ctx: FLContext): fire_event(event=event_type, handlers=self.fl_components, ctx=fl_ctx) @@ -470,6 +473,13 @@ def reset_errors(self, job_id): def get_all_job_ids(self): return self.client_executor.get_run_processes_keys() + def fire_and_forget_aux_request( + self, topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False + ) -> dict: + return self.send_aux_request( + topic=topic, request=request, timeout=0.0, fl_ctx=fl_ctx, optional=optional, secure=secure + ) + def shutdown_client(federated_client, touch_file): with open(touch_file, "a"): diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index be8d27d757..5b84f55909 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -13,13 +13,10 @@ # limitations under the License. import importlib import json -import logging -import logging.config import os import pkgutil import sys import warnings -from logging.handlers import RotatingFileHandler from typing import Any, List, Union from nvflare.apis.app_validation import AppValidator @@ -52,46 +49,6 @@ from .app_authz import AppAuthzService -def add_logfile_handler(log_file: str): - """Adds a log file handler to the root logger. - - The purpose for this is to handle dynamic log file locations. - - If a handler named errorFileHandler is found, it will be used as a template to - create a new handler for writing to the error log file at the same directory as log_file. - The original errorFileHandler will be removed and replaced by the new handler. - - Each log file will be rotated when it reaches 20MB. - - Args: - log_file (str): log file path - """ - root_logger = logging.getLogger() - configured_handlers = root_logger.handlers - main_handler = root_logger.handlers[0] - file_handler = RotatingFileHandler(log_file, maxBytes=20 * 1024 * 1024, backupCount=10) - file_handler.setLevel(main_handler.level) - file_handler.setFormatter(main_handler.formatter) - root_logger.addHandler(file_handler) - - configured_error_handler = None - for handler in configured_handlers: - if handler.get_name() == "errorFileHandler": - configured_error_handler = handler - break - - if not configured_error_handler: - return - - error_log_file = os.path.join(os.path.dirname(log_file), WorkspaceConstants.ERROR_LOG_FILE_NAME) - error_file_handler = RotatingFileHandler(error_log_file, maxBytes=20 * 1024 * 1024, backupCount=10) - error_file_handler.setLevel(configured_error_handler.level) - error_file_handler.setFormatter(configured_error_handler.formatter) - - root_logger.addHandler(error_file_handler) - root_logger.removeHandler(configured_error_handler) - - def _check_secure_content(site_type: str) -> List[str]: """To check the security contents. @@ -253,12 +210,6 @@ def find_char_positions(s, ch): return [i for i, c in enumerate(s) if c == ch] -def configure_logging(workspace: Workspace): - log_config_file_path = workspace.get_log_config_file_path() - assert os.path.isfile(log_config_file_path), f"missing log config file {log_config_file_path}" - logging.config.fileConfig(fname=log_config_file_path, disable_existing_loggers=False) - - def get_scope_info(): try: privacy_manager = PrivacyService.get_manager() diff --git a/nvflare/widgets/fed_event.py b/nvflare/widgets/fed_event.py index 75a9b63529..02272a60da 100644 --- a/nvflare/widgets/fed_event.py +++ b/nvflare/widgets/fed_event.py @@ -52,7 +52,7 @@ def __init__(self, topic=FED_EVENT_TOPIC, regular_interval=0.01, grace_period=2. self.poster = None def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: + if event_type == EventType.START_RUN or event_type == EventType.SYSTEM_BOOTSTRAP: self.engine = fl_ctx.get_engine() self.engine.register_aux_message_handler(topic=self.topic, message_handle_func=self._receive) self.abort_signal = fl_ctx.get_run_abort_signal() @@ -208,7 +208,7 @@ def __init__(self, topic=FED_EVENT_TOPIC): def handle_event(self, event_type: str, fl_ctx: FLContext): super().handle_event(event_type, fl_ctx) - if event_type == EventType.START_RUN: + if event_type == EventType.START_RUN or event_type == EventType.SYSTEM_BOOTSTRAP: self.ready = True def fire_and_forget_request(self, request: Shareable, fl_ctx: FLContext, targets=None, secure=False): diff --git a/research/CONTRIBUTING.md b/research/CONTRIBUTING.md new file mode 100644 index 0000000000..ba1ea0765f --- /dev/null +++ b/research/CONTRIBUTING.md @@ -0,0 +1,70 @@ +# Research Directory +This research directory is the place to host various research work from the community on Federated learning +leveraging NVIDIA FLARE. **The code will not be maintained by NVIDIA FLARE team**, but will require Pull Request +approval process. + +## License +By providing the code in NVFLARE repository, you will grant the research project in NVIDIA repo to be released under Apache v2 License or equivalent open source license. + +## Requirements +Each research project should create a subdirectory with the following requirements. + +* Subdirectory name must be in ASCII string, all in lower, kebab-case, and no longer than 35 characters long +* Each project should include + * README.md -- document must include + * Objective + * Background + * Description + * Setup + * Steps to run the code + * Data download and preparation (if applicable) + * Expected results + * Jobs-folder including configurations and optional custom code + * All code should be in runnable condition, i.e., no broken code + * License file + * Requirements file listing all dependencies, including the NVFLARE version used + +## Example +``` +sample_research$ +. +├── jobs + └── job1 + ├── app_server + ├── config + └── config_fed_server.json + └── custom + └── sample_controller.py + └── app_client + ├── config + └── config_fed_client.json + └── custom + └── sample_executor.py + └── meta.json +└── README.md +└── LICENSE +└── requirements.txt +``` + +## Setup +To run the research code, we recommend using a virtual environment. + +### Set up a virtual environment +``` +python3 -m pip install --user --upgrade pip +python3 -m pip install --user virtualenv +``` +(If needed) make all shell scripts executable using +``` +find . -name ".sh" -exec chmod +x {} \; +``` +initialize virtual environment. +``` +python3 -m venv venv +source venv/bin/activate +``` +within each research folder, install required packages for training +``` +pip install --upgrade pip +pip install -r requirements.txt +``` diff --git a/research/README.md b/research/README.md index ba1ea0765f..508b97fbd2 100644 --- a/research/README.md +++ b/research/README.md @@ -1,70 +1,20 @@ -# Research Directory -This research directory is the place to host various research work from the community on Federated learning -leveraging NVIDIA FLARE. **The code will not be maintained by NVIDIA FLARE team**, but will require Pull Request -approval process. +# Research with NVIDIA FLARE -## License -By providing the code in NVFLARE repository, you will grant the research project in NVIDIA repo to be released under Apache v2 License or equivalent open source license. +Researcher Icon -## Requirements -Each research project should create a subdirectory with the following requirements. +NVIDIA FLARE has been used in several research studies. In this directory, you can find their reference implementations. -* Subdirectory name must be in ASCII string, all in lower, kebab-case, and no longer than 35 characters long -* Each project should include - * README.md -- document must include - * Objective - * Background - * Description - * Setup - * Steps to run the code - * Data download and preparation (if applicable) - * Expected results - * Jobs-folder including configurations and optional custom code - * All code should be in runnable condition, i.e., no broken code - * License file - * Requirements file listing all dependencies, including the NVFLARE version used +## Research Implementations -## Example -``` -sample_research$ -. -├── jobs - └── job1 - ├── app_server - ├── config - └── config_fed_server.json - └── custom - └── sample_controller.py - └── app_client - ├── config - └── config_fed_client.json - └── custom - └── sample_executor.py - └── meta.json -└── README.md -└── LICENSE -└── requirements.txt -``` +1. [FedBPT: Efficient Federated Black-box Prompt Tuning for Large Language Models](./fed-bpt/README.md) [ICML 2024](https://arxiv.org/abs/2310.01467) +2. [ConDistFL: Conditional Distillation for Federated Learning from Partially Annotated Data](./condist-fl/README.md) ([DeCaF 2023](https://arxiv.org/abs/2308.04070)) +3. [Fair Federated Medical Image Segmentation via Client Contribution Estimation](./fed-ce/README.md) ([CVPR 2023](https://arxiv.org/abs/2303.16520)) +4. [Communication-Efficient Vertical Federated Learning with Limited Overlapping Samples](./one-shot-vfl/README.md) [ICCV 2023](https://arxiv.org/abs/2303.16270) +5. [Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation](./fed-sm/README.md) ([CVPR 2022](https://arxiv.org/abs/2203.10144)) +6. [Do Gradient Inversion Attacks Make Federated Learning Unsafe?](./quantifying-data-leakage/README.md) ([IEEE Transactions on Medical Imaging 2022](https://arxiv.org/abs/2202.06924)) +7. [Auto-FedRL: Federated Hyperparameter Optimization for Multi-institutional Medical Image Segmentation](./auto-fed-rl/README.md) ([ECCV 2022](https://arxiv.org/abs/2203.06338)) +8. [FedBN: Federated Learning on Non-IID Features via Local Batch Normalization](./fed-bn/README.md) [ICLR 2021](https://arxiv.org/abs/2102.07623) -## Setup -To run the research code, we recommend using a virtual environment. +## Contributing -### Set up a virtual environment -``` -python3 -m pip install --user --upgrade pip -python3 -m pip install --user virtualenv -``` -(If needed) make all shell scripts executable using -``` -find . -name ".sh" -exec chmod +x {} \; -``` -initialize virtual environment. -``` -python3 -m venv venv -source venv/bin/activate -``` -within each research folder, install required packages for training -``` -pip install --upgrade pip -pip install -r requirements.txt -``` +To provide your own research implementations, please follow this [contribution guide](./CONTRIBUTING.md). diff --git a/runtest.sh b/runtest.sh index bc4d76a343..78dee66bdc 100755 --- a/runtest.sh +++ b/runtest.sh @@ -92,7 +92,7 @@ function check_license() { folders_to_check_license="nvflare examples tests integration research" echo "checking license header in folder: $folders_to_check_license" (grep -r --include "*.py" --exclude-dir "*protos*" --exclude "modeling_roberta.py" -L \ - "\(# Copyright (c) \(2021\|2022\|2023\|2024\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \ + "\(# Copyright (c) \(2021\|2022\|2023\|2024\|2025\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \ ${folders_to_check_license} || true) > no_license.lst if [ -s no_license.lst ]; then # The file is not-empty. diff --git a/setup.cfg b/setup.cfg index f855afc365..664e334262 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ install_requires = docker>=6.0 websockets>=10.4 pyhocon + tdigest [options.extras_require] HE = diff --git a/tests/unit_test/app_common/statistics/numeric_stats_test.py b/tests/unit_test/app_common/statistics/numeric_stats_test.py index ff75786946..6e654ae828 100644 --- a/tests/unit_test/app_common/statistics/numeric_stats_test.py +++ b/tests/unit_test/app_common/statistics/numeric_stats_test.py @@ -53,7 +53,9 @@ def test_accumulate_metrics(self, client_stats, expected_global_stats): global_stats = {} for client_name in client_stats: - global_stats = accumulate_metrics(metrics=client_stats[client_name], global_metrics=global_stats) + global_stats = accumulate_metrics( + metrics=client_stats[client_name], global_metrics=global_stats, precision=4 + ) assert global_stats.keys() == expected_global_stats.keys() assert global_stats == expected_global_stats diff --git a/tests/unit_test/app_opt/statistics/__init__.py b/tests/unit_test/app_opt/statistics/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/app_opt/statistics/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/statistics/percentiles_test.py b/tests/unit_test/app_opt/statistics/percentiles_test.py new file mode 100644 index 0000000000..3f1eee5e22 --- /dev/null +++ b/tests/unit_test/app_opt/statistics/percentiles_test.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy as np +import pandas as pd + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.app_constant import StatisticsConstants +from nvflare.app_common.statistics.numeric_stats import aggregate_centroids, compute_percentiles +from nvflare.app_opt.statistics.df.df_core_statistics import DFStatisticsCore + + +class MockDFStats(DFStatisticsCore): + def __init__(self, given_median: int): + super().__init__() + self.median = given_median + self.data = {"train": None} + + def initialize(self, fl_ctx: FLContext): + self.load_data() + + def load_data(self): + data = np.concatenate( + (np.arange(0, self.median), [self.median], np.arange(self.median + 1, self.median * 2 + 1)) + ) + + # Shuffle the data to make it unordered + np.random.shuffle(data) + + # Create the DataFrame + df = pd.DataFrame(data, columns=["Feature"]) + self.data = {"train": df} + + +class MockDFStats2(DFStatisticsCore): + def __init__(self, data_array: List[int]): + super().__init__() + self.raw_data = data_array + self.data = {"train": None} + + def initialize(self, fl_ctx: FLContext): + self.load_data() + + def load_data(self): + # Create the DataFrame + df = pd.DataFrame(self.raw_data, columns=["Feature"]) + self.data = {"train": df} + + +class TestPercentiles: + + def test_percentile_metrics(self): + stats_generator = MockDFStats(given_median=100) + stats_generator.load_data() + percentiles = stats_generator.percentiles("train", "Feature", percents=[50]) + result = percentiles.get(StatisticsConstants.STATS_PERCENTILES_KEY) + print(f"{percentiles=}") + assert result is not None + assert result.get(50) == stats_generator.median + + def test_percentile_metrics_aggregation(self): + stats_generators = [ + MockDFStats2(data_array=[0, 1, 2, 3, 4, 5]), + MockDFStats(given_median=10), + MockDFStats2(data_array=[100, 110, 120, 130, 140, 150]), + ] + global_digest = {} + result = {} + for g in stats_generators: # each site/client + g.load_data() + local_percentiles = g.percentiles("train", "Feature", percents=[50]) + local_metrics = {"train": {"Feature": local_percentiles}} + aggregate_centroids(local_metrics, global_digest) + result = compute_percentiles(global_digest, {"Feature": [50]}, 2) + + expected_median = 10 + assert result["train"]["Feature"].get(50) == expected_median diff --git a/tests/unit_test/fuel/utils/fobs/fobs_test.py b/tests/unit_test/fuel/utils/fobs/fobs_test.py index 360fe415b4..4e800c491c 100644 --- a/tests/unit_test/fuel/utils/fobs/fobs_test.py +++ b/tests/unit_test/fuel/utils/fobs/fobs_test.py @@ -28,6 +28,7 @@ class TestFobs: NUMBER = 123456 FLOAT = 123.456 NAME = "FOBS Test" + SET = {4, 5, 6} NOW = datetime.now() test_data = { @@ -35,7 +36,7 @@ class TestFobs: "number": NUMBER, "float": FLOAT, "list": [7, 8, 9], - "set": {4, 5, 6}, + "set": SET, "tuple": ("abc", "xyz"), "time": NOW, } @@ -44,11 +45,7 @@ def test_builtin(self): buf = fobs.dumps(TestFobs.test_data) data = fobs.loads(buf) assert data["number"] == TestFobs.NUMBER - - def test_aliases(self): - buf = fobs.dumps(TestFobs.test_data) - data = fobs.loads(buf) - assert data["number"] == TestFobs.NUMBER + assert data["set"] == TestFobs.SET def test_unsupported_classes(self): with pytest.raises(TypeError):