Skip to content

Commit

Permalink
Added check for duplicate RM request (#2858)
Browse files Browse the repository at this point in the history
* Added check for duplicate RM request

* Addressed PR comment
  • Loading branch information
nvidianz authored Aug 27, 2024
1 parent dd90ddb commit 7b01b0f
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions nvflare/apis/utils/reliable_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
STATUS_NOT_RECEIVED = "not_received"
STATUS_REPLIED = "replied"
STATUS_ABORTED = "aborted"
STATUS_DUP_REQUEST = "dup_request"

# Topics for Reliable Message
TOPIC_RELIABLE_REQUEST = "RM.RELIABLE_REQUEST"
Expand Down Expand Up @@ -227,6 +228,7 @@ class ReliableMessage:

_topic_to_handle = {}
_req_receivers = {} # tx id => receiver
_req_completed = {} # tx id => expiration
_enabled = False
_executor = None
_query_interval = 1.0
Expand Down Expand Up @@ -293,6 +295,9 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext):
# no handler registered for this topic!
cls.error(fl_ctx, f"no handler registered for request {rm_topic=}")
return make_reply(ReturnCode.TOPIC_UNKNOWN)
if cls._req_completed.get(tx_id):
cls.debug(fl_ctx, "Completed tx_id received")
return _status_reply(STATUS_DUP_REQUEST)
receiver = cls._get_or_create_receiver(rm_topic, request, handler_f)
cls.debug(fl_ctx, f"received request {rm_topic=}")
return receiver.process(request, fl_ctx)
Expand Down Expand Up @@ -336,6 +341,7 @@ def release_request_receiver(cls, receiver: _RequestReceiver, fl_ctx: FLContext)
"""
with cls._tx_lock:
cls._register_completed_req(receiver.tx_id, receiver.tx_timeout)
cls._req_receivers.pop(receiver.tx_id, None)
cls.debug(fl_ctx, f"released request receiver of TX {receiver.tx_id}")

Expand Down Expand Up @@ -679,3 +685,15 @@ def _query_result(
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
else:
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")

@classmethod
def _register_completed_req(cls, tx_id, tx_timeout):
# Remove expired entries, need to use a copy of the keys
now = time.time()
for key in list(cls._req_completed.keys()):
expiration = cls._req_completed.get(key)
if expiration and expiration < now:
cls._req_completed.pop(key, None)

# Expire in 2 x tx_timeout
cls._req_completed[tx_id] = now + 2 * tx_timeout

0 comments on commit 7b01b0f

Please sign in to comment.