diff --git a/python/rtcconn.py b/python/rtcconn.py index b06e471..ddb1694 100644 --- a/python/rtcconn.py +++ b/python/rtcconn.py @@ -42,7 +42,7 @@ def on_message(message): self.on_message_fn(message, channel) async def _on_connectionstatechange(self): - self.on_connection_state_change_fn() + await self.on_connection_state_change_fn() def _on_track(self, track): @@ -50,7 +50,7 @@ def _on_track(self, track): @track.on("ended") async def on_ended(): - self.on_track_end_fn(track) + await self.on_track_end_fn(track) async def create_answer(self, offer_request: dict[str, str]): offer = RTCSessionDescription(sdp=offer_request["sdp"], type=offer_request["type"]) diff --git a/python/server.py b/python/server.py index 2ea8aa5..2252049 100644 --- a/python/server.py +++ b/python/server.py @@ -9,6 +9,7 @@ from aiohttp import web from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay +from stream_handle import StreamHandler from rtcconn import RTCConnectionHandler from video_transform import VideoTransformTrack @@ -16,7 +17,6 @@ logger = logging.getLogger("pc") pcs = set() -relay = MediaRelay() async def index(request): @@ -32,56 +32,16 @@ async def javascript(request): async def offer(request): params = await request.json() rtc = RTCConnectionHandler() - - pc = rtc.conn - pc_id = "PeerConnection(%s)" % uuid.uuid4() - pcs.add(pc) - - def log_info(msg, *args): - logger.info(pc_id + " " + msg, *args) - - log_info("Created for %s", request.remote) + stream = StreamHandler(rtc) + pcs.add(stream) player = MediaPlayer(os.path.join(ROOT, "demo-instruct.wav")) - if args.record_to: - recorder = MediaRecorder(args.record_to) - else: - recorder = MediaBlackhole() - - def on_message(message, channel): - if isinstance(message, str) and message.startswith("ping"): - channel.send("pong" + message[4:]) - - async def on_connectionstatechange(): - log_info("Connection state is %s", pc.connectionState) - if pc.connectionState == "failed": - await pc.close() - pcs.discard(pc) - - def on_track(track): - log_info("Track %s received", track.kind) - - if track.kind == "audio": - pc.addTrack(player.audio) - recorder.addTrack(track) - - elif track.kind == "video": - # Duplicate incoming video track so different transforms can be applied - track1 = relay.subscribe(track) - track2 = relay.subscribe(track) - pc.addTrack(VideoTransformTrack(track1, transform=params["video_transform"])) - if args.record_to: - recorder.addTrack(track2) - - async def on_ended(track): - log_info("Track %s ended", track.kind) - await recorder.stop() - - rtc.add_on_track(on_track) - rtc.add_on_message(on_message) - rtc.add_on_track_end(on_ended) - rtc.add_on_connection_state_change(on_connectionstatechange) - answer = await rtc.create_answer(params) - await recorder.start() + file = os.path.join(ROOT, "video.mp4") + print(file) + recorder = MediaRecorder(file) + # recorder = MediaBlackhole() + stream.set_media_player(player) + stream.set_media_recorder(recorder) + answer = await stream.start(params) return web.Response( content_type="application/json", diff --git a/python/stream_handle.py b/python/stream_handle.py new file mode 100644 index 0000000..617bbb4 --- /dev/null +++ b/python/stream_handle.py @@ -0,0 +1,67 @@ +import uuid + +from aiortc.contrib.media import MediaRelay + +from video_transform import VideoTransformTrack +from rtcconn import RTCConnectionHandler +import logging + +logger = logging.getLogger("pc") + + +class StreamHandler: + def __init__(self, conn: RTCConnectionHandler): + self.media_player = None + self.conn_handler = conn + self.rtc = conn.conn + self.media_recorder = None + self.stream_id = "PeerConnection(%s)" % uuid.uuid4() + self.relay = MediaRelay() + + async def close(self): + await self.rtc.close() + + def set_media_recorder(self, media_recorder): + self.media_recorder = media_recorder + + def set_media_player(self, media_player): + self.media_player = media_player + + def log_info(self, msg, *args): + logger.info(self.stream_id + " " + msg, *args) + + def on_message(self, message, channel): + if isinstance(message, str) and message.startswith("ping"): + channel.send("pong" + message[4:]) + + async def on_connectionstatechange(self): + self.log_info("Connection state is %s", self.rtc.connectionState) + if self.rtc.connectionState == "failed": + await self.rtc.close() + + def on_track(self, track): + self.log_info("Track %s received", track.kind) + + if track.kind == "audio": + self.rtc.addTrack(self.media_player.audio) + self.media_recorder.addTrack(track) + + elif track.kind == "video": + # Duplicate incoming video track so different transforms can be applied + track1 = self.relay.subscribe(track) + track2 = self.relay.subscribe(track) + + self.rtc.addTrack(VideoTransformTrack(track1)) + self.media_recorder.addTrack(track2) + + async def on_ended(self, track): + self.log_info("Track %s ended", track.kind) + await self.media_recorder.stop() + + async def start(self, params): + await self.media_recorder.start() + self.conn_handler.add_on_track(self.on_track) + self.conn_handler.add_on_message(self.on_message) + self.conn_handler.add_on_track_end(self.on_ended) + self.conn_handler.add_on_connection_state_change(self.on_connectionstatechange) + return await self.conn_handler.create_answer(params) diff --git a/python/video_transform.py b/python/video_transform.py index 28d4b74..296125e 100644 --- a/python/video_transform.py +++ b/python/video_transform.py @@ -10,64 +10,10 @@ class VideoTransformTrack(MediaStreamTrack): kind = "video" - def __init__(self, track, transform): + def __init__(self, track): super().__init__() # don't forget this! self.track = track - self.transform = transform async def recv(self): frame = await self.track.recv() - - if self.transform == "cartoon": - img = frame.to_ndarray(format="bgr24") - - # prepare color - img_color = cv2.pyrDown(cv2.pyrDown(img)) - for _ in range(6): - img_color = cv2.bilateralFilter(img_color, 9, 9, 7) - img_color = cv2.pyrUp(cv2.pyrUp(img_color)) - - # prepare edges - img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - img_edges = cv2.adaptiveThreshold( - cv2.medianBlur(img_edges, 7), - 255, - cv2.ADAPTIVE_THRESH_MEAN_C, - cv2.THRESH_BINARY, - 9, - 2, - ) - img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) - - # combine color and edges - img = cv2.bitwise_and(img_color, img_edges) - - # rebuild a VideoFrame, preserving timing information - new_frame = VideoFrame.from_ndarray(img, format="bgr24") - new_frame.pts = frame.pts - new_frame.time_base = frame.time_base - return new_frame - elif self.transform == "edges": - # perform edge detection - img = frame.to_ndarray(format="bgr24") - img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) - - # rebuild a VideoFrame, preserving timing information - new_frame = VideoFrame.from_ndarray(img, format="bgr24") - new_frame.pts = frame.pts - new_frame.time_base = frame.time_base - return new_frame - elif self.transform == "rotate": - # rotate image - img = frame.to_ndarray(format="bgr24") - rows, cols, _ = img.shape - M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) - img = cv2.warpAffine(img, M, (cols, rows)) - - # rebuild a VideoFrame, preserving timing information - new_frame = VideoFrame.from_ndarray(img, format="bgr24") - new_frame.pts = frame.pts - new_frame.time_base = frame.time_base - return new_frame - else: - return frame \ No newline at end of file + return frame \ No newline at end of file