Skip to content

Commit

Permalink
type checking etc
Browse files Browse the repository at this point in the history
  • Loading branch information
ottojo committed Aug 25, 2023
1 parent 74469a3 commit 6d0e854
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 37 deletions.
9 changes: 4 additions & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
{
"python.formatting.autopep8Args": [
"--max-line-length=150"
],
"python.autoComplete.extraPaths": [
"/home/gja38/sandbox_otto/install/orchestrator_interfaces/local/lib/python3.10/dist-packages",
"/opt/carolo/install/spatz_interfaces/local/lib/python3.10/dist-packages",
Expand All @@ -18,7 +15,6 @@
"/home/gja38/sandbox_otto/src/orchestrator",
"/workspaces/ros2_def/install/orchestrator_interfaces/lib/python3.10/site-packages"
],
"python.analysis.typeCheckingMode": "basic",
"python.analysis.inlayHints.functionReturnTypes": true,
"python.analysis.inlayHints.variableTypes": true,
"files.exclude": {
Expand Down Expand Up @@ -47,5 +43,8 @@
"editor.stickyScroll.enabled": true,
"terminal.integrated.scrollback": 10000,
"python.analysis.stubPath": "ros2/orchestrator/typings",
"cmake.configureOnOpen": false
"cmake.configureOnOpen": false,
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
}
}
2 changes: 2 additions & 0 deletions ros2/orchestrator/orchestrator/lib.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from orchestrator.orchestrator_lib.orchestrator import Orchestrator

__all__ = ["Orchestrator"]
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, name: str, remappings: SimpleRemapRules) -> None:
super().__init__()

@abstractmethod
def state_sequence_push(self, x: Any):
def state_sequence_push(self, message: Any):
...

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pyright: strict

from __future__ import annotations

import base64
import json
from dataclasses import dataclass
from typing import cast, final, Dict, Optional, Any
from typing import Iterable, cast, final, Optional, Any

from orchestrator.orchestrator_lib.name_utils import normalize_topic_name
from orchestrator.orchestrator_lib.node_model import Cause, Effect, NodeModel, ServiceCall, ServiceName, \
Expand Down Expand Up @@ -41,23 +43,23 @@ def dump_state_sequence(self):
with open('state_sequence_' + self.get_name() + '.json', 'w') as f:
json.dump(self.state_recording, f)

def __init__(self, node_config: dict, name, remappings: Dict[str, str],
state_sequence: Optional[list] = None) -> None:
def __init__(self, node_config: dict[str, Any], name: str, remappings: dict[str, str],
state_sequence: Optional[list[str]] = None) -> None:

self.state_sequence = state_sequence
self.state_sequence: Optional[list[str]] = state_sequence
if self.state_sequence is not None:
self.state_sequence.reverse()
self.state_recording = []
self.state_recording: list[str] = []

# Mappings from internal to external name
mappings: dict[str, str] = {}

inputs = set()
inputs: set[str] = set()

# Initialize mappings by identity for all known inputs and outputs from
# node config.
for callback in node_config["callbacks"]:
trigger = callback["trigger"]
for callback in cast(Iterable[dict[str, Any]], node_config["callbacks"]):
trigger: str | dict[str, Any] = cast(str | dict[str, Any], callback["trigger"])

