From 4b71abdf369a56b9e173d43e262a62511d6931bf Mon Sep 17 00:00:00 2001
From: William Zeng <44910023+wwzeng1@users.noreply.github.com>
Date: Mon, 27 May 2024 10:19:31 -0700
Subject: [PATCH] Feat/reliable GitHub wrapper (#3900)

---
 sweepai/config/server.py               |  6 +--
 sweepai/handlers/on_ticket.py          |  6 +--
 sweepai/utils/github_utils.py          | 53 +++++++++++++++++++++++---
 tests/test_github_request_exception.py | 48 +++++++++++++++++++++++
 4 files changed, 99 insertions(+), 14 deletions(-)
 create mode 100644 tests/test_github_request_exception.py

diff --git a/sweepai/config/server.py b/sweepai/config/server.py
index 54a3e8e4c4..db45766dee 100644
--- a/sweepai/config/server.py
+++ b/sweepai/config/server.py
@@ -24,11 +24,7 @@
 
 BOT_TOKEN_NAME = "bot-token"
 
-# goes under Modal 'discord' secret name (optional, can leave env var blank)
-DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
-DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
-DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
-DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
+GITHUB_BASE_URL = os.environ.get("GITHUB_BASE_URL", "https://api.github.com") # configure for enterprise
 
 SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
 DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
diff --git a/sweepai/handlers/on_ticket.py b/sweepai/handlers/on_ticket.py
index 603c4614b1..4a1c1bed7c 100644
--- a/sweepai/handlers/on_ticket.py
+++ b/sweepai/handlers/on_ticket.py
@@ -308,8 +308,8 @@ def edit_sweep_comment(
                         + "\n"
                         + message
                         + (
-                            "\n\nFor bonus Sweep issues, please report this bug on our"
-                            f" **[community forum](https://community.sweep.dev/)** (tracking ID: `{tracking_id}`)."
+                            "\n\n"
+                            " **[Report a bug](https://community.sweep.dev/)**."
                             if add_bonus_message
                             else ""
                         )
@@ -383,7 +383,7 @@ def edit_sweep_comment(
                 logger.warning(f"Validation error: {error_message}")
                 edit_sweep_comment(
                     (
-                        f"The issue was rejected with the following response:\n\n{bold(error_message)}"
+                        f"\n\n{bold(error_message)}"
                     ),
                     -1,
                 )
diff --git a/sweepai/utils/github_utils.py b/sweepai/utils/github_utils.py
index c15f5a0619..d61f464c30 100644
--- a/sweepai/utils/github_utils.py
+++ b/sweepai/utils/github_utils.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import copy
 import datetime
 import difflib
@@ -16,12 +18,18 @@
 
 import git
 import requests
-from github import Github, PullRequest, Repository, InputGitTreeElement, GithubException
+from github import Github
+from github.Auth import Token
+# get default_base_url from github
+from github.Requester import Requester
+from github.GithubException import BadCredentialsException, UnknownObjectException
+from github import PullRequest, Repository, InputGitTreeElement, GithubException
 from jwt import encode
 from loguru import logger
+from urllib3 import Retry
 
 from sweepai.config.client import SweepConfig
-from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
+from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BASE_URL, GITHUB_BOT_USERNAME
 from sweepai.core.entities import FileChangeRequest
 from sweepai.utils.str_utils import get_hash
 from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
@@ -91,12 +99,45 @@ def get_app():
     response = requests.get("https://api.github.com/app", headers=headers)
     return response.json()
 
+class CustomRequester(Requester):
+    def __init__(self, token, timeout=15, user_agent="PyGithub/Python", per_page=30, verify=True, retry=None, pool_size=None, installation_id=None) -> "CustomRequester":
+        self.token = token
+        self.installation_id = installation_id
+        base_url = GITHUB_BASE_URL
+        auth = Token(token)
+        retry = Retry(total=3,) # 3 retries
+        super().__init__(auth=auth, base_url=base_url, timeout=timeout, user_agent=user_agent, per_page=per_page, verify=verify, retry=retry, pool_size=pool_size)
+
+    def _refresh_token(self):
+        self.token = get_token(self.installation_id)
+        self._Requester__authorizationHeader = f"token {self.token}"
+
+    def requestJsonAndCheck(self, *args, **kwargs): # more endpoints like these may need to be added
+        try:
+            return super().requestJsonAndCheck(*args, **kwargs)
+        except (BadCredentialsException, UnknownObjectException):
+            self._refresh_token()
+            return super().requestJsonAndCheck(*args, **kwargs)
+
+class CustomGithub(Github):
+    def __init__(self, installation_id: int, *args, **kwargs) -> "CustomGithub":
+        self.installation_id = installation_id
+        self.token = self._get_token()
+        super().__init__(self.token, *args, **kwargs)
+        self._Github__requester = CustomRequester(self.token, installation_id=self.installation_id)
+
+    def _get_token(self) -> str:
+        if not self.installation_id:
+            return os.environ["GITHUB_PAT"]
+        return get_token(self.installation_id)
 
-def get_github_client(installation_id: int) -> tuple[str, Github]:
+def get_github_client(installation_id: int) -> tuple[str, CustomGithub]:
+    github_instance = None
     if not installation_id:
-        return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
-    token: str = get_token(installation_id)
-    return token, Github(token)
+        github_instance = Github(os.environ["GITHUB_PAT"])
+    else:
+        github_instance = CustomGithub(installation_id)
+    return github_instance.token, github_instance
 
 # fetch installation object
 def get_installation(username: str): 
diff --git a/tests/test_github_request_exception.py b/tests/test_github_request_exception.py
new file mode 100644
index 0000000000..0e8db4e6e4
--- /dev/null
+++ b/tests/test_github_request_exception.py
@@ -0,0 +1,48 @@
+from sweepai.utils.github_utils import get_token
+import os
+from github import Github
+from github.Auth import Token
+# get default_base_url from github
+from github.Requester import Requester
+from github.GithubException import BadCredentialsException, UnknownObjectException
+
+class CustomRequester(Requester):
+    def __init__(self, token, timeout=15, user_agent="PyGithub/Python", per_page=30, verify=True, retry=True, pool_size=None, installation_id=None):
+        self.token = token
+        self.installation_id = installation_id
+        base_url = "https://api.github.com"
+        auth = Token(token)
+        super().__init__(auth=auth, base_url=base_url, timeout=timeout, user_agent=user_agent, per_page=per_page, verify=verify, retry=retry, pool_size=pool_size)
+
+    def _refresh_token(self):
+        self.token = get_token(self.installation_id)
+        self._Requester__authorizationHeader = f"token {self.token}"
+
+    def requestJsonAndCheck(self, *args, **kwargs):
+        try:
+            breakpoint() # test breakpoint to ensure it refreshes credentials
+            raise BadCredentialsException(status=401, data={"message": "Bad credentials"})
+            return super().requestJsonAndCheck(*args, **kwargs)
+        except (BadCredentialsException, UnknownObjectException):
+            self._refresh_token()
+            return super().requestJsonAndCheck(*args, **kwargs)
+
+class CustomGithub(Github):
+    def __init__(self, installation_id: int, *args, **kwargs):
+        self.installation_id = installation_id
+        self.token = self._get_token()
+        super().__init__(self.token, *args, **kwargs)
+        self._Github__requester = CustomRequester(self.token, installation_id=self.installation_id)
+
+    def _get_token(self) -> str:
+        if not self.installation_id:
+            return os.environ["GITHUB_PAT"]
+        return get_token(self.installation_id)
+
+def get_github_client(installation_id: int) -> tuple[str, CustomGithub]:
+    github_instance = CustomGithub(installation_id)
+    return github_instance.token, github_instance
+
+installation_id = None
+user_token, g = get_github_client(installation_id)
+repo = g.get_repo("sweepai/sweep")
\ No newline at end of file