Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflow communication APIs and Simplified ML Algorithms #2250

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0029993
controller workflow APIs and simplified FedAvg and Fed Kaplan-Meier e…
chesterxgchen Dec 29, 2023
fce9c61
update
chesterxgchen Dec 29, 2023
e47ff4a
update
chesterxgchen Dec 29, 2023
660bdcf
update
chesterxgchen Dec 29, 2023
8733f60
Add Fed Cyclic example
chesterxgchen Dec 29, 2023
f1ad53d
Add Fed Cyclic example
chesterxgchen Dec 29, 2023
b1fa51b
addres PR comments
chesterxgchen Dec 29, 2023
d9ad8b2
1. Remove Base Class ErrorHandleController, instead move the function…
chesterxgchen Dec 31, 2023
4322602
add header
chesterxgchen Dec 31, 2023
3cab712
add header
chesterxgchen Dec 31, 2023
6221dff
code style format
chesterxgchen Dec 31, 2023
f535c62
make better user experience
chesterxgchen Dec 31, 2023
c6e19c7
code format and import
chesterxgchen Dec 31, 2023
a3fb099
remove comment
chesterxgchen Dec 31, 2023
6b57dc9
remove used method
chesterxgchen Dec 31, 2023
eb2272f
1. add intime aggregate version of fedavg
chesterxgchen Jan 2, 2024
373ce19
update README.md
chesterxgchen Jan 2, 2024
6ac0888
add ask all clients to end run when server in exception
chesterxgchen Jan 3, 2024
c5994c3
rebase and remove extra command
chesterxgchen Jan 5, 2024
4b84310
wip
chesterxgchen Jan 12, 2024
4f3abe5
remove WF dependency
chesterxgchen Jan 13, 2024
e635ea4
1. remove ctrl_msg_Queue, use controller directly.
chesterxgchen Jan 13, 2024
c9ee619
update README.md and cleanup
chesterxgchen Jan 13, 2024
76c3c43
change comm_msg_pull_interval to result_pull_interval
chesterxgchen Jan 13, 2024
a52d27a
1. fix message_bus
chesterxgchen Jan 13, 2024
8343392
design change, broken commit
chesterxgchen Jan 18, 2024
d914592
everything works now
chesterxgchen Jan 20, 2024
30552e4
everything works now
chesterxgchen Jan 20, 2024
c297073
merge with new data bus changes. The code is broken now.
chesterxgchen Jan 28, 2024
b4da21f
fix the lock issue.
chesterxgchen Jan 28, 2024
ae657f6
define strategy.py in case it is needed.
chesterxgchen Jan 28, 2024
f0ae6e3
define strategy.py in case it is needed.
chesterxgchen Jan 28, 2024
207c13a
make sure the publish in parallel instead of sequential
chesterxgchen Jan 29, 2024
a274a21
ADD CODE TO ADDRESS THE NEW DESIGN CHANGES.
chesterxgchen Jan 30, 2024
7c4f62e
format code
chesterxgchen Jan 30, 2024
5fdd1c0
fix the issue with return result
chesterxgchen Jan 30, 2024
5d09732
cleanup, fix original controller parsing
SYangster Feb 6, 2024
8f35c4a
Merge branch 'main' into wl_controller
SYangster Feb 13, 2024
5b8420c
fix format
SYangster Feb 14, 2024
a5f63ad
databus updates
SYangster Feb 14, 2024
da30df2
add docstrings, address comments
SYangster Feb 22, 2024
d37cd23
fix communicator pairing, remove temp example
SYangster Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nvflare/apis/dxo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DataKind(object):
COLLECTION = "COLLECTION" # Dict or List of DXO objects
STATISTICS = "STATISTICS"
PSI = "PSI"
RAW = "RAW"


class MetaKey(FLMetaKey):
Expand Down
97 changes: 97 additions & 0 deletions nvflare/apis/wf_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Callable, List, Optional

from nvflare.app_common import wf_comm

from .fl_constant import ReturnCode

ABORT_WHEN_IN_ERROR = {
ReturnCode.EXECUTION_EXCEPTION: True,
ReturnCode.TASK_UNKNOWN: True,
ReturnCode.EXECUTION_RESULT_ERROR: False,
ReturnCode.TASK_DATA_FILTER_ERROR: True,
ReturnCode.TASK_RESULT_FILTER_ERROR: True,
}


class WFController(ABC):
def __init__(self):
self.communicator = wf_comm.get_wf_comm_api()

@abstractmethod
def run(self):
pass