if isinstance(trigger, str):
trigger = normalize_topic_name(trigger)
Expand Down Expand Up @@ -114,7 +116,7 @@ def __init__(self, node_config: dict, name, remappings: Dict[str, str],
# Mapping from external topic input to external topic outputs
self.effects: dict[Cause, Callback] = {}

def add_effect(trigger: Cause, outputs, service_calls, changes_dp_state: bool, may_reconfigure: bool):
def add_effect(trigger: Cause, outputs: Iterable[str], service_calls: Iterable[str], changes_dp_state: bool, may_reconfigure: bool):
output_effects: list[Effect] = []
for output in outputs:
output = normalize_topic_name(output)
Expand All @@ -136,19 +138,19 @@ def add_effect(trigger: Cause, outputs, service_calls, changes_dp_state: bool, m

for callback in node_config["callbacks"]:
trigger = callback["trigger"]

cause: Cause
if isinstance(trigger, str):
trigger = normalize_topic_name(trigger)
trigger = self.internal_topic_input(trigger)
add_effect(trigger,
cause = self.internal_topic_input(trigger)
add_effect(cause,
callback.get("outputs", []),
callback.get("service_calls", []),
callback.get("changes_dataprovider_state", False),
callback.get("may_cause_reconfiguration", False))
elif trigger.get("type", None) == "topic" and "name" in trigger:
topic_name = normalize_topic_name(trigger["name"])
trigger = self.internal_topic_input(topic_name)
add_effect(trigger,
cause = self.internal_topic_input(topic_name)
add_effect(cause,
callback.get("outputs", []),
callback.get("service_calls", []),
callback.get("changes_dataprovider_state", False),
Expand All @@ -164,8 +166,7 @@ def add_effect(trigger: Cause, outputs, service_calls, changes_dp_state: bool, m
callback.get("service_calls", []),
callback.get("changes_dataprovider_state", False),
callback.get("may_cause_reconfiguration", False))
elif trigger.get("type",
None) == "approximate_time_sync" and "input_topics" in trigger and "slop" in trigger and "queue_size" in trigger:
elif trigger.get("type", None) == "approximate_time_sync" and "input_topics" in trigger and "slop" in trigger and "queue_size" in trigger:
input_topics = trigger["input_topics"]
slop = trigger["slop"]
queue = trigger["queue_size"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# at some point, enable: pyright: strict, reportMissingTypeStubs=true
# pyright: basic

import datetime
from dataclasses import dataclass
Expand Down Expand Up @@ -235,7 +235,7 @@ def intercept_topic(canonical_name: TopicName, node: NodeModel):
TopicType,
effect.output_topic,
lambda msg,
topic_name=effect.output_topic: self.__interception_subscription_callback(
topic_name=effect.output_topic: self.__interception_subscription_callback(
topic_name, msg),
10, raw=(TopicType != rosgraph_msgs.msg.Clock))
self.interception_subs[effect.output_topic] = sub
Expand Down Expand Up @@ -648,7 +648,7 @@ def __buffer_nodes_with_data(self) -> Generator[Tuple[GraphNodeId, OrchestratorB
yield (id, action)

def __buffer_childs_of_parent(self, parent: GraphNodeId) -> Generator[
Tuple[GraphNodeId, OrchestratorBufferAction], None, None]:
Tuple[GraphNodeId, OrchestratorBufferAction], None, None]:
for id, data in self.__buffer_nodes_with_data():
if self.graph.has_edge(id, parent):
yield id, data
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pyright: strict

import sys

from typing import List, Dict
Expand All @@ -6,8 +8,8 @@
from orchestrator.orchestrator_lib.name_utils import intercepted_name, normalize_topic_name
from .model_loader import *

from launch_ros.actions import SetRemap
from launch.substitutions import TextSubstitution
from launch_ros.actions import SetRemap # pyright: ignore [reportMissingTypeStubs]
from launch.substitutions import TextSubstitution # pyright: ignore [reportMissingTypeStubs]


def _find_node_model(name: str, models: List[NodeModel]) -> NodeModel:
Expand Down Expand Up @@ -38,7 +40,7 @@ def generate_remappings_from_config_file(package_name: str, launch_config_file:
return generate_remappings_from_config(launch_config)


def generate_remappings_from_config(launch_config: dict) -> List[SetRemap]:
def generate_remappings_from_config(launch_config: dict[str, Any]) -> List[SetRemap]:
"""
Generate remappings for topic interception by orchestrator.
Expand All @@ -49,7 +51,7 @@ def generate_remappings_from_config(launch_config: dict) -> List[SetRemap]:
"""
node_models = load_models(launch_config, load_node_config_schema())

remap_actions = []
remap_actions: list[SetRemap] = []

for node_name, node in launch_config["nodes"].items():
if "/" in node_name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
def lc(logger, msg):
# pyright: strict
from rclpy.impl.rcutils_logger import RcutilsLogger


def lc(logger: RcutilsLogger, msg: str) -> bool:
return logger.info('\033[96m' + msg + '\033[0m')
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pyright: strict

from typing import Any, List, Mapping
from threading import Lock
from message_filters import SimpleFilter, ApproximateTimeSynchronizer
Expand All @@ -19,7 +21,7 @@ def __init__(self, topic_names: List[str], queue_size: int, slop: float) -> None
self._lock = Lock()
self._time_synchronizer.registerCallback(self.__callback)

def __callback(self, *_msgs):
def __callback(self, *_msgs: Any):
self._was_called = True

def test_input(self, topic_name: str, msg: Any) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Type, Optional
# pyright: strict

from typing import Any, Type, Optional

from rclpy.task import Future
from rclpy.executors import Executor
Expand All @@ -8,7 +10,7 @@
from orchestrator.orchestrator_lib.name_utils import TopicName, type_from_string


def wait_for_topic(name: TopicName, logger: RcutilsLogger, node: Node, executor: Executor) -> Type:
def wait_for_topic(name: TopicName, logger: RcutilsLogger, node: Node, executor: Executor) -> Type[Any]:
name = node.resolve_topic_name(name)

def find_type():
Expand All @@ -29,10 +31,10 @@ def find_type():
return msgtype


def wait_for_node_sub(topic_name: str, node_name: str, logger: RcutilsLogger, node: Node, executor: Executor) -> Type:
def wait_for_node_sub(topic_name: str, node_name: str, logger: RcutilsLogger, node: Node, executor: Executor) -> Type[Any]:
topic_name = node.resolve_topic_name(topic_name)

def try_get_type() -> Optional[Type]:
def try_get_type() -> Optional[Type[Any]]:
for info in node.get_subscriptions_info_by_topic(topic_name):
if info.node_name == node_name:
return type_from_string(info.topic_type)
Expand All @@ -59,7 +61,7 @@ def wait_for_node_pub(topic_name: str, node_name: str, logger: RcutilsLogger, no

def node_has_pub():
by_node = node.get_publisher_names_and_types_by_node(node_name, node.get_namespace())
for topic, types in by_node:
for topic, _types in by_node:
if topic == topic_name:
return True
return False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pyright: strict

import datetime
from rclpy.executors import Executor

Expand Down
4 changes: 3 additions & 1 deletion ros2/orchestrator/pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"include": [
"orchestrator"
]
],
"typeCheckingMode": "strict",
"reportUnnecessaryTypeIgnoreComment": "error"
}
8 changes: 8 additions & 0 deletions ros2/orchestrator/typings/deepdiff/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""

import logging
from .diff import DeepDiff

__all__ = ["DeepDiff"]
8 changes: 8 additions & 0 deletions ros2/orchestrator/typings/deepdiff/diff.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""


class DeepDiff:
def __init__(self, t1, t2) -> None:
...
37 changes: 37 additions & 0 deletions ros2/orchestrator/typings/message_filters/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
This type stub file was generated by pyright.
"""


from typing import Any, Callable, Iterable


class SimpleFilter:
def __init__(self) -> None:
...

def registerCallback(self, cb: Callable[..., None], *args: Any) -> int:
...

def signalMessage(self, *msg: Any) -> None:
...


class TimeSynchronizer(SimpleFilter):
def __init__(self, fs, queue_size) -> None:
...

def connectInput(self, fs): # -> None:
...

def add(self, msg, my_queue, my_queue_index=...): # -> None:
...


class ApproximateTimeSynchronizer(TimeSynchronizer):

def __init__(self, fs: Iterable[SimpleFilter], queue_size, slop, allow_headerless=...) -> None:
...

def add(self, msg, my_queue, my_queue_index=...): # -> None:
...
13 changes: 13 additions & 0 deletions ros2/orchestrator/typings/rclpy/serialization.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""

from typing import Any, Type


def serialize_message(message: Any) -> bytes:
...


def deserialize_message(serialized_message: bytes, message_type: Type[Any]) -> Any:
...

0 comments on commit 6d0e854

Please sign in to comment.