Skip to content

Commit

Permalink
Feat/reliable GitHub wrapper (#3900)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwzeng1 authored May 27, 2024
1 parent 6735e16 commit 4b71abd
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 14 deletions.
6 changes: 1 addition & 5 deletions sweepai/config/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@

BOT_TOKEN_NAME = "bot-token"

# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
GITHUB_BASE_URL = os.environ.get("GITHUB_BASE_URL", "https://api.github.com") # configure for enterprise

SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
Expand Down
6 changes: 3 additions & 3 deletions sweepai/handlers/on_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def edit_sweep_comment(
+ "\n"
+ message
+ (
"\n\nFor bonus Sweep issues, please report this bug on our"
f" **[community forum](https://community.sweep.dev/)** (tracking ID: `{tracking_id}`)."
"\n\n"
" **[Report a bug](https://community.sweep.dev/)**."
if add_bonus_message
else ""
)
Expand Down Expand Up @@ -383,7 +383,7 @@ def edit_sweep_comment(
logger.warning(f"Validation error: {error_message}")
edit_sweep_comment(
(
f"The issue was rejected with the following response:\n\n{bold(error_message)}"
f"\n\n{bold(error_message)}"
),
-1,
)
Expand Down
53 changes: 47 additions & 6 deletions sweepai/utils/github_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import datetime
import difflib
Expand All @@ -16,12 +18,18 @@

import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement, GithubException
from github import Github
from github.Auth import Token
# get default_base_url from github
from github.Requester import Requester
from github.GithubException import BadCredentialsException, UnknownObjectException
from github import PullRequest, Repository, InputGitTreeElement, GithubException
from jwt import encode
from loguru import logger
from urllib3 import Retry

from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BASE_URL, GITHUB_BOT_USERNAME
from sweepai.core.entities import FileChangeRequest
from sweepai.utils.str_utils import get_hash
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
Expand Down Expand Up @@ -91,12 +99,45 @@ def get_app():
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()

class CustomRequester(Requester):
def __init__(self, token, timeout=15, user_agent="PyGithub/Python", per_page=30, verify=True, retry=None, pool_size=None, installation_id=None) -> "CustomRequester":
self.token = token
self.installation_id = installation_id
base_url = GITHUB_BASE_URL
auth = Token(token)
retry = Retry(total=3,) # 3 retries
super().__init__(auth=auth, base_url=base_url, timeout=timeout, user_agent=user_agent, per_page=per_page, verify=verify, retry=retry, pool_size=pool_size)

def _refresh_token(self):
self.token = get_token(self.installation_id)
self._Requester__authorizationHeader = f"token {self.token}"

def requestJsonAndCheck(self, *args, **kwargs): # more endpoints like these may need to be added
try:
return super().requestJsonAndCheck(*args, **kwargs)
except (BadCredentialsException, UnknownObjectException):
self._refresh_token()
return super().requestJsonAndCheck(*args, **kwargs)

class CustomGithub(Github):
def __init__(self, installation_id: int, *args, **kwargs) -> "CustomGithub":
self.installation_id = installation_id
self.token = self._get_token()
super().__init__(self.token, *args, **kwargs)
self._Github__requester = CustomRequester(self.token, installation_id=self.installation_id)

def _get_token(self) -> str:
if not self.installation_id:
return os.environ["GITHUB_PAT"]
return get_token(self.installation_id)

def get_github_client(installation_id: int) -> tuple[str, Github]:
def get_github_client(installation_id: int) -> tuple[str, CustomGithub]:
github_instance = None
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
github_instance = Github(os.environ["GITHUB_PAT"])
else:
github_instance = CustomGithub(installation_id)
return github_instance.token, github_instance

# fetch installation object
def get_installation(username: str):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_github_request_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sweepai.utils.github_utils import get_token
import os
from github import Github
from github.Auth import Token
# get default_base_url from github
from github.Requester import Requester
from github.GithubException import BadCredentialsException, UnknownObjectException

class CustomRequester(Requester):
def __init__(self, token, timeout=15, user_agent="PyGithub/Python", per_page=30, verify=True, retry=True, pool_size=None, installation_id=None):
self.token = token
self.installation_id = installation_id
base_url = "https://api.github.com"
auth = Token(token)
super().__init__(auth=auth, base_url=base_url, timeout=timeout, user_agent=user_agent, per_page=per_page, verify=verify, retry=retry, pool_size=pool_size)

def _refresh_token(self):
self.token = get_token(self.installation_id)
self._Requester__authorizationHeader = f"token {self.token}"

def requestJsonAndCheck(self, *args, **kwargs):
try:
breakpoint() # test breakpoint to ensure it refreshes credentials
raise BadCredentialsException(status=401, data={"message": "Bad credentials"})
return super().requestJsonAndCheck(*args, **kwargs)
except (BadCredentialsException, UnknownObjectException):
self._refresh_token()
return super().requestJsonAndCheck(*args, **kwargs)

class CustomGithub(Github):
def __init__(self, installation_id: int, *args, **kwargs):
self.installation_id = installation_id
self.token = self._get_token()
super().__init__(self.token, *args, **kwargs)
self._Github__requester = CustomRequester(self.token, installation_id=self.installation_id)

def _get_token(self) -> str:
if not self.installation_id:
return os.environ["GITHUB_PAT"]
return get_token(self.installation_id)

def get_github_client(installation_id: int) -> tuple[str, CustomGithub]:
github_instance = CustomGithub(installation_id)
return github_instance.token, github_instance

installation_id = None
user_token, g = get_github_client(installation_id)
repo = g.get_repo("sweepai/sweep")

0 comments on commit 4b71abd

Please sign in to comment.