Skip to content

Commit

Permalink
fix hello-pt, empty metrics (#2840)
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Aug 24, 2024
1 parent debc484 commit 76b3bcd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"

Expand All @@ -33,7 +34,6 @@ def main():
lr = 0.01
model = SimpleNetwork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
loss = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
transforms = Compose(
Expand All @@ -58,6 +58,7 @@ def main():
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
Expand Down
18 changes: 11 additions & 7 deletions nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,26 @@ def aggregate_fn(results: List[FLModel]) -> FLModel:

aggr_helper = WeightedAggregationHelper()
aggr_metrics_helper = WeightedAggregationHelper()
all_metrics = True
for _result in results:
aggr_helper.add(
data=_result.params,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.current_round,
)
aggr_metrics_helper.add(
data=_result.metrics,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.current_round,
)
if not _result.metrics:
all_metrics = False
if all_metrics:
aggr_metrics_helper.add(
data=_result.metrics,
weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0),
contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN),
contribution_round=_result.current_round,
)

aggr_params = aggr_helper.get_result()
aggr_metrics = aggr_metrics_helper.get_result()
aggr_metrics = aggr_metrics_helper.get_result() if all_metrics else None

aggr_result = FLModel(
params=aggr_params,
Expand Down

0 comments on commit 76b3bcd

Please sign in to comment.