Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tr/polling #1198

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 149 additions & 38 deletions pr_agent/servers/github_polling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import multiprocessing
from collections import deque
import traceback
from datetime import datetime, timezone

import time
import requests
import aiohttp

from pr_agent.agent.pr_agent import PRAgent
Expand All @@ -13,6 +16,15 @@
NOTIFICATION_URL = "https://api.github.com/notifications"


async def mark_notification_as_read(headers, notification, session):
async with session.patch(
f"https://api.github.com/notifications/threads/{notification['id']}",
headers=headers) as mark_read_response:
if mark_read_response.status != 205:
get_logger().error(
f"Failed to mark notification as read. Status code: {mark_read_response.status}")


def now() -> str:
"""
Get the current UTC time in ISO 8601 format.
Expand All @@ -24,6 +36,106 @@ def now() -> str:
now_utc = now_utc.replace("+00:00", "Z")
return now_utc

async def async_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
return success

def run_handle_request(pr_url, rest_of_comment, comment_id, git_provider):
return asyncio.run(async_handle_request(pr_url, rest_of_comment, comment_id, git_provider))


def process_comment_sync(pr_url, rest_of_comment, comment_id):
try:
# Run the async handle_request in a separate function
git_provider = get_git_provider()(pr_url=pr_url)
success = run_handle_request(pr_url, rest_of_comment, comment_id, git_provider)
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})


async def process_comment(pr_url, rest_of_comment, comment_id):
try:
git_provider = get_git_provider()(pr_url=pr_url)
git_provider.set_pr(pr_url)
agent = PRAgent()
success = await agent.handle_request(
pr_url,
rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(comment_id)
)
get_logger().info(f"Finished processing comment for PR: {pr_url}")
except Exception as e:
get_logger().error(f"Error processing comment: {e}", artifact={"traceback": traceback.format_exc()})

async def is_valid_notification(notification, headers, handled_ids, session, user_id):
try:
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
get_logger().debug(f"not latest_comment, but its ok")
# continue
async with session.get(latest_comment, headers=headers) as comment_response:
check_prev_comments = False
if comment_response.status == 200:
comment = await comment_response.json()
if 'id' in comment:
if comment['id'] in handled_ids:
get_logger().debug(f"comment['id'] in handled_ids")
return False, handled_ids
else:
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
get_logger().debug(f"comment['user']['login'] == user_id")
check_prev_comments = True
comment_body = comment.get('body', '')
if not comment_body:
get_logger().debug(f"no comment_body")
check_prev_comments = True
commenter_github_user = comment['user']['login'] \
if 'user' in comment else ''
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
user_tag = "@" + user_id
if user_tag not in comment_body:
get_logger().debug(f"user_tag not in comment_body")
check_prev_comments = True

if not check_prev_comments:
return True, handled_ids, comment, comment_body, pr_url, user_tag
else: # we could not find the user tag in the latest comment. Check previous comments
# get all comments in the PR
requests_url = f"{pr_url}/comments".replace("pulls", "issues")
comments_response = requests.get(requests_url, headers=headers)
comments = comments_response.json()[::-1]
max_comment_to_scan = 4
for comment in comments[:max_comment_to_scan]:
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
continue
comment_body = comment.get('body', '')
if not comment_body:
continue
if user_tag in comment_body:
get_logger().info("found user tag in previous comments")
return True, handled_ids, comment, comment_body, pr_url, user_tag

get_logger().error(f"Failed to fetch comments for PR: {pr_url}")
return False, handled_ids

return False, handled_ids
except Exception as e:
get_logger().error(f"Error processing notification: {e}", artifact={"traceback": traceback.format_exc()})
return False, handled_ids



async def polling_loop():
"""
Expand All @@ -34,7 +146,6 @@ async def polling_loop():
last_modified = [None]
git_provider = get_git_provider()()
user_id = git_provider.get_user_id()
agent = PRAgent()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
get_settings().set("pr_description.publish_description_as_comment", True)

Expand Down Expand Up @@ -74,43 +185,43 @@ async def polling_loop():
notifications = await response.json()
if not notifications:
continue
get_logger().info(f"Received {len(notifications)} notifications")
task_queue = deque()
for notification in notifications:
# mark notification as read
await mark_notification_as_read(headers, notification, session)

handled_ids.add(notification['id'])
if 'reason' in notification and notification['reason'] == 'mention':
if 'subject' in notification and notification['subject']['type'] == 'PullRequest':
pr_url = notification['subject']['url']
latest_comment = notification['subject']['latest_comment_url']
if not latest_comment or not isinstance(latest_comment, str):
continue
async with session.get(latest_comment, headers=headers) as comment_response:
if comment_response.status == 200:
comment = await comment_response.json()
if 'id' in comment:
if comment['id'] in handled_ids:
continue
else:
handled_ids.add(comment['id'])
if 'user' in comment and 'login' in comment['user']:
if comment['user']['login'] == user_id:
continue
comment_body = comment.get('body', '')
if not comment_body:
continue
commenter_github_user = comment['user']['login'] \
if 'user' in comment else ''
get_logger().info(f"Polling, pr_url: {pr_url}",
artifact={"comment": comment_body})
user_tag = "@" + user_id
if user_tag not in comment_body:
continue
rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']
git_provider.set_pr(pr_url)
success = await agent.handle_request(pr_url, rest_of_comment,
notify=lambda: git_provider.add_eyes_reaction(
comment_id)) # noqa E501
if not success:
git_provider.set_pr(pr_url)
output = await is_valid_notification(notification, headers, handled_ids, session, user_id)
if output[0]:
_, handled_ids, comment, comment_body, pr_url, user_tag = output
rest_of_comment = comment_body.split(user_tag)[1].strip()
comment_id = comment['id']

# Add to the task queue
get_logger().info(
f"Adding comment processing to task queue for PR, {pr_url}, comment_body: {comment_body}")
task_queue.append((process_comment_sync, (pr_url, rest_of_comment, comment_id)))
get_logger().info(f"Queued comment processing for PR: {pr_url}")
else:
get_logger().debug(f"Skipping comment processing for PR: {pr_url}")

max_allowed_parallel_tasks = 10
if task_queue:
processes = []
for i, func, args in enumerate(task_queue): # Create parallel tasks
p = multiprocessing.Process(target=func, args=args)
processes.append(p)
p.start()
if i > max_allowed_parallel_tasks:
get_logger().error(
f"Dropping {len(task_queue) - max_allowed_parallel_tasks} tasks from polling session")
break
task_queue.clear()

# Dont wait for all processes to complete. Move on to the next iteration
# for p in processes:
# p.join()

elif response.status != 304:
print(f"Failed to fetch notifications. Status code: {response.status}")
Expand All @@ -121,4 +232,4 @@ async def polling_loop():


if __name__ == '__main__':
asyncio.run(polling_loop())
asyncio.run(polling_loop())
Loading