def broadcast_and_wait(
self,
task_name: str,
min_responses: int,
data: any,
meta: dict = None,
targets: Optional[List[str]] = None,
callback: Callable = None,
):
return self.communicator.broadcast_and_wait(task_name, min_responses, data, meta, targets, callback)

def send_and_wait(
self,
task_name: str,
min_responses: int,
data: any,
meta: dict = None,
targets: Optional[List[str]] = None,
send_order: str = "sequential",
callback: Callable = None,
):
return self.communicator.send_and_wait(task_name, min_responses, data, meta, targets, send_order, callback)

def relay_and_wait(
self,
task_name: str,
min_responses: int,
data: any,
meta: dict = None,
targets: Optional[List[str]] = None,
relay_order: str = "sequential",
callback: Callable = None,
):
return self.communicator.relay_and_wait(task_name, min_responses, data, meta, targets, relay_order, callback)

def broadcast(self, task_name: str, data: any, meta: dict = None, targets: Optional[List[str]] = None):
return self.communicator.broadcast(task_name, data, meta, targets)

def send(
self,
task_name: str,
data: any,
meta: dict = None,
targets: Optional[str] = None,
send_order: str = "sequential",
):
return self.communicator.send(task_name, data, meta, targets, send_order)

def relay(
self,
task_name: str,
data: any,
meta: dict = None,
targets: Optional[List[str]] = None,
relay_order: str = "sequential",
):
return self.communicator.send(task_name, data, meta, targets, relay_order)

def get_site_names(self) -> List[str]:
return self.communicator.get_site_names()
2 changes: 2 additions & 0 deletions nvflare/app_common/abstract/fl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
params: Any = None,
optimizer_params: Any = None,
metrics: Optional[Dict] = None,
start_round: int = 0,
current_round: Optional[int] = None,
total_rounds: Optional[int] = None,
meta: Optional[Dict] = None,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
self.params = params
self.optimizer_params = optimizer_params
self.metrics = metrics
self.start_round = start_round
self.current_round = current_round
self.total_rounds = total_rounds

Expand Down
5 changes: 3 additions & 2 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def get_configs(model: FLModel) -> Optional[dict]:

@staticmethod
def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel:

model.metrics = model_update.metrics

SYangster marked this conversation as resolved.
Show resolved Hide resolved
if model.params_type != ParamsType.FULL:
raise RuntimeError(f"params_type {model.params_type} of `model` not supported! Expected `ParamsType.FULL`.")

Expand All @@ -209,8 +212,6 @@ def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = Tru
else:
model.meta.update(model_update.meta)

model.metrics = model_update.metrics

if model_update.params_type == ParamsType.FULL:
model.params = model_update.params
elif model_update.params_type == ParamsType.DIFF:
Expand Down
58 changes: 58 additions & 0 deletions nvflare/app_common/utils/math_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from typing import Callable, Optional, Tuple

operator_mapping = {
">=": operator.ge,
"<=": operator.le,
">": operator.gt,
"<": operator.lt,
"=": operator.eq,
}


def parse_compare_criteria(compare_expr: Optional[str] = None) -> Tuple[str, float, Callable]:
SYangster marked this conversation as resolved.
Show resolved Hide resolved
"""
Parse the compare expression into individual component
compare expression is in the format of string literal : "<key> <op> <value>"
such as
accuracy >= 0.5
loss > 2.4
Args:
compare_expr: string literal in the format of "<key> <op> <value>"

Returns: Tuple key, value, operator

"""
tokens = compare_expr.split(" ")
if len(tokens) != 3:
raise ValueError(
f"Invalid early_stop_condition, expecting form of '<metric> <op> value' but got '{compare_expr}'"
)

key = tokens[0]
op = tokens[1]
target = tokens[2]
op_fn = operator_mapping.get(op, None)
if op_fn is None:
raise ValueError("Invalid operator symbol: expecting one of <=, =, >=, <, > ")
if not target:
raise ValueError("Invalid empty or None target value")
try:
target_value = float(target)
except Exception as e:
raise ValueError(f"expect a number, but get '{target}' in '{compare_expr}'")

return key, target_value, op_fn
22 changes: 22 additions & 0 deletions nvflare/app_common/wf_comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.app_common.wf_comm.wf_comm_api_spec import WFCommAPISpec
from nvflare.fuel.data_event.data_bus import DataBus

data_bus = DataBus()


def get_wf_comm_api() -> WFCommAPISpec:
return data_bus.get_data("wf_comm_api")
SYangster marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading