Skip to content

Commit

Permalink
Merge pull request #1198 from Codium-ai/tr/polling
Browse files Browse the repository at this point in the history
Tr/polling
  • Loading branch information
mrT23 authored Sep 5, 2024
2 parents 6f17c08 + 85754d2 commit b02fa22
Showing 1 changed file with 149 additions and 38 deletions.
187 changes: 149 additions & 38 deletions pr_agent/servers/github_polling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import multiprocessing
from collections import deque
import traceback
from datetime import datetime, timezone

import time
import requests
import aiohttp

from pr_agent.agent.pr_agent import PRAgent
Expand All @@ -13,6 +16,15 @@
NOTIFICATION_URL = "https://api.github.com/notifications"


async def mark_notification_as_read(headers, notification, session):
async with session.patch(
f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers) as mark_read_response:
if mark_read_response.status != 205:
get_logger().error(
f"Failed to mark notification as read. Status code: {mark_read_response.status}")


def now() -> str:
"""
Get the current UTC time in ISO 8601 format.
Expand All @@ -24,6 +36,106 @@ def now() -> str:
now_utc = now_utc.replace("+00:00", "Z")
return now_utc

async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
return success

def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))


def process_comment_sync(pr_url, rest_of_comment, comment_id):
try:
# Run the async handle_request in a separate function
git_provider = get_git_provider()(pr_url=pr_url)
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})


async def process_comment(pr_url, rest_of_comment, comment_id):
try:
git_provider = get_git_provider()(pr_url=pr_url)
git_provider.set_pr(pr_url)
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
get_logger().info(f"Finished processing comment for PR: {pr_url}")
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})

async def is_valid_notification(notification, headers, handled_ids, session, user_id):
try:
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
get_logger().debug(f"not latest_comment, but its ok")
# continue
async with session.get(latest_comment, headers=headers) as comment_response:
check_prev_comments = False
if comment_response.status == 200:
comment = await comment_response.json()
if 'id' in comment:
if comment['id'] in handled_ids:
get_logger().debug(f"comment['id'] in handled_ids")
return False, handled_ids
else:
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
get_logger().debug(f"comment['user']['login'] == user_id")
check_prev_comments = True
comment_body = comment.get('body', '')
if not comment_body:
get_logger().debug(f"no comment_body")
check_prev_comments = True
commenter_github_user = comment['user']['login'] \
if 'user' in comment else ''
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
user_tag = "@" + user_id
if user_tag not in comment_body:
get_logger().debug(f"user_tag not in comment_body")
check_prev_comments = True

if not check_prev_comments:
return True, handled_ids, comment, comment_body, pr_url, user_tag
else: # we could not find the user tag in the latest comment. Check previous comments
# get all comments in the PR
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
comments_response = requests.get(requests_url, headers=headers)
comments = comments_response.json()[::-1]
max_comment_to_scan = 4
for comment in comments[:max_comment_to_scan]:
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
continue
comment_body = comment.get('body', '')
if not comment_body:
continue
if user_tag in comment_body:
get_logger().info("found user tag in previous comments")
return True, handled_ids, comment, comment_body, pr_url, user_tag

get_logger().error(f"Failed to fetch comments for PR: {pr_url}")
return False, handled_ids

return False, handled_ids
except Exception as e:
get_logger().error(f"Error processing notification: {e}", artifact={"traceback": traceback.format_exc()})
return False, handled_ids



async def polling_loop():
"""
Expand All @@ -34,7 +146,6 @@ async def polling_loop():
last_modified = [None]
git_provider = get_git_provider()()
user_id = git_provider.get_user_id()
agent = PRAgent()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
get_settings().set("pr_description.publish_description_as_comment", True)

Expand Down Expand Up @@ -74,43 +185,43 @@ async def polling_loop():
notifications = await response.json()
if not notifications:
continue
get_logger().info(f"Received {len(notifications)} notifications")
task_queue = deque()
for notification in notifications:
# mark notification as read
await mark_notification_as_read(headers, notification, session)

handled_ids.add(notification['id'])
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
continue
async with session.get(latest_comment, headers=headers) as comment_response:
if comment_response.status == 200:
comment = await comment_response.json()
if 'id' in comment:
if comment['id'] in handled_ids:
continue
else:
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
continue
comment_body = comment.get('body', '')
if not comment_body:
continue
commenter_github_user = comment['user']['login'] \
if 'user' in comment else ''
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
user_tag = "@" + user_id
if user_tag not in comment_body:
continue
rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']
git_provider.set_pr(pr_url)
success = await agent.handle_request(pr_url, rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(
comment_id)) # noqa E501
if not success:
git_provider.set_pr(pr_url)
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
if output[0]:
_, handled_ids, comment, comment_body, pr_url, user_tag = output
rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']

# Add to the task queue
get_logger().info(
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
get_logger().info(f"Queued comment processing for PR: {pr_url}")
else:
get_logger().debug(f"Skipping comment processing for PR: {pr_url}")

max_allowed_parallel_tasks = 10
if task_queue:
processes = []
for i, func, args in enumerate(task_queue): # Create parallel tasks
p = multiprocessing.Process(target=func, args=args)
processes.append(p)
p.start()
if i > max_allowed_parallel_tasks:
get_logger().error(
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
break
task_queue.clear()

# Dont wait for all processes to complete. Move on to the next iteration
# for p in processes:
# p.join()

elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}")
Expand All @@ -121,4 +232,4 @@ async def polling_loop():


if __name__ == '__main__':
asyncio.run(polling_loop())
asyncio.run(polling_loop())

0 comments on commit b02fa22

Please sign in to comment.