diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index cd5e286c0c..1c8f9286bf 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -63,12 +63,12 @@ from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl import learner_service from lerobot.scripts.rl.gym_manipulator import make_robot_env from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import ( bytes_to_state_dict, + grpc_channel_options, python_object_to_bytes, receive_bytes_in_chunks, send_bytes_in_chunks, @@ -399,8 +399,6 @@ def learner_service_client( host: str = "127.0.0.1", port: int = 50051, ) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: - import json - """ Returns a client for the learner service. @@ -408,34 +406,9 @@ def learner_service_client( So we need to create only one client and reuse it. """ - service_config = { - "methodConfig": [ - { - "name": [{}], # Applies to ALL methods in ALL services - "retryPolicy": { - "maxAttempts": 5, # Max retries (total attempts = 5) - "initialBackoff": "0.1s", # First retry after 0.1s - "maxBackoff": "2s", # Max wait time between retries - "backoffMultiplier": 2, # Exponential backoff factor - "retryableStatusCodes": [ - "UNAVAILABLE", - "DEADLINE_EXCEEDED", - ], # Retries on network failures - }, - } - ] - } - - service_config_json = json.dumps(service_config) - channel = grpc.insecure_channel( f"{host}:{port}", - options=[ - ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.enable_retries", 1), - ("grpc.service_config", service_config_json), - ], + grpc_channel_options(), ) stub = services_pb2_grpc.LearnerServiceStub(channel) logging.info("[ACTOR] Learner service client created") diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index edd2363b12..cb88895cfa 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -77,6 +77,7 @@ from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( + MAX_MESSAGE_SIZE, bytes_to_python_object, bytes_to_transitions, state_to_bytes, @@ -658,8 +659,8 @@ def start_learner( server = grpc.server( ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), options=[ - ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), ], ) diff --git a/src/lerobot/scripts/rl/learner_service.py b/src/lerobot/scripts/rl/learner_service.py index 198e52945b..b07c296e6e 100644 --- a/src/lerobot/scripts/rl/learner_service.py +++ b/src/lerobot/scripts/rl/learner_service.py @@ -23,7 +23,6 @@ from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks from lerobot.utils.queue import get_last_item_from_queue -MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index a6d7b72427..44d9cdf776 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -76,6 +76,7 @@ async_inference_pb2, # type: ignore async_inference_pb2_grpc, # type: ignore ) +from lerobot.transport.utils import grpc_channel_options class RobotClient: @@ -113,7 +114,9 @@ def __init__(self, config: RobotClientConfig): config.actions_per_chunk, config.policy_device, ) - self.channel = grpc.insecure_channel(self.server_address) + self.channel = grpc.insecure_channel( + self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") + ) self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) self.logger.info(f"Initializing client to connect to server at {self.server_address}") diff --git a/src/lerobot/transport/utils.py b/src/lerobot/transport/utils.py index 1c66832624..bf1aab7554 100644 --- a/src/lerobot/transport/utils.py +++ b/src/lerobot/transport/utils.py @@ -16,6 +16,7 @@ # limitations under the License. import io +import json import logging import pickle # nosec B403: Safe usage for internal serialization only from multiprocessing import Event, Queue @@ -27,6 +28,7 @@ from lerobot.utils.transition import Transition CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB def bytes_buffer_size(buffer: io.BytesIO) -> int: @@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes: buffer = io.BytesIO() torch.save(transitions, buffer) return buffer.getvalue() + + +def grpc_channel_options( + max_receive_message_length: int = MAX_MESSAGE_SIZE, + max_send_message_length: int = MAX_MESSAGE_SIZE, + enable_retries: bool = True, + initial_backoff: str = "0.1s", + max_attempts: int = 5, + backoff_multiplier: float = 2, + max_backoff: str = "2s", +): + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": max_attempts, # Max retries (total attempts = 5) + "initialBackoff": initial_backoff, # First retry after 0.1s + "maxBackoff": max_backoff, # Max wait time between retries + "backoffMultiplier": backoff_multiplier, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + retries_option = 1 if enable_retries else 0 + + return [ + ("grpc.max_receive_message_length", max_receive_message_length), + ("grpc.max_send_message_length", max_send_message_length), + ("grpc.enable_retries", retries_option), + ("grpc.service_config", service_config_json), + ]