Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 2 additions & 29 deletions src/lerobot/scripts/rl/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl import learner_service
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
Expand Down Expand Up @@ -399,43 +399,16 @@ def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json

"""
Returns a client for the learner service.

GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""

service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}

service_config_json = json.dumps(service_config)

channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
grpc_channel_options(),
)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
Expand Down
5 changes: 3 additions & 2 deletions src/lerobot/scripts/rl/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
Expand Down Expand Up @@ -658,8 +659,8 @@ def start_learner(
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
],
)

Expand Down
1 change: 0 additions & 1 deletion src/lerobot/scripts/rl/learner_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.queue import get_last_item_from_queue

MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10

Expand Down
5 changes: 4 additions & 1 deletion src/lerobot/scripts/server/robot_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options


class RobotClient:
Expand Down Expand Up @@ -113,7 +114,9 @@ def __init__(self, config: RobotClientConfig):
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(self.server_address)
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")

Expand Down
41 changes: 41 additions & 0 deletions src/lerobot/transport/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

import io
import json
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
Expand All @@ -27,6 +28,7 @@
from lerobot.utils.transition import Transition

CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB


def bytes_buffer_size(buffer: io.BytesIO) -> int:
Expand Down Expand Up @@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()


def grpc_channel_options(
max_receive_message_length: int = MAX_MESSAGE_SIZE,
max_send_message_length: int = MAX_MESSAGE_SIZE,
enable_retries: bool = True,
initial_backoff: str = "0.1s",
max_attempts: int = 5,
backoff_multiplier: float = 2,
max_backoff: str = "2s",
):
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": max_attempts, # Max retries (total attempts = 5)
"initialBackoff": initial_backoff, # First retry after 0.1s
"maxBackoff": max_backoff, # Max wait time between retries
"backoffMultiplier": backoff_multiplier, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}

service_config_json = json.dumps(service_config)

retries_option = 1 if enable_retries else 0

return [
("grpc.max_receive_message_length", max_receive_message_length),
("grpc.max_send_message_length", max_send_message_length),
("grpc.enable_retries", retries_option),
("grpc.service_config", service_config_json),
]
Loading