From 834a8aa9fbaf27c97e4fe2e128074435ad76429e Mon Sep 17 00:00:00 2001 From: derkmed Date: Fri, 2 Feb 2024 11:20:26 -0500 Subject: [PATCH] Apply code formatting to audio nodes --- .../angel_system_nodes/audio/asr.py | 17 +++++++---------- .../audio/dialogue_utterance_processing.py | 7 ++++--- .../audio/emotion/base_emotion_detector.py | 10 +++++----- .../audio/intent/base_intent_detector.py | 13 +++++-------- .../audio/intent/gpt_intent_detector.py | 4 ++-- 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/ros/angel_system_nodes/angel_system_nodes/audio/asr.py b/ros/angel_system_nodes/angel_system_nodes/audio/asr.py index 1d765e919..a24208dea 100644 --- a/ros/angel_system_nodes/angel_system_nodes/audio/asr.py +++ b/ros/angel_system_nodes/angel_system_nodes/audio/asr.py @@ -106,8 +106,9 @@ def __init__(self): self.subscription = self.create_subscription( HeadsetAudioData, self._audio_topic, self.listener_callback, 1 ) - self._publisher = self.create_publisher(DialogueUtterance, - self._utterances_topic, 1) + self._publisher = self.create_publisher( + DialogueUtterance, self._utterances_topic, 1 + ) self.audio_stream = [] self.t = threading.Thread() @@ -204,18 +205,15 @@ def asr_server_request_thread(self, audio_data, num_channels, sample_rate): if response: response_text = json.loads(response.text)["text"] self.log.info("Complete ASR text is:\n" + f'"{response_text}"') - self._publish_response(response_text, - self._is_sentence_tokenize_mode) + self._publish_response(response_text, self._is_sentence_tokenize_mode) def _publish_response(self, response_text: str, tokenize_sentences: bool): if tokenize_sentences: for sentence in sent_tokenize(response_text): - self._publisher.publish( - self._construct_dialogue_utterance(sentence)) + self._publisher.publish(self._construct_dialogue_utterance(sentence)) else: - self._publisher.publish( - self._construct_dialogue_utterance(response_text)) - + self._publisher.publish(self._construct_dialogue_utterance(response_text)) + def _construct_dialogue_utterance(self, msg_text: str) -> DialogueUtterance: msg = DialogueUtterance() msg.header.frame_id = "ASR" @@ -225,7 +223,6 @@ def _construct_dialogue_utterance(self, msg_text: str) -> DialogueUtterance: return msg - main = make_default_main(ASR) diff --git a/ros/angel_system_nodes/angel_system_nodes/audio/dialogue_utterance_processing.py b/ros/angel_system_nodes/angel_system_nodes/audio/dialogue_utterance_processing.py index 16f6b3a44..2674d7000 100644 --- a/ros/angel_system_nodes/angel_system_nodes/audio/dialogue_utterance_processing.py +++ b/ros/angel_system_nodes/angel_system_nodes/audio/dialogue_utterance_processing.py @@ -1,8 +1,9 @@ from angel_msgs.msg import DialogueUtterance -def copy_dialogue_utterance(msg: DialogueUtterance, - node_name, - copy_time) -> DialogueUtterance: + +def copy_dialogue_utterance( + msg: DialogueUtterance, node_name, copy_time +) -> DialogueUtterance: msg = DialogueUtterance() msg.header.frame_id = node_name msg.utterance_text = msg.utterance_text diff --git a/ros/angel_system_nodes/angel_system_nodes/audio/emotion/base_emotion_detector.py b/ros/angel_system_nodes/angel_system_nodes/audio/emotion/base_emotion_detector.py index bb8555a14..f2b98bf89 100644 --- a/ros/angel_system_nodes/angel_system_nodes/audio/emotion/base_emotion_detector.py +++ b/ros/angel_system_nodes/angel_system_nodes/audio/emotion/base_emotion_detector.py @@ -53,9 +53,7 @@ def __init__(self): self.emotion_detection_callback, 1, ) - self._publication = self.create_publisher( - DialogueUtterance, self._out_topic, 1 - ) + self._publication = self.create_publisher(DialogueUtterance, self._out_topic, 1) self.message_queue = queue.Queue() self.handler_thread = threading.Thread(target=self.process_message_queue) @@ -119,8 +117,10 @@ def process_message(self, msg: DialogueUtterance): """ classification, confidence_score = self.get_inference(msg) pub_msg = dialogue_utterance_processing.copy_dialogue_utterance( - msg, node_name="Emotion Detection", - copy_time=self.get_clock().now().to_msg()) + msg, + node_name="Emotion Detection", + copy_time=self.get_clock().now().to_msg(), + ) # Overwrite the user emotion with the latest classification information. pub_msg.emotion = classification pub_msg.emotion_confidence_score = confidence_score diff --git a/ros/angel_system_nodes/angel_system_nodes/audio/intent/base_intent_detector.py b/ros/angel_system_nodes/angel_system_nodes/audio/intent/base_intent_detector.py index 3ff903d18..cb1c7f6fc 100644 --- a/ros/angel_system_nodes/angel_system_nodes/audio/intent/base_intent_detector.py +++ b/ros/angel_system_nodes/angel_system_nodes/audio/intent/base_intent_detector.py @@ -110,30 +110,27 @@ def _tiebreak_intents(intents, confidences): if not intents: colored_utterance = colored(msg.utterance_text, "light_blue") - self.log.info( - f'No intents detected for:\n>>> "{colored_utterance}":') + self.log.info(f'No intents detected for:\n>>> "{colored_utterance}":') return None, -1.0 else: classification, confidence = _tiebreak_intents(intents, confidences) classification = colored(classification, "light_green") self.publish_message(msg.utterance_text, classification, confidence) - def publish_message(self, msg: DialogueUtterance, intent: str, - score: float): + def publish_message(self, msg: DialogueUtterance, intent: str, score: float): """ Handles message publishing for an utterance with a detected intent. """ pub_msg = self.copy_dialogue_utterance( - msg, node_name="Intent Detection", - copy_time=self.get_clock().now().to_msg()) + msg, node_name="Intent Detection", copy_time=self.get_clock().now().to_msg() + ) # Overwrite the user intent with the latest classification information. pub_msg.intent = intent pub_msg.intent_confidence_score = score # Decide which intent topic to publish the message to. published_topic = None - if self._contains_phrase(pub_msg.utterance_text.lower(), - OVERRIDE_KEYPHRASES): + if self._contains_phrase(pub_msg.utterance_text.lower(), OVERRIDE_KEYPHRASES): published_topic = PARAM_EXPECT_USER_INTENT_TOPIC pub_msg.intent_confidence_score = 1.0 self._expected_publisher.publish(pub_msg) diff --git a/ros/angel_system_nodes/angel_system_nodes/audio/intent/gpt_intent_detector.py b/ros/angel_system_nodes/angel_system_nodes/audio/intent/gpt_intent_detector.py index 8c5d29a78..47099c2ce 100644 --- a/ros/angel_system_nodes/angel_system_nodes/audio/intent/gpt_intent_detector.py +++ b/ros/angel_system_nodes/angel_system_nodes/audio/intent/gpt_intent_detector.py @@ -13,7 +13,6 @@ from angel_utils import declare_and_get_parameters, make_default_main - openai.organization = os.getenv("OPENAI_ORG_ID") openai.api_key = os.getenv("OPENAI_API_KEY") @@ -28,6 +27,7 @@ PARAM_TIMEOUT = "timeout" + class GptIntentDetector(BaseIntentDetector): def __init__(self): super().__init__() @@ -100,7 +100,7 @@ def detect_intents(self, msg: DialogueUtterance): Detects the user intent via langchain execution of GPT. """ intent = self.chain.run(utterance=msg.utterance_text) - return intent.split('[eos]')[0], 0.5 + return intent.split("[eos]")[0], 0.5 main = make_default_main(GptIntentDetector)