Skip to content

Commit

Permalink
Add stream handler
Browse files Browse the repository at this point in the history
  • Loading branch information
apssouza22 committed Oct 22, 2024
1 parent e27a91a commit a9a5249
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 108 deletions.
4 changes: 2 additions & 2 deletions python/rtcconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def on_message(message):
self.on_message_fn(message, channel)

async def _on_connectionstatechange(self):
self.on_connection_state_change_fn()
await self.on_connection_state_change_fn()


def _on_track(self, track):
self.on_track_fn(track)

@track.on("ended")
async def on_ended():
self.on_track_end_fn(track)
await self.on_track_end_fn(track)

async def create_answer(self, offer_request: dict[str, str]):
offer = RTCSessionDescription(sdp=offer_request["sdp"], type=offer_request["type"])
Expand Down
60 changes: 10 additions & 50 deletions python/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from aiohttp import web
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay

from stream_handle import StreamHandler
from rtcconn import RTCConnectionHandler
from video_transform import VideoTransformTrack

ROOT = os.path.dirname(__file__)

logger = logging.getLogger("pc")
pcs = set()
relay = MediaRelay()


async def index(request):
Expand All @@ -32,56 +32,16 @@ async def javascript(request):
async def offer(request):
params = await request.json()
rtc = RTCConnectionHandler()

pc = rtc.conn
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)

def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)

log_info("Created for %s", request.remote)
stream = StreamHandler(rtc)
pcs.add(stream)
player = MediaPlayer(os.path.join(ROOT, "demo-instruct.wav"))
if args.record_to:
recorder = MediaRecorder(args.record_to)
else:
recorder = MediaBlackhole()

def on_message(message, channel):
if isinstance(message, str) and message.startswith("ping"):
channel.send("pong" + message[4:])

async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)

def on_track(track):
log_info("Track %s received", track.kind)

if track.kind == "audio":
pc.addTrack(player.audio)
recorder.addTrack(track)

elif track.kind == "video":
# Duplicate incoming video track so different transforms can be applied
track1 = relay.subscribe(track)
track2 = relay.subscribe(track)
pc.addTrack(VideoTransformTrack(track1, transform=params["video_transform"]))
if args.record_to:
recorder.addTrack(track2)

async def on_ended(track):
log_info("Track %s ended", track.kind)
await recorder.stop()

rtc.add_on_track(on_track)
rtc.add_on_message(on_message)
rtc.add_on_track_end(on_ended)
rtc.add_on_connection_state_change(on_connectionstatechange)
answer = await rtc.create_answer(params)
await recorder.start()
file = os.path.join(ROOT, "video.mp4")
print(file)
recorder = MediaRecorder(file)
# recorder = MediaBlackhole()
stream.set_media_player(player)
stream.set_media_recorder(recorder)
answer = await stream.start(params)

return web.Response(
content_type="application/json",
Expand Down
67 changes: 67 additions & 0 deletions python/stream_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import uuid

from aiortc.contrib.media import MediaRelay

from video_transform import VideoTransformTrack
from rtcconn import RTCConnectionHandler
import logging

logger = logging.getLogger("pc")


class StreamHandler:
def __init__(self, conn: RTCConnectionHandler):
self.media_player = None
self.conn_handler = conn
self.rtc = conn.conn
self.media_recorder = None
self.stream_id = "PeerConnection(%s)" % uuid.uuid4()
self.relay = MediaRelay()

async def close(self):
await self.rtc.close()

def set_media_recorder(self, media_recorder):
self.media_recorder = media_recorder

def set_media_player(self, media_player):
self.media_player = media_player

def log_info(self, msg, *args):
logger.info(self.stream_id + " " + msg, *args)

def on_message(self, message, channel):
if isinstance(message, str) and message.startswith("ping"):
channel.send("pong" + message[4:])

async def on_connectionstatechange(self):
self.log_info("Connection state is %s", self.rtc.connectionState)
if self.rtc.connectionState == "failed":
await self.rtc.close()

def on_track(self, track):
self.log_info("Track %s received", track.kind)

if track.kind == "audio":
self.rtc.addTrack(self.media_player.audio)
self.media_recorder.addTrack(track)

elif track.kind == "video":
# Duplicate incoming video track so different transforms can be applied
track1 = self.relay.subscribe(track)
track2 = self.relay.subscribe(track)

self.rtc.addTrack(VideoTransformTrack(track1))
self.media_recorder.addTrack(track2)

async def on_ended(self, track):
self.log_info("Track %s ended", track.kind)
await self.media_recorder.stop()

async def start(self, params):
await self.media_recorder.start()
self.conn_handler.add_on_track(self.on_track)
self.conn_handler.add_on_message(self.on_message)
self.conn_handler.add_on_track_end(self.on_ended)
self.conn_handler.add_on_connection_state_change(self.on_connectionstatechange)
return await self.conn_handler.create_answer(params)
58 changes: 2 additions & 56 deletions python/video_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,10 @@ class VideoTransformTrack(MediaStreamTrack):

kind = "video"

def __init__(self, track, transform):
def __init__(self, track):
super().__init__() # don't forget this!
self.track = track
self.transform = transform

async def recv(self):
frame = await self.track.recv()

if self.transform == "cartoon":
img = frame.to_ndarray(format="bgr24")

# prepare color
img_color = cv2.pyrDown(cv2.pyrDown(img))
for _ in range(6):
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
img_color = cv2.pyrUp(cv2.pyrUp(img_color))

# prepare edges
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img_edges = cv2.adaptiveThreshold(
cv2.medianBlur(img_edges, 7),
255,
cv2.ADAPTIVE_THRESH_MEAN_C,
cv2.THRESH_BINARY,
9,
2,
)
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)

# combine color and edges
img = cv2.bitwise_and(img_color, img_edges)

# rebuild a VideoFrame, preserving timing information
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
elif self.transform == "edges":
# perform edge detection
img = frame.to_ndarray(format="bgr24")
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)

# rebuild a VideoFrame, preserving timing information
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
elif self.transform == "rotate":
# rotate image
img = frame.to_ndarray(format="bgr24")
rows, cols, _ = img.shape
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
img = cv2.warpAffine(img, M, (cols, rows))

# rebuild a VideoFrame, preserving timing information
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
else:
return frame
return frame

0 comments on commit a9a5249

Please sign in to comment.