Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep: allow for rebase. #3284

Open
1 task done
wwzeng1 opened this issue Mar 13, 2024 · 15 comments · May be fixed by #3455, #3456, #3457, #3499 or #3498
Open
1 task done

Sweep: allow for rebase. #3284

wwzeng1 opened this issue Mar 13, 2024 · 15 comments · May be fixed by #3455, #3456, #3457, #3499 or #3498
Labels
sweep Assigns Sweep to an issue or pull request.

Comments

@wwzeng1
Copy link
Contributor

wwzeng1 commented Mar 13, 2024

Details

Quote from user:

I've noticed that Sweep merges the target branch into the branch it has created when the target branch is updated. Nice, that Sweep detects and handles it automatically, but I would prefer to perform rebase instead of adding a merge commit. Is there such an option now?

Branch

No response

Checklist
@wwzeng1 wwzeng1 added the sweep Assigns Sweep to an issue or pull request. label Mar 13, 2024
Copy link
Contributor

sweep-nightly bot commented Apr 5, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: 019bbe84fc).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

@sweepai sweepai deleted a comment from sweep-nightly bot Apr 5, 2024
@sweepai sweepai deleted a comment from sweep-nightly bot Apr 5, 2024
Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: 4b21ad1b3d).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: 5529a92db2).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024

Sweeping

✨ Track Sweep's progress on our progress dashboard!


25%

💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 0bc8dd4c44)

Tip

I can email you when I complete this pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

sweep/sweepai/api.py

Lines 1 to 1185 in 0643263

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
Depends,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from prometheus_fastapi_instrumentator import Instrumentator
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
if not IS_SELF_HOSTED:
Instrumentator().instrument(app).expose(
app,
should_gzip=False,
endpoint="/metrics",
include_in_schema=True,
tags=["metrics"],
dependencies=[Depends(auth_metrics)],
)
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
# delayed_kill_thread = threading.Thread(target=delayed_kill, args=(thread,))
# delayed_kill_thread.start()
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

import re
import traceback
from typing import TypeVar
from sweepai.config.server import DEFAULT_GPT4_32K_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message, RegexMatchableBaseModel
from loguru import logger
system_prompt = """You are a brilliant and meticulous engineer assigned to review the following commit diffs and make sure the file conforms to the user's rules.
If the diffs do not conform to the rules, we should create a GitHub issue telling the user what changes should be made.
Provide your response in the following format:
<rule_analysis>
- Analysis of each file_diff and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
user_message = """Review the following diffs and make sure they conform to the rules:
{diff}
The rule is: {rule}
Provide your response in the following format:
<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
- Analysis of code diff 2 and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
Self = TypeVar("Self", bound="RegexMatchableBaseModel")
class IssueTitleAndDescription(RegexMatchableBaseModel):
changes_required: bool = False
issue_title: str
issue_description: str
@classmethod
def from_string(cls: type["IssueTitleAndDescription"], string: str, **kwargs) -> "IssueTitleAndDescription":
changes_required_pattern = (
r"""<changes_required>(\n)?(?P<changes_required>.*)</changes_required>"""
)
changes_required_match = re.search(changes_required_pattern, string, re.DOTALL)
changes_required = (
changes_required_match.groupdict()["changes_required"].strip()
if changes_required_match
else None
)
if changes_required and "true" in changes_required.lower():
changes_required = True
else:
changes_required = False
issue_title_pattern = r"""<issue_title>(\n)?(?P<issue_title>.*)</issue_title>"""
issue_title_match = re.search(issue_title_pattern, string, re.DOTALL)
issue_title = (
issue_title_match.groupdict()["issue_title"].strip()
if issue_title_match
else ""
)
issue_description_pattern = (
r"""<issue_description>(\n)?(?P<issue_description>.*)</issue_description>"""
)
issue_description_match = re.search(
issue_description_pattern, string, re.DOTALL
)
issue_description = (
issue_description_match.groupdict()["issue_description"].strip()
if issue_description_match
else ""
)
return cls(
changes_required=changes_required,
issue_title=issue_title,
issue_description=issue_description,
)
class PostMerge(ChatGPT):
def check_for_issues(self, rule, diff) -> tuple[bool, str, str]:
try:
self.messages = [
Message(
role="system",
content=system_prompt.format(rule=rule),
key="system",
)
]
if self.chat_logger and not self.chat_logger.is_paying_user():
raise ValueError("User is not a paying user")
self.model = DEFAULT_GPT4_32K_MODEL
response = self.chat(
user_message.format(
rule=rule,
diff=diff,
)
)
issue_title_and_description = IssueTitleAndDescription.from_string(response)
return (
issue_title_and_description.changes_required,
issue_title_and_description.issue_title,
issue_title_and_description.issue_description,
)
except SystemExit:
raise SystemExit
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return False, "", ""
if __name__ == "__main__":
changes_required_response = """<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
The code diff 1 does not break the rule. There are no docstrings or comments that need to be updated.
- Analysis of code diff 2 and whether it breaks the rule
The code diff 2 breaks the rule. There is a commented out code block that should be removed.
</rule_analysis>
<changes_required>
True if the rule is broken, False otherwise
True
</changes_required>
<issue_title>
Outdated Commented Code Block in plan-list.blade.php
</issue_title>
<issue_description>
There is an outdated commented out code block in the file `resources/views/livewire/plan-list.blade.php` that should be removed. The code block starts at line 104 and ends at line 110. Please remove this code block as it is no longer needed.
Please refer to the file `resources/views/livewire/plan-list.blade.php` and remove the commented out code block starting at line 104 and ending at line 110.
</issue_description>"""

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to ake this dynamic + backoff
BATCH_SIZE = int(

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int):
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

import re
from dataclasses import dataclass
from functools import lru_cache
from rapidfuzz import fuzz
from tqdm import tqdm
from sweepai.logn import file_cache
from loguru import logger
@lru_cache()
def score_line(str1: str, str2: str) -> float:
if str1 == str2:
return 100
if str1.lstrip() == str2.lstrip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 90 - whitespace_ratio * 10
return max(score, 0)
if str1.strip() == str2.strip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 80 - whitespace_ratio * 10
return max(score, 0)
levenshtein_ratio = fuzz.ratio(str1, str2)
score = 85 * (levenshtein_ratio / 100)
return max(score, 0)
def match_without_whitespace(str1: str, str2: str) -> bool:
return str1.strip() == str2.strip()
def line_cost(line: str) -> float:
if line.strip() == "":
return 50
if line.strip().startswith("#") or line.strip().startswith("//"):
return 50 + len(line) / (len(line) + 1) * 30
return len(line) / (len(line) + 1) * 100
def score_multiline(query: list[str], target: list[str]) -> float:
# TODO: add weighting on first and last lines
q, t = 0, 0 # indices for query and target
scores: list[tuple[float, float]] = []
skipped_comments = 0
def get_weight(q: int) -> float:
# Prefers lines at beginning and end of query
# Sequence: 1, 2/3, 1/2, 2/5...
index = min(q, len(query) - q)
return 100 / (index / 2 + 1)
while q < len(query) and t < len(target):
q_line = query[q]
t_line = target[t]
weight = get_weight(q)
if match_without_whitespace(q_line, t_line):
# Case 1: lines match
scores.append((score_line(q_line, t_line), weight))
q += 1
t += 1
elif q_line.strip().startswith("...") or q_line.strip().endswith("..."):
# Case 3: ellipsis wildcard
t += 1
if q + 1 == len(query):
scores.append((100 - (len(target) - t), weight))
q += 1
t = len(target)
break
max_score = 0
# Radix optimization
indices = [
t + i
for i, line in enumerate(target[t:])
if match_without_whitespace(line, query[q + 1])
]
if not indices:
# logger.warning(f"Could not find whitespace match, using brute force")
indices = range(t, len(target))
for i in indices:
score, weight = score_multiline(query[q + 1 :], target[i:]), (
100 - (i - t) / len(target) * 10
)
new_scores = scores + [(score, weight)]
total_score = sum(
[value * weight for value, weight in new_scores]
) / sum([weight for _, weight in new_scores])
max_score = max(max_score, total_score)
return max_score
elif (
t_line.strip() == ""
or t_line.strip().startswith("#")
or t_line.strip().startswith("//")
or t_line.strip().startswith("print")
or t_line.strip().startswith("logger")
or t_line.strip().startswith("console.")
):
# Case 2: skipped comment
skipped_comments += 1
t += 1
scores.append((90, weight))
else:
break
if q < len(query):
scores.extend(
(100 - line_cost(line), get_weight(index))
for index, line in enumerate(query[q:])
)
if t < len(target):
scores.extend(
(100 - line_cost(line), 100) for index, line in enumerate(target[t:])
)
final_score = (
sum([value * weight for value, weight in scores])
/ sum([weight for _, weight in scores])
if scores
else 0
)
final_score *= 1 - 0.05 * skipped_comments
return final_score
@dataclass
class Match:
start: int
end: int
score: float
indent: str = ""
def __gt__(self, other):
return self.score > other.score
def get_indent_type(content: str):
two_spaces = len(re.findall(r"\n {2}[^ ]", content))
four_spaces = len(re.findall(r"\n {4}[^ ]", content))
return " " if two_spaces > four_spaces else " "
def get_max_indent(content: str, indent_type: str):
return max(len(line) - len(line.lstrip()) for line in content.split("\n")) // len(
indent_type
)
@file_cache()
def find_best_match(query: str, code_file: str):
best_match = Match(-1, -1, 0)
code_file_lines = code_file.split("\n")
query_lines = query.split("\n")
if len(query_lines) > 0 and query_lines[-1].strip() == "...":
query_lines = query_lines[:-1]
if len(query_lines) > 0 and query_lines[0].strip() == "...":
query_lines = query_lines[1:]
indent = get_indent_type(code_file)
max_indents = get_max_indent(code_file, indent)
top_matches = []
if len(query_lines) == 1:
for i, line in enumerate(code_file_lines):
score = score_line(line, query_lines[0])
if score > best_match.score:
best_match = Match(i, i + 1, score)
return best_match
truncate = min(40, len(code_file_lines) // 5)
if truncate < 1:
truncate = len(code_file_lines)
indent_array = [i for i in range(0, max(min(max_indents + 1, 20), 1))]
if max_indents > 3:
indent_array = [3, 2, 4, 0, 1] + list(range(5, max_indents + 1))
for num_indents in indent_array:
indented_query_lines = [indent * num_indents + line for line in query_lines]
start_pairs = [
(i, score_line(line, indented_query_lines[0]))
for i, line in enumerate(code_file_lines)
]
start_pairs.sort(key=lambda x: x[1], reverse=True)
start_pairs = start_pairs[:truncate]
start_indices = [i for i, _ in start_pairs]
for i in tqdm(
start_indices,
position=0,
desc=f"Indent {num_indents}/{max_indents}",
leave=False,
):
end_pairs = [
(j, score_line(line, indented_query_lines[-1]))
for j, line in enumerate(code_file_lines[i:], start=i)
]
end_pairs.sort(key=lambda x: x[1], reverse=True)
end_pairs = end_pairs[:truncate]
end_indices = [j for j, _ in end_pairs]
for j in tqdm(
end_indices, position=1, leave=False, desc=f"Starting line {i}"
):
candidate = code_file_lines[i : j + 1]
raw_score = score_multiline(indented_query_lines, candidate)
score = raw_score * (1 - num_indents * 0.01)
current_match = Match(i, j + 1, score, indent * num_indents)
if raw_score >= 99.99: # early exit, 99.99 for floating point error
logger.info(f"Exact match found! Returning: {current_match}")
return current_match
top_matches.append(current_match)
if score > best_match.score:
best_match = current_match
unique_top_matches: list[Match] = []
unique_spans = set()
for top_match in sorted(top_matches, reverse=True):
if (top_match.start, top_match.end) not in unique_spans:
unique_top_matches.append(top_match)
unique_spans.add((top_match.start, top_match.end))
for top_match in unique_top_matches[:5]:
logger.print(top_match)
# Todo: on_comment file comments able to modify multiple files
return unique_top_matches[0] if unique_top_matches else Match(-1, -1, 0)
def split_ellipses(query: str) -> list[str]:
queries = []
current_query = ""
for line in query.split("\n"):
if line.strip() == "...":
queries.append(current_query.strip("\n"))
current_query = ""
else:
current_query += line + "\n"
queries.append(current_query.strip("\n"))
return queries
def match_indent(generated: str, original: str) -> str:
indent_type = "\t" if "\t" in original[:5] else " "
generated_indents = len(generated) - len(generated.lstrip())
target_indents = len(original) - len(original.lstrip())
diff_indents = target_indents - generated_indents
if diff_indents > 0:
generated = indent_type * diff_indents + generated.replace(
"\n", "\n" + indent_type * diff_indents
)
return generated
old_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
new_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import hashlib
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
# print(match_indent(new_code, old_code))
test_code = """\
def naive_euclidean_profile(X, q, mask):
r\"\"\"
Compute a euclidean distance profile in a brute force way.
A distance profile between a (univariate) time series :math:`X_i = {x_1, ..., x_m}`
and a query :math:`Q = {q_1, ..., q_m}` is defined as a vector of size :math:`m-(
l-1)`, such as :math:`P(X_i, Q) = {d(C_1, Q), ..., d(C_m-(l-1), Q)}` with d the
Euclidean distance, and :math:`C_j = {x_j, ..., x_{j+(l-1)}}` the j-th candidate
subsequence of size :math:`l` in :math:`X_i`.
\"\"\"
return _naive_euclidean_profile(X, q, mask)
"""
if __name__ == "__main__":
# for section in split_ellipses(test_code):
# print(section)
code_file = r"""
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
if check_button_title_match(REVERT_CHANGED_FILES_TITLE, request.comment.body, request.changes):
revert_files = []
for button_text in selected_buttons:
revert_files.append(button_text.split(f"{RESET_FILE} ")[-1].strip())
handle_revert(revert_files, request_dict["issue"]["number"], repo)
comment.edit(
body=ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title = REVERT_CHANGED_FILES_TITLE,
).serialize()
)
"""
# Sample target snippet
target = """
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
...
""".strip(
"\n"
)
# Find the best match
# best_span = find_best_match(target, code_file)
best_span = find_best_match("a\nb", "a\nb")


Step 2: ⌨️ Coding

Working on it...


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024

🚀 Here's the PR! #3453

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: bb53a6416d)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

sweep/sweepai/api.py

Lines 1 to 1185 in 0643263

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
Depends,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from prometheus_fastapi_instrumentator import Instrumentator
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
if not IS_SELF_HOSTED:
Instrumentator().instrument(app).expose(
app,
should_gzip=False,
endpoint="/metrics",
include_in_schema=True,
tags=["metrics"],
dependencies=[Depends(auth_metrics)],
)
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
# delayed_kill_thread = threading.Thread(target=delayed_kill, args=(thread,))
# delayed_kill_thread.start()
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

import re
import traceback
from typing import TypeVar
from sweepai.config.server import DEFAULT_GPT4_32K_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message, RegexMatchableBaseModel
from loguru import logger
system_prompt = """You are a brilliant and meticulous engineer assigned to review the following commit diffs and make sure the file conforms to the user's rules.
If the diffs do not conform to the rules, we should create a GitHub issue telling the user what changes should be made.
Provide your response in the following format:
<rule_analysis>
- Analysis of each file_diff and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
user_message = """Review the following diffs and make sure they conform to the rules:
{diff}
The rule is: {rule}
Provide your response in the following format:
<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
- Analysis of code diff 2 and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
Self = TypeVar("Self", bound="RegexMatchableBaseModel")
class IssueTitleAndDescription(RegexMatchableBaseModel):
changes_required: bool = False
issue_title: str
issue_description: str
@classmethod
def from_string(cls: type["IssueTitleAndDescription"], string: str, **kwargs) -> "IssueTitleAndDescription":
changes_required_pattern = (
r"""<changes_required>(\n)?(?P<changes_required>.*)</changes_required>"""
)
changes_required_match = re.search(changes_required_pattern, string, re.DOTALL)
changes_required = (
changes_required_match.groupdict()["changes_required"].strip()
if changes_required_match
else None
)
if changes_required and "true" in changes_required.lower():
changes_required = True
else:
changes_required = False
issue_title_pattern = r"""<issue_title>(\n)?(?P<issue_title>.*)</issue_title>"""
issue_title_match = re.search(issue_title_pattern, string, re.DOTALL)
issue_title = (
issue_title_match.groupdict()["issue_title"].strip()
if issue_title_match
else ""
)
issue_description_pattern = (
r"""<issue_description>(\n)?(?P<issue_description>.*)</issue_description>"""
)
issue_description_match = re.search(
issue_description_pattern, string, re.DOTALL
)
issue_description = (
issue_description_match.groupdict()["issue_description"].strip()
if issue_description_match
else ""
)
return cls(
changes_required=changes_required,
issue_title=issue_title,
issue_description=issue_description,
)
class PostMerge(ChatGPT):
def check_for_issues(self, rule, diff) -> tuple[bool, str, str]:
try:
self.messages = [
Message(
role="system",
content=system_prompt.format(rule=rule),
key="system",
)
]
if self.chat_logger and not self.chat_logger.is_paying_user():
raise ValueError("User is not a paying user")
self.model = DEFAULT_GPT4_32K_MODEL
response = self.chat(
user_message.format(
rule=rule,
diff=diff,
)
)
issue_title_and_description = IssueTitleAndDescription.from_string(response)
return (
issue_title_and_description.changes_required,
issue_title_and_description.issue_title,
issue_title_and_description.issue_description,
)
except SystemExit:
raise SystemExit
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return False, "", ""
if __name__ == "__main__":
changes_required_response = """<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
The code diff 1 does not break the rule. There are no docstrings or comments that need to be updated.
- Analysis of code diff 2 and whether it breaks the rule
The code diff 2 breaks the rule. There is a commented out code block that should be removed.
</rule_analysis>
<changes_required>
True if the rule is broken, False otherwise
True
</changes_required>
<issue_title>
Outdated Commented Code Block in plan-list.blade.php
</issue_title>
<issue_description>
There is an outdated commented out code block in the file `resources/views/livewire/plan-list.blade.php` that should be removed. The code block starts at line 104 and ends at line 110. Please remove this code block as it is no longer needed.
Please refer to the file `resources/views/livewire/plan-list.blade.php` and remove the commented out code block starting at line 104 and ending at line 110.
</issue_description>"""

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to ake this dynamic + backoff
BATCH_SIZE = int(

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int):
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

import re
from dataclasses import dataclass
from functools import lru_cache
from rapidfuzz import fuzz
from tqdm import tqdm
from sweepai.logn import file_cache
from loguru import logger
@lru_cache()
def score_line(str1: str, str2: str) -> float:
if str1 == str2:
return 100
if str1.lstrip() == str2.lstrip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 90 - whitespace_ratio * 10
return max(score, 0)
if str1.strip() == str2.strip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 80 - whitespace_ratio * 10
return max(score, 0)
levenshtein_ratio = fuzz.ratio(str1, str2)
score = 85 * (levenshtein_ratio / 100)
return max(score, 0)
def match_without_whitespace(str1: str, str2: str) -> bool:
return str1.strip() == str2.strip()
def line_cost(line: str) -> float:
if line.strip() == "":
return 50
if line.strip().startswith("#") or line.strip().startswith("//"):
return 50 + len(line) / (len(line) + 1) * 30
return len(line) / (len(line) + 1) * 100
def score_multiline(query: list[str], target: list[str]) -> float:
# TODO: add weighting on first and last lines
q, t = 0, 0 # indices for query and target
scores: list[tuple[float, float]] = []
skipped_comments = 0
def get_weight(q: int) -> float:
# Prefers lines at beginning and end of query
# Sequence: 1, 2/3, 1/2, 2/5...
index = min(q, len(query) - q)
return 100 / (index / 2 + 1)
while q < len(query) and t < len(target):
q_line = query[q]
t_line = target[t]
weight = get_weight(q)
if match_without_whitespace(q_line, t_line):
# Case 1: lines match
scores.append((score_line(q_line, t_line), weight))
q += 1
t += 1
elif q_line.strip().startswith("...") or q_line.strip().endswith("..."):
# Case 3: ellipsis wildcard
t += 1
if q + 1 == len(query):
scores.append((100 - (len(target) - t), weight))
q += 1
t = len(target)
break
max_score = 0
# Radix optimization
indices = [
t + i
for i, line in enumerate(target[t:])
if match_without_whitespace(line, query[q + 1])
]
if not indices:
# logger.warning(f"Could not find whitespace match, using brute force")
indices = range(t, len(target))
for i in indices:
score, weight = score_multiline(query[q + 1 :], target[i:]), (
100 - (i - t) / len(target) * 10
)
new_scores = scores + [(score, weight)]
total_score = sum(
[value * weight for value, weight in new_scores]
) / sum([weight for _, weight in new_scores])
max_score = max(max_score, total_score)
return max_score
elif (
t_line.strip() == ""
or t_line.strip().startswith("#")
or t_line.strip().startswith("//")
or t_line.strip().startswith("print")
or t_line.strip().startswith("logger")
or t_line.strip().startswith("console.")
):
# Case 2: skipped comment
skipped_comments += 1
t += 1
scores.append((90, weight))
else:
break
if q < len(query):
scores.extend(
(100 - line_cost(line), get_weight(index))
for index, line in enumerate(query[q:])
)
if t < len(target):
scores.extend(
(100 - line_cost(line), 100) for index, line in enumerate(target[t:])
)
final_score = (
sum([value * weight for value, weight in scores])
/ sum([weight for _, weight in scores])
if scores
else 0
)
final_score *= 1 - 0.05 * skipped_comments
return final_score
@dataclass
class Match:
start: int
end: int
score: float
indent: str = ""
def __gt__(self, other):
return self.score > other.score
def get_indent_type(content: str):
two_spaces = len(re.findall(r"\n {2}[^ ]", content))
four_spaces = len(re.findall(r"\n {4}[^ ]", content))
return " " if two_spaces > four_spaces else " "
def get_max_indent(content: str, indent_type: str):
return max(len(line) - len(line.lstrip()) for line in content.split("\n")) // len(
indent_type
)
@file_cache()
def find_best_match(query: str, code_file: str):
best_match = Match(-1, -1, 0)
code_file_lines = code_file.split("\n")
query_lines = query.split("\n")
if len(query_lines) > 0 and query_lines[-1].strip() == "...":
query_lines = query_lines[:-1]
if len(query_lines) > 0 and query_lines[0].strip() == "...":
query_lines = query_lines[1:]
indent = get_indent_type(code_file)
max_indents = get_max_indent(code_file, indent)
top_matches = []
if len(query_lines) == 1:
for i, line in enumerate(code_file_lines):
score = score_line(line, query_lines[0])
if score > best_match.score:
best_match = Match(i, i + 1, score)
return best_match
truncate = min(40, len(code_file_lines) // 5)
if truncate < 1:
truncate = len(code_file_lines)
indent_array = [i for i in range(0, max(min(max_indents + 1, 20), 1))]
if max_indents > 3:
indent_array = [3, 2, 4, 0, 1] + list(range(5, max_indents + 1))
for num_indents in indent_array:
indented_query_lines = [indent * num_indents + line for line in query_lines]
start_pairs = [
(i, score_line(line, indented_query_lines[0]))
for i, line in enumerate(code_file_lines)
]
start_pairs.sort(key=lambda x: x[1], reverse=True)
start_pairs = start_pairs[:truncate]
start_indices = [i for i, _ in start_pairs]
for i in tqdm(
start_indices,
position=0,
desc=f"Indent {num_indents}/{max_indents}",
leave=False,
):
end_pairs = [
(j, score_line(line, indented_query_lines[-1]))
for j, line in enumerate(code_file_lines[i:], start=i)
]
end_pairs.sort(key=lambda x: x[1], reverse=True)
end_pairs = end_pairs[:truncate]
end_indices = [j for j, _ in end_pairs]
for j in tqdm(
end_indices, position=1, leave=False, desc=f"Starting line {i}"
):
candidate = code_file_lines[i : j + 1]
raw_score = score_multiline(indented_query_lines, candidate)
score = raw_score * (1 - num_indents * 0.01)
current_match = Match(i, j + 1, score, indent * num_indents)
if raw_score >= 99.99: # early exit, 99.99 for floating point error
logger.info(f"Exact match found! Returning: {current_match}")
return current_match
top_matches.append(current_match)
if score > best_match.score:
best_match = current_match
unique_top_matches: list[Match] = []
unique_spans = set()
for top_match in sorted(top_matches, reverse=True):
if (top_match.start, top_match.end) not in unique_spans:
unique_top_matches.append(top_match)
unique_spans.add((top_match.start, top_match.end))
for top_match in unique_top_matches[:5]:
logger.print(top_match)
# Todo: on_comment file comments able to modify multiple files
return unique_top_matches[0] if unique_top_matches else Match(-1, -1, 0)
def split_ellipses(query: str) -> list[str]:
queries = []
current_query = ""
for line in query.split("\n"):
if line.strip() == "...":
queries.append(current_query.strip("\n"))
current_query = ""
else:
current_query += line + "\n"
queries.append(current_query.strip("\n"))
return queries
def match_indent(generated: str, original: str) -> str:
indent_type = "\t" if "\t" in original[:5] else " "
generated_indents = len(generated) - len(generated.lstrip())
target_indents = len(original) - len(original.lstrip())
diff_indents = target_indents - generated_indents
if diff_indents > 0:
generated = indent_type * diff_indents + generated.replace(
"\n", "\n" + indent_type * diff_indents
)
return generated
old_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
new_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import hashlib
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
# print(match_indent(new_code, old_code))
test_code = """\
def naive_euclidean_profile(X, q, mask):
r\"\"\"
Compute a euclidean distance profile in a brute force way.
A distance profile between a (univariate) time series :math:`X_i = {x_1, ..., x_m}`
and a query :math:`Q = {q_1, ..., q_m}` is defined as a vector of size :math:`m-(
l-1)`, such as :math:`P(X_i, Q) = {d(C_1, Q), ..., d(C_m-(l-1), Q)}` with d the
Euclidean distance, and :math:`C_j = {x_j, ..., x_{j+(l-1)}}` the j-th candidate
subsequence of size :math:`l` in :math:`X_i`.
\"\"\"
return _naive_euclidean_profile(X, q, mask)
"""
if __name__ == "__main__":
# for section in split_ellipses(test_code):
# print(section)
code_file = r"""
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
if check_button_title_match(REVERT_CHANGED_FILES_TITLE, request.comment.body, request.changes):
revert_files = []
for button_text in selected_buttons:
revert_files.append(button_text.split(f"{RESET_FILE} ")[-1].strip())
handle_revert(revert_files, request_dict["issue"]["number"], repo)
comment.edit(
body=ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title = REVERT_CHANGED_FILES_TITLE,
).serialize()
)
"""
# Sample target snippet
target = """
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
...
""".strip(
"\n"
)
# Find the best match
# best_span = find_best_match(target, code_file)
best_span = find_best_match("a\nb", "a\nb")


Step 2: ⌨️ Coding

Modify sweepai/api.py with contents:
• In the `update_sweep_prs_v2` function, find the code block that performs the merge: ```python repo.merge( feature_branch, pr.base.ref, f"Merge main into {feature_branch}", ) ```
• Replace the `repo.merge` call with the following to perform a rebase instead: ```python repo.rebase(pr.base.ref, feature_branch) ```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
  • Modify sweepai/utils/github_utils.py ! No changes made 3068556 Edit
Modify sweepai/utils/github_utils.py with contents:
• In the `ClonedRepo` class, check if there are any methods involved in the merge process (e.g. in the `clone` method).
• If found, update those methods to use `git rebase` instead of `git merge` when updating the PR branch.
• Ensure the rebase is performed against the `origin/` remote branch.

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_4e4d7.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: dfbd278d30).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: 6aace80a95).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 6, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024


Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

❌ Unable to Complete PR

I'm sorry, but it looks like an error has occurred due to a planning failure. Feel free to add more details to the issue description so Sweep can better address it. Alternatively, reach out to Kevin or William for help at https://discord.gg/sweep.

For bonus GPT-4 tickets, please report this bug on Discord (tracking ID: d3dbd2bab8).


Please look at the generated plan. If something looks wrong, please add more details to your issue.

File Path Proposed Changes
sweepai/api.py Modify sweepai/api.py with contents:
• In the update_sweep_prs_v2 function, find the code block that performs the merge:
```python
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
```
• Replace the repo.merge call with the following to perform a rebase instead:
```python
repo.rebase(pr.base.ref, feature_branch)
```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
sweepai/utils/github_utils.py Modify sweepai/utils/github_utils.py with contents:
• In the ClonedRepo class, check if there are any methods involved in the merge process (e.g. in the clone method).
• If found, update those methods to use git rebase instead of git merge when updating the PR branch.
• Ensure the rebase is performed against the origin/<target_branch> remote branch.

🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 6, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024

🚀 Here's the PR! #3457

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 222741c235)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

sweep/sweepai/api.py

Lines 1 to 1185 in 0643263

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
Depends,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from prometheus_fastapi_instrumentator import Instrumentator
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
if not IS_SELF_HOSTED:
Instrumentator().instrument(app).expose(
app,
should_gzip=False,
endpoint="/metrics",
include_in_schema=True,
tags=["metrics"],
dependencies=[Depends(auth_metrics)],
)
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
# delayed_kill_thread = threading.Thread(target=delayed_kill, args=(thread,))
# delayed_kill_thread.start()
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

import re
import traceback
from typing import TypeVar
from sweepai.config.server import DEFAULT_GPT4_32K_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message, RegexMatchableBaseModel
from loguru import logger
system_prompt = """You are a brilliant and meticulous engineer assigned to review the following commit diffs and make sure the file conforms to the user's rules.
If the diffs do not conform to the rules, we should create a GitHub issue telling the user what changes should be made.
Provide your response in the following format:
<rule_analysis>
- Analysis of each file_diff and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
user_message = """Review the following diffs and make sure they conform to the rules:
{diff}
The rule is: {rule}
Provide your response in the following format:
<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
- Analysis of code diff 2 and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
Self = TypeVar("Self", bound="RegexMatchableBaseModel")
class IssueTitleAndDescription(RegexMatchableBaseModel):
changes_required: bool = False
issue_title: str
issue_description: str
@classmethod
def from_string(cls: type["IssueTitleAndDescription"], string: str, **kwargs) -> "IssueTitleAndDescription":
changes_required_pattern = (
r"""<changes_required>(\n)?(?P<changes_required>.*)</changes_required>"""
)
changes_required_match = re.search(changes_required_pattern, string, re.DOTALL)
changes_required = (
changes_required_match.groupdict()["changes_required"].strip()
if changes_required_match
else None
)
if changes_required and "true" in changes_required.lower():
changes_required = True
else:
changes_required = False
issue_title_pattern = r"""<issue_title>(\n)?(?P<issue_title>.*)</issue_title>"""
issue_title_match = re.search(issue_title_pattern, string, re.DOTALL)
issue_title = (
issue_title_match.groupdict()["issue_title"].strip()
if issue_title_match
else ""
)
issue_description_pattern = (
r"""<issue_description>(\n)?(?P<issue_description>.*)</issue_description>"""
)
issue_description_match = re.search(
issue_description_pattern, string, re.DOTALL
)
issue_description = (
issue_description_match.groupdict()["issue_description"].strip()
if issue_description_match
else ""
)
return cls(
changes_required=changes_required,
issue_title=issue_title,
issue_description=issue_description,
)
class PostMerge(ChatGPT):
def check_for_issues(self, rule, diff) -> tuple[bool, str, str]:
try:
self.messages = [
Message(
role="system",
content=system_prompt.format(rule=rule),
key="system",
)
]
if self.chat_logger and not self.chat_logger.is_paying_user():
raise ValueError("User is not a paying user")
self.model = DEFAULT_GPT4_32K_MODEL
response = self.chat(
user_message.format(
rule=rule,
diff=diff,
)
)
issue_title_and_description = IssueTitleAndDescription.from_string(response)
return (
issue_title_and_description.changes_required,
issue_title_and_description.issue_title,
issue_title_and_description.issue_description,
)
except SystemExit:
raise SystemExit
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return False, "", ""
if __name__ == "__main__":
changes_required_response = """<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
The code diff 1 does not break the rule. There are no docstrings or comments that need to be updated.
- Analysis of code diff 2 and whether it breaks the rule
The code diff 2 breaks the rule. There is a commented out code block that should be removed.
</rule_analysis>
<changes_required>
True if the rule is broken, False otherwise
True
</changes_required>
<issue_title>
Outdated Commented Code Block in plan-list.blade.php
</issue_title>
<issue_description>
There is an outdated commented out code block in the file `resources/views/livewire/plan-list.blade.php` that should be removed. The code block starts at line 104 and ends at line 110. Please remove this code block as it is no longer needed.
Please refer to the file `resources/views/livewire/plan-list.blade.php` and remove the commented out code block starting at line 104 and ending at line 110.
</issue_description>"""

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to ake this dynamic + backoff
BATCH_SIZE = int(

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int):
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

import re
from dataclasses import dataclass
from functools import lru_cache
from rapidfuzz import fuzz
from tqdm import tqdm
from sweepai.logn import file_cache
from loguru import logger
@lru_cache()
def score_line(str1: str, str2: str) -> float:
if str1 == str2:
return 100
if str1.lstrip() == str2.lstrip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 90 - whitespace_ratio * 10
return max(score, 0)
if str1.strip() == str2.strip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 80 - whitespace_ratio * 10
return max(score, 0)
levenshtein_ratio = fuzz.ratio(str1, str2)
score = 85 * (levenshtein_ratio / 100)
return max(score, 0)
def match_without_whitespace(str1: str, str2: str) -> bool:
return str1.strip() == str2.strip()
def line_cost(line: str) -> float:
if line.strip() == "":
return 50
if line.strip().startswith("#") or line.strip().startswith("//"):
return 50 + len(line) / (len(line) + 1) * 30
return len(line) / (len(line) + 1) * 100
def score_multiline(query: list[str], target: list[str]) -> float:
# TODO: add weighting on first and last lines
q, t = 0, 0 # indices for query and target
scores: list[tuple[float, float]] = []
skipped_comments = 0
def get_weight(q: int) -> float:
# Prefers lines at beginning and end of query
# Sequence: 1, 2/3, 1/2, 2/5...
index = min(q, len(query) - q)
return 100 / (index / 2 + 1)
while q < len(query) and t < len(target):
q_line = query[q]
t_line = target[t]
weight = get_weight(q)
if match_without_whitespace(q_line, t_line):
# Case 1: lines match
scores.append((score_line(q_line, t_line), weight))
q += 1
t += 1
elif q_line.strip().startswith("...") or q_line.strip().endswith("..."):
# Case 3: ellipsis wildcard
t += 1
if q + 1 == len(query):
scores.append((100 - (len(target) - t), weight))
q += 1
t = len(target)
break
max_score = 0
# Radix optimization
indices = [
t + i
for i, line in enumerate(target[t:])
if match_without_whitespace(line, query[q + 1])
]
if not indices:
# logger.warning(f"Could not find whitespace match, using brute force")
indices = range(t, len(target))
for i in indices:
score, weight = score_multiline(query[q + 1 :], target[i:]), (
100 - (i - t) / len(target) * 10
)
new_scores = scores + [(score, weight)]
total_score = sum(
[value * weight for value, weight in new_scores]
) / sum([weight for _, weight in new_scores])
max_score = max(max_score, total_score)
return max_score
elif (
t_line.strip() == ""
or t_line.strip().startswith("#")
or t_line.strip().startswith("//")
or t_line.strip().startswith("print")
or t_line.strip().startswith("logger")
or t_line.strip().startswith("console.")
):
# Case 2: skipped comment
skipped_comments += 1
t += 1
scores.append((90, weight))
else:
break
if q < len(query):
scores.extend(
(100 - line_cost(line), get_weight(index))
for index, line in enumerate(query[q:])
)
if t < len(target):
scores.extend(
(100 - line_cost(line), 100) for index, line in enumerate(target[t:])
)
final_score = (
sum([value * weight for value, weight in scores])
/ sum([weight for _, weight in scores])
if scores
else 0
)
final_score *= 1 - 0.05 * skipped_comments
return final_score
@dataclass
class Match:
start: int
end: int
score: float
indent: str = ""
def __gt__(self, other):
return self.score > other.score
def get_indent_type(content: str):
two_spaces = len(re.findall(r"\n {2}[^ ]", content))
four_spaces = len(re.findall(r"\n {4}[^ ]", content))
return " " if two_spaces > four_spaces else " "
def get_max_indent(content: str, indent_type: str):
return max(len(line) - len(line.lstrip()) for line in content.split("\n")) // len(
indent_type
)
@file_cache()
def find_best_match(query: str, code_file: str):
best_match = Match(-1, -1, 0)
code_file_lines = code_file.split("\n")
query_lines = query.split("\n")
if len(query_lines) > 0 and query_lines[-1].strip() == "...":
query_lines = query_lines[:-1]
if len(query_lines) > 0 and query_lines[0].strip() == "...":
query_lines = query_lines[1:]
indent = get_indent_type(code_file)
max_indents = get_max_indent(code_file, indent)
top_matches = []
if len(query_lines) == 1:
for i, line in enumerate(code_file_lines):
score = score_line(line, query_lines[0])
if score > best_match.score:
best_match = Match(i, i + 1, score)
return best_match
truncate = min(40, len(code_file_lines) // 5)
if truncate < 1:
truncate = len(code_file_lines)
indent_array = [i for i in range(0, max(min(max_indents + 1, 20), 1))]
if max_indents > 3:
indent_array = [3, 2, 4, 0, 1] + list(range(5, max_indents + 1))
for num_indents in indent_array:
indented_query_lines = [indent * num_indents + line for line in query_lines]
start_pairs = [
(i, score_line(line, indented_query_lines[0]))
for i, line in enumerate(code_file_lines)
]
start_pairs.sort(key=lambda x: x[1], reverse=True)
start_pairs = start_pairs[:truncate]
start_indices = [i for i, _ in start_pairs]
for i in tqdm(
start_indices,
position=0,
desc=f"Indent {num_indents}/{max_indents}",
leave=False,
):
end_pairs = [
(j, score_line(line, indented_query_lines[-1]))
for j, line in enumerate(code_file_lines[i:], start=i)
]
end_pairs.sort(key=lambda x: x[1], reverse=True)
end_pairs = end_pairs[:truncate]
end_indices = [j for j, _ in end_pairs]
for j in tqdm(
end_indices, position=1, leave=False, desc=f"Starting line {i}"
):
candidate = code_file_lines[i : j + 1]
raw_score = score_multiline(indented_query_lines, candidate)
score = raw_score * (1 - num_indents * 0.01)
current_match = Match(i, j + 1, score, indent * num_indents)
if raw_score >= 99.99: # early exit, 99.99 for floating point error
logger.info(f"Exact match found! Returning: {current_match}")
return current_match
top_matches.append(current_match)
if score > best_match.score:
best_match = current_match
unique_top_matches: list[Match] = []
unique_spans = set()
for top_match in sorted(top_matches, reverse=True):
if (top_match.start, top_match.end) not in unique_spans:
unique_top_matches.append(top_match)
unique_spans.add((top_match.start, top_match.end))
for top_match in unique_top_matches[:5]:
logger.print(top_match)
# Todo: on_comment file comments able to modify multiple files
return unique_top_matches[0] if unique_top_matches else Match(-1, -1, 0)
def split_ellipses(query: str) -> list[str]:
queries = []
current_query = ""
for line in query.split("\n"):
if line.strip() == "...":
queries.append(current_query.strip("\n"))
current_query = ""
else:
current_query += line + "\n"
queries.append(current_query.strip("\n"))
return queries
def match_indent(generated: str, original: str) -> str:
indent_type = "\t" if "\t" in original[:5] else " "
generated_indents = len(generated) - len(generated.lstrip())
target_indents = len(original) - len(original.lstrip())
diff_indents = target_indents - generated_indents
if diff_indents > 0:
generated = indent_type * diff_indents + generated.replace(
"\n", "\n" + indent_type * diff_indents
)
return generated
old_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
new_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import hashlib
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
# print(match_indent(new_code, old_code))
test_code = """\
def naive_euclidean_profile(X, q, mask):
r\"\"\"
Compute a euclidean distance profile in a brute force way.
A distance profile between a (univariate) time series :math:`X_i = {x_1, ..., x_m}`
and a query :math:`Q = {q_1, ..., q_m}` is defined as a vector of size :math:`m-(
l-1)`, such as :math:`P(X_i, Q) = {d(C_1, Q), ..., d(C_m-(l-1), Q)}` with d the
Euclidean distance, and :math:`C_j = {x_j, ..., x_{j+(l-1)}}` the j-th candidate
subsequence of size :math:`l` in :math:`X_i`.
\"\"\"
return _naive_euclidean_profile(X, q, mask)
"""
if __name__ == "__main__":
# for section in split_ellipses(test_code):
# print(section)
code_file = r"""
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
if check_button_title_match(REVERT_CHANGED_FILES_TITLE, request.comment.body, request.changes):
revert_files = []
for button_text in selected_buttons:
revert_files.append(button_text.split(f"{RESET_FILE} ")[-1].strip())
handle_revert(revert_files, request_dict["issue"]["number"], repo)
comment.edit(
body=ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title = REVERT_CHANGED_FILES_TITLE,
).serialize()
)
"""
# Sample target snippet
target = """
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
...
""".strip(
"\n"
)
# Find the best match
# best_span = find_best_match(target, code_file)
best_span = find_best_match("a\nb", "a\nb")


Step 2: ⌨️ Coding

Modify sweepai/api.py with contents:
• In the `update_sweep_prs_v2` function, find the code block that performs the merge: ```python repo.merge( feature_branch, pr.base.ref, f"Merge main into {feature_branch}", ) ```
• Replace the `repo.merge` call with the following to perform a rebase instead: ```python repo.rebase(pr.base.ref, feature_branch) ```
• Update the commit message to reflect the rebase operation.
• If there are any merge conflicts during the rebase, catch the exception and handle it appropriately (e.g. by closing the PR similar to the existing merge conflict handling).
  • Modify sweepai/utils/github_utils.py ! No changes made 1e1b8c1 Edit
Modify sweepai/utils/github_utils.py with contents:
• In the `ClonedRepo` class, check if there are any methods involved in the merge process (e.g. in the `clone` method).
• If found, update those methods to use `git rebase` instead of `git merge` when updating the PR branch.
• Ensure the rebase is performed against the `origin/` remote branch.

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_ccbe6.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 6, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 6, 2024

Sweeping

✨ Track Sweep's progress on our progress dashboard!


0%

💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: c004ce5405)

Tip

I can email you when I complete this pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

I am currently looking into this ticket! I will update the progress of the ticket in this comment. I am currently searching through your code, looking for relevant snippets.


Step 1: 🔎 Searching

I'm searching for relevant snippets in your repository. If this is your first time using Sweep, I'm indexing your repository. You can monitor the progress using the progress dashboard


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

Copy link
Contributor

sweep-nightly bot commented Apr 8, 2024

🚀 Here's the PR! #3498

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 43ca2f81de)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

"""
create_pr is a function that creates a pull request from a list of file change requests.
It is also responsible for handling Sweep config PR creation. test
"""
import datetime
from typing import Any, Generator
import openai
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import DEFAULT_RULES_STRING, SweepConfig, get_blocked_dirs
from sweepai.config.server import (
ENV,
GITHUB_BOT_USERNAME,
GITHUB_CONFIG_BRANCH,
GITHUB_DEFAULT_CONFIG,
GITHUB_LABEL_NAME,
MONGODB_URI,
)
from sweepai.core.entities import (
FileChangeRequest,
MaxTokensExceeded,
Message,
MockPR,
PullRequest,
)
from sweepai.core.sweep_bot import SweepBot
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.str_utils import UPDATES_MESSAGE
num_of_snippets_to_query = 10
max_num_of_snippets = 5
INSTRUCTIONS_FOR_REVIEW = """\
### 💡 To get Sweep to edit this pull request, you can:
* Comment below, and Sweep can edit the entire PR
* Comment on a file, Sweep will only modify the commented file
* Edit the original issue to get Sweep to recreate the PR from scratch"""
def create_pr_changes(
file_change_requests: list[FileChangeRequest],
pull_request: PullRequest,
sweep_bot: SweepBot,
username: str,
installation_id: int,
issue_number: int | None = None,
chat_logger: ChatLogger = None,
base_branch: str = None,
additional_messages: list[Message] = []
) -> Generator[tuple[FileChangeRequest, int, Any], None, dict]:
# Flow:
# 1. Get relevant files
# 2: Get human message
# 3. Get files to change
# 4. Get file changes
# 5. Create PR
chat_logger = (
chat_logger
if chat_logger is not None
else ChatLogger(
{
"username": username,
"installation_id": installation_id,
"repo_full_name": sweep_bot.repo.full_name,
"title": pull_request.title,
"summary": "",
"issue_url": "",
}
)
if MONGODB_URI
else None
)
sweep_bot.chat_logger = chat_logger
organization, repo_name = sweep_bot.repo.full_name.split("/")
metadata = {
"repo_full_name": sweep_bot.repo.full_name,
"organization": organization,
"repo_name": repo_name,
"repo_description": sweep_bot.repo.description,
"username": username,
"installation_id": installation_id,
"function": "create_pr",
"mode": ENV,
"issue_number": issue_number,
}
posthog.capture(username, "started", properties=metadata)
try:
logger.info("Making PR...")
pull_request.branch_name = sweep_bot.create_branch(
pull_request.branch_name, base_branch=base_branch
)
completed_count, fcr_count = 0, len(file_change_requests)
blocked_dirs = get_blocked_dirs(sweep_bot.repo)
for (
new_file_contents,
changed_file,
commit,
file_change_requests,
) in sweep_bot.change_files_in_github_iterator(
file_change_requests,
pull_request.branch_name,
blocked_dirs,
additional_messages=additional_messages
):
completed_count += len(new_file_contents or [])
logger.info(f"Completed {completed_count}/{fcr_count} files")
yield new_file_contents, changed_file, commit, file_change_requests
if completed_count == 0 and fcr_count != 0:
logger.info("No changes made")
posthog.capture(
username,
"failed",
properties={
"error": "No changes made",
"reason": "No changes made",
**metadata,
},
)
# If no changes were made, delete branch
commits = sweep_bot.repo.get_commits(pull_request.branch_name)
if commits.totalCount == 0:
branch = sweep_bot.repo.get_git_ref(f"heads/{pull_request.branch_name}")
branch.delete()
return
# Include issue number in PR description
if issue_number:
# If the #issue changes, then change on_ticket (f'Fixes #{issue_number}.\n' in pr.body:)
pr_description = (
f"{pull_request.content}\n\nFixes"
f" #{issue_number}.\n\n---\n\n{UPDATES_MESSAGE}\n\n---\n\n{INSTRUCTIONS_FOR_REVIEW}"
)
else:
pr_description = f"{pull_request.content}"
pr_title = pull_request.title
if "sweep.yaml" in pr_title:
pr_title = "[config] " + pr_title
except MaxTokensExceeded as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Max tokens exceeded",
**metadata,
},
)
raise e
except openai.BadRequestError as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Invalid request error / context length",
**metadata,
},
)
raise e
except Exception as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Unexpected error",
**metadata,
},
)
raise e
posthog.capture(username, "success", properties={**metadata})
logger.info("create_pr success")
result = {
"success": True,
"pull_request": MockPR(
file_count=completed_count,
title=pr_title,
body=pr_description,
pr_head=pull_request.branch_name,
base=sweep_bot.repo.get_branch(
SweepConfig.get_branch(sweep_bot.repo)
).commit,
head=sweep_bot.repo.get_branch(pull_request.branch_name).commit,
),
}
yield result # TODO: refactor this as it doesn't need to be an iterator
return
def safe_delete_sweep_branch(
pr, # Github PullRequest
repo: Repository,
) -> bool:
"""
Safely delete Sweep branch
1. Only edited by Sweep
2. Prefixed by sweep/
"""
pr_commits = pr.get_commits()
pr_commit_authors = set([commit.author.login for commit in pr_commits])
# Check if only Sweep has edited the PR, and sweep/ prefix
if (
len(pr_commit_authors) == 1
and GITHUB_BOT_USERNAME in pr_commit_authors
and pr.head.ref.startswith("sweep")
):
branch = repo.get_git_ref(f"heads/{pr.head.ref}")
# pr.edit(state='closed')
branch.delete()
return True
else:
# Failed to delete branch as it was edited by someone else
return False
def create_config_pr(
sweep_bot: SweepBot | None, repo: Repository = None, cloned_repo: ClonedRepo = None
):
if repo is not None:
# Check if file exists in repo
try:
repo.get_contents("sweep.yaml")
return
except SystemExit:
raise SystemExit
except Exception:
pass
title = "Configure Sweep"
branch_name = GITHUB_CONFIG_BRANCH
if sweep_bot is not None:
branch_name = sweep_bot.create_branch(branch_name, retry=False)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
sweep_bot.repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=sweep_bot.repo.default_branch,
additional_rules=DEFAULT_RULES_STRING,
),
branch=branch_name,
)
sweep_bot.repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
else:
# Create branch based on default branch
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=repo.default_branch, additional_rules=DEFAULT_RULES_STRING
),
branch=branch_name,
)
repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
repo = sweep_bot.repo if sweep_bot is not None else repo
# Check if the pull request from this branch to main already exists.
# If it does, then we don't need to create a new one.
if repo is not None:
pull_requests = repo.get_pulls(
state="open",
sort="created",
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
head=branch_name,
)
for pr in pull_requests:
if pr.title == title:
return pr
logger.print("Default branch", repo.default_branch)
logger.print("New branch", branch_name)
pr = repo.create_pull(
title=title,
body="""🎉 Thank you for installing Sweep! We're thrilled to announce the latest update for Sweep, your AI junior developer on GitHub. This PR creates a `sweep.yaml` config file, allowing you to personalize Sweep's performance according to your project requirements.
## What's new?
- **Sweep is now configurable**.
- To configure Sweep, simply edit the `sweep.yaml` file in the root of your repository.
- If you need help, check out the [Sweep Default Config](https://github.com/sweepai/sweep/blob/main/sweep.yaml) or [Join Our Discord](https://discord.gg/sweep) for help.
If you would like me to stop creating this PR, go to issues and say "Sweep: create an empty `sweep.yaml` file".
Thank you for using Sweep! 🧹""".replace(
" ", ""
),
head=branch_name,
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
)
pr.add_to_labels(GITHUB_LABEL_NAME)
return pr
def add_config_to_top_repos(installation_id, username, repositories, max_repos=3):
user_token, g = get_github_client(installation_id)
repo_activity = {}
for repo_entity in repositories:
repo = g.get_repo(repo_entity.full_name)
# instead of using total count, use the date of the latest commit
commits = repo.get_commits(
author=username,
since=datetime.datetime.now() - datetime.timedelta(days=30),
)
# get latest commit date
commit_date = datetime.datetime.now() - datetime.timedelta(days=30)
for commit in commits:
if commit.commit.author.date > commit_date:
commit_date = commit.commit.author.date
# since_date = datetime.datetime.now() - datetime.timedelta(days=30)
# commits = repo.get_commits(since=since_date, author="lukejagg")
repo_activity[repo] = commit_date
# print(repo, commits.totalCount)
logger.print(repo, commit_date)
sorted_repos = sorted(repo_activity, key=repo_activity.get, reverse=True)
sorted_repos = sorted_repos[:max_repos]
# For each repo, create a branch based on main branch, then create PR to main branch
for repo in sorted_repos:
try:
logger.print("Creating config for", repo.full_name)
create_config_pr(
None,
repo=repo,
cloned_repo=ClonedRepo(
repo_full_name=repo.full_name,
installation_id=installation_id,
token=user_token,
),
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.print(e)
logger.print("Finished creating configs for top repos")
def create_gha_pr(g, repo):
# Create a new branch
branch_name = "sweep/gha-enable"
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
# Update the sweep.yaml file in this branch to add "gha_enabled: True"
sweep_yaml_content = (
repo.get_contents("sweep.yaml", ref=branch_name).decoded_content.decode()
+ "\ngha_enabled: True"
)
repo.update_file(
"sweep.yaml",
"Enable GitHub Actions",
sweep_yaml_content,
repo.get_contents("sweep.yaml", ref=branch_name).sha,
branch=branch_name,
)
# Create a PR from this branch to the main branch
pr = repo.create_pull(
title="Enable GitHub Actions",
body="This PR enables GitHub Actions for this repository.",
head=branch_name,
base=repo.default_branch,
)
return pr
SWEEP_TEMPLATE = """\
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
- type: input
id: branch
attributes:
label: Branch
description: The branch to work off of (optional)
placeholder: |

import copy
import re
import traceback
from pathlib import Path
from loguru import logger
from sweepai.agents.assistant_wrapper import (
client,
openai_assistant_call,
run_until_complete,
)
from sweepai.core.entities import AssistantRaisedException, FileChangeRequest, Message
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.progress import AssistantConversation, TicketProgress
system_message = r""" You are searching through a codebase to guide a junior developer on how to solve the user request. The junior developer will follow your instructions exactly and make the changes.
# User Request
{user_request}
# Guide
## Step 1: Unzip the file into /mnt/data/repo. Then list all root level directories. You must copy the below code verbatim into the file.
```python
import zipfile
import os
zip_path = '{file_path}'
extract_to_path = 'mnt/data/repo'
os.makedirs(extract_to_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to_path)
zip_contents = zip_ref.namelist()
root_dirs = {{name.split('/')[0] for name in zip_contents}}
print(f'Root directories: {{root_dirs}}')
```
## Step 2: Find the relevant files.
You can search by file name or by keyword search in the contents.
## Step 3: Find relevant lines.
1. Locate the lines of code that contain the identified keywords or are at the specified line number. You can use keyword search or manually look through the file 100 lines at a time.
2. Check the surrounding lines to establish the full context of the code block.
3. Adjust the starting line to include the entire functionality that needs to be refactored or moved.
4. Finally determine the exact line spans that include a logical and complete section of code to be edited.
```python
def print_lines_with_keyword(content, keywords):
max_matches=5
context = 10
matches = [i for i, line in enumerate(content.splitlines()) if any(keyword in line.lower() for keyword in keywords)]
print(f"Found {{len(matches)}} matches, but capping at {{max_match}}")
matches = matches[:max_matches]
expanded_matches = set()
for match in matches:
start = max(0, match - context)
end = min(len(content.splitlines()), match + context + 1)
for i in range(start, end):
expanded_matches.add(i)
for i in sorted(expanded_matches):
print(f"{{i}}: {{content.splitlines()[i]}}")
```
## Step 4: Construct a plan.
Provide the final plan to solve the issue, following these rules:
* DO NOT apply any changes here, they will not be persisted. You must provide the plan and the developer will apply the changes.
* You may only create new files and modify existing files.
* File paths should be relative paths from the root of the repo.
* Use the minimum number of create and modify operations required to solve the issue.
* Start and end lines indicate the exact start and end lines to edit. Expand this to encompass more lines if you're unsure where to make the exact edit.
Respond in the following format:
```xml
<plan>
<create_file file="file_path_1">
* Natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create_file>
...
<modify_file file="file_path_2" start_line="i" end_line="j">
* Natural language instructions for the modifications needed to solve the issue.
* Be concise and reference necessary files, imports and entity names.
...
</modify_file>
...
</plan>
```"""
@file_cache(ignore_params=["zip_path", "chat_logger", "ticket_progress"])
def new_planning(
request: str,
zip_path: str,
additional_messages: list[Message] = [],
chat_logger: ChatLogger | None = None,
assistant_id: str = None,
ticket_progress: TicketProgress | None = None,
) -> list[FileChangeRequest]:
planning_iterations = 3
try:
def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
assistant_conversation = AssistantConversation.from_ids(
assistant_id=assistant_id, run_id=run_id, thread_id=thread_id
)
if not assistant_conversation:
return
ticket_progress.planning_progress.assistant_conversation = (
assistant_conversation
)
ticket_progress.save()
logger.info("Uploading file...")
zip_file_object = client.files.create(file=Path(zip_path), purpose="assistants")
logger.info("Done uploading file.")
zip_file_id = zip_file_object.id
response = openai_assistant_call(
request=request,
assistant_id=assistant_id,
additional_messages=additional_messages,
uploaded_file_ids=[zip_file_id],
chat_logger=chat_logger,
save_ticket_progress=save_ticket_progress
if ticket_progress is not None
else None,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = response.run_id
thread_id = response.thread_id
for _ in range(planning_iterations):
save_ticket_progress(
assistant_id=response.assistant_id,
thread_id=response.thread_id,
run_id=response.run_id,
)
messages = response.messages
final_message = messages.data[0].content[0].text.value
fcrs = []
fcr_matches = list(
re.finditer(FileChangeRequest._regex, final_message, re.DOTALL)
)
if len(fcr_matches) > 0:
break
else:
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content="A valid plan (within the <plan> tags) was not provided. Please continue working on the plan. If you are stuck, consider starting over.",
)
run = client.beta.threads.runs.create(
thread_id=response.thread_id,
assistant_id=response.assistant_id,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = run.id
messages = run_until_complete(
thread_id=thread_id,
run_id=run_id,
assistant_id=response.assistant_id,
)
for match_ in fcr_matches:
group_dict = match_.groupdict()
if group_dict["change_type"] == "create_file":
group_dict["change_type"] = "create"
if group_dict["change_type"] == "modify_file":
group_dict["change_type"] = "modify"
fcr = FileChangeRequest(**group_dict)
fcr.filename = fcr.filename.lstrip("/")
fcr.instructions = fcr.instructions.replace("\n*", "\n•")
fcr.instructions = fcr.instructions.strip("\n")
if fcr.instructions.startswith("*"):
fcr.instructions = "•" + fcr.instructions[1:]
fcrs.append(fcr)
new_file_change_request = copy.deepcopy(fcr)
new_file_change_request.change_type = "check"
new_file_change_request.parent = fcr
fcrs.append(new_file_change_request)
assert len(fcrs) > 0
return fcrs
except AssistantRaisedException as e:
raise e
except Exception as e:
logger.exception(e)
if chat_logger is not None:
discord_log_error(
str(e)
+ "\n\n"
+ traceback.format_exc()
+ "\n\n"
+ str(chat_logger.data)
)
return None
if __name__ == "__main__":
request = """## Title: replace the broken tutorial link in installation.md with https://docs.sweep.dev/usage/tutorial\n"""
additional_messages = [
Message(
role="user",
content='<relevant_snippets_in_repo>\n<snippet source="docs/pages/usage/tutorial.mdx:45-60">\n...\n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:30-45">\n...\n30: \n31: ![PR Comment](/tutorial/comment.png)\n32: \n33: c. If you have GitHub Actions set up, it will automatically run the linters, build, and tests and will show any failed logs to Sweep to handle. This only works with GitHub Actions and not other CI providers, so unfortunately for Vercel we have to copy paste manually.\n34: \n35: ![GitHub Actions](/tutorial/github_actions.png)\n36: \n37: 6. Once you are happy with the PR, you can merge it and it will be deployed to production via Vercel.\n38: \n39: \n40: ![Final](/tutorial/final.png)\n41: \n42: \n43: You can see the final example at https://github.com/kevinlu1248/docusaurus-2/pull/4 with preview https://docusaurus-2-ql4cskc5o-sweepai.vercel.app/.\n44: \n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n...\n</snippet>\n<snippet source="docs/installation.md:45-60">\n...\n45: * Provide any additional context that might be helpful, e.g. see "src/App.test.tsx" for an example of a good unit test.\n46: * For more guidance, visit [Advanced](https://docs.sweep.dev/usage/advanced), or watch the following video.\n47: \n48: [![Video](http://img.youtube.com/vi/Qn9vB71R4UM/0.jpg)](http://www.youtube.com/watch?v=Qn9vB71R4UM "Advanced Sweep Tricks and Feedback Tips")\n49: \n50: For configuring Sweep for your repo, see [Config](https://docs.sweep.dev/usage/config), especially for setting up Sweep Rules and Sweep Sweep.\n51: \n52: ## Limitations of Sweep (for now) ⚠️\n53: \n54: * 🗃️ **Gigantic repos**: >5000 files. We have default extensions and directories to exclude but sometimes this doesn\'t catch them all. You may need to block some directories (see [`blocked_dirs`](https://docs.sweep.dev/usage/config#blocked_dirs))\n55: * If Sweep is stuck at 0% for over 30 min and your repo has a few thousand files, let us know.\n56: \n57: * 🏗️ **Large-scale refactors**: >5 files or >300 lines of code changes (we\'re working on this!)\n58: * We can\'t do this - "Refactor entire codebase from Tensorflow to PyTorch"\n59: \n60: * 🖼️ **Editing images** and other non-text assets\n...\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:0-15">\n0: # Tutorial for Getting Started with Sweep\n1: \n2: We recommend using an existing **real project** for Sweep, but if you must start from scratch, we recommend **using a template**. In particular, we recommend Vercel templates and Vercel auto-deploy, since Vercel\'s auto-generated previews make it **easy to review Sweep\'s PRs**\n3: \n4: We\'ll use [Docusaurus](https://vercel.com/templates/next.js/docusaurus-2) since it\'s is the easiest to set up (no backend). To see other templates see https://vercel.com/templates.\n5: \n6: 1. Go to https://vercel.com/templates/next.js/docusaurus-2 (or another template) and click "Deploy".\n7: \n8: ![Deploy](/tutorial/deployment.png)\n9: \n10: 2. Vercel will prompt you to select a GitHub account and click "Clone" after. This will trigger a build and deploy which will take a few minutes. Once the build is done, you will be greeted with a congratulations message.\n11: \n12: ![Congratulations](/tutorial/congratulations.png)\n13: \n14: 3. Go to the [Sweep Installation](https://github.com/apps/sweep-ai) page and click the grey "Configure" button or the green "Install" button. Ensure that that the Vercel template (i.e. Docusaurus) is configured to use Sweep.\n...\n</snippet>\n</relevant_snippets_in_repo>\ndocs/\n installation.md\n docs/pages/\n docs/pages/usage/\n _meta.json\n advanced.mdx\n config.mdx\n extra-self-host.mdx\n sandbox.mdx\n tutorial.mdx',
name=None,
function_call=None,
key=None,
)
]
print(
new_planning(
request,
"/tmp/sweep_archive.zip",
chat_logger=ChatLogger(
{"username": "kevinlu1248", "title": "Unit test for planning"}
),
ticket_progress=TicketProgress(tracking_id="ed47605a38"),
)

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

sweep/sweepai/api.py

Lines 1 to 1178 in 76aecb2

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_jira_ticket import handle_jira_ticket
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
@app.post("/jira")
def jira_webhook(
request_dict: dict = Body(...),
) -> None:
def call_jira_ticket(*args, **kwargs):
thread = threading.Thread(target=handle_jira_ticket, args=args, kwargs=kwargs)
thread.start()
call_jira_ticket(event=request_dict)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 32 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)

# Advanced Features: becoming a Power User 🧠
## Usage 📖
### Mention important files
To ensure that Sweep scans a file, mention the file name in your ticket. Sweep searches for relevant files at runtime, but specifying the file helps avoid missing important details.
### Giving Sweep feedback
If Sweep's plan isn't accurate, you can respond to Sweep in three places:
1. **Issue**: Sweep will create a new pull request and close the old one. Alternatively, you can edit the issue description to recreate the pull request.
2. **Pull request**: Sweep will update the PR based on your PR comments
3. **Code**: Sweep will only update the file that the comment is on
Whenever you make a message that Sweep is taking a look at, you will see an 👀 emoji. If you don't see this, make sure the PR/issue is open and you prefixed the message with "sweep:".
Further, on failed Github Action runs, Sweep will update the PR based on the error message.
### Switch branch
To get Sweep to use a different base branch for one issue, add the following to the issue description.
> branch: BRANCH_NAME
## Configuration 🛠️
### Use GitHub Actions
We highly recommend linters, as well as Netlify/Vercel preview builds. Sweep auto-corrects based on linter and build errors, and Netlify and Vercel helps with iteration cycles by providing previews of static sites using Netlify.
### Set up `sweep.yaml`
You can set up `sweep.yaml` to
* Provide up to date docs by setting up `docs` (https://docs.sweep.dev/usage/config#docs)
* Set up automated formatting and linting by setting up `sandbox` (https://docs.sweep.dev/usage/config#sandbox). Never have Sweep commit a failing `npm lint` again.
* Give Sweep a high level description of where to find files in your repo by editing the `repo_description` field.
For more on configs, check out https://docs.sweep.dev/usage/config.
## Prompting 🗣️
The amount of prompting you need to give Sweep directly scales with the complexity of the problem.
For harder problems, try to provide the same information a human would need, and for simpler problems, providing a single line and a file name should suffice.
### Prompting formats
A good issue should include **where to look** (file name or entity name), **what to do** ("change the logic to do this"), and **additional context** (there's a bug/we need this feature/there's this dependency). Examples:

sweep/sweepai/cli.py

Lines 1 to 363 in 76aecb2

import datetime
import json
import os
import pickle
import threading
import time
import uuid
from itertools import chain, islice
import typer
from github import Github
from github.Event import Event
from github.IssueEvent import IssueEvent
from github.Repository import Repository
from loguru import logger
from rich.console import Console
from rich.prompt import Prompt
from sweepai.api import handle_request
from sweepai.handlers.on_ticket import on_ticket
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import get_hash
from sweepai.web.events import Account, Installation, IssueRequest
app = typer.Typer(
name="sweepai", context_settings={"help_option_names": ["-h", "--help"]}
)
app_dir = typer.get_app_dir("sweepai")
config_path = os.path.join(app_dir, "config.json")
console = Console()
cprint = console.print
def posthog_capture(event_name, properties, *args, **kwargs):
POSTHOG_DISTINCT_ID = os.environ.get("POSTHOG_DISTINCT_ID")
if POSTHOG_DISTINCT_ID:
posthog.capture(POSTHOG_DISTINCT_ID, event_name, properties, *args, **kwargs)
def load_config():
if os.path.exists(config_path):
cprint(f"\nLoading configuration from {config_path}", style="yellow")
with open(config_path, "r") as f:
config = json.load(f)
os.environ["GITHUB_PAT"] = config.get("GITHUB_PAT", "")
os.environ["OPENAI_API_KEY"] = config.get("OPENAI_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = config.get("ANTHROPIC_API_KEY", "")
os.environ["VOYAGE_API_KEY"] = config.get("VOYAGE_API_KEY", "")
os.environ["POSTHOG_DISTINCT_ID"] = str(config.get("POSTHOG_DISTINCT_ID", ""))
def fetch_issue_request(issue_url: str, __version__: str = "0"):
(
protocol_name,
_,
_base_url,
org_name,
repo_name,
_issues,
issue_number,
) = issue_url.split("/")
cprint("Fetching installation ID...")
installation_id = -1
cprint("Fetching access token...")
_token, g = get_github_client(installation_id)
g: Github = g
cprint("Fetching repo...")
issue = g.get_repo(f"{org_name}/{repo_name}").get_issue(int(issue_number))
issue_request = IssueRequest(
action="labeled",
issue=IssueRequest.Issue(
title=issue.title,
number=int(issue_number),
html_url=issue_url,
user=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
body=issue.body,
labels=[
IssueRequest.Issue.Label(
name="sweep",
),
],
assignees=None,
pull_request=None,
),
repository=IssueRequest.Issue.Repository(
full_name=issue.repository.full_name,
description=issue.repository.description,
),
assignee=IssueRequest.Issue.Assignee(login=issue.user.login),
installation=Installation(
id=installation_id,
account=Account(
id=issue.user.id,
login=issue.user.login,
type="User",
),
),
sender=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
)
return issue_request
def pascal_to_snake(name):
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
def get_event_type(event: Event | IssueEvent):
if isinstance(event, IssueEvent):
return "issues"
else:
return pascal_to_snake(event.type)[: -len("_event")]
@app.command()
def test():
cprint("Sweep AI is installed correctly and ready to go!", style="yellow")
@app.command()
def watch(
repo_name: str,
debug: bool = False,
record_events: bool = False,
max_events: int = 30,
):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
posthog_capture(
"sweep_watch_started",
{
"repo": repo_name,
"debug": debug,
"record_events": record_events,
"max_events": max_events,
},
)
GITHUB_PAT = os.environ.get("GITHUB_PAT", None)
if GITHUB_PAT is None:
raise ValueError("GITHUB_PAT environment variable must be set")
g = Github(os.environ["GITHUB_PAT"])
repo = g.get_repo(repo_name)
if debug:
logger.debug("Debug mode enabled")
def stream_events(repo: Repository, timeout: int = 2, offset: int = 2 * 60):
processed_event_ids = set()
current_time = time.time() - offset
current_time = datetime.datetime.fromtimestamp(current_time)
local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
while True:
events_iterator = chain(
islice(repo.get_events(), max_events),
islice(repo.get_issues_events(), max_events),
)
for i, event in enumerate(events_iterator):
if event.id not in processed_event_ids:
local_time = event.created_at.replace(
tzinfo=datetime.timezone.utc
).astimezone(local_tz)
if local_time.timestamp() > current_time.timestamp():
yield event
else:
if debug:
logger.debug(
f"Skipping event {event.id} because it is in the past (local_time={local_time}, current_time={current_time}, i={i})"
)
if debug:
logger.debug(
f"Skipping event {event.id} because it is already handled"
)
processed_event_ids.add(event.id)
time.sleep(timeout)
def handle_event(event: Event | IssueEvent, do_async: bool = True):
if isinstance(event, IssueEvent):
payload = event.raw_data
payload["action"] = payload["event"]
else:
payload = {**event.raw_data, **event.payload}
payload["sender"] = payload.get("sender", payload["actor"])
payload["sender"]["type"] = "User"
payload["pusher"] = payload.get("pusher", payload["actor"])
payload["pusher"]["name"] = payload["pusher"]["login"]
payload["pusher"]["type"] = "User"
payload["after"] = payload.get("after", payload.get("head"))
payload["repository"] = repo.raw_data
payload["installation"] = {"id": -1}
logger.info(str(event) + " " + str(event.created_at))
if record_events:
_type = get_event_type(event) if isinstance(event, Event) else "issue"
pickle.dump(
event,
open(
"tests/events/"
+ f"{_type}_{payload.get('action')}_{str(event.id)}.pkl",
"wb",
),
)
if do_async:
thread = threading.Thread(
target=handle_request, args=(payload, get_event_type(event))
)
thread.start()
return thread
else:
return handle_request(payload, get_event_type(event))
def main():
cprint(
f"\n[bold black on white] Starting server, listening to events from {repo_name}... [/bold black on white]\n",
)
cprint(
f"To create a PR, please create an issue at https://github.com/{repo_name}/issues with a title prefixed with 'Sweep:' or label an existing issue with 'sweep'. The events will be logged here, but there may be a brief delay.\n"
)
for event in stream_events(repo):
handle_event(event)
if __name__ == "__main__":
main()
@app.command()
def init(override: bool = False):
# TODO: Fix telemetry
if not override:
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
if "OPENAI_API_KEY" in config and "ANTHROPIC_API_KEY" in config and "GITHUB_PAT" in config:
override = typer.confirm(
f"\nConfiguration already exists at {config_path}. Override?",
default=False,
abort=True,
)
cprint(
"\n[bold black on white] Initializing Sweep CLI... [/bold black on white]\n",
)
cprint(
"\nFirstly, let's store your OpenAI API Key. You can get it here: https://platform.openai.com/api-keys\n",
style="yellow",
)
openai_api_key = Prompt.ask("OpenAI API Key", password=True)
assert len(openai_api_key) > 30, "OpenAI API Key must be of length at least 30."
assert openai_api_key.startswith("sk-"), "OpenAI API Key must start with 'sk-'."
cprint(
"\nNext, let's store your Anthropic API key. You can get it here: https://console.anthropic.com/settings/keys.",
style="yellow",
)
anthropic_api_key = Prompt.ask("Anthropic API Key", password=True)
assert len(anthropic_api_key) > 30, "Anthropic API Key must be of length at least 30."
assert anthropic_api_key.startswith("sk-ant-api03-"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nGreat! Next, we'll need just your GitHub PAT. Here's a link with all the permissions pre-filled:\nhttps://github.com/settings/tokens/new?description=Sweep%20Self-hosted&scopes=repo,workflow\n",
style="yellow",
)
github_pat = Prompt.ask("GitHub PAT", password=True)
assert len(github_pat) > 30, "GitHub PAT must be of length at least 30."
assert github_pat.startswith("ghp_"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nAwesome! Lastly, let's get your Voyage AI API key from https://dash.voyageai.com/api-keys. This is optional, but improves code search by about [cyan]5%[/cyan]. You can always return to this later by re-running 'sweep init'.",
style="yellow",
)
voyage_api_key = Prompt.ask("Voyage AI API key", password=True)
if voyage_api_key:
assert len(voyage_api_key) > 30, "Voyage AI API key must be of length at least 30."
assert voyage_api_key.startswith("pa-"), "Voyage API key must start with 'pa-'."
POSTHOG_DISTINCT_ID = None
enable_telemetry = typer.confirm(
"\nEnable usage statistics? This will help us improve the product.",
default=True,
)
if enable_telemetry:
cprint(
"\nThank you for enabling telemetry. We'll collect anonymous usage statistics to improve the product. You can disable this at any time by rerunning 'sweep init'.",
style="yellow",
)
POSTHOG_DISTINCT_ID = uuid.getnode()
posthog.capture(POSTHOG_DISTINCT_ID, "sweep_init", {})
config = {
"GITHUB_PAT": github_pat,
"OPENAI_API_KEY": openai_api_key,
"ANTHROPIC_API_KEY": anthropic_api_key,
"VOYAGE_API_KEY": voyage_api_key,
}
if POSTHOG_DISTINCT_ID:
config["POSTHOG_DISTINCT_ID"] = POSTHOG_DISTINCT_ID
os.makedirs(app_dir, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f)
cprint(f"\nConfiguration saved to {config_path}\n", style="yellow")
cprint(
"Installation complete! You can now run [green]'sweep run <issue-url>'[/green][yellow] to run Sweep on an issue. or [/yellow][green]'sweep watch <org-name>/<repo-name>'[/green] to have Sweep listen for and fix newly created GitHub issues.",
style="yellow",
)
@app.command()
def run(issue_url: str):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
cprint(f"\n Running Sweep on issue: {issue_url} \n", style="bold black on white")
posthog_capture("sweep_run_started", {"issue_url": issue_url})
request = fetch_issue_request(issue_url)
try:
cprint(f'\nRunning Sweep to solve "{request.issue.title}"!\n')
on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.sender.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
edited=False,
tracking_id=get_hash(),
)
except Exception as e:
posthog_capture("sweep_run_fail", {"issue_url": issue_url, "error": str(e)})
else:
posthog_capture("sweep_run_success", {"issue_url": issue_url})
def main():
cprint(
"By using the Sweep CLI, you agree to the Sweep AI Terms of Service at https://sweep.dev/tos.pdf",
style="cyan",
)
load_config()
app()
if __name__ == "__main__":

# Frequently Asked Questions
<details id="does-sweep-write-tests">
<summary>Does Sweep write tests?</summary>
Yep! The easiest way to have Sweep write tests is by modifying the `description` parameter in your `sweep.yaml`. You can add something like:
“In [your repository], the tests are written in [your format]. If you modify business logic, modify the tests as well using this format.” You can add anything you’d like to the description parameter, including formatting rules (like PEP8), code style, etc!
</details>
<details id="can-we-trust-code-written-by-sweep">
<summary>Can we trust the code written by Sweep?</summary>
You should always review the PR. However, we also perform testing to make sure the PR works using your existing GitHub actions.
To get the best performance, add GitHub actions that lint, test, and validate your code.
</details>
<details id="work-off-another-branch">
<summary>Can I have Sweep work off of another branch besides main?</summary>
Yes! In the `sweep.yaml`, you can set the `branch` parameter to something besides your default branch, and Sweep will use that as a reference.
</details>
<details id="retry-issue-with-sweep">
<summary>How do I retry an issue with Sweep?</summary>
To retry an issue, prefix your issue reply with 'Sweep: '. This will trigger Sweep to retry the issue.
</details>
<details id="give-documentation-to-sweep">
<summary>Can I give documentation to Sweep?</summary>
Yes! In the `sweep.yaml`, you can specify docs. Be sure to pick the prefix of the site, which will allow us to only fetch the docs you need.
Check out the example here: https://github.com/sweepai/sweep/blob/main/sweep.yaml.
</details>
<details id="comment-on-sweeps-prs">
<summary>Can I comment on Sweep’s PRs?</summary>
Yep! You have three options depending on the degree of the change:
1. You can comment on the issue, and Sweep will rewrite the entire pull request. This will use one of your GPT4 credits.
2. You can comment on the pull request (not a file) and Sweep can make substantial changes to the pull request. Sweep will search the codebase, and is able to modify and create files.
3. You can comment on the file directly, and Sweep will only modify that file. Use this for small single file changes.
</details>

Once Sweep has the reference implementation, Sweep generates the corresponding test as commits in a [GitHub PR](https://github.com/sweepai/sweep/pull/2378):
```python
def get_file_contents(self, file_path, ref=None):
local_path = os.path.join(self.cache_dir, file_path)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
```
We have Sweep generated mocks for `os.path.join` and `open`. <br></br>
This code looks great!
```python
@patch("os.path.join")
@patch("open")
def test_get_file_contents(self, mock_open, mock_join):
mock_join.return_value = "/tmp/cache/repos/sweepai/sweep/main/file1"
mock_open.return_value.__enter__.return_value.read.return_value = "file content"
content = self.cloned_repo.get_file_contents("file1")
self.assertEqual(content, "file content")
```
We generated mocks for `os.path.join` and `open`, which should return the correct path and file contents. <br></br>
Ok we're done here right? Can we just write these tests and leave the rest to the developer?
## 3. **Run the tests.**
Most other AI tools stop here, but it’s not enough. <br></br>
If you just committed these tests it would be great, but you’d end up with a frustrating bug. Here it is:
```bash
File "/usr/lib/python3.10/unittest/mock.py", line 1616, in _get_target
raise TypeError(
TypeError: Need a valid target to patch. You supplied: 'open'
```
Did we really save time for the developer here? It’s frustrating that most other tools don’t fix these issues.
*Unlike every other tool, Sweep actually runs these tests.*
Sweep ran the code, found the issue, and identified the solution: <br></br>
**”Change the target of the patch in the 'test_get_file_contents' method from 'open' to 'builtins.open'. This will correctly patch the built-in 'open' function during the test.”**
Sweep added [this commit](https://github.com/sweepai/sweep/pull/2378/commits/0ded79eab77ca3e511257ff0bf3874893b038e9e):
```python


Step 2: ⌨️ Coding

  • Modify sweepai/handlers/on_merge_conflict.py8fe7dbe Edit
Modify sweepai/handlers/on_merge_conflict.py with contents: In the `on_merge_conflict` function:
• After creating the new branch `new_pull_request.branch_name`, call a new function `rebase_branch` from `github_utils.py` to rebase the new branch onto the target branch `pr.base.ref` instead of performing a merge.
• Remove the existing code that performs the merge using `git_repo.git.merge("origin/" + pr.base.ref)`.
• Update the comment to indicate that a rebase is being performed instead of a merge.
  • Modify sweepai/utils/github_utils.py8fe7dbe Edit
Modify sweepai/utils/github_utils.py with contents:
• Add a new function `rebase_branch` that takes the `git_repo`, `source_branch`, and `target_branch` as parameters.
• In the `rebase_branch` function: - Checkout the `source_branch`. - Perform the rebase onto the `target_branch` using `git_repo.git.rebase(target_branch)`. - Handle any rebase conflicts by iterating over the conflicted files, resolving the conflicts, adding the resolved files, and continuing the rebase with `git_repo.git.rebase("--continue")`. - Return the updated `git_repo` object.

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_2bf50.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 8, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 8, 2024

🚀 Here's the PR! #3499

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 2593f729a7)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

"""
create_pr is a function that creates a pull request from a list of file change requests.
It is also responsible for handling Sweep config PR creation. test
"""
import datetime
from typing import Any, Generator
import openai
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import DEFAULT_RULES_STRING, SweepConfig, get_blocked_dirs
from sweepai.config.server import (
ENV,
GITHUB_BOT_USERNAME,
GITHUB_CONFIG_BRANCH,
GITHUB_DEFAULT_CONFIG,
GITHUB_LABEL_NAME,
MONGODB_URI,
)
from sweepai.core.entities import (
FileChangeRequest,
MaxTokensExceeded,
Message,
MockPR,
PullRequest,
)
from sweepai.core.sweep_bot import SweepBot
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.str_utils import UPDATES_MESSAGE
num_of_snippets_to_query = 10
max_num_of_snippets = 5
INSTRUCTIONS_FOR_REVIEW = """\
### 💡 To get Sweep to edit this pull request, you can:
* Comment below, and Sweep can edit the entire PR
* Comment on a file, Sweep will only modify the commented file
* Edit the original issue to get Sweep to recreate the PR from scratch"""
def create_pr_changes(
file_change_requests: list[FileChangeRequest],
pull_request: PullRequest,
sweep_bot: SweepBot,
username: str,
installation_id: int,
issue_number: int | None = None,
chat_logger: ChatLogger = None,
base_branch: str = None,
additional_messages: list[Message] = []
) -> Generator[tuple[FileChangeRequest, int, Any], None, dict]:
# Flow:
# 1. Get relevant files
# 2: Get human message
# 3. Get files to change
# 4. Get file changes
# 5. Create PR
chat_logger = (
chat_logger
if chat_logger is not None
else ChatLogger(
{
"username": username,
"installation_id": installation_id,
"repo_full_name": sweep_bot.repo.full_name,
"title": pull_request.title,
"summary": "",
"issue_url": "",
}
)
if MONGODB_URI
else None
)
sweep_bot.chat_logger = chat_logger
organization, repo_name = sweep_bot.repo.full_name.split("/")
metadata = {
"repo_full_name": sweep_bot.repo.full_name,
"organization": organization,
"repo_name": repo_name,
"repo_description": sweep_bot.repo.description,
"username": username,
"installation_id": installation_id,
"function": "create_pr",
"mode": ENV,
"issue_number": issue_number,
}
posthog.capture(username, "started", properties=metadata)
try:
logger.info("Making PR...")
pull_request.branch_name = sweep_bot.create_branch(
pull_request.branch_name, base_branch=base_branch
)
completed_count, fcr_count = 0, len(file_change_requests)
blocked_dirs = get_blocked_dirs(sweep_bot.repo)
for (
new_file_contents,
changed_file,
commit,
file_change_requests,
) in sweep_bot.change_files_in_github_iterator(
file_change_requests,
pull_request.branch_name,
blocked_dirs,
additional_messages=additional_messages
):
completed_count += len(new_file_contents or [])
logger.info(f"Completed {completed_count}/{fcr_count} files")
yield new_file_contents, changed_file, commit, file_change_requests
if completed_count == 0 and fcr_count != 0:
logger.info("No changes made")
posthog.capture(
username,
"failed",
properties={
"error": "No changes made",
"reason": "No changes made",
**metadata,
},
)
# If no changes were made, delete branch
commits = sweep_bot.repo.get_commits(pull_request.branch_name)
if commits.totalCount == 0:
branch = sweep_bot.repo.get_git_ref(f"heads/{pull_request.branch_name}")
branch.delete()
return
# Include issue number in PR description
if issue_number:
# If the #issue changes, then change on_ticket (f'Fixes #{issue_number}.\n' in pr.body:)
pr_description = (
f"{pull_request.content}\n\nFixes"
f" #{issue_number}.\n\n---\n\n{UPDATES_MESSAGE}\n\n---\n\n{INSTRUCTIONS_FOR_REVIEW}"
)
else:
pr_description = f"{pull_request.content}"
pr_title = pull_request.title
if "sweep.yaml" in pr_title:
pr_title = "[config] " + pr_title
except MaxTokensExceeded as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Max tokens exceeded",
**metadata,
},
)
raise e
except openai.BadRequestError as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Invalid request error / context length",
**metadata,
},
)
raise e
except Exception as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Unexpected error",
**metadata,
},
)
raise e
posthog.capture(username, "success", properties={**metadata})
logger.info("create_pr success")
result = {
"success": True,
"pull_request": MockPR(
file_count=completed_count,
title=pr_title,
body=pr_description,
pr_head=pull_request.branch_name,
base=sweep_bot.repo.get_branch(
SweepConfig.get_branch(sweep_bot.repo)
).commit,
head=sweep_bot.repo.get_branch(pull_request.branch_name).commit,
),
}
yield result # TODO: refactor this as it doesn't need to be an iterator
return
def safe_delete_sweep_branch(
pr, # Github PullRequest
repo: Repository,
) -> bool:
"""
Safely delete Sweep branch
1. Only edited by Sweep
2. Prefixed by sweep/
"""
pr_commits = pr.get_commits()
pr_commit_authors = set([commit.author.login for commit in pr_commits])
# Check if only Sweep has edited the PR, and sweep/ prefix
if (
len(pr_commit_authors) == 1
and GITHUB_BOT_USERNAME in pr_commit_authors
and pr.head.ref.startswith("sweep")
):
branch = repo.get_git_ref(f"heads/{pr.head.ref}")
# pr.edit(state='closed')
branch.delete()
return True
else:
# Failed to delete branch as it was edited by someone else
return False
def create_config_pr(
sweep_bot: SweepBot | None, repo: Repository = None, cloned_repo: ClonedRepo = None
):
if repo is not None:
# Check if file exists in repo
try:
repo.get_contents("sweep.yaml")
return
except SystemExit:
raise SystemExit
except Exception:
pass
title = "Configure Sweep"
branch_name = GITHUB_CONFIG_BRANCH
if sweep_bot is not None:
branch_name = sweep_bot.create_branch(branch_name, retry=False)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
sweep_bot.repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=sweep_bot.repo.default_branch,
additional_rules=DEFAULT_RULES_STRING,
),
branch=branch_name,
)
sweep_bot.repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
else:
# Create branch based on default branch
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=repo.default_branch, additional_rules=DEFAULT_RULES_STRING
),
branch=branch_name,
)
repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
repo = sweep_bot.repo if sweep_bot is not None else repo
# Check if the pull request from this branch to main already exists.
# If it does, then we don't need to create a new one.
if repo is not None:
pull_requests = repo.get_pulls(
state="open",
sort="created",
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
head=branch_name,
)
for pr in pull_requests:
if pr.title == title:
return pr
logger.print("Default branch", repo.default_branch)
logger.print("New branch", branch_name)
pr = repo.create_pull(
title=title,
body="""🎉 Thank you for installing Sweep! We're thrilled to announce the latest update for Sweep, your AI junior developer on GitHub. This PR creates a `sweep.yaml` config file, allowing you to personalize Sweep's performance according to your project requirements.
## What's new?
- **Sweep is now configurable**.
- To configure Sweep, simply edit the `sweep.yaml` file in the root of your repository.
- If you need help, check out the [Sweep Default Config](https://github.com/sweepai/sweep/blob/main/sweep.yaml) or [Join Our Discord](https://discord.gg/sweep) for help.
If you would like me to stop creating this PR, go to issues and say "Sweep: create an empty `sweep.yaml` file".
Thank you for using Sweep! 🧹""".replace(
" ", ""
),
head=branch_name,
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
)
pr.add_to_labels(GITHUB_LABEL_NAME)
return pr
def add_config_to_top_repos(installation_id, username, repositories, max_repos=3):
user_token, g = get_github_client(installation_id)
repo_activity = {}
for repo_entity in repositories:
repo = g.get_repo(repo_entity.full_name)
# instead of using total count, use the date of the latest commit
commits = repo.get_commits(
author=username,
since=datetime.datetime.now() - datetime.timedelta(days=30),
)
# get latest commit date
commit_date = datetime.datetime.now() - datetime.timedelta(days=30)
for commit in commits:
if commit.commit.author.date > commit_date:
commit_date = commit.commit.author.date
# since_date = datetime.datetime.now() - datetime.timedelta(days=30)
# commits = repo.get_commits(since=since_date, author="lukejagg")
repo_activity[repo] = commit_date
# print(repo, commits.totalCount)
logger.print(repo, commit_date)
sorted_repos = sorted(repo_activity, key=repo_activity.get, reverse=True)
sorted_repos = sorted_repos[:max_repos]
# For each repo, create a branch based on main branch, then create PR to main branch
for repo in sorted_repos:
try:
logger.print("Creating config for", repo.full_name)
create_config_pr(
None,
repo=repo,
cloned_repo=ClonedRepo(
repo_full_name=repo.full_name,
installation_id=installation_id,
token=user_token,
),
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.print(e)
logger.print("Finished creating configs for top repos")
def create_gha_pr(g, repo):
# Create a new branch
branch_name = "sweep/gha-enable"
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
# Update the sweep.yaml file in this branch to add "gha_enabled: True"
sweep_yaml_content = (
repo.get_contents("sweep.yaml", ref=branch_name).decoded_content.decode()
+ "\ngha_enabled: True"
)
repo.update_file(
"sweep.yaml",
"Enable GitHub Actions",
sweep_yaml_content,
repo.get_contents("sweep.yaml", ref=branch_name).sha,
branch=branch_name,
)
# Create a PR from this branch to the main branch
pr = repo.create_pull(
title="Enable GitHub Actions",
body="This PR enables GitHub Actions for this repository.",
head=branch_name,
base=repo.default_branch,
)
return pr
SWEEP_TEMPLATE = """\
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
- type: input
id: branch
attributes:
label: Branch
description: The branch to work off of (optional)
placeholder: |

import copy
import re
import traceback
from pathlib import Path
from loguru import logger
from sweepai.agents.assistant_wrapper import (
client,
openai_assistant_call,
run_until_complete,
)
from sweepai.core.entities import AssistantRaisedException, FileChangeRequest, Message
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.progress import AssistantConversation, TicketProgress
system_message = r""" You are searching through a codebase to guide a junior developer on how to solve the user request. The junior developer will follow your instructions exactly and make the changes.
# User Request
{user_request}
# Guide
## Step 1: Unzip the file into /mnt/data/repo. Then list all root level directories. You must copy the below code verbatim into the file.
```python
import zipfile
import os
zip_path = '{file_path}'
extract_to_path = 'mnt/data/repo'
os.makedirs(extract_to_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to_path)
zip_contents = zip_ref.namelist()
root_dirs = {{name.split('/')[0] for name in zip_contents}}
print(f'Root directories: {{root_dirs}}')
```
## Step 2: Find the relevant files.
You can search by file name or by keyword search in the contents.
## Step 3: Find relevant lines.
1. Locate the lines of code that contain the identified keywords or are at the specified line number. You can use keyword search or manually look through the file 100 lines at a time.
2. Check the surrounding lines to establish the full context of the code block.
3. Adjust the starting line to include the entire functionality that needs to be refactored or moved.
4. Finally determine the exact line spans that include a logical and complete section of code to be edited.
```python
def print_lines_with_keyword(content, keywords):
max_matches=5
context = 10
matches = [i for i, line in enumerate(content.splitlines()) if any(keyword in line.lower() for keyword in keywords)]
print(f"Found {{len(matches)}} matches, but capping at {{max_match}}")
matches = matches[:max_matches]
expanded_matches = set()
for match in matches:
start = max(0, match - context)
end = min(len(content.splitlines()), match + context + 1)
for i in range(start, end):
expanded_matches.add(i)
for i in sorted(expanded_matches):
print(f"{{i}}: {{content.splitlines()[i]}}")
```
## Step 4: Construct a plan.
Provide the final plan to solve the issue, following these rules:
* DO NOT apply any changes here, they will not be persisted. You must provide the plan and the developer will apply the changes.
* You may only create new files and modify existing files.
* File paths should be relative paths from the root of the repo.
* Use the minimum number of create and modify operations required to solve the issue.
* Start and end lines indicate the exact start and end lines to edit. Expand this to encompass more lines if you're unsure where to make the exact edit.
Respond in the following format:
```xml
<plan>
<create_file file="file_path_1">
* Natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create_file>
...
<modify_file file="file_path_2" start_line="i" end_line="j">
* Natural language instructions for the modifications needed to solve the issue.
* Be concise and reference necessary files, imports and entity names.
...
</modify_file>
...
</plan>
```"""
@file_cache(ignore_params=["zip_path", "chat_logger", "ticket_progress"])
def new_planning(
request: str,
zip_path: str,
additional_messages: list[Message] = [],
chat_logger: ChatLogger | None = None,
assistant_id: str = None,
ticket_progress: TicketProgress | None = None,
) -> list[FileChangeRequest]:
planning_iterations = 3
try:
def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
assistant_conversation = AssistantConversation.from_ids(
assistant_id=assistant_id, run_id=run_id, thread_id=thread_id
)
if not assistant_conversation:
return
ticket_progress.planning_progress.assistant_conversation = (
assistant_conversation
)
ticket_progress.save()
logger.info("Uploading file...")
zip_file_object = client.files.create(file=Path(zip_path), purpose="assistants")
logger.info("Done uploading file.")
zip_file_id = zip_file_object.id
response = openai_assistant_call(
request=request,
assistant_id=assistant_id,
additional_messages=additional_messages,
uploaded_file_ids=[zip_file_id],
chat_logger=chat_logger,
save_ticket_progress=save_ticket_progress
if ticket_progress is not None
else None,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = response.run_id
thread_id = response.thread_id
for _ in range(planning_iterations):
save_ticket_progress(
assistant_id=response.assistant_id,
thread_id=response.thread_id,
run_id=response.run_id,
)
messages = response.messages
final_message = messages.data[0].content[0].text.value
fcrs = []
fcr_matches = list(
re.finditer(FileChangeRequest._regex, final_message, re.DOTALL)
)
if len(fcr_matches) > 0:
break
else:
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content="A valid plan (within the <plan> tags) was not provided. Please continue working on the plan. If you are stuck, consider starting over.",
)
run = client.beta.threads.runs.create(
thread_id=response.thread_id,
assistant_id=response.assistant_id,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = run.id
messages = run_until_complete(
thread_id=thread_id,
run_id=run_id,
assistant_id=response.assistant_id,
)
for match_ in fcr_matches:
group_dict = match_.groupdict()
if group_dict["change_type"] == "create_file":
group_dict["change_type"] = "create"
if group_dict["change_type"] == "modify_file":
group_dict["change_type"] = "modify"
fcr = FileChangeRequest(**group_dict)
fcr.filename = fcr.filename.lstrip("/")
fcr.instructions = fcr.instructions.replace("\n*", "\n•")
fcr.instructions = fcr.instructions.strip("\n")
if fcr.instructions.startswith("*"):
fcr.instructions = "•" + fcr.instructions[1:]
fcrs.append(fcr)
new_file_change_request = copy.deepcopy(fcr)
new_file_change_request.change_type = "check"
new_file_change_request.parent = fcr
fcrs.append(new_file_change_request)
assert len(fcrs) > 0
return fcrs
except AssistantRaisedException as e:
raise e
except Exception as e:
logger.exception(e)
if chat_logger is not None:
discord_log_error(
str(e)
+ "\n\n"
+ traceback.format_exc()
+ "\n\n"
+ str(chat_logger.data)
)
return None
if __name__ == "__main__":
request = """## Title: replace the broken tutorial link in installation.md with https://docs.sweep.dev/usage/tutorial\n"""
additional_messages = [
Message(
role="user",
content='<relevant_snippets_in_repo>\n<snippet source="docs/pages/usage/tutorial.mdx:45-60">\n...\n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:30-45">\n...\n30: \n31: ![PR Comment](/tutorial/comment.png)\n32: \n33: c. If you have GitHub Actions set up, it will automatically run the linters, build, and tests and will show any failed logs to Sweep to handle. This only works with GitHub Actions and not other CI providers, so unfortunately for Vercel we have to copy paste manually.\n34: \n35: ![GitHub Actions](/tutorial/github_actions.png)\n36: \n37: 6. Once you are happy with the PR, you can merge it and it will be deployed to production via Vercel.\n38: \n39: \n40: ![Final](/tutorial/final.png)\n41: \n42: \n43: You can see the final example at https://github.com/kevinlu1248/docusaurus-2/pull/4 with preview https://docusaurus-2-ql4cskc5o-sweepai.vercel.app/.\n44: \n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n...\n</snippet>\n<snippet source="docs/installation.md:45-60">\n...\n45: * Provide any additional context that might be helpful, e.g. see "src/App.test.tsx" for an example of a good unit test.\n46: * For more guidance, visit [Advanced](https://docs.sweep.dev/usage/advanced), or watch the following video.\n47: \n48: [![Video](http://img.youtube.com/vi/Qn9vB71R4UM/0.jpg)](http://www.youtube.com/watch?v=Qn9vB71R4UM "Advanced Sweep Tricks and Feedback Tips")\n49: \n50: For configuring Sweep for your repo, see [Config](https://docs.sweep.dev/usage/config), especially for setting up Sweep Rules and Sweep Sweep.\n51: \n52: ## Limitations of Sweep (for now) ⚠️\n53: \n54: * 🗃️ **Gigantic repos**: >5000 files. We have default extensions and directories to exclude but sometimes this doesn\'t catch them all. You may need to block some directories (see [`blocked_dirs`](https://docs.sweep.dev/usage/config#blocked_dirs))\n55: * If Sweep is stuck at 0% for over 30 min and your repo has a few thousand files, let us know.\n56: \n57: * 🏗️ **Large-scale refactors**: >5 files or >300 lines of code changes (we\'re working on this!)\n58: * We can\'t do this - "Refactor entire codebase from Tensorflow to PyTorch"\n59: \n60: * 🖼️ **Editing images** and other non-text assets\n...\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:0-15">\n0: # Tutorial for Getting Started with Sweep\n1: \n2: We recommend using an existing **real project** for Sweep, but if you must start from scratch, we recommend **using a template**. In particular, we recommend Vercel templates and Vercel auto-deploy, since Vercel\'s auto-generated previews make it **easy to review Sweep\'s PRs**\n3: \n4: We\'ll use [Docusaurus](https://vercel.com/templates/next.js/docusaurus-2) since it\'s is the easiest to set up (no backend). To see other templates see https://vercel.com/templates.\n5: \n6: 1. Go to https://vercel.com/templates/next.js/docusaurus-2 (or another template) and click "Deploy".\n7: \n8: ![Deploy](/tutorial/deployment.png)\n9: \n10: 2. Vercel will prompt you to select a GitHub account and click "Clone" after. This will trigger a build and deploy which will take a few minutes. Once the build is done, you will be greeted with a congratulations message.\n11: \n12: ![Congratulations](/tutorial/congratulations.png)\n13: \n14: 3. Go to the [Sweep Installation](https://github.com/apps/sweep-ai) page and click the grey "Configure" button or the green "Install" button. Ensure that that the Vercel template (i.e. Docusaurus) is configured to use Sweep.\n...\n</snippet>\n</relevant_snippets_in_repo>\ndocs/\n installation.md\n docs/pages/\n docs/pages/usage/\n _meta.json\n advanced.mdx\n config.mdx\n extra-self-host.mdx\n sandbox.mdx\n tutorial.mdx',
name=None,
function_call=None,
key=None,
)
]
print(
new_planning(
request,
"/tmp/sweep_archive.zip",
chat_logger=ChatLogger(
{"username": "kevinlu1248", "title": "Unit test for planning"}
),
ticket_progress=TicketProgress(tracking_id="ed47605a38"),
)

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

import re
import traceback
from typing import TypeVar
from sweepai.config.server import DEFAULT_GPT4_32K_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message, RegexMatchableBaseModel
from loguru import logger
system_prompt = """You are a brilliant and meticulous engineer assigned to review the following commit diffs and make sure the file conforms to the user's rules.
If the diffs do not conform to the rules, we should create a GitHub issue telling the user what changes should be made.
Provide your response in the following format:
<rule_analysis>
- Analysis of each file_diff and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
user_message = """Review the following diffs and make sure they conform to the rules:
{diff}
The rule is: {rule}
Provide your response in the following format:
<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
- Analysis of code diff 2 and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
Self = TypeVar("Self", bound="RegexMatchableBaseModel")
class IssueTitleAndDescription(RegexMatchableBaseModel):
changes_required: bool = False
issue_title: str
issue_description: str
@classmethod
def from_string(cls: type["IssueTitleAndDescription"], string: str, **kwargs) -> "IssueTitleAndDescription":
changes_required_pattern = (
r"""<changes_required>(\n)?(?P<changes_required>.*)</changes_required>"""
)
changes_required_match = re.search(changes_required_pattern, string, re.DOTALL)
changes_required = (
changes_required_match.groupdict()["changes_required"].strip()
if changes_required_match
else None
)
if changes_required and "true" in changes_required.lower():
changes_required = True
else:
changes_required = False
issue_title_pattern = r"""<issue_title>(\n)?(?P<issue_title>.*)</issue_title>"""
issue_title_match = re.search(issue_title_pattern, string, re.DOTALL)
issue_title = (
issue_title_match.groupdict()["issue_title"].strip()
if issue_title_match
else ""
)
issue_description_pattern = (
r"""<issue_description>(\n)?(?P<issue_description>.*)</issue_description>"""
)
issue_description_match = re.search(
issue_description_pattern, string, re.DOTALL
)
issue_description = (
issue_description_match.groupdict()["issue_description"].strip()
if issue_description_match
else ""
)
return cls(
changes_required=changes_required,
issue_title=issue_title,
issue_description=issue_description,
)
class PostMerge(ChatGPT):
def check_for_issues(self, rule, diff) -> tuple[bool, str, str]:
try:
self.messages = [
Message(
role="system",
content=system_prompt.format(rule=rule),
key="system",
)
]
if self.chat_logger and not self.chat_logger.is_paying_user():
raise ValueError("User is not a paying user")
self.model = DEFAULT_GPT4_32K_MODEL
response = self.chat(
user_message.format(
rule=rule,
diff=diff,
)
)
issue_title_and_description = IssueTitleAndDescription.from_string(response)
return (
issue_title_and_description.changes_required,
issue_title_and_description.issue_title,
issue_title_and_description.issue_description,
)
except SystemExit:
raise SystemExit
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return False, "", ""
if __name__ == "__main__":
changes_required_response = """<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
The code diff 1 does not break the rule. There are no docstrings or comments that need to be updated.
- Analysis of code diff 2 and whether it breaks the rule
The code diff 2 breaks the rule. There is a commented out code block that should be removed.
</rule_analysis>
<changes_required>
True if the rule is broken, False otherwise
True
</changes_required>
<issue_title>
Outdated Commented Code Block in plan-list.blade.php
</issue_title>
<issue_description>
There is an outdated commented out code block in the file `resources/views/livewire/plan-list.blade.php` that should be removed. The code block starts at line 104 and ends at line 110. Please remove this code block as it is no longer needed.
Please refer to the file `resources/views/livewire/plan-list.blade.php` and remove the commented out code block starting at line 104 and ending at line 110.
</issue_description>"""

sweep/sweepai/api.py

Lines 1 to 1178 in 0277fad

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_jira_ticket import handle_jira_ticket
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
@app.post("/jira")
def jira_webhook(
request_dict: dict = Body(...),
) -> None:
def call_jira_ticket(*args, **kwargs):
thread = threading.Thread(target=handle_jira_ticket, args=args, kwargs=kwargs)
thread.start()
call_jira_ticket(event=request_dict)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

# Advanced Features: becoming a Power User 🧠
## Usage 📖
### Mention important files
To ensure that Sweep scans a file, mention the file name in your ticket. Sweep searches for relevant files at runtime, but specifying the file helps avoid missing important details.
### Giving Sweep feedback
If Sweep's plan isn't accurate, you can respond to Sweep in three places:
1. **Issue**: Sweep will create a new pull request and close the old one. Alternatively, you can edit the issue description to recreate the pull request.
2. **Pull request**: Sweep will update the PR based on your PR comments
3. **Code**: Sweep will only update the file that the comment is on
Whenever you make a message that Sweep is taking a look at, you will see an 👀 emoji. If you don't see this, make sure the PR/issue is open and you prefixed the message with "sweep:".
Further, on failed Github Action runs, Sweep will update the PR based on the error message.
### Switch branch
To get Sweep to use a different base branch for one issue, add the following to the issue description.
> branch: BRANCH_NAME
## Configuration 🛠️
### Use GitHub Actions
We highly recommend linters, as well as Netlify/Vercel preview builds. Sweep auto-corrects based on linter and build errors, and Netlify and Vercel helps with iteration cycles by providing previews of static sites using Netlify.
### Set up `sweep.yaml`
You can set up `sweep.yaml` to
* Provide up to date docs by setting up `docs` (https://docs.sweep.dev/usage/config#docs)
* Set up automated formatting and linting by setting up `sandbox` (https://docs.sweep.dev/usage/config#sandbox). Never have Sweep commit a failing `npm lint` again.
* Give Sweep a high level description of where to find files in your repo by editing the `repo_description` field.
For more on configs, check out https://docs.sweep.dev/usage/config.
## Prompting 🗣️
The amount of prompting you need to give Sweep directly scales with the complexity of the problem.
For harder problems, try to provide the same information a human would need, and for simpler problems, providing a single line and a file name should suffice.
### Prompting formats
A good issue should include **where to look** (file name or entity name), **what to do** ("change the logic to do this"), and **additional context** (there's a bug/we need this feature/there's this dependency). Examples:

sweep/sweepai/cli.py

Lines 1 to 363 in 0277fad

import datetime
import json
import os
import pickle
import threading
import time
import uuid
from itertools import chain, islice
import typer
from github import Github
from github.Event import Event
from github.IssueEvent import IssueEvent
from github.Repository import Repository
from loguru import logger
from rich.console import Console
from rich.prompt import Prompt
from sweepai.api import handle_request
from sweepai.handlers.on_ticket import on_ticket
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import get_hash
from sweepai.web.events import Account, Installation, IssueRequest
app = typer.Typer(
name="sweepai", context_settings={"help_option_names": ["-h", "--help"]}
)
app_dir = typer.get_app_dir("sweepai")
config_path = os.path.join(app_dir, "config.json")
console = Console()
cprint = console.print
def posthog_capture(event_name, properties, *args, **kwargs):
POSTHOG_DISTINCT_ID = os.environ.get("POSTHOG_DISTINCT_ID")
if POSTHOG_DISTINCT_ID:
posthog.capture(POSTHOG_DISTINCT_ID, event_name, properties, *args, **kwargs)
def load_config():
if os.path.exists(config_path):
cprint(f"\nLoading configuration from {config_path}", style="yellow")
with open(config_path, "r") as f:
config = json.load(f)
os.environ["GITHUB_PAT"] = config.get("GITHUB_PAT", "")
os.environ["OPENAI_API_KEY"] = config.get("OPENAI_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = config.get("ANTHROPIC_API_KEY", "")
os.environ["VOYAGE_API_KEY"] = config.get("VOYAGE_API_KEY", "")
os.environ["POSTHOG_DISTINCT_ID"] = str(config.get("POSTHOG_DISTINCT_ID", ""))
def fetch_issue_request(issue_url: str, __version__: str = "0"):
(
protocol_name,
_,
_base_url,
org_name,
repo_name,
_issues,
issue_number,
) = issue_url.split("/")
cprint("Fetching installation ID...")
installation_id = -1
cprint("Fetching access token...")
_token, g = get_github_client(installation_id)
g: Github = g
cprint("Fetching repo...")
issue = g.get_repo(f"{org_name}/{repo_name}").get_issue(int(issue_number))
issue_request = IssueRequest(
action="labeled",
issue=IssueRequest.Issue(
title=issue.title,
number=int(issue_number),
html_url=issue_url,
user=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
body=issue.body,
labels=[
IssueRequest.Issue.Label(
name="sweep",
),
],
assignees=None,
pull_request=None,
),
repository=IssueRequest.Issue.Repository(
full_name=issue.repository.full_name,
description=issue.repository.description,
),
assignee=IssueRequest.Issue.Assignee(login=issue.user.login),
installation=Installation(
id=installation_id,
account=Account(
id=issue.user.id,
login=issue.user.login,
type="User",
),
),
sender=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
)
return issue_request
def pascal_to_snake(name):
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
def get_event_type(event: Event | IssueEvent):
if isinstance(event, IssueEvent):
return "issues"
else:
return pascal_to_snake(event.type)[: -len("_event")]
@app.command()
def test():
cprint("Sweep AI is installed correctly and ready to go!", style="yellow")
@app.command()
def watch(
repo_name: str,
debug: bool = False,
record_events: bool = False,
max_events: int = 30,
):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
posthog_capture(
"sweep_watch_started",
{
"repo": repo_name,
"debug": debug,
"record_events": record_events,
"max_events": max_events,
},
)
GITHUB_PAT = os.environ.get("GITHUB_PAT", None)
if GITHUB_PAT is None:
raise ValueError("GITHUB_PAT environment variable must be set")
g = Github(os.environ["GITHUB_PAT"])
repo = g.get_repo(repo_name)
if debug:
logger.debug("Debug mode enabled")
def stream_events(repo: Repository, timeout: int = 2, offset: int = 2 * 60):
processed_event_ids = set()
current_time = time.time() - offset
current_time = datetime.datetime.fromtimestamp(current_time)
local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
while True:
events_iterator = chain(
islice(repo.get_events(), max_events),
islice(repo.get_issues_events(), max_events),
)
for i, event in enumerate(events_iterator):
if event.id not in processed_event_ids:
local_time = event.created_at.replace(
tzinfo=datetime.timezone.utc
).astimezone(local_tz)
if local_time.timestamp() > current_time.timestamp():
yield event
else:
if debug:
logger.debug(
f"Skipping event {event.id} because it is in the past (local_time={local_time}, current_time={current_time}, i={i})"
)
if debug:
logger.debug(
f"Skipping event {event.id} because it is already handled"
)
processed_event_ids.add(event.id)
time.sleep(timeout)
def handle_event(event: Event | IssueEvent, do_async: bool = True):
if isinstance(event, IssueEvent):
payload = event.raw_data
payload["action"] = payload["event"]
else:
payload = {**event.raw_data, **event.payload}
payload["sender"] = payload.get("sender", payload["actor"])
payload["sender"]["type"] = "User"
payload["pusher"] = payload.get("pusher", payload["actor"])
payload["pusher"]["name"] = payload["pusher"]["login"]
payload["pusher"]["type"] = "User"
payload["after"] = payload.get("after", payload.get("head"))
payload["repository"] = repo.raw_data
payload["installation"] = {"id": -1}
logger.info(str(event) + " " + str(event.created_at))
if record_events:
_type = get_event_type(event) if isinstance(event, Event) else "issue"
pickle.dump(
event,
open(
"tests/events/"
+ f"{_type}_{payload.get('action')}_{str(event.id)}.pkl",
"wb",
),
)
if do_async:
thread = threading.Thread(
target=handle_request, args=(payload, get_event_type(event))
)
thread.start()
return thread
else:
return handle_request(payload, get_event_type(event))
def main():
cprint(
f"\n[bold black on white] Starting server, listening to events from {repo_name}... [/bold black on white]\n",
)
cprint(
f"To create a PR, please create an issue at https://github.com/{repo_name}/issues with a title prefixed with 'Sweep:' or label an existing issue with 'sweep'. The events will be logged here, but there may be a brief delay.\n"
)
for event in stream_events(repo):
handle_event(event)
if __name__ == "__main__":
main()
@app.command()
def init(override: bool = False):
# TODO: Fix telemetry
if not override:
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
if "OPENAI_API_KEY" in config and "ANTHROPIC_API_KEY" in config and "GITHUB_PAT" in config:
override = typer.confirm(
f"\nConfiguration already exists at {config_path}. Override?",
default=False,
abort=True,
)
cprint(
"\n[bold black on white] Initializing Sweep CLI... [/bold black on white]\n",
)
cprint(
"\nFirstly, let's store your OpenAI API Key. You can get it here: https://platform.openai.com/api-keys\n",
style="yellow",
)
openai_api_key = Prompt.ask("OpenAI API Key", password=True)
assert len(openai_api_key) > 30, "OpenAI API Key must be of length at least 30."
assert openai_api_key.startswith("sk-"), "OpenAI API Key must start with 'sk-'."
cprint(
"\nNext, let's store your Anthropic API key. You can get it here: https://console.anthropic.com/settings/keys.",
style="yellow",
)
anthropic_api_key = Prompt.ask("Anthropic API Key", password=True)
assert len(anthropic_api_key) > 30, "Anthropic API Key must be of length at least 30."
assert anthropic_api_key.startswith("sk-ant-api03-"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nGreat! Next, we'll need just your GitHub PAT. Here's a link with all the permissions pre-filled:\nhttps://github.com/settings/tokens/new?description=Sweep%20Self-hosted&scopes=repo,workflow\n",
style="yellow",
)
github_pat = Prompt.ask("GitHub PAT", password=True)
assert len(github_pat) > 30, "GitHub PAT must be of length at least 30."
assert github_pat.startswith("ghp_"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nAwesome! Lastly, let's get your Voyage AI API key from https://dash.voyageai.com/api-keys. This is optional, but improves code search by about [cyan]5%[/cyan]. You can always return to this later by re-running 'sweep init'.",
style="yellow",
)
voyage_api_key = Prompt.ask("Voyage AI API key", password=True)
if voyage_api_key:
assert len(voyage_api_key) > 30, "Voyage AI API key must be of length at least 30."
assert voyage_api_key.startswith("pa-"), "Voyage API key must start with 'pa-'."
POSTHOG_DISTINCT_ID = None
enable_telemetry = typer.confirm(
"\nEnable usage statistics? This will help us improve the product.",
default=True,
)
if enable_telemetry:
cprint(
"\nThank you for enabling telemetry. We'll collect anonymous usage statistics to improve the product. You can disable this at any time by rerunning 'sweep init'.",
style="yellow",
)
POSTHOG_DISTINCT_ID = uuid.getnode()
posthog.capture(POSTHOG_DISTINCT_ID, "sweep_init", {})
config = {
"GITHUB_PAT": github_pat,
"OPENAI_API_KEY": openai_api_key,
"ANTHROPIC_API_KEY": anthropic_api_key,
"VOYAGE_API_KEY": voyage_api_key,
}
if POSTHOG_DISTINCT_ID:
config["POSTHOG_DISTINCT_ID"] = POSTHOG_DISTINCT_ID
os.makedirs(app_dir, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f)
cprint(f"\nConfiguration saved to {config_path}\n", style="yellow")
cprint(
"Installation complete! You can now run [green]'sweep run <issue-url>'[/green][yellow] to run Sweep on an issue. or [/yellow][green]'sweep watch <org-name>/<repo-name>'[/green] to have Sweep listen for and fix newly created GitHub issues.",
style="yellow",
)
@app.command()
def run(issue_url: str):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
cprint(f"\n Running Sweep on issue: {issue_url} \n", style="bold black on white")
posthog_capture("sweep_run_started", {"issue_url": issue_url})
request = fetch_issue_request(issue_url)
try:
cprint(f'\nRunning Sweep to solve "{request.issue.title}"!\n')
on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.sender.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
edited=False,
tracking_id=get_hash(),
)
except Exception as e:
posthog_capture("sweep_run_fail", {"issue_url": issue_url, "error": str(e)})
else:
posthog_capture("sweep_run_success", {"issue_url": issue_url})
def main():
cprint(
"By using the Sweep CLI, you agree to the Sweep AI Terms of Service at https://sweep.dev/tos.pdf",
style="cyan",
)
load_config()
app()
if __name__ == "__main__":

# Frequently Asked Questions
<details id="does-sweep-write-tests">
<summary>Does Sweep write tests?</summary>
Yep! The easiest way to have Sweep write tests is by modifying the `description` parameter in your `sweep.yaml`. You can add something like:
“In [your repository], the tests are written in [your format]. If you modify business logic, modify the tests as well using this format.” You can add anything you’d like to the description parameter, including formatting rules (like PEP8), code style, etc!
</details>
<details id="can-we-trust-code-written-by-sweep">
<summary>Can we trust the code written by Sweep?</summary>
You should always review the PR. However, we also perform testing to make sure the PR works using your existing GitHub actions.
To get the best performance, add GitHub actions that lint, test, and validate your code.
</details>
<details id="work-off-another-branch">
<summary>Can I have Sweep work off of another branch besides main?</summary>
Yes! In the `sweep.yaml`, you can set the `branch` parameter to something besides your default branch, and Sweep will use that as a reference.
</details>
<details id="retry-issue-with-sweep">
<summary>How do I retry an issue with Sweep?</summary>
To retry an issue, prefix your issue reply with 'Sweep: '. This will trigger Sweep to retry the issue.
</details>
<details id="give-documentation-to-sweep">
<summary>Can I give documentation to Sweep?</summary>
Yes! In the `sweep.yaml`, you can specify docs. Be sure to pick the prefix of the site, which will allow us to only fetch the docs you need.
Check out the example here: https://github.com/sweepai/sweep/blob/main/sweep.yaml.
</details>
<details id="comment-on-sweeps-prs">
<summary>Can I comment on Sweep’s PRs?</summary>
Yep! You have three options depending on the degree of the change:
1. You can comment on the issue, and Sweep will rewrite the entire pull request. This will use one of your GPT4 credits.
2. You can comment on the pull request (not a file) and Sweep can make substantial changes to the pull request. Sweep will search the codebase, and is able to modify and create files.
3. You can comment on the file directly, and Sweep will only modify that file. Use this for small single file changes.
</details>

Once Sweep has the reference implementation, Sweep generates the corresponding test as commits in a [GitHub PR](https://github.com/sweepai/sweep/pull/2378):
```python
def get_file_contents(self, file_path, ref=None):
local_path = os.path.join(self.cache_dir, file_path)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
```
We have Sweep generated mocks for `os.path.join` and `open`. <br></br>
This code looks great!
```python
@patch("os.path.join")
@patch("open")
def test_get_file_contents(self, mock_open, mock_join):
mock_join.return_value = "/tmp/cache/repos/sweepai/sweep/main/file1"
mock_open.return_value.__enter__.return_value.read.return_value = "file content"
content = self.cloned_repo.get_file_contents("file1")
self.assertEqual(content, "file content")
```
We generated mocks for `os.path.join` and `open`, which should return the correct path and file contents. <br></br>
Ok we're done here right? Can we just write these tests and leave the rest to the developer?
## 3. **Run the tests.**
Most other AI tools stop here, but it’s not enough. <br></br>
If you just committed these tests it would be great, but you’d end up with a frustrating bug. Here it is:
```bash
File "/usr/lib/python3.10/unittest/mock.py", line 1616, in _get_target
raise TypeError(
TypeError: Need a valid target to patch. You supplied: 'open'
```
Did we really save time for the developer here? It’s frustrating that most other tools don’t fix these issues.
*Unlike every other tool, Sweep actually runs these tests.*
Sweep ran the code, found the issue, and identified the solution: <br></br>
**”Change the target of the patch in the 'test_get_file_contents' method from 'open' to 'builtins.open'. This will correctly patch the built-in 'open' function during the test.”**
Sweep added [this commit](https://github.com/sweepai/sweep/pull/2378/commits/0ded79eab77ca3e511257ff0bf3874893b038e9e):
```python

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 32 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)


Step 2: ⌨️ Coding

  • Modify sweepai/handlers/on_merge_conflict.pya5da0c3 Edit
Modify sweepai/handlers/on_merge_conflict.py with contents: In the `on_merge_conflict` function:
• Import the new `MERGE_CONFLICT_RESOLUTION_STRATEGY` constant from `sweepai/config/server.py`
• In the `try` block, after creating the new branch `new_pull_request.branch_name`, add an if statement: - If `MERGE_CONFLICT_RESOLUTION_STRATEGY` is `"rebase"`, call `git_repo.git.rebase("origin/" + pr.base.ref)` instead of `git_repo.git.merge("origin/" + pr.base.ref)` - Otherwise, keep the existing `git_repo.git.merge("origin/" + pr.base.ref)` call
Modify sweepai/config/server.py with contents:
• Add a new constant `MERGE_CONFLICT_RESOLUTION_STRATEGY` with a default value of `"merge"`
• Add a new environment variable `MERGE_CONFLICT_RESOLUTION_STRATEGY` that overrides the default value if set

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_25d21.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 8, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 8, 2024

🚀 Here's the PR! #3500

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 1b2187e717)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

"""
create_pr is a function that creates a pull request from a list of file change requests.
It is also responsible for handling Sweep config PR creation. test
"""
import datetime
from typing import Any, Generator
import openai
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import DEFAULT_RULES_STRING, SweepConfig, get_blocked_dirs
from sweepai.config.server import (
ENV,
GITHUB_BOT_USERNAME,
GITHUB_CONFIG_BRANCH,
GITHUB_DEFAULT_CONFIG,
GITHUB_LABEL_NAME,
MONGODB_URI,
)
from sweepai.core.entities import (
FileChangeRequest,
MaxTokensExceeded,
Message,
MockPR,
PullRequest,
)
from sweepai.core.sweep_bot import SweepBot
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.str_utils import UPDATES_MESSAGE
num_of_snippets_to_query = 10
max_num_of_snippets = 5
INSTRUCTIONS_FOR_REVIEW = """\
### 💡 To get Sweep to edit this pull request, you can:
* Comment below, and Sweep can edit the entire PR
* Comment on a file, Sweep will only modify the commented file
* Edit the original issue to get Sweep to recreate the PR from scratch"""
def create_pr_changes(
file_change_requests: list[FileChangeRequest],
pull_request: PullRequest,
sweep_bot: SweepBot,
username: str,
installation_id: int,
issue_number: int | None = None,
chat_logger: ChatLogger = None,
base_branch: str = None,
additional_messages: list[Message] = []
) -> Generator[tuple[FileChangeRequest, int, Any], None, dict]:
# Flow:
# 1. Get relevant files
# 2: Get human message
# 3. Get files to change
# 4. Get file changes
# 5. Create PR
chat_logger = (
chat_logger
if chat_logger is not None
else ChatLogger(
{
"username": username,
"installation_id": installation_id,
"repo_full_name": sweep_bot.repo.full_name,
"title": pull_request.title,
"summary": "",
"issue_url": "",
}
)
if MONGODB_URI
else None
)
sweep_bot.chat_logger = chat_logger
organization, repo_name = sweep_bot.repo.full_name.split("/")
metadata = {
"repo_full_name": sweep_bot.repo.full_name,
"organization": organization,
"repo_name": repo_name,
"repo_description": sweep_bot.repo.description,
"username": username,
"installation_id": installation_id,
"function": "create_pr",
"mode": ENV,
"issue_number": issue_number,
}
posthog.capture(username, "started", properties=metadata)
try:
logger.info("Making PR...")
pull_request.branch_name = sweep_bot.create_branch(
pull_request.branch_name, base_branch=base_branch
)
completed_count, fcr_count = 0, len(file_change_requests)
blocked_dirs = get_blocked_dirs(sweep_bot.repo)
for (
new_file_contents,
changed_file,
commit,
file_change_requests,
) in sweep_bot.change_files_in_github_iterator(
file_change_requests,
pull_request.branch_name,
blocked_dirs,
additional_messages=additional_messages
):
completed_count += len(new_file_contents or [])
logger.info(f"Completed {completed_count}/{fcr_count} files")
yield new_file_contents, changed_file, commit, file_change_requests
if completed_count == 0 and fcr_count != 0:
logger.info("No changes made")
posthog.capture(
username,
"failed",
properties={
"error": "No changes made",
"reason": "No changes made",
**metadata,
},
)
# If no changes were made, delete branch
commits = sweep_bot.repo.get_commits(pull_request.branch_name)
if commits.totalCount == 0:
branch = sweep_bot.repo.get_git_ref(f"heads/{pull_request.branch_name}")
branch.delete()
return
# Include issue number in PR description
if issue_number:
# If the #issue changes, then change on_ticket (f'Fixes #{issue_number}.\n' in pr.body:)
pr_description = (
f"{pull_request.content}\n\nFixes"
f" #{issue_number}.\n\n---\n\n{UPDATES_MESSAGE}\n\n---\n\n{INSTRUCTIONS_FOR_REVIEW}"
)
else:
pr_description = f"{pull_request.content}"
pr_title = pull_request.title
if "sweep.yaml" in pr_title:
pr_title = "[config] " + pr_title
except MaxTokensExceeded as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Max tokens exceeded",
**metadata,
},
)
raise e
except openai.BadRequestError as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Invalid request error / context length",
**metadata,
},
)
raise e
except Exception as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Unexpected error",
**metadata,
},
)
raise e
posthog.capture(username, "success", properties={**metadata})
logger.info("create_pr success")
result = {
"success": True,
"pull_request": MockPR(
file_count=completed_count,
title=pr_title,
body=pr_description,
pr_head=pull_request.branch_name,
base=sweep_bot.repo.get_branch(
SweepConfig.get_branch(sweep_bot.repo)
).commit,
head=sweep_bot.repo.get_branch(pull_request.branch_name).commit,
),
}
yield result # TODO: refactor this as it doesn't need to be an iterator
return
def safe_delete_sweep_branch(
pr, # Github PullRequest
repo: Repository,
) -> bool:
"""
Safely delete Sweep branch
1. Only edited by Sweep
2. Prefixed by sweep/
"""
pr_commits = pr.get_commits()
pr_commit_authors = set([commit.author.login for commit in pr_commits])
# Check if only Sweep has edited the PR, and sweep/ prefix
if (
len(pr_commit_authors) == 1
and GITHUB_BOT_USERNAME in pr_commit_authors
and pr.head.ref.startswith("sweep")
):
branch = repo.get_git_ref(f"heads/{pr.head.ref}")
# pr.edit(state='closed')
branch.delete()
return True
else:
# Failed to delete branch as it was edited by someone else
return False
def create_config_pr(
sweep_bot: SweepBot | None, repo: Repository = None, cloned_repo: ClonedRepo = None
):
if repo is not None:
# Check if file exists in repo
try:
repo.get_contents("sweep.yaml")
return
except SystemExit:
raise SystemExit
except Exception:
pass
title = "Configure Sweep"
branch_name = GITHUB_CONFIG_BRANCH
if sweep_bot is not None:
branch_name = sweep_bot.create_branch(branch_name, retry=False)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
sweep_bot.repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=sweep_bot.repo.default_branch,
additional_rules=DEFAULT_RULES_STRING,
),
branch=branch_name,
)
sweep_bot.repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
else:
# Create branch based on default branch
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=repo.default_branch, additional_rules=DEFAULT_RULES_STRING
),
branch=branch_name,
)
repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
repo = sweep_bot.repo if sweep_bot is not None else repo
# Check if the pull request from this branch to main already exists.
# If it does, then we don't need to create a new one.
if repo is not None:
pull_requests = repo.get_pulls(
state="open",
sort="created",
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
head=branch_name,
)
for pr in pull_requests:
if pr.title == title:
return pr
logger.print("Default branch", repo.default_branch)
logger.print("New branch", branch_name)
pr = repo.create_pull(
title=title,
body="""🎉 Thank you for installing Sweep! We're thrilled to announce the latest update for Sweep, your AI junior developer on GitHub. This PR creates a `sweep.yaml` config file, allowing you to personalize Sweep's performance according to your project requirements.
## What's new?
- **Sweep is now configurable**.
- To configure Sweep, simply edit the `sweep.yaml` file in the root of your repository.
- If you need help, check out the [Sweep Default Config](https://github.com/sweepai/sweep/blob/main/sweep.yaml) or [Join Our Discord](https://discord.gg/sweep) for help.
If you would like me to stop creating this PR, go to issues and say "Sweep: create an empty `sweep.yaml` file".
Thank you for using Sweep! 🧹""".replace(
" ", ""
),
head=branch_name,
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
)
pr.add_to_labels(GITHUB_LABEL_NAME)
return pr
def add_config_to_top_repos(installation_id, username, repositories, max_repos=3):
user_token, g = get_github_client(installation_id)
repo_activity = {}
for repo_entity in repositories:
repo = g.get_repo(repo_entity.full_name)
# instead of using total count, use the date of the latest commit
commits = repo.get_commits(
author=username,
since=datetime.datetime.now() - datetime.timedelta(days=30),
)
# get latest commit date
commit_date = datetime.datetime.now() - datetime.timedelta(days=30)
for commit in commits:
if commit.commit.author.date > commit_date:
commit_date = commit.commit.author.date
# since_date = datetime.datetime.now() - datetime.timedelta(days=30)
# commits = repo.get_commits(since=since_date, author="lukejagg")
repo_activity[repo] = commit_date
# print(repo, commits.totalCount)
logger.print(repo, commit_date)
sorted_repos = sorted(repo_activity, key=repo_activity.get, reverse=True)
sorted_repos = sorted_repos[:max_repos]
# For each repo, create a branch based on main branch, then create PR to main branch
for repo in sorted_repos:
try:
logger.print("Creating config for", repo.full_name)
create_config_pr(
None,
repo=repo,
cloned_repo=ClonedRepo(
repo_full_name=repo.full_name,
installation_id=installation_id,
token=user_token,
),
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.print(e)
logger.print("Finished creating configs for top repos")
def create_gha_pr(g, repo):
# Create a new branch
branch_name = "sweep/gha-enable"
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
# Update the sweep.yaml file in this branch to add "gha_enabled: True"
sweep_yaml_content = (
repo.get_contents("sweep.yaml", ref=branch_name).decoded_content.decode()
+ "\ngha_enabled: True"
)
repo.update_file(
"sweep.yaml",
"Enable GitHub Actions",
sweep_yaml_content,
repo.get_contents("sweep.yaml", ref=branch_name).sha,
branch=branch_name,
)
# Create a PR from this branch to the main branch
pr = repo.create_pull(
title="Enable GitHub Actions",
body="This PR enables GitHub Actions for this repository.",
head=branch_name,
base=repo.default_branch,
)
return pr
SWEEP_TEMPLATE = """\
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
- type: input
id: branch
attributes:
label: Branch
description: The branch to work off of (optional)
placeholder: |

import copy
import re
import traceback
from pathlib import Path
from loguru import logger
from sweepai.agents.assistant_wrapper import (
client,
openai_assistant_call,
run_until_complete,
)
from sweepai.core.entities import AssistantRaisedException, FileChangeRequest, Message
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.progress import AssistantConversation, TicketProgress
system_message = r""" You are searching through a codebase to guide a junior developer on how to solve the user request. The junior developer will follow your instructions exactly and make the changes.
# User Request
{user_request}
# Guide
## Step 1: Unzip the file into /mnt/data/repo. Then list all root level directories. You must copy the below code verbatim into the file.
```python
import zipfile
import os
zip_path = '{file_path}'
extract_to_path = 'mnt/data/repo'
os.makedirs(extract_to_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to_path)
zip_contents = zip_ref.namelist()
root_dirs = {{name.split('/')[0] for name in zip_contents}}
print(f'Root directories: {{root_dirs}}')
```
## Step 2: Find the relevant files.
You can search by file name or by keyword search in the contents.
## Step 3: Find relevant lines.
1. Locate the lines of code that contain the identified keywords or are at the specified line number. You can use keyword search or manually look through the file 100 lines at a time.
2. Check the surrounding lines to establish the full context of the code block.
3. Adjust the starting line to include the entire functionality that needs to be refactored or moved.
4. Finally determine the exact line spans that include a logical and complete section of code to be edited.
```python
def print_lines_with_keyword(content, keywords):
max_matches=5
context = 10
matches = [i for i, line in enumerate(content.splitlines()) if any(keyword in line.lower() for keyword in keywords)]
print(f"Found {{len(matches)}} matches, but capping at {{max_match}}")
matches = matches[:max_matches]
expanded_matches = set()
for match in matches:
start = max(0, match - context)
end = min(len(content.splitlines()), match + context + 1)
for i in range(start, end):
expanded_matches.add(i)
for i in sorted(expanded_matches):
print(f"{{i}}: {{content.splitlines()[i]}}")
```
## Step 4: Construct a plan.
Provide the final plan to solve the issue, following these rules:
* DO NOT apply any changes here, they will not be persisted. You must provide the plan and the developer will apply the changes.
* You may only create new files and modify existing files.
* File paths should be relative paths from the root of the repo.
* Use the minimum number of create and modify operations required to solve the issue.
* Start and end lines indicate the exact start and end lines to edit. Expand this to encompass more lines if you're unsure where to make the exact edit.
Respond in the following format:
```xml
<plan>
<create_file file="file_path_1">
* Natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create_file>
...
<modify_file file="file_path_2" start_line="i" end_line="j">
* Natural language instructions for the modifications needed to solve the issue.
* Be concise and reference necessary files, imports and entity names.
...
</modify_file>
...
</plan>
```"""
@file_cache(ignore_params=["zip_path", "chat_logger", "ticket_progress"])
def new_planning(
request: str,
zip_path: str,
additional_messages: list[Message] = [],
chat_logger: ChatLogger | None = None,
assistant_id: str = None,
ticket_progress: TicketProgress | None = None,
) -> list[FileChangeRequest]:
planning_iterations = 3
try:
def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
assistant_conversation = AssistantConversation.from_ids(
assistant_id=assistant_id, run_id=run_id, thread_id=thread_id
)
if not assistant_conversation:
return
ticket_progress.planning_progress.assistant_conversation = (
assistant_conversation
)
ticket_progress.save()
logger.info("Uploading file...")
zip_file_object = client.files.create(file=Path(zip_path), purpose="assistants")
logger.info("Done uploading file.")
zip_file_id = zip_file_object.id
response = openai_assistant_call(
request=request,
assistant_id=assistant_id,
additional_messages=additional_messages,
uploaded_file_ids=[zip_file_id],
chat_logger=chat_logger,
save_ticket_progress=save_ticket_progress
if ticket_progress is not None
else None,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = response.run_id
thread_id = response.thread_id
for _ in range(planning_iterations):
save_ticket_progress(
assistant_id=response.assistant_id,
thread_id=response.thread_id,
run_id=response.run_id,
)
messages = response.messages
final_message = messages.data[0].content[0].text.value
fcrs = []
fcr_matches = list(
re.finditer(FileChangeRequest._regex, final_message, re.DOTALL)
)
if len(fcr_matches) > 0:
break
else:
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content="A valid plan (within the <plan> tags) was not provided. Please continue working on the plan. If you are stuck, consider starting over.",
)
run = client.beta.threads.runs.create(
thread_id=response.thread_id,
assistant_id=response.assistant_id,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = run.id
messages = run_until_complete(
thread_id=thread_id,
run_id=run_id,
assistant_id=response.assistant_id,
)
for match_ in fcr_matches:
group_dict = match_.groupdict()
if group_dict["change_type"] == "create_file":
group_dict["change_type"] = "create"
if group_dict["change_type"] == "modify_file":
group_dict["change_type"] = "modify"
fcr = FileChangeRequest(**group_dict)
fcr.filename = fcr.filename.lstrip("/")
fcr.instructions = fcr.instructions.replace("\n*", "\n•")
fcr.instructions = fcr.instructions.strip("\n")
if fcr.instructions.startswith("*"):
fcr.instructions = "•" + fcr.instructions[1:]
fcrs.append(fcr)
new_file_change_request = copy.deepcopy(fcr)
new_file_change_request.change_type = "check"
new_file_change_request.parent = fcr
fcrs.append(new_file_change_request)
assert len(fcrs) > 0
return fcrs
except AssistantRaisedException as e:
raise e
except Exception as e:
logger.exception(e)
if chat_logger is not None:
discord_log_error(
str(e)
+ "\n\n"
+ traceback.format_exc()
+ "\n\n"
+ str(chat_logger.data)
)
return None
if __name__ == "__main__":
request = """## Title: replace the broken tutorial link in installation.md with https://docs.sweep.dev/usage/tutorial\n"""
additional_messages = [
Message(
role="user",
content='<relevant_snippets_in_repo>\n<snippet source="docs/pages/usage/tutorial.mdx:45-60">\n...\n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:30-45">\n...\n30: \n31: ![PR Comment](/tutorial/comment.png)\n32: \n33: c. If you have GitHub Actions set up, it will automatically run the linters, build, and tests and will show any failed logs to Sweep to handle. This only works with GitHub Actions and not other CI providers, so unfortunately for Vercel we have to copy paste manually.\n34: \n35: ![GitHub Actions](/tutorial/github_actions.png)\n36: \n37: 6. Once you are happy with the PR, you can merge it and it will be deployed to production via Vercel.\n38: \n39: \n40: ![Final](/tutorial/final.png)\n41: \n42: \n43: You can see the final example at https://github.com/kevinlu1248/docusaurus-2/pull/4 with preview https://docusaurus-2-ql4cskc5o-sweepai.vercel.app/.\n44: \n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n...\n</snippet>\n<snippet source="docs/installation.md:45-60">\n...\n45: * Provide any additional context that might be helpful, e.g. see "src/App.test.tsx" for an example of a good unit test.\n46: * For more guidance, visit [Advanced](https://docs.sweep.dev/usage/advanced), or watch the following video.\n47: \n48: [![Video](http://img.youtube.com/vi/Qn9vB71R4UM/0.jpg)](http://www.youtube.com/watch?v=Qn9vB71R4UM "Advanced Sweep Tricks and Feedback Tips")\n49: \n50: For configuring Sweep for your repo, see [Config](https://docs.sweep.dev/usage/config), especially for setting up Sweep Rules and Sweep Sweep.\n51: \n52: ## Limitations of Sweep (for now) ⚠️\n53: \n54: * 🗃️ **Gigantic repos**: >5000 files. We have default extensions and directories to exclude but sometimes this doesn\'t catch them all. You may need to block some directories (see [`blocked_dirs`](https://docs.sweep.dev/usage/config#blocked_dirs))\n55: * If Sweep is stuck at 0% for over 30 min and your repo has a few thousand files, let us know.\n56: \n57: * 🏗️ **Large-scale refactors**: >5 files or >300 lines of code changes (we\'re working on this!)\n58: * We can\'t do this - "Refactor entire codebase from Tensorflow to PyTorch"\n59: \n60: * 🖼️ **Editing images** and other non-text assets\n...\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:0-15">\n0: # Tutorial for Getting Started with Sweep\n1: \n2: We recommend using an existing **real project** for Sweep, but if you must start from scratch, we recommend **using a template**. In particular, we recommend Vercel templates and Vercel auto-deploy, since Vercel\'s auto-generated previews make it **easy to review Sweep\'s PRs**\n3: \n4: We\'ll use [Docusaurus](https://vercel.com/templates/next.js/docusaurus-2) since it\'s is the easiest to set up (no backend). To see other templates see https://vercel.com/templates.\n5: \n6: 1. Go to https://vercel.com/templates/next.js/docusaurus-2 (or another template) and click "Deploy".\n7: \n8: ![Deploy](/tutorial/deployment.png)\n9: \n10: 2. Vercel will prompt you to select a GitHub account and click "Clone" after. This will trigger a build and deploy which will take a few minutes. Once the build is done, you will be greeted with a congratulations message.\n11: \n12: ![Congratulations](/tutorial/congratulations.png)\n13: \n14: 3. Go to the [Sweep Installation](https://github.com/apps/sweep-ai) page and click the grey "Configure" button or the green "Install" button. Ensure that that the Vercel template (i.e. Docusaurus) is configured to use Sweep.\n...\n</snippet>\n</relevant_snippets_in_repo>\ndocs/\n installation.md\n docs/pages/\n docs/pages/usage/\n _meta.json\n advanced.mdx\n config.mdx\n extra-self-host.mdx\n sandbox.mdx\n tutorial.mdx',
name=None,
function_call=None,
key=None,
)
]
print(
new_planning(
request,
"/tmp/sweep_archive.zip",
chat_logger=ChatLogger(
{"username": "kevinlu1248", "title": "Unit test for planning"}
),
ticket_progress=TicketProgress(tracking_id="ed47605a38"),
)

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

sweep/sweepai/api.py

Lines 1 to 1178 in 0277fad

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_jira_ticket import handle_jira_ticket
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
@app.post("/jira")
def jira_webhook(
request_dict: dict = Body(...),
) -> None:
def call_jira_ticket(*args, **kwargs):
thread = threading.Thread(target=handle_jira_ticket, args=args, kwargs=kwargs)
thread.start()
call_jira_ticket(event=request_dict)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

# Advanced Features: becoming a Power User 🧠
## Usage 📖
### Mention important files
To ensure that Sweep scans a file, mention the file name in your ticket. Sweep searches for relevant files at runtime, but specifying the file helps avoid missing important details.
### Giving Sweep feedback
If Sweep's plan isn't accurate, you can respond to Sweep in three places:
1. **Issue**: Sweep will create a new pull request and close the old one. Alternatively, you can edit the issue description to recreate the pull request.
2. **Pull request**: Sweep will update the PR based on your PR comments
3. **Code**: Sweep will only update the file that the comment is on
Whenever you make a message that Sweep is taking a look at, you will see an 👀 emoji. If you don't see this, make sure the PR/issue is open and you prefixed the message with "sweep:".
Further, on failed Github Action runs, Sweep will update the PR based on the error message.
### Switch branch
To get Sweep to use a different base branch for one issue, add the following to the issue description.
> branch: BRANCH_NAME
## Configuration 🛠️
### Use GitHub Actions
We highly recommend linters, as well as Netlify/Vercel preview builds. Sweep auto-corrects based on linter and build errors, and Netlify and Vercel helps with iteration cycles by providing previews of static sites using Netlify.
### Set up `sweep.yaml`
You can set up `sweep.yaml` to
* Provide up to date docs by setting up `docs` (https://docs.sweep.dev/usage/config#docs)
* Set up automated formatting and linting by setting up `sandbox` (https://docs.sweep.dev/usage/config#sandbox). Never have Sweep commit a failing `npm lint` again.
* Give Sweep a high level description of where to find files in your repo by editing the `repo_description` field.
For more on configs, check out https://docs.sweep.dev/usage/config.
## Prompting 🗣️
The amount of prompting you need to give Sweep directly scales with the complexity of the problem.
For harder problems, try to provide the same information a human would need, and for simpler problems, providing a single line and a file name should suffice.
### Prompting formats
A good issue should include **where to look** (file name or entity name), **what to do** ("change the logic to do this"), and **additional context** (there's a bug/we need this feature/there's this dependency). Examples:

sweep/sweepai/cli.py

Lines 1 to 363 in 0277fad

import datetime
import json
import os
import pickle
import threading
import time
import uuid
from itertools import chain, islice
import typer
from github import Github
from github.Event import Event
from github.IssueEvent import IssueEvent
from github.Repository import Repository
from loguru import logger
from rich.console import Console
from rich.prompt import Prompt
from sweepai.api import handle_request
from sweepai.handlers.on_ticket import on_ticket
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import get_hash
from sweepai.web.events import Account, Installation, IssueRequest
app = typer.Typer(
name="sweepai", context_settings={"help_option_names": ["-h", "--help"]}
)
app_dir = typer.get_app_dir("sweepai")
config_path = os.path.join(app_dir, "config.json")
console = Console()
cprint = console.print
def posthog_capture(event_name, properties, *args, **kwargs):
POSTHOG_DISTINCT_ID = os.environ.get("POSTHOG_DISTINCT_ID")
if POSTHOG_DISTINCT_ID:
posthog.capture(POSTHOG_DISTINCT_ID, event_name, properties, *args, **kwargs)
def load_config():
if os.path.exists(config_path):
cprint(f"\nLoading configuration from {config_path}", style="yellow")
with open(config_path, "r") as f:
config = json.load(f)
os.environ["GITHUB_PAT"] = config.get("GITHUB_PAT", "")
os.environ["OPENAI_API_KEY"] = config.get("OPENAI_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = config.get("ANTHROPIC_API_KEY", "")
os.environ["VOYAGE_API_KEY"] = config.get("VOYAGE_API_KEY", "")
os.environ["POSTHOG_DISTINCT_ID"] = str(config.get("POSTHOG_DISTINCT_ID", ""))
def fetch_issue_request(issue_url: str, __version__: str = "0"):
(
protocol_name,
_,
_base_url,
org_name,
repo_name,
_issues,
issue_number,
) = issue_url.split("/")
cprint("Fetching installation ID...")
installation_id = -1
cprint("Fetching access token...")
_token, g = get_github_client(installation_id)
g: Github = g
cprint("Fetching repo...")
issue = g.get_repo(f"{org_name}/{repo_name}").get_issue(int(issue_number))
issue_request = IssueRequest(
action="labeled",
issue=IssueRequest.Issue(
title=issue.title,
number=int(issue_number),
html_url=issue_url,
user=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
body=issue.body,
labels=[
IssueRequest.Issue.Label(
name="sweep",
),
],
assignees=None,
pull_request=None,
),
repository=IssueRequest.Issue.Repository(
full_name=issue.repository.full_name,
description=issue.repository.description,
),
assignee=IssueRequest.Issue.Assignee(login=issue.user.login),
installation=Installation(
id=installation_id,
account=Account(
id=issue.user.id,
login=issue.user.login,
type="User",
),
),
sender=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
)
return issue_request
def pascal_to_snake(name):
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
def get_event_type(event: Event | IssueEvent):
if isinstance(event, IssueEvent):
return "issues"
else:
return pascal_to_snake(event.type)[: -len("_event")]
@app.command()
def test():
cprint("Sweep AI is installed correctly and ready to go!", style="yellow")
@app.command()
def watch(
repo_name: str,
debug: bool = False,
record_events: bool = False,
max_events: int = 30,
):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
posthog_capture(
"sweep_watch_started",
{
"repo": repo_name,
"debug": debug,
"record_events": record_events,
"max_events": max_events,
},
)
GITHUB_PAT = os.environ.get("GITHUB_PAT", None)
if GITHUB_PAT is None:
raise ValueError("GITHUB_PAT environment variable must be set")
g = Github(os.environ["GITHUB_PAT"])
repo = g.get_repo(repo_name)
if debug:
logger.debug("Debug mode enabled")
def stream_events(repo: Repository, timeout: int = 2, offset: int = 2 * 60):
processed_event_ids = set()
current_time = time.time() - offset
current_time = datetime.datetime.fromtimestamp(current_time)
local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
while True:
events_iterator = chain(
islice(repo.get_events(), max_events),
islice(repo.get_issues_events(), max_events),
)
for i, event in enumerate(events_iterator):
if event.id not in processed_event_ids:
local_time = event.created_at.replace(
tzinfo=datetime.timezone.utc
).astimezone(local_tz)
if local_time.timestamp() > current_time.timestamp():
yield event
else:
if debug:
logger.debug(
f"Skipping event {event.id} because it is in the past (local_time={local_time}, current_time={current_time}, i={i})"
)
if debug:
logger.debug(
f"Skipping event {event.id} because it is already handled"
)
processed_event_ids.add(event.id)
time.sleep(timeout)
def handle_event(event: Event | IssueEvent, do_async: bool = True):
if isinstance(event, IssueEvent):
payload = event.raw_data
payload["action"] = payload["event"]
else:
payload = {**event.raw_data, **event.payload}
payload["sender"] = payload.get("sender", payload["actor"])
payload["sender"]["type"] = "User"
payload["pusher"] = payload.get("pusher", payload["actor"])
payload["pusher"]["name"] = payload["pusher"]["login"]
payload["pusher"]["type"] = "User"
payload["after"] = payload.get("after", payload.get("head"))
payload["repository"] = repo.raw_data
payload["installation"] = {"id": -1}
logger.info(str(event) + " " + str(event.created_at))
if record_events:
_type = get_event_type(event) if isinstance(event, Event) else "issue"
pickle.dump(
event,
open(
"tests/events/"
+ f"{_type}_{payload.get('action')}_{str(event.id)}.pkl",
"wb",
),
)
if do_async:
thread = threading.Thread(
target=handle_request, args=(payload, get_event_type(event))
)
thread.start()
return thread
else:
return handle_request(payload, get_event_type(event))
def main():
cprint(
f"\n[bold black on white] Starting server, listening to events from {repo_name}... [/bold black on white]\n",
)
cprint(
f"To create a PR, please create an issue at https://github.com/{repo_name}/issues with a title prefixed with 'Sweep:' or label an existing issue with 'sweep'. The events will be logged here, but there may be a brief delay.\n"
)
for event in stream_events(repo):
handle_event(event)
if __name__ == "__main__":
main()
@app.command()
def init(override: bool = False):
# TODO: Fix telemetry
if not override:
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
if "OPENAI_API_KEY" in config and "ANTHROPIC_API_KEY" in config and "GITHUB_PAT" in config:
override = typer.confirm(
f"\nConfiguration already exists at {config_path}. Override?",
default=False,
abort=True,
)
cprint(
"\n[bold black on white] Initializing Sweep CLI... [/bold black on white]\n",
)
cprint(
"\nFirstly, let's store your OpenAI API Key. You can get it here: https://platform.openai.com/api-keys\n",
style="yellow",
)
openai_api_key = Prompt.ask("OpenAI API Key", password=True)
assert len(openai_api_key) > 30, "OpenAI API Key must be of length at least 30."
assert openai_api_key.startswith("sk-"), "OpenAI API Key must start with 'sk-'."
cprint(
"\nNext, let's store your Anthropic API key. You can get it here: https://console.anthropic.com/settings/keys.",
style="yellow",
)
anthropic_api_key = Prompt.ask("Anthropic API Key", password=True)
assert len(anthropic_api_key) > 30, "Anthropic API Key must be of length at least 30."
assert anthropic_api_key.startswith("sk-ant-api03-"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nGreat! Next, we'll need just your GitHub PAT. Here's a link with all the permissions pre-filled:\nhttps://github.com/settings/tokens/new?description=Sweep%20Self-hosted&scopes=repo,workflow\n",
style="yellow",
)
github_pat = Prompt.ask("GitHub PAT", password=True)
assert len(github_pat) > 30, "GitHub PAT must be of length at least 30."
assert github_pat.startswith("ghp_"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nAwesome! Lastly, let's get your Voyage AI API key from https://dash.voyageai.com/api-keys. This is optional, but improves code search by about [cyan]5%[/cyan]. You can always return to this later by re-running 'sweep init'.",
style="yellow",
)
voyage_api_key = Prompt.ask("Voyage AI API key", password=True)
if voyage_api_key:
assert len(voyage_api_key) > 30, "Voyage AI API key must be of length at least 30."
assert voyage_api_key.startswith("pa-"), "Voyage API key must start with 'pa-'."
POSTHOG_DISTINCT_ID = None
enable_telemetry = typer.confirm(
"\nEnable usage statistics? This will help us improve the product.",
default=True,
)
if enable_telemetry:
cprint(
"\nThank you for enabling telemetry. We'll collect anonymous usage statistics to improve the product. You can disable this at any time by rerunning 'sweep init'.",
style="yellow",
)
POSTHOG_DISTINCT_ID = uuid.getnode()
posthog.capture(POSTHOG_DISTINCT_ID, "sweep_init", {})
config = {
"GITHUB_PAT": github_pat,
"OPENAI_API_KEY": openai_api_key,
"ANTHROPIC_API_KEY": anthropic_api_key,
"VOYAGE_API_KEY": voyage_api_key,
}
if POSTHOG_DISTINCT_ID:
config["POSTHOG_DISTINCT_ID"] = POSTHOG_DISTINCT_ID
os.makedirs(app_dir, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f)
cprint(f"\nConfiguration saved to {config_path}\n", style="yellow")
cprint(
"Installation complete! You can now run [green]'sweep run <issue-url>'[/green][yellow] to run Sweep on an issue. or [/yellow][green]'sweep watch <org-name>/<repo-name>'[/green] to have Sweep listen for and fix newly created GitHub issues.",
style="yellow",
)
@app.command()
def run(issue_url: str):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
cprint(f"\n Running Sweep on issue: {issue_url} \n", style="bold black on white")
posthog_capture("sweep_run_started", {"issue_url": issue_url})
request = fetch_issue_request(issue_url)
try:
cprint(f'\nRunning Sweep to solve "{request.issue.title}"!\n')
on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.sender.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
edited=False,
tracking_id=get_hash(),
)
except Exception as e:
posthog_capture("sweep_run_fail", {"issue_url": issue_url, "error": str(e)})
else:
posthog_capture("sweep_run_success", {"issue_url": issue_url})
def main():
cprint(
"By using the Sweep CLI, you agree to the Sweep AI Terms of Service at https://sweep.dev/tos.pdf",
style="cyan",
)
load_config()
app()
if __name__ == "__main__":

# Frequently Asked Questions
<details id="does-sweep-write-tests">
<summary>Does Sweep write tests?</summary>
Yep! The easiest way to have Sweep write tests is by modifying the `description` parameter in your `sweep.yaml`. You can add something like:
“In [your repository], the tests are written in [your format]. If you modify business logic, modify the tests as well using this format.” You can add anything you’d like to the description parameter, including formatting rules (like PEP8), code style, etc!
</details>
<details id="can-we-trust-code-written-by-sweep">
<summary>Can we trust the code written by Sweep?</summary>
You should always review the PR. However, we also perform testing to make sure the PR works using your existing GitHub actions.
To get the best performance, add GitHub actions that lint, test, and validate your code.
</details>
<details id="work-off-another-branch">
<summary>Can I have Sweep work off of another branch besides main?</summary>
Yes! In the `sweep.yaml`, you can set the `branch` parameter to something besides your default branch, and Sweep will use that as a reference.
</details>
<details id="retry-issue-with-sweep">
<summary>How do I retry an issue with Sweep?</summary>
To retry an issue, prefix your issue reply with 'Sweep: '. This will trigger Sweep to retry the issue.
</details>
<details id="give-documentation-to-sweep">
<summary>Can I give documentation to Sweep?</summary>
Yes! In the `sweep.yaml`, you can specify docs. Be sure to pick the prefix of the site, which will allow us to only fetch the docs you need.
Check out the example here: https://github.com/sweepai/sweep/blob/main/sweep.yaml.
</details>
<details id="comment-on-sweeps-prs">
<summary>Can I comment on Sweep’s PRs?</summary>
Yep! You have three options depending on the degree of the change:
1. You can comment on the issue, and Sweep will rewrite the entire pull request. This will use one of your GPT4 credits.
2. You can comment on the pull request (not a file) and Sweep can make substantial changes to the pull request. Sweep will search the codebase, and is able to modify and create files.
3. You can comment on the file directly, and Sweep will only modify that file. Use this for small single file changes.
</details>

Once Sweep has the reference implementation, Sweep generates the corresponding test as commits in a [GitHub PR](https://github.com/sweepai/sweep/pull/2378):
```python
def get_file_contents(self, file_path, ref=None):
local_path = os.path.join(self.cache_dir, file_path)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
```
We have Sweep generated mocks for `os.path.join` and `open`. <br></br>
This code looks great!
```python
@patch("os.path.join")
@patch("open")
def test_get_file_contents(self, mock_open, mock_join):
mock_join.return_value = "/tmp/cache/repos/sweepai/sweep/main/file1"
mock_open.return_value.__enter__.return_value.read.return_value = "file content"
content = self.cloned_repo.get_file_contents("file1")
self.assertEqual(content, "file content")
```
We generated mocks for `os.path.join` and `open`, which should return the correct path and file contents. <br></br>
Ok we're done here right? Can we just write these tests and leave the rest to the developer?
## 3. **Run the tests.**
Most other AI tools stop here, but it’s not enough. <br></br>
If you just committed these tests it would be great, but you’d end up with a frustrating bug. Here it is:
```bash
File "/usr/lib/python3.10/unittest/mock.py", line 1616, in _get_target
raise TypeError(
TypeError: Need a valid target to patch. You supplied: 'open'
```
Did we really save time for the developer here? It’s frustrating that most other tools don’t fix these issues.
*Unlike every other tool, Sweep actually runs these tests.*
Sweep ran the code, found the issue, and identified the solution: <br></br>
**”Change the target of the patch in the 'test_get_file_contents' method from 'open' to 'builtins.open'. This will correctly patch the built-in 'open' function during the test.”**
Sweep added [this commit](https://github.com/sweepai/sweep/pull/2378/commits/0ded79eab77ca3e511257ff0bf3874893b038e9e):
```python

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 32 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)


Step 2: ⌨️ Coding

  • Modify sweepai/handlers/on_merge_conflict.py32d958c Edit
Modify sweepai/handlers/on_merge_conflict.py with contents: In the `on_merge_conflict` function:

Replace this code block:

try:
    git_repo.config_writer().set_value("user", "name", "sweep-nightly[bot]").release()
    git_repo.config_writer().set_value("user", "email", "[email protected]").release()  
    git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
    # Assume there are merge conflicts
    pass

with:

try:
    git_repo.config_writer().set_value("user", "name", "sweep-nightly[bot]").release()
    git_repo.config_writer().set_value("user", "email", "[email protected]").release()
    git_repo.git.fetch()
    git_repo.git.rebase("origin/" + pr.base.ref)
except GitCommandError:
    # Assume there are conflicts during rebase
    pass

This will perform a rebase from the target branch instead of a merge.

Import the GitCommandError exception from the git module at the top of the file:

from git import GitCommandError

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_27da7.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 8, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 8, 2024

🚀 Here's the PR! #3501

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: bbba2e9b0c)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

"""
create_pr is a function that creates a pull request from a list of file change requests.
It is also responsible for handling Sweep config PR creation. test
"""
import datetime
from typing import Any, Generator
import openai
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import DEFAULT_RULES_STRING, SweepConfig, get_blocked_dirs
from sweepai.config.server import (
ENV,
GITHUB_BOT_USERNAME,
GITHUB_CONFIG_BRANCH,
GITHUB_DEFAULT_CONFIG,
GITHUB_LABEL_NAME,
MONGODB_URI,
)
from sweepai.core.entities import (
FileChangeRequest,
MaxTokensExceeded,
Message,
MockPR,
PullRequest,
)
from sweepai.core.sweep_bot import SweepBot
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.str_utils import UPDATES_MESSAGE
num_of_snippets_to_query = 10
max_num_of_snippets = 5
INSTRUCTIONS_FOR_REVIEW = """\
### 💡 To get Sweep to edit this pull request, you can:
* Comment below, and Sweep can edit the entire PR
* Comment on a file, Sweep will only modify the commented file
* Edit the original issue to get Sweep to recreate the PR from scratch"""
def create_pr_changes(
file_change_requests: list[FileChangeRequest],
pull_request: PullRequest,
sweep_bot: SweepBot,
username: str,
installation_id: int,
issue_number: int | None = None,
chat_logger: ChatLogger = None,
base_branch: str = None,
additional_messages: list[Message] = []
) -> Generator[tuple[FileChangeRequest, int, Any], None, dict]:
# Flow:
# 1. Get relevant files
# 2: Get human message
# 3. Get files to change
# 4. Get file changes
# 5. Create PR
chat_logger = (
chat_logger
if chat_logger is not None
else ChatLogger(
{
"username": username,
"installation_id": installation_id,
"repo_full_name": sweep_bot.repo.full_name,
"title": pull_request.title,
"summary": "",
"issue_url": "",
}
)
if MONGODB_URI
else None
)
sweep_bot.chat_logger = chat_logger
organization, repo_name = sweep_bot.repo.full_name.split("/")
metadata = {
"repo_full_name": sweep_bot.repo.full_name,
"organization": organization,
"repo_name": repo_name,
"repo_description": sweep_bot.repo.description,
"username": username,
"installation_id": installation_id,
"function": "create_pr",
"mode": ENV,
"issue_number": issue_number,
}
posthog.capture(username, "started", properties=metadata)
try:
logger.info("Making PR...")
pull_request.branch_name = sweep_bot.create_branch(
pull_request.branch_name, base_branch=base_branch
)
completed_count, fcr_count = 0, len(file_change_requests)
blocked_dirs = get_blocked_dirs(sweep_bot.repo)
for (
new_file_contents,
changed_file,
commit,
file_change_requests,
) in sweep_bot.change_files_in_github_iterator(
file_change_requests,
pull_request.branch_name,
blocked_dirs,
additional_messages=additional_messages
):
completed_count += len(new_file_contents or [])
logger.info(f"Completed {completed_count}/{fcr_count} files")
yield new_file_contents, changed_file, commit, file_change_requests
if completed_count == 0 and fcr_count != 0:
logger.info("No changes made")
posthog.capture(
username,
"failed",
properties={
"error": "No changes made",
"reason": "No changes made",
**metadata,
},
)
# If no changes were made, delete branch
commits = sweep_bot.repo.get_commits(pull_request.branch_name)
if commits.totalCount == 0:
branch = sweep_bot.repo.get_git_ref(f"heads/{pull_request.branch_name}")
branch.delete()
return
# Include issue number in PR description
if issue_number:
# If the #issue changes, then change on_ticket (f'Fixes #{issue_number}.\n' in pr.body:)
pr_description = (
f"{pull_request.content}\n\nFixes"
f" #{issue_number}.\n\n---\n\n{UPDATES_MESSAGE}\n\n---\n\n{INSTRUCTIONS_FOR_REVIEW}"
)
else:
pr_description = f"{pull_request.content}"
pr_title = pull_request.title
if "sweep.yaml" in pr_title:
pr_title = "[config] " + pr_title
except MaxTokensExceeded as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Max tokens exceeded",
**metadata,
},
)
raise e
except openai.BadRequestError as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Invalid request error / context length",
**metadata,
},
)
raise e
except Exception as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Unexpected error",
**metadata,
},
)
raise e
posthog.capture(username, "success", properties={**metadata})
logger.info("create_pr success")
result = {
"success": True,
"pull_request": MockPR(
file_count=completed_count,
title=pr_title,
body=pr_description,
pr_head=pull_request.branch_name,
base=sweep_bot.repo.get_branch(
SweepConfig.get_branch(sweep_bot.repo)
).commit,
head=sweep_bot.repo.get_branch(pull_request.branch_name).commit,
),
}
yield result # TODO: refactor this as it doesn't need to be an iterator
return
def safe_delete_sweep_branch(
pr, # Github PullRequest
repo: Repository,
) -> bool:
"""
Safely delete Sweep branch
1. Only edited by Sweep
2. Prefixed by sweep/
"""
pr_commits = pr.get_commits()
pr_commit_authors = set([commit.author.login for commit in pr_commits])
# Check if only Sweep has edited the PR, and sweep/ prefix
if (
len(pr_commit_authors) == 1
and GITHUB_BOT_USERNAME in pr_commit_authors
and pr.head.ref.startswith("sweep")
):
branch = repo.get_git_ref(f"heads/{pr.head.ref}")
# pr.edit(state='closed')
branch.delete()
return True
else:
# Failed to delete branch as it was edited by someone else
return False
def create_config_pr(
sweep_bot: SweepBot | None, repo: Repository = None, cloned_repo: ClonedRepo = None
):
if repo is not None:
# Check if file exists in repo
try:
repo.get_contents("sweep.yaml")
return
except SystemExit:
raise SystemExit
except Exception:
pass
title = "Configure Sweep"
branch_name = GITHUB_CONFIG_BRANCH
if sweep_bot is not None:
branch_name = sweep_bot.create_branch(branch_name, retry=False)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
sweep_bot.repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=sweep_bot.repo.default_branch,
additional_rules=DEFAULT_RULES_STRING,
),
branch=branch_name,
)
sweep_bot.repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
else:
# Create branch based on default branch
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=repo.default_branch, additional_rules=DEFAULT_RULES_STRING
),
branch=branch_name,
)
repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
repo = sweep_bot.repo if sweep_bot is not None else repo
# Check if the pull request from this branch to main already exists.
# If it does, then we don't need to create a new one.
if repo is not None:
pull_requests = repo.get_pulls(
state="open",
sort="created",
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
head=branch_name,
)
for pr in pull_requests:
if pr.title == title:
return pr
logger.print("Default branch", repo.default_branch)
logger.print("New branch", branch_name)
pr = repo.create_pull(
title=title,
body="""🎉 Thank you for installing Sweep! We're thrilled to announce the latest update for Sweep, your AI junior developer on GitHub. This PR creates a `sweep.yaml` config file, allowing you to personalize Sweep's performance according to your project requirements.
## What's new?
- **Sweep is now configurable**.
- To configure Sweep, simply edit the `sweep.yaml` file in the root of your repository.
- If you need help, check out the [Sweep Default Config](https://github.com/sweepai/sweep/blob/main/sweep.yaml) or [Join Our Discord](https://discord.gg/sweep) for help.
If you would like me to stop creating this PR, go to issues and say "Sweep: create an empty `sweep.yaml` file".
Thank you for using Sweep! 🧹""".replace(
" ", ""
),
head=branch_name,
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
)
pr.add_to_labels(GITHUB_LABEL_NAME)
return pr
def add_config_to_top_repos(installation_id, username, repositories, max_repos=3):
user_token, g = get_github_client(installation_id)
repo_activity = {}
for repo_entity in repositories:
repo = g.get_repo(repo_entity.full_name)
# instead of using total count, use the date of the latest commit
commits = repo.get_commits(
author=username,
since=datetime.datetime.now() - datetime.timedelta(days=30),
)
# get latest commit date
commit_date = datetime.datetime.now() - datetime.timedelta(days=30)
for commit in commits:
if commit.commit.author.date > commit_date:
commit_date = commit.commit.author.date
# since_date = datetime.datetime.now() - datetime.timedelta(days=30)
# commits = repo.get_commits(since=since_date, author="lukejagg")
repo_activity[repo] = commit_date
# print(repo, commits.totalCount)
logger.print(repo, commit_date)
sorted_repos = sorted(repo_activity, key=repo_activity.get, reverse=True)
sorted_repos = sorted_repos[:max_repos]
# For each repo, create a branch based on main branch, then create PR to main branch
for repo in sorted_repos:
try:
logger.print("Creating config for", repo.full_name)
create_config_pr(
None,
repo=repo,
cloned_repo=ClonedRepo(
repo_full_name=repo.full_name,
installation_id=installation_id,
token=user_token,
),
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.print(e)
logger.print("Finished creating configs for top repos")
def create_gha_pr(g, repo):
# Create a new branch
branch_name = "sweep/gha-enable"
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
# Update the sweep.yaml file in this branch to add "gha_enabled: True"
sweep_yaml_content = (
repo.get_contents("sweep.yaml", ref=branch_name).decoded_content.decode()
+ "\ngha_enabled: True"
)
repo.update_file(
"sweep.yaml",
"Enable GitHub Actions",
sweep_yaml_content,
repo.get_contents("sweep.yaml", ref=branch_name).sha,
branch=branch_name,
)
# Create a PR from this branch to the main branch
pr = repo.create_pull(
title="Enable GitHub Actions",
body="This PR enables GitHub Actions for this repository.",
head=branch_name,
base=repo.default_branch,
)
return pr
SWEEP_TEMPLATE = """\
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
- type: input
id: branch
attributes:
label: Branch
description: The branch to work off of (optional)
placeholder: |

import copy
import re
import traceback
from pathlib import Path
from loguru import logger
from sweepai.agents.assistant_wrapper import (
client,
openai_assistant_call,
run_until_complete,
)
from sweepai.core.entities import AssistantRaisedException, FileChangeRequest, Message
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.progress import AssistantConversation, TicketProgress
system_message = r""" You are searching through a codebase to guide a junior developer on how to solve the user request. The junior developer will follow your instructions exactly and make the changes.
# User Request
{user_request}
# Guide
## Step 1: Unzip the file into /mnt/data/repo. Then list all root level directories. You must copy the below code verbatim into the file.
```python
import zipfile
import os
zip_path = '{file_path}'
extract_to_path = 'mnt/data/repo'
os.makedirs(extract_to_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to_path)
zip_contents = zip_ref.namelist()
root_dirs = {{name.split('/')[0] for name in zip_contents}}
print(f'Root directories: {{root_dirs}}')
```
## Step 2: Find the relevant files.
You can search by file name or by keyword search in the contents.
## Step 3: Find relevant lines.
1. Locate the lines of code that contain the identified keywords or are at the specified line number. You can use keyword search or manually look through the file 100 lines at a time.
2. Check the surrounding lines to establish the full context of the code block.
3. Adjust the starting line to include the entire functionality that needs to be refactored or moved.
4. Finally determine the exact line spans that include a logical and complete section of code to be edited.
```python
def print_lines_with_keyword(content, keywords):
max_matches=5
context = 10
matches = [i for i, line in enumerate(content.splitlines()) if any(keyword in line.lower() for keyword in keywords)]
print(f"Found {{len(matches)}} matches, but capping at {{max_match}}")
matches = matches[:max_matches]
expanded_matches = set()
for match in matches:
start = max(0, match - context)
end = min(len(content.splitlines()), match + context + 1)
for i in range(start, end):
expanded_matches.add(i)
for i in sorted(expanded_matches):
print(f"{{i}}: {{content.splitlines()[i]}}")
```
## Step 4: Construct a plan.
Provide the final plan to solve the issue, following these rules:
* DO NOT apply any changes here, they will not be persisted. You must provide the plan and the developer will apply the changes.
* You may only create new files and modify existing files.
* File paths should be relative paths from the root of the repo.
* Use the minimum number of create and modify operations required to solve the issue.
* Start and end lines indicate the exact start and end lines to edit. Expand this to encompass more lines if you're unsure where to make the exact edit.
Respond in the following format:
```xml
<plan>
<create_file file="file_path_1">
* Natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create_file>
...
<modify_file file="file_path_2" start_line="i" end_line="j">
* Natural language instructions for the modifications needed to solve the issue.
* Be concise and reference necessary files, imports and entity names.
...
</modify_file>
...
</plan>
```"""
@file_cache(ignore_params=["zip_path", "chat_logger", "ticket_progress"])
def new_planning(
request: str,
zip_path: str,
additional_messages: list[Message] = [],
chat_logger: ChatLogger | None = None,
assistant_id: str = None,
ticket_progress: TicketProgress | None = None,
) -> list[FileChangeRequest]:
planning_iterations = 3
try:
def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
assistant_conversation = AssistantConversation.from_ids(
assistant_id=assistant_id, run_id=run_id, thread_id=thread_id
)
if not assistant_conversation:
return
ticket_progress.planning_progress.assistant_conversation = (
assistant_conversation
)
ticket_progress.save()
logger.info("Uploading file...")
zip_file_object = client.files.create(file=Path(zip_path), purpose="assistants")
logger.info("Done uploading file.")
zip_file_id = zip_file_object.id
response = openai_assistant_call(
request=request,
assistant_id=assistant_id,
additional_messages=additional_messages,
uploaded_file_ids=[zip_file_id],
chat_logger=chat_logger,
save_ticket_progress=save_ticket_progress
if ticket_progress is not None
else None,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = response.run_id
thread_id = response.thread_id
for _ in range(planning_iterations):
save_ticket_progress(
assistant_id=response.assistant_id,
thread_id=response.thread_id,
run_id=response.run_id,
)
messages = response.messages
final_message = messages.data[0].content[0].text.value
fcrs = []
fcr_matches = list(
re.finditer(FileChangeRequest._regex, final_message, re.DOTALL)
)
if len(fcr_matches) > 0:
break
else:
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content="A valid plan (within the <plan> tags) was not provided. Please continue working on the plan. If you are stuck, consider starting over.",
)
run = client.beta.threads.runs.create(
thread_id=response.thread_id,
assistant_id=response.assistant_id,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = run.id
messages = run_until_complete(
thread_id=thread_id,
run_id=run_id,
assistant_id=response.assistant_id,
)
for match_ in fcr_matches:
group_dict = match_.groupdict()
if group_dict["change_type"] == "create_file":
group_dict["change_type"] = "create"
if group_dict["change_type"] == "modify_file":
group_dict["change_type"] = "modify"
fcr = FileChangeRequest(**group_dict)
fcr.filename = fcr.filename.lstrip("/")
fcr.instructions = fcr.instructions.replace("\n*", "\n•")
fcr.instructions = fcr.instructions.strip("\n")
if fcr.instructions.startswith("*"):
fcr.instructions = "•" + fcr.instructions[1:]
fcrs.append(fcr)
new_file_change_request = copy.deepcopy(fcr)
new_file_change_request.change_type = "check"
new_file_change_request.parent = fcr
fcrs.append(new_file_change_request)
assert len(fcrs) > 0
return fcrs
except AssistantRaisedException as e:
raise e
except Exception as e:
logger.exception(e)
if chat_logger is not None:
discord_log_error(
str(e)
+ "\n\n"
+ traceback.format_exc()
+ "\n\n"
+ str(chat_logger.data)
)
return None
if __name__ == "__main__":
request = """## Title: replace the broken tutorial link in installation.md with https://docs.sweep.dev/usage/tutorial\n"""
additional_messages = [
Message(
role="user",
content='<relevant_snippets_in_repo>\n<snippet source="docs/pages/usage/tutorial.mdx:45-60">\n...\n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:30-45">\n...\n30: \n31: ![PR Comment](/tutorial/comment.png)\n32: \n33: c. If you have GitHub Actions set up, it will automatically run the linters, build, and tests and will show any failed logs to Sweep to handle. This only works with GitHub Actions and not other CI providers, so unfortunately for Vercel we have to copy paste manually.\n34: \n35: ![GitHub Actions](/tutorial/github_actions.png)\n36: \n37: 6. Once you are happy with the PR, you can merge it and it will be deployed to production via Vercel.\n38: \n39: \n40: ![Final](/tutorial/final.png)\n41: \n42: \n43: You can see the final example at https://github.com/kevinlu1248/docusaurus-2/pull/4 with preview https://docusaurus-2-ql4cskc5o-sweepai.vercel.app/.\n44: \n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n...\n</snippet>\n<snippet source="docs/installation.md:45-60">\n...\n45: * Provide any additional context that might be helpful, e.g. see "src/App.test.tsx" for an example of a good unit test.\n46: * For more guidance, visit [Advanced](https://docs.sweep.dev/usage/advanced), or watch the following video.\n47: \n48: [![Video](http://img.youtube.com/vi/Qn9vB71R4UM/0.jpg)](http://www.youtube.com/watch?v=Qn9vB71R4UM "Advanced Sweep Tricks and Feedback Tips")\n49: \n50: For configuring Sweep for your repo, see [Config](https://docs.sweep.dev/usage/config), especially for setting up Sweep Rules and Sweep Sweep.\n51: \n52: ## Limitations of Sweep (for now) ⚠️\n53: \n54: * 🗃️ **Gigantic repos**: >5000 files. We have default extensions and directories to exclude but sometimes this doesn\'t catch them all. You may need to block some directories (see [`blocked_dirs`](https://docs.sweep.dev/usage/config#blocked_dirs))\n55: * If Sweep is stuck at 0% for over 30 min and your repo has a few thousand files, let us know.\n56: \n57: * 🏗️ **Large-scale refactors**: >5 files or >300 lines of code changes (we\'re working on this!)\n58: * We can\'t do this - "Refactor entire codebase from Tensorflow to PyTorch"\n59: \n60: * 🖼️ **Editing images** and other non-text assets\n...\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:0-15">\n0: # Tutorial for Getting Started with Sweep\n1: \n2: We recommend using an existing **real project** for Sweep, but if you must start from scratch, we recommend **using a template**. In particular, we recommend Vercel templates and Vercel auto-deploy, since Vercel\'s auto-generated previews make it **easy to review Sweep\'s PRs**\n3: \n4: We\'ll use [Docusaurus](https://vercel.com/templates/next.js/docusaurus-2) since it\'s is the easiest to set up (no backend). To see other templates see https://vercel.com/templates.\n5: \n6: 1. Go to https://vercel.com/templates/next.js/docusaurus-2 (or another template) and click "Deploy".\n7: \n8: ![Deploy](/tutorial/deployment.png)\n9: \n10: 2. Vercel will prompt you to select a GitHub account and click "Clone" after. This will trigger a build and deploy which will take a few minutes. Once the build is done, you will be greeted with a congratulations message.\n11: \n12: ![Congratulations](/tutorial/congratulations.png)\n13: \n14: 3. Go to the [Sweep Installation](https://github.com/apps/sweep-ai) page and click the grey "Configure" button or the green "Install" button. Ensure that that the Vercel template (i.e. Docusaurus) is configured to use Sweep.\n...\n</snippet>\n</relevant_snippets_in_repo>\ndocs/\n installation.md\n docs/pages/\n docs/pages/usage/\n _meta.json\n advanced.mdx\n config.mdx\n extra-self-host.mdx\n sandbox.mdx\n tutorial.mdx',
name=None,
function_call=None,
key=None,
)
]
print(
new_planning(
request,
"/tmp/sweep_archive.zip",
chat_logger=ChatLogger(
{"username": "kevinlu1248", "title": "Unit test for planning"}
),
ticket_progress=TicketProgress(tracking_id="ed47605a38"),
)

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

sweep/sweepai/api.py

Lines 1 to 1178 in 0277fad

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_jira_ticket import handle_jira_ticket
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
@app.post("/jira")
def jira_webhook(
request_dict: dict = Body(...),
) -> None:
def call_jira_ticket(*args, **kwargs):
thread = threading.Thread(target=handle_jira_ticket, args=args, kwargs=kwargs)
thread.start()
call_jira_ticket(event=request_dict)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

# Advanced Features: becoming a Power User 🧠
## Usage 📖
### Mention important files
To ensure that Sweep scans a file, mention the file name in your ticket. Sweep searches for relevant files at runtime, but specifying the file helps avoid missing important details.
### Giving Sweep feedback
If Sweep's plan isn't accurate, you can respond to Sweep in three places:
1. **Issue**: Sweep will create a new pull request and close the old one. Alternatively, you can edit the issue description to recreate the pull request.
2. **Pull request**: Sweep will update the PR based on your PR comments
3. **Code**: Sweep will only update the file that the comment is on
Whenever you make a message that Sweep is taking a look at, you will see an 👀 emoji. If you don't see this, make sure the PR/issue is open and you prefixed the message with "sweep:".
Further, on failed Github Action runs, Sweep will update the PR based on the error message.
### Switch branch
To get Sweep to use a different base branch for one issue, add the following to the issue description.
> branch: BRANCH_NAME
## Configuration 🛠️
### Use GitHub Actions
We highly recommend linters, as well as Netlify/Vercel preview builds. Sweep auto-corrects based on linter and build errors, and Netlify and Vercel helps with iteration cycles by providing previews of static sites using Netlify.
### Set up `sweep.yaml`
You can set up `sweep.yaml` to
* Provide up to date docs by setting up `docs` (https://docs.sweep.dev/usage/config#docs)
* Set up automated formatting and linting by setting up `sandbox` (https://docs.sweep.dev/usage/config#sandbox). Never have Sweep commit a failing `npm lint` again.
* Give Sweep a high level description of where to find files in your repo by editing the `repo_description` field.
For more on configs, check out https://docs.sweep.dev/usage/config.
## Prompting 🗣️
The amount of prompting you need to give Sweep directly scales with the complexity of the problem.
For harder problems, try to provide the same information a human would need, and for simpler problems, providing a single line and a file name should suffice.
### Prompting formats
A good issue should include **where to look** (file name or entity name), **what to do** ("change the logic to do this"), and **additional context** (there's a bug/we need this feature/there's this dependency). Examples:

sweep/sweepai/cli.py

Lines 1 to 363 in 0277fad

import datetime
import json
import os
import pickle
import threading
import time
import uuid
from itertools import chain, islice
import typer
from github import Github
from github.Event import Event
from github.IssueEvent import IssueEvent
from github.Repository import Repository
from loguru import logger
from rich.console import Console
from rich.prompt import Prompt
from sweepai.api import handle_request
from sweepai.handlers.on_ticket import on_ticket
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import get_hash
from sweepai.web.events import Account, Installation, IssueRequest
app = typer.Typer(
name="sweepai", context_settings={"help_option_names": ["-h", "--help"]}
)
app_dir = typer.get_app_dir("sweepai")
config_path = os.path.join(app_dir, "config.json")
console = Console()
cprint = console.print
def posthog_capture(event_name, properties, *args, **kwargs):
POSTHOG_DISTINCT_ID = os.environ.get("POSTHOG_DISTINCT_ID")
if POSTHOG_DISTINCT_ID:
posthog.capture(POSTHOG_DISTINCT_ID, event_name, properties, *args, **kwargs)
def load_config():
if os.path.exists(config_path):
cprint(f"\nLoading configuration from {config_path}", style="yellow")
with open(config_path, "r") as f:
config = json.load(f)
os.environ["GITHUB_PAT"] = config.get("GITHUB_PAT", "")
os.environ["OPENAI_API_KEY"] = config.get("OPENAI_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = config.get("ANTHROPIC_API_KEY", "")
os.environ["VOYAGE_API_KEY"] = config.get("VOYAGE_API_KEY", "")
os.environ["POSTHOG_DISTINCT_ID"] = str(config.get("POSTHOG_DISTINCT_ID", ""))
def fetch_issue_request(issue_url: str, __version__: str = "0"):
(
protocol_name,
_,
_base_url,
org_name,
repo_name,
_issues,
issue_number,
) = issue_url.split("/")
cprint("Fetching installation ID...")
installation_id = -1
cprint("Fetching access token...")
_token, g = get_github_client(installation_id)
g: Github = g
cprint("Fetching repo...")
issue = g.get_repo(f"{org_name}/{repo_name}").get_issue(int(issue_number))
issue_request = IssueRequest(
action="labeled",
issue=IssueRequest.Issue(
title=issue.title,
number=int(issue_number),
html_url=issue_url,
user=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
body=issue.body,
labels=[
IssueRequest.Issue.Label(
name="sweep",
),
],
assignees=None,
pull_request=None,
),
repository=IssueRequest.Issue.Repository(
full_name=issue.repository.full_name,
description=issue.repository.description,
),
assignee=IssueRequest.Issue.Assignee(login=issue.user.login),
installation=Installation(
id=installation_id,
account=Account(
id=issue.user.id,
login=issue.user.login,
type="User",
),
),
sender=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
)
return issue_request
def pascal_to_snake(name):
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
def get_event_type(event: Event | IssueEvent):
if isinstance(event, IssueEvent):
return "issues"
else:
return pascal_to_snake(event.type)[: -len("_event")]
@app.command()
def test():
cprint("Sweep AI is installed correctly and ready to go!", style="yellow")
@app.command()
def watch(
repo_name: str,
debug: bool = False,
record_events: bool = False,
max_events: int = 30,
):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
posthog_capture(
"sweep_watch_started",
{
"repo": repo_name,
"debug": debug,
"record_events": record_events,
"max_events": max_events,
},
)
GITHUB_PAT = os.environ.get("GITHUB_PAT", None)
if GITHUB_PAT is None:
raise ValueError("GITHUB_PAT environment variable must be set")
g = Github(os.environ["GITHUB_PAT"])
repo = g.get_repo(repo_name)
if debug:
logger.debug("Debug mode enabled")
def stream_events(repo: Repository, timeout: int = 2, offset: int = 2 * 60):
processed_event_ids = set()
current_time = time.time() - offset
current_time = datetime.datetime.fromtimestamp(current_time)
local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
while True:
events_iterator = chain(
islice(repo.get_events(), max_events),
islice(repo.get_issues_events(), max_events),
)
for i, event in enumerate(events_iterator):
if event.id not in processed_event_ids:
local_time = event.created_at.replace(
tzinfo=datetime.timezone.utc
).astimezone(local_tz)
if local_time.timestamp() > current_time.timestamp():
yield event
else:
if debug:
logger.debug(
f"Skipping event {event.id} because it is in the past (local_time={local_time}, current_time={current_time}, i={i})"
)
if debug:
logger.debug(
f"Skipping event {event.id} because it is already handled"
)
processed_event_ids.add(event.id)
time.sleep(timeout)
def handle_event(event: Event | IssueEvent, do_async: bool = True):
if isinstance(event, IssueEvent):
payload = event.raw_data
payload["action"] = payload["event"]
else:
payload = {**event.raw_data, **event.payload}
payload["sender"] = payload.get("sender", payload["actor"])
payload["sender"]["type"] = "User"
payload["pusher"] = payload.get("pusher", payload["actor"])
payload["pusher"]["name"] = payload["pusher"]["login"]
payload["pusher"]["type"] = "User"
payload["after"] = payload.get("after", payload.get("head"))
payload["repository"] = repo.raw_data
payload["installation"] = {"id": -1}
logger.info(str(event) + " " + str(event.created_at))
if record_events:
_type = get_event_type(event) if isinstance(event, Event) else "issue"
pickle.dump(
event,
open(
"tests/events/"
+ f"{_type}_{payload.get('action')}_{str(event.id)}.pkl",
"wb",
),
)
if do_async:
thread = threading.Thread(
target=handle_request, args=(payload, get_event_type(event))
)
thread.start()
return thread
else:
return handle_request(payload, get_event_type(event))
def main():
cprint(
f"\n[bold black on white] Starting server, listening to events from {repo_name}... [/bold black on white]\n",
)
cprint(
f"To create a PR, please create an issue at https://github.com/{repo_name}/issues with a title prefixed with 'Sweep:' or label an existing issue with 'sweep'. The events will be logged here, but there may be a brief delay.\n"
)
for event in stream_events(repo):
handle_event(event)
if __name__ == "__main__":
main()
@app.command()
def init(override: bool = False):
# TODO: Fix telemetry
if not override:
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
if "OPENAI_API_KEY" in config and "ANTHROPIC_API_KEY" in config and "GITHUB_PAT" in config:
override = typer.confirm(
f"\nConfiguration already exists at {config_path}. Override?",
default=False,
abort=True,
)
cprint(
"\n[bold black on white] Initializing Sweep CLI... [/bold black on white]\n",
)
cprint(
"\nFirstly, let's store your OpenAI API Key. You can get it here: https://platform.openai.com/api-keys\n",
style="yellow",
)
openai_api_key = Prompt.ask("OpenAI API Key", password=True)
assert len(openai_api_key) > 30, "OpenAI API Key must be of length at least 30."
assert openai_api_key.startswith("sk-"), "OpenAI API Key must start with 'sk-'."
cprint(
"\nNext, let's store your Anthropic API key. You can get it here: https://console.anthropic.com/settings/keys.",
style="yellow",
)
anthropic_api_key = Prompt.ask("Anthropic API Key", password=True)
assert len(anthropic_api_key) > 30, "Anthropic API Key must be of length at least 30."
assert anthropic_api_key.startswith("sk-ant-api03-"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nGreat! Next, we'll need just your GitHub PAT. Here's a link with all the permissions pre-filled:\nhttps://github.com/settings/tokens/new?description=Sweep%20Self-hosted&scopes=repo,workflow\n",
style="yellow",
)
github_pat = Prompt.ask("GitHub PAT", password=True)
assert len(github_pat) > 30, "GitHub PAT must be of length at least 30."
assert github_pat.startswith("ghp_"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nAwesome! Lastly, let's get your Voyage AI API key from https://dash.voyageai.com/api-keys. This is optional, but improves code search by about [cyan]5%[/cyan]. You can always return to this later by re-running 'sweep init'.",
style="yellow",
)
voyage_api_key = Prompt.ask("Voyage AI API key", password=True)
if voyage_api_key:
assert len(voyage_api_key) > 30, "Voyage AI API key must be of length at least 30."
assert voyage_api_key.startswith("pa-"), "Voyage API key must start with 'pa-'."
POSTHOG_DISTINCT_ID = None
enable_telemetry = typer.confirm(
"\nEnable usage statistics? This will help us improve the product.",
default=True,
)
if enable_telemetry:
cprint(
"\nThank you for enabling telemetry. We'll collect anonymous usage statistics to improve the product. You can disable this at any time by rerunning 'sweep init'.",
style="yellow",
)
POSTHOG_DISTINCT_ID = uuid.getnode()
posthog.capture(POSTHOG_DISTINCT_ID, "sweep_init", {})
config = {
"GITHUB_PAT": github_pat,
"OPENAI_API_KEY": openai_api_key,
"ANTHROPIC_API_KEY": anthropic_api_key,
"VOYAGE_API_KEY": voyage_api_key,
}
if POSTHOG_DISTINCT_ID:
config["POSTHOG_DISTINCT_ID"] = POSTHOG_DISTINCT_ID
os.makedirs(app_dir, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f)
cprint(f"\nConfiguration saved to {config_path}\n", style="yellow")
cprint(
"Installation complete! You can now run [green]'sweep run <issue-url>'[/green][yellow] to run Sweep on an issue. or [/yellow][green]'sweep watch <org-name>/<repo-name>'[/green] to have Sweep listen for and fix newly created GitHub issues.",
style="yellow",
)
@app.command()
def run(issue_url: str):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
cprint(f"\n Running Sweep on issue: {issue_url} \n", style="bold black on white")
posthog_capture("sweep_run_started", {"issue_url": issue_url})
request = fetch_issue_request(issue_url)
try:
cprint(f'\nRunning Sweep to solve "{request.issue.title}"!\n')
on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.sender.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
edited=False,
tracking_id=get_hash(),
)
except Exception as e:
posthog_capture("sweep_run_fail", {"issue_url": issue_url, "error": str(e)})
else:
posthog_capture("sweep_run_success", {"issue_url": issue_url})
def main():
cprint(
"By using the Sweep CLI, you agree to the Sweep AI Terms of Service at https://sweep.dev/tos.pdf",
style="cyan",
)
load_config()
app()
if __name__ == "__main__":

# Frequently Asked Questions
<details id="does-sweep-write-tests">
<summary>Does Sweep write tests?</summary>
Yep! The easiest way to have Sweep write tests is by modifying the `description` parameter in your `sweep.yaml`. You can add something like:
“In [your repository], the tests are written in [your format]. If you modify business logic, modify the tests as well using this format.” You can add anything you’d like to the description parameter, including formatting rules (like PEP8), code style, etc!
</details>
<details id="can-we-trust-code-written-by-sweep">
<summary>Can we trust the code written by Sweep?</summary>
You should always review the PR. However, we also perform testing to make sure the PR works using your existing GitHub actions.
To get the best performance, add GitHub actions that lint, test, and validate your code.
</details>
<details id="work-off-another-branch">
<summary>Can I have Sweep work off of another branch besides main?</summary>
Yes! In the `sweep.yaml`, you can set the `branch` parameter to something besides your default branch, and Sweep will use that as a reference.
</details>
<details id="retry-issue-with-sweep">
<summary>How do I retry an issue with Sweep?</summary>
To retry an issue, prefix your issue reply with 'Sweep: '. This will trigger Sweep to retry the issue.
</details>
<details id="give-documentation-to-sweep">
<summary>Can I give documentation to Sweep?</summary>
Yes! In the `sweep.yaml`, you can specify docs. Be sure to pick the prefix of the site, which will allow us to only fetch the docs you need.
Check out the example here: https://github.com/sweepai/sweep/blob/main/sweep.yaml.
</details>
<details id="comment-on-sweeps-prs">
<summary>Can I comment on Sweep’s PRs?</summary>
Yep! You have three options depending on the degree of the change:
1. You can comment on the issue, and Sweep will rewrite the entire pull request. This will use one of your GPT4 credits.
2. You can comment on the pull request (not a file) and Sweep can make substantial changes to the pull request. Sweep will search the codebase, and is able to modify and create files.
3. You can comment on the file directly, and Sweep will only modify that file. Use this for small single file changes.
</details>

Once Sweep has the reference implementation, Sweep generates the corresponding test as commits in a [GitHub PR](https://github.com/sweepai/sweep/pull/2378):
```python
def get_file_contents(self, file_path, ref=None):
local_path = os.path.join(self.cache_dir, file_path)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
```
We have Sweep generated mocks for `os.path.join` and `open`. <br></br>
This code looks great!
```python
@patch("os.path.join")
@patch("open")
def test_get_file_contents(self, mock_open, mock_join):
mock_join.return_value = "/tmp/cache/repos/sweepai/sweep/main/file1"
mock_open.return_value.__enter__.return_value.read.return_value = "file content"
content = self.cloned_repo.get_file_contents("file1")
self.assertEqual(content, "file content")
```
We generated mocks for `os.path.join` and `open`, which should return the correct path and file contents. <br></br>
Ok we're done here right? Can we just write these tests and leave the rest to the developer?
## 3. **Run the tests.**
Most other AI tools stop here, but it’s not enough. <br></br>
If you just committed these tests it would be great, but you’d end up with a frustrating bug. Here it is:
```bash
File "/usr/lib/python3.10/unittest/mock.py", line 1616, in _get_target
raise TypeError(
TypeError: Need a valid target to patch. You supplied: 'open'
```
Did we really save time for the developer here? It’s frustrating that most other tools don’t fix these issues.
*Unlike every other tool, Sweep actually runs these tests.*
Sweep ran the code, found the issue, and identified the solution: <br></br>
**”Change the target of the patch in the 'test_get_file_contents' method from 'open' to 'builtins.open'. This will correctly patch the built-in 'open' function during the test.”**
Sweep added [this commit](https://github.com/sweepai/sweep/pull/2378/commits/0ded79eab77ca3e511257ff0bf3874893b038e9e):
```python

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 32 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)


Step 2: ⌨️ Coding

  • Modify sweepai/handlers/on_merge_conflict.py530993b Edit
Modify sweepai/handlers/on_merge_conflict.py with contents: In the `on_merge_conflict` function:

Replace this code block:

try:
    git_repo.config_writer().set_value("user", "name", "sweep-nightly[bot]").release()
    git_repo.config_writer().set_value("user", "email", "[email protected]").release()  
    git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
    # Assume there are merge conflicts
    pass

with:

try:
    git_repo.config_writer().set_value("user", "name", "sweep-nightly[bot]").release()
    git_repo.config_writer().set_value("user", "email", "[email protected]").release()
    git_repo.git.fetch()
    git_repo.git.rebase("origin/" + pr.base.ref)
except GitCommandError:
    # Assume there are conflicts during rebase
    pass

This will perform a rebase from the target branch instead of a merge.

Import the GitCommandError exception from the git module at the top of the file:

from git import GitCommandError

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_02da1.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 8, 2024 that will close this issue
Copy link
Contributor

sweep-nightly bot commented Apr 8, 2024

🚀 Here's the PR! #3502

See Sweep's progress at the progress dashboard!
💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: 1110cfc18d)

Tip

I can email you next time I complete a pull request if you set up your email here!


Actions (click)

  • ↻ Restart Sweep

Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description.

import time
import traceback
from git import GitCommandError
from github.PullRequest import PullRequest
from loguru import logger
from sweepai.config.server import PROGRESS_BASE_URL
from sweepai.core import entities
from sweepai.core.entities import FileChangeRequest
from sweepai.core.sweep_bot import SweepBot
from sweepai.handlers.create_pr import create_pr_changes
from sweepai.handlers.on_ticket import get_branch_diff_text, sweeping_gif
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.diff import generate_diff
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.progress import (
PaymentContext,
TicketContext,
TicketProgress,
TicketProgressStatus,
)
from sweepai.utils.prompt_constructor import HumanMessagePrompt
from sweepai.utils.str_utils import to_branch_name
from sweepai.utils.ticket_utils import center
instructions_format = """Resolve the merge conflicts in the PR by incorporating changes from both branches into the final code.
Title of PR: {title}
Here were the original changes to this file in the head branch:
Commit message: {head_commit_message}
```diff
{head_diff}
```
Here were the original changes to this file in the base branch:
Commit message: {base_commit_message}
```diff
{base_diff}
```
In the analysis_and_identification, first determine what each change does. Then determine what the final code should be. Then, use the keyword_search to find the merge conflict markers <<<<<<< and >>>>>>>. Finally, make the code changes by writing the old_code and the new_code."""
def on_merge_conflict(
pr_number: int,
username: str,
repo_full_name: str,
installation_id: int,
tracking_id: str,
):
# copied from stack_pr
token, g = get_github_client(installation_id=installation_id)
try:
repo = g.get_repo(repo_full_name)
except Exception as e:
print("Exception occured while getting repo", e)
pr: PullRequest = repo.get_pull(pr_number)
branch = pr.head.ref
status_message = center(
f"{sweeping_gif}\n\n"
+ f'Resolving merge conflicts: track the progress <a href="{PROGRESS_BASE_URL}/issues/{tracking_id}">here</a>.'
)
header = f"{status_message}\n---\n\nI'm currently resolving the merge conflicts in this PR. I will stack a new PR once I'm done."
comment = None
for current_comment in pr.get_issue_comments():
if (
current_comment.user.login == "sweep-nightly[bot]"
and "Resolving merge conflicts: track the progress" in current_comment.body
):
current_comment.edit(body=header)
comment = current_comment
break
comment = pr.create_issue_comment(body=header)
def edit_comment(body):
nonlocal comment
comment.edit(header + "\n\n" + body)
metadata = {}
try:
cloned_repo = ClonedRepo(
repo_full_name=repo_full_name,
installation_id=installation_id,
branch=branch,
token=token,
)
time.time()
request = f"Sweep: Resolve merge conflicts for PR #{pr_number}: {pr.title}"
title = request
if len(title) > 50:
title = title[:50] + "..."
chat_logger = ChatLogger(
data={
"username": username,
"metadata": metadata,
"tracking_id": tracking_id,
}
)
is_paying_user = chat_logger.is_paying_user()
chat_logger.is_consumer_tier()
# this logic is partly taken from on_ticket.py, if there is an issue please refer to that file
if chat_logger:
use_faster_model = chat_logger.use_faster_model()
else:
is_paying_user = True
ticket_progress = TicketProgress(
tracking_id=tracking_id,
username=username,
context=TicketContext(
title=title,
description="",
repo_full_name=repo_full_name,
branch_name="sweep/" + to_branch_name(request),
issue_number=pr_number,
is_public=repo.private is False,
start_time=int(time.time()),
# mostly copied from on_ticket, if issue please check that file
payment_context=PaymentContext(
use_faster_model=use_faster_model,
pro_user=is_paying_user,
daily_tickets_used=(
chat_logger.get_ticket_count(use_date=True)
if chat_logger
else 0
),
monthly_tickets_used=(
chat_logger.get_ticket_count() if chat_logger else 0
),
),
),
)
metadata = {
"tracking_id": tracking_id,
"username": username,
"function": "on_merge_conflict",
**ticket_progress.context.dict(),
}
posthog.capture(
username,
"started",
properties=metadata,
)
issue_url = pr.html_url
edit_comment("Configuring branch...")
new_pull_request = entities.PullRequest(
title=title,
branch_name="sweep/" + branch + "-merge-conflict",
content="",
)
# Making sure name is unique
for i in range(30):
try:
repo.get_branch(new_pull_request.branch_name + "_" + str(i))
except Exception:
new_pull_request.branch_name += "_" + str(i)
break
# Merge into base branch from cloned_repo.repo_dir to pr.base.ref
git_repo = cloned_repo.git_repo
old_head_branch = git_repo.branches[branch]
head_branch = git_repo.create_head(
new_pull_request.branch_name,
commit=old_head_branch.commit,
)
head_branch.checkout()
try:
git_repo.config_writer().set_value(
"user", "name", "sweep-nightly[bot]"
).release()
git_repo.config_writer().set_value(
"user", "email", "[email protected]"
).release()
git_repo.git.merge("origin/" + pr.base.ref)
except GitCommandError:
# Assume there are merge conflicts
pass
git_repo.git.add(update=True)
# -m and message are needed otherwise exception is thrown
git_repo.git.commit("-m", "Start of Merge Conflict Resolution")
origin = git_repo.remotes.origin
new_url = f"https://x-access-token:{token}@github.com/{repo_full_name}.git"
origin.set_url(new_url)
git_repo.git.push("--set-upstream", origin, new_pull_request.branch_name)
last_commit = git_repo.head.commit
all_files = [item.a_path for item in last_commit.diff("HEAD~1")]
conflict_files = []
for file in all_files:
try:
contents = open(cloned_repo.repo_dir + "/" + file).read()
if "\n<<<<<<<" in contents and "\n>>>>>>>" in contents:
conflict_files.append(file)
except UnicodeDecodeError:
pass
snippets = []
for conflict_file in conflict_files:
contents = open(cloned_repo.repo_dir + "/" + conflict_file).read()
snippet = entities.Snippet(
file_path=conflict_file,
start=0,
end=len(contents.splitlines()),
content=contents,
)
snippets.append(snippet)
tree = ""
ticket_progress.status = TicketProgressStatus.PLANNING
ticket_progress.save()
human_message = HumanMessagePrompt(
repo_name=repo_full_name,
issue_url=issue_url,
username=username,
repo_description=(repo.description or "").strip(),
title=request,
summary=request,
snippets=snippets,
tree=tree,
)
sweep_bot = SweepBot.from_system_message_content(
human_message=human_message,
repo=repo,
ticket_progress=ticket_progress,
chat_logger=chat_logger,
cloned_repo=cloned_repo,
branch=new_pull_request.branch_name,
)
# can select more precise snippets
file_change_requests = []
base_commits = pr.base.repo.get_commits().get_page(0)
head_commits = list(pr.get_commits())
for conflict_file in conflict_files:
old_code = repo.get_contents(
conflict_file, ref=head_commits[0].parents[0].sha
).decoded_content.decode()
base_code = repo.get_contents(
conflict_file, ref=pr.base.ref
).decoded_content.decode()
head_code = repo.get_contents(
conflict_file, ref=pr.head.ref
).decoded_content.decode()
base_diff = generate_diff(old_code=old_code, new_code=base_code)
head_diff = generate_diff(old_code=old_code, new_code=head_code)
base_commit_message = ""
for commit in base_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
base_commit_message = commit.raw_data["commit"]["message"]
break
head_commit_message = ""
for commit in head_commits[::-1]:
if any(
commit_file.filename == conflict_file
for commit_file in commit.files
):
head_commit_message = commit.raw_data["commit"]["message"]
break
file_change_requests.append(
FileChangeRequest(
filename=conflict_file,
instructions=instructions_format.format(
title=pr.title,
base_commit_message=base_commit_message,
base_diff=base_diff,
head_commit_message=head_commit_message,
head_diff=head_diff,
),
change_type="modify",
)
)
ticket_progress.status = TicketProgressStatus.CODING
ticket_progress.save()
edit_comment("Resolving merge conflicts...")
generator = create_pr_changes(
file_change_requests,
new_pull_request,
sweep_bot,
username,
installation_id,
pr_number,
chat_logger=chat_logger,
base_branch=new_pull_request.branch_name,
)
for item in generator:
if isinstance(item, dict):
break
(
file_change_request,
changed_file,
sandbox_response,
commit,
file_change_requests,
) = item
logger.info("Status", file_change_request.status == "succeeded")
ticket_progress.status = TicketProgressStatus.COMPLETE
ticket_progress.save()
edit_comment("Done creating pull request.")
get_branch_diff_text(repo, new_pull_request.branch_name)
new_description = f"This PR resolves the merge conflicts in #{pr_number}. This branch can be directly merged into {pr.base.ref}.\n\nFixes #{pr_number}."
# Create pull request
new_pull_request.content = new_description
github_pull_request = repo.create_pull(
title=request,
body=new_description,
head=new_pull_request.branch_name,
base=pr.base.ref,
)
ticket_progress.context.pr_id = github_pull_request.number
ticket_progress.context.done_time = time.time()
ticket_progress.save()
edit_comment(f"✨ **Created Pull Request:** {github_pull_request.html_url}")
posthog.capture(
username,
"success",
properties=metadata,
)
return {"success": True}
except Exception as e:
print(f"Exception occured: {e}")
edit_comment(
f"> [!CAUTION]\n> \nAn error has occurred: {str(e)} (tracking ID: {tracking_id})"
)
discord_log_error(
"Error occured in on_merge_conflict.py"
+ traceback.format_exc()
+ "\n\n"
+ str(e)
+ "\n\n"
+ f"tracking ID: {tracking_id}"
)
posthog.capture(
username,
"failed",
properties=metadata,
)
return {"success": False}
if __name__ == "__main__":
on_merge_conflict(
pr_number=68,
username="MartinYe1234",
repo_full_name="MartinYe1234/Chess-Game",
installation_id=45945746,
tracking_id="ADD-BOB-2",

"""
This file contains the on_merge handler which is called when a pull request is merged to master.
on_merge is called by sweepai/api.py
"""
import time
from sweepai.config.client import SweepConfig, get_blocked_dirs, get_rules
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.pr_utils import make_pr
from loguru import logger
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
# change threshold for number of lines changed
CHANGE_BOUNDS = (10, 1500)
# dictionary to map from github repo to the last time a rule was activated
merge_rule_debounce = {}
# debounce time in seconds
DEBOUNCE_TIME = 120
diff_section_prompt = """
<file_diff file="{diff_file_path}">
{diffs}
</file_diff>"""
def comparison_to_diff(comparison, blocked_dirs):
pr_diffs = []
for file in comparison.files:
diff = file.patch
if (
file.status == "added"
or file.status == "modified"
or file.status == "removed"
):
if any(file.filename.startswith(dir) for dir in blocked_dirs):
continue
pr_diffs.append((file.filename, diff))
else:
logger.info(
f"File status {file.status} not recognized"
) # TODO(sweep): We don't handle renamed files
formatted_diffs = []
for file_name, file_patch in pr_diffs:
format_diff = diff_section_prompt.format(
diff_file_path=file_name, diffs=file_patch
)
formatted_diffs.append(format_diff)
return "\n".join(formatted_diffs)
def on_merge(request_dict: dict, chat_logger: ChatLogger):
before_sha = request_dict["before"]
after_sha = request_dict["after"]
commit_author = request_dict["sender"]["login"]
ref = request_dict["ref"]
if not ref.startswith("refs/heads/"):
return
user_token, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
if ref[len("refs/heads/") :] != SweepConfig.get_branch(repo):
return
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
return # if any check suite failed, return
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(before_sha, after_sha)
commits_diff = comparison_to_diff(comparison, blocked_dirs)
# check if the current repo is in the merge_rule_debounce dictionary
# and if the difference between the current time and the time stored in the dictionary is less than DEBOUNCE_TIME seconds
if (
repo.full_name in merge_rule_debounce
and time.time() - merge_rule_debounce[repo.full_name] < DEBOUNCE_TIME
):
return
merge_rule_debounce[repo.full_name] = time.time()
if not (
commits_diff.count("\n") >= CHANGE_BOUNDS[0]
and commits_diff.count("\n") <= CHANGE_BOUNDS[1]
):
return
rules = get_rules(repo)
rules = [rule for rule in rules if len(rule) > 0]
if not rules:
return
for rule in rules:
chat_logger.data["title"] = f"Sweep Rules - {rule}"
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
if changes_required:
make_pr(
title="[Sweep Rules] " + issue_title,
repo_description=repo.description,
summary=issue_description,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=user_token,
use_faster_model=chat_logger.use_faster_model(),
username=commit_author,
chat_logger=chat_logger,
rule=rule,
)

"""
create_pr is a function that creates a pull request from a list of file change requests.
It is also responsible for handling Sweep config PR creation. test
"""
import datetime
from typing import Any, Generator
import openai
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import DEFAULT_RULES_STRING, SweepConfig, get_blocked_dirs
from sweepai.config.server import (
ENV,
GITHUB_BOT_USERNAME,
GITHUB_CONFIG_BRANCH,
GITHUB_DEFAULT_CONFIG,
GITHUB_LABEL_NAME,
MONGODB_URI,
)
from sweepai.core.entities import (
FileChangeRequest,
MaxTokensExceeded,
Message,
MockPR,
PullRequest,
)
from sweepai.core.sweep_bot import SweepBot
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo, get_github_client
from sweepai.utils.str_utils import UPDATES_MESSAGE
num_of_snippets_to_query = 10
max_num_of_snippets = 5
INSTRUCTIONS_FOR_REVIEW = """\
### 💡 To get Sweep to edit this pull request, you can:
* Comment below, and Sweep can edit the entire PR
* Comment on a file, Sweep will only modify the commented file
* Edit the original issue to get Sweep to recreate the PR from scratch"""
def create_pr_changes(
file_change_requests: list[FileChangeRequest],
pull_request: PullRequest,
sweep_bot: SweepBot,
username: str,
installation_id: int,
issue_number: int | None = None,
chat_logger: ChatLogger = None,
base_branch: str = None,
additional_messages: list[Message] = []
) -> Generator[tuple[FileChangeRequest, int, Any], None, dict]:
# Flow:
# 1. Get relevant files
# 2: Get human message
# 3. Get files to change
# 4. Get file changes
# 5. Create PR
chat_logger = (
chat_logger
if chat_logger is not None
else ChatLogger(
{
"username": username,
"installation_id": installation_id,
"repo_full_name": sweep_bot.repo.full_name,
"title": pull_request.title,
"summary": "",
"issue_url": "",
}
)
if MONGODB_URI
else None
)
sweep_bot.chat_logger = chat_logger
organization, repo_name = sweep_bot.repo.full_name.split("/")
metadata = {
"repo_full_name": sweep_bot.repo.full_name,
"organization": organization,
"repo_name": repo_name,
"repo_description": sweep_bot.repo.description,
"username": username,
"installation_id": installation_id,
"function": "create_pr",
"mode": ENV,
"issue_number": issue_number,
}
posthog.capture(username, "started", properties=metadata)
try:
logger.info("Making PR...")
pull_request.branch_name = sweep_bot.create_branch(
pull_request.branch_name, base_branch=base_branch
)
completed_count, fcr_count = 0, len(file_change_requests)
blocked_dirs = get_blocked_dirs(sweep_bot.repo)
for (
new_file_contents,
changed_file,
commit,
file_change_requests,
) in sweep_bot.change_files_in_github_iterator(
file_change_requests,
pull_request.branch_name,
blocked_dirs,
additional_messages=additional_messages
):
completed_count += len(new_file_contents or [])
logger.info(f"Completed {completed_count}/{fcr_count} files")
yield new_file_contents, changed_file, commit, file_change_requests
if completed_count == 0 and fcr_count != 0:
logger.info("No changes made")
posthog.capture(
username,
"failed",
properties={
"error": "No changes made",
"reason": "No changes made",
**metadata,
},
)
# If no changes were made, delete branch
commits = sweep_bot.repo.get_commits(pull_request.branch_name)
if commits.totalCount == 0:
branch = sweep_bot.repo.get_git_ref(f"heads/{pull_request.branch_name}")
branch.delete()
return
# Include issue number in PR description
if issue_number:
# If the #issue changes, then change on_ticket (f'Fixes #{issue_number}.\n' in pr.body:)
pr_description = (
f"{pull_request.content}\n\nFixes"
f" #{issue_number}.\n\n---\n\n{UPDATES_MESSAGE}\n\n---\n\n{INSTRUCTIONS_FOR_REVIEW}"
)
else:
pr_description = f"{pull_request.content}"
pr_title = pull_request.title
if "sweep.yaml" in pr_title:
pr_title = "[config] " + pr_title
except MaxTokensExceeded as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Max tokens exceeded",
**metadata,
},
)
raise e
except openai.BadRequestError as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Invalid request error / context length",
**metadata,
},
)
raise e
except Exception as e:
logger.error(e)
posthog.capture(
username,
"failed",
properties={
"error": str(e),
"reason": "Unexpected error",
**metadata,
},
)
raise e
posthog.capture(username, "success", properties={**metadata})
logger.info("create_pr success")
result = {
"success": True,
"pull_request": MockPR(
file_count=completed_count,
title=pr_title,
body=pr_description,
pr_head=pull_request.branch_name,
base=sweep_bot.repo.get_branch(
SweepConfig.get_branch(sweep_bot.repo)
).commit,
head=sweep_bot.repo.get_branch(pull_request.branch_name).commit,
),
}
yield result # TODO: refactor this as it doesn't need to be an iterator
return
def safe_delete_sweep_branch(
pr, # Github PullRequest
repo: Repository,
) -> bool:
"""
Safely delete Sweep branch
1. Only edited by Sweep
2. Prefixed by sweep/
"""
pr_commits = pr.get_commits()
pr_commit_authors = set([commit.author.login for commit in pr_commits])
# Check if only Sweep has edited the PR, and sweep/ prefix
if (
len(pr_commit_authors) == 1
and GITHUB_BOT_USERNAME in pr_commit_authors
and pr.head.ref.startswith("sweep")
):
branch = repo.get_git_ref(f"heads/{pr.head.ref}")
# pr.edit(state='closed')
branch.delete()
return True
else:
# Failed to delete branch as it was edited by someone else
return False
def create_config_pr(
sweep_bot: SweepBot | None, repo: Repository = None, cloned_repo: ClonedRepo = None
):
if repo is not None:
# Check if file exists in repo
try:
repo.get_contents("sweep.yaml")
return
except SystemExit:
raise SystemExit
except Exception:
pass
title = "Configure Sweep"
branch_name = GITHUB_CONFIG_BRANCH
if sweep_bot is not None:
branch_name = sweep_bot.create_branch(branch_name, retry=False)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
sweep_bot.repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=sweep_bot.repo.default_branch,
additional_rules=DEFAULT_RULES_STRING,
),
branch=branch_name,
)
sweep_bot.repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
else:
# Create branch based on default branch
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
try:
# commit_history = []
# if cloned_repo is not None:
# commit_history = cloned_repo.get_commit_history(
# limit=1000, time_limited=False
# )
# commit_string = "\n".join(commit_history)
# sweep_yaml_bot = SweepYamlBot()
# generated_rules = sweep_yaml_bot.get_sweep_yaml_rules(
# commit_history=commit_string
# )
repo.create_file(
"sweep.yaml",
"Create sweep.yaml",
GITHUB_DEFAULT_CONFIG.format(
branch=repo.default_branch, additional_rules=DEFAULT_RULES_STRING
),
branch=branch_name,
)
repo.create_file(
".github/ISSUE_TEMPLATE/sweep-template.yml",
"Create sweep template",
SWEEP_TEMPLATE,
branch=branch_name,
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.error(e)
repo = sweep_bot.repo if sweep_bot is not None else repo
# Check if the pull request from this branch to main already exists.
# If it does, then we don't need to create a new one.
if repo is not None:
pull_requests = repo.get_pulls(
state="open",
sort="created",
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
head=branch_name,
)
for pr in pull_requests:
if pr.title == title:
return pr
logger.print("Default branch", repo.default_branch)
logger.print("New branch", branch_name)
pr = repo.create_pull(
title=title,
body="""🎉 Thank you for installing Sweep! We're thrilled to announce the latest update for Sweep, your AI junior developer on GitHub. This PR creates a `sweep.yaml` config file, allowing you to personalize Sweep's performance according to your project requirements.
## What's new?
- **Sweep is now configurable**.
- To configure Sweep, simply edit the `sweep.yaml` file in the root of your repository.
- If you need help, check out the [Sweep Default Config](https://github.com/sweepai/sweep/blob/main/sweep.yaml) or [Join Our Discord](https://discord.gg/sweep) for help.
If you would like me to stop creating this PR, go to issues and say "Sweep: create an empty `sweep.yaml` file".
Thank you for using Sweep! 🧹""".replace(
" ", ""
),
head=branch_name,
base=SweepConfig.get_branch(repo)
if sweep_bot is not None
else repo.default_branch,
)
pr.add_to_labels(GITHUB_LABEL_NAME)
return pr
def add_config_to_top_repos(installation_id, username, repositories, max_repos=3):
user_token, g = get_github_client(installation_id)
repo_activity = {}
for repo_entity in repositories:
repo = g.get_repo(repo_entity.full_name)
# instead of using total count, use the date of the latest commit
commits = repo.get_commits(
author=username,
since=datetime.datetime.now() - datetime.timedelta(days=30),
)
# get latest commit date
commit_date = datetime.datetime.now() - datetime.timedelta(days=30)
for commit in commits:
if commit.commit.author.date > commit_date:
commit_date = commit.commit.author.date
# since_date = datetime.datetime.now() - datetime.timedelta(days=30)
# commits = repo.get_commits(since=since_date, author="lukejagg")
repo_activity[repo] = commit_date
# print(repo, commits.totalCount)
logger.print(repo, commit_date)
sorted_repos = sorted(repo_activity, key=repo_activity.get, reverse=True)
sorted_repos = sorted_repos[:max_repos]
# For each repo, create a branch based on main branch, then create PR to main branch
for repo in sorted_repos:
try:
logger.print("Creating config for", repo.full_name)
create_config_pr(
None,
repo=repo,
cloned_repo=ClonedRepo(
repo_full_name=repo.full_name,
installation_id=installation_id,
token=user_token,
),
)
except SystemExit:
raise SystemExit
except Exception as e:
logger.print(e)
logger.print("Finished creating configs for top repos")
def create_gha_pr(g, repo):
# Create a new branch
branch_name = "sweep/gha-enable"
repo.create_git_ref(
ref=f"refs/heads/{branch_name}",
sha=repo.get_branch(repo.default_branch).commit.sha,
)
# Update the sweep.yaml file in this branch to add "gha_enabled: True"
sweep_yaml_content = (
repo.get_contents("sweep.yaml", ref=branch_name).decoded_content.decode()
+ "\ngha_enabled: True"
)
repo.update_file(
"sweep.yaml",
"Enable GitHub Actions",
sweep_yaml_content,
repo.get_contents("sweep.yaml", ref=branch_name).sha,
branch=branch_name,
)
# Create a PR from this branch to the main branch
pr = repo.create_pull(
title="Enable GitHub Actions",
body="This PR enables GitHub Actions for this repository.",
head=branch_name,
base=repo.default_branch,
)
return pr
SWEEP_TEMPLATE = """\
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
- type: input
id: branch
attributes:
label: Branch
description: The branch to work off of (optional)
placeholder: |

import copy
import re
import traceback
from pathlib import Path
from loguru import logger
from sweepai.agents.assistant_wrapper import (
client,
openai_assistant_call,
run_until_complete,
)
from sweepai.core.entities import AssistantRaisedException, FileChangeRequest, Message
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.progress import AssistantConversation, TicketProgress
system_message = r""" You are searching through a codebase to guide a junior developer on how to solve the user request. The junior developer will follow your instructions exactly and make the changes.
# User Request
{user_request}
# Guide
## Step 1: Unzip the file into /mnt/data/repo. Then list all root level directories. You must copy the below code verbatim into the file.
```python
import zipfile
import os
zip_path = '{file_path}'
extract_to_path = 'mnt/data/repo'
os.makedirs(extract_to_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to_path)
zip_contents = zip_ref.namelist()
root_dirs = {{name.split('/')[0] for name in zip_contents}}
print(f'Root directories: {{root_dirs}}')
```
## Step 2: Find the relevant files.
You can search by file name or by keyword search in the contents.
## Step 3: Find relevant lines.
1. Locate the lines of code that contain the identified keywords or are at the specified line number. You can use keyword search or manually look through the file 100 lines at a time.
2. Check the surrounding lines to establish the full context of the code block.
3. Adjust the starting line to include the entire functionality that needs to be refactored or moved.
4. Finally determine the exact line spans that include a logical and complete section of code to be edited.
```python
def print_lines_with_keyword(content, keywords):
max_matches=5
context = 10
matches = [i for i, line in enumerate(content.splitlines()) if any(keyword in line.lower() for keyword in keywords)]
print(f"Found {{len(matches)}} matches, but capping at {{max_match}}")
matches = matches[:max_matches]
expanded_matches = set()
for match in matches:
start = max(0, match - context)
end = min(len(content.splitlines()), match + context + 1)
for i in range(start, end):
expanded_matches.add(i)
for i in sorted(expanded_matches):
print(f"{{i}}: {{content.splitlines()[i]}}")
```
## Step 4: Construct a plan.
Provide the final plan to solve the issue, following these rules:
* DO NOT apply any changes here, they will not be persisted. You must provide the plan and the developer will apply the changes.
* You may only create new files and modify existing files.
* File paths should be relative paths from the root of the repo.
* Use the minimum number of create and modify operations required to solve the issue.
* Start and end lines indicate the exact start and end lines to edit. Expand this to encompass more lines if you're unsure where to make the exact edit.
Respond in the following format:
```xml
<plan>
<create_file file="file_path_1">
* Natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create_file>
...
<modify_file file="file_path_2" start_line="i" end_line="j">
* Natural language instructions for the modifications needed to solve the issue.
* Be concise and reference necessary files, imports and entity names.
...
</modify_file>
...
</plan>
```"""
@file_cache(ignore_params=["zip_path", "chat_logger", "ticket_progress"])
def new_planning(
request: str,
zip_path: str,
additional_messages: list[Message] = [],
chat_logger: ChatLogger | None = None,
assistant_id: str = None,
ticket_progress: TicketProgress | None = None,
) -> list[FileChangeRequest]:
planning_iterations = 3
try:
def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
assistant_conversation = AssistantConversation.from_ids(
assistant_id=assistant_id, run_id=run_id, thread_id=thread_id
)
if not assistant_conversation:
return
ticket_progress.planning_progress.assistant_conversation = (
assistant_conversation
)
ticket_progress.save()
logger.info("Uploading file...")
zip_file_object = client.files.create(file=Path(zip_path), purpose="assistants")
logger.info("Done uploading file.")
zip_file_id = zip_file_object.id
response = openai_assistant_call(
request=request,
assistant_id=assistant_id,
additional_messages=additional_messages,
uploaded_file_ids=[zip_file_id],
chat_logger=chat_logger,
save_ticket_progress=save_ticket_progress
if ticket_progress is not None
else None,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = response.run_id
thread_id = response.thread_id
for _ in range(planning_iterations):
save_ticket_progress(
assistant_id=response.assistant_id,
thread_id=response.thread_id,
run_id=response.run_id,
)
messages = response.messages
final_message = messages.data[0].content[0].text.value
fcrs = []
fcr_matches = list(
re.finditer(FileChangeRequest._regex, final_message, re.DOTALL)
)
if len(fcr_matches) > 0:
break
else:
client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content="A valid plan (within the <plan> tags) was not provided. Please continue working on the plan. If you are stuck, consider starting over.",
)
run = client.beta.threads.runs.create(
thread_id=response.thread_id,
assistant_id=response.assistant_id,
instructions=system_message.format(
user_request=request, file_path=f"mnt/data/{zip_file_id}"
),
)
run_id = run.id
messages = run_until_complete(
thread_id=thread_id,
run_id=run_id,
assistant_id=response.assistant_id,
)
for match_ in fcr_matches:
group_dict = match_.groupdict()
if group_dict["change_type"] == "create_file":
group_dict["change_type"] = "create"
if group_dict["change_type"] == "modify_file":
group_dict["change_type"] = "modify"
fcr = FileChangeRequest(**group_dict)
fcr.filename = fcr.filename.lstrip("/")
fcr.instructions = fcr.instructions.replace("\n*", "\n•")
fcr.instructions = fcr.instructions.strip("\n")
if fcr.instructions.startswith("*"):
fcr.instructions = "•" + fcr.instructions[1:]
fcrs.append(fcr)
new_file_change_request = copy.deepcopy(fcr)
new_file_change_request.change_type = "check"
new_file_change_request.parent = fcr
fcrs.append(new_file_change_request)
assert len(fcrs) > 0
return fcrs
except AssistantRaisedException as e:
raise e
except Exception as e:
logger.exception(e)
if chat_logger is not None:
discord_log_error(
str(e)
+ "\n\n"
+ traceback.format_exc()
+ "\n\n"
+ str(chat_logger.data)
)
return None
if __name__ == "__main__":
request = """## Title: replace the broken tutorial link in installation.md with https://docs.sweep.dev/usage/tutorial\n"""
additional_messages = [
Message(
role="user",
content='<relevant_snippets_in_repo>\n<snippet source="docs/pages/usage/tutorial.mdx:45-60">\n...\n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:30-45">\n...\n30: \n31: ![PR Comment](/tutorial/comment.png)\n32: \n33: c. If you have GitHub Actions set up, it will automatically run the linters, build, and tests and will show any failed logs to Sweep to handle. This only works with GitHub Actions and not other CI providers, so unfortunately for Vercel we have to copy paste manually.\n34: \n35: ![GitHub Actions](/tutorial/github_actions.png)\n36: \n37: 6. Once you are happy with the PR, you can merge it and it will be deployed to production via Vercel.\n38: \n39: \n40: ![Final](/tutorial/final.png)\n41: \n42: \n43: You can see the final example at https://github.com/kevinlu1248/docusaurus-2/pull/4 with preview https://docusaurus-2-ql4cskc5o-sweepai.vercel.app/.\n44: \n45: Now to be a Sweep power user, check out [Advanced: becoming a Sweep power user](https://docs.sweep.dev/usage/advanced).\n...\n</snippet>\n<snippet source="docs/installation.md:45-60">\n...\n45: * Provide any additional context that might be helpful, e.g. see "src/App.test.tsx" for an example of a good unit test.\n46: * For more guidance, visit [Advanced](https://docs.sweep.dev/usage/advanced), or watch the following video.\n47: \n48: [![Video](http://img.youtube.com/vi/Qn9vB71R4UM/0.jpg)](http://www.youtube.com/watch?v=Qn9vB71R4UM "Advanced Sweep Tricks and Feedback Tips")\n49: \n50: For configuring Sweep for your repo, see [Config](https://docs.sweep.dev/usage/config), especially for setting up Sweep Rules and Sweep Sweep.\n51: \n52: ## Limitations of Sweep (for now) ⚠️\n53: \n54: * 🗃️ **Gigantic repos**: >5000 files. We have default extensions and directories to exclude but sometimes this doesn\'t catch them all. You may need to block some directories (see [`blocked_dirs`](https://docs.sweep.dev/usage/config#blocked_dirs))\n55: * If Sweep is stuck at 0% for over 30 min and your repo has a few thousand files, let us know.\n56: \n57: * 🏗️ **Large-scale refactors**: >5 files or >300 lines of code changes (we\'re working on this!)\n58: * We can\'t do this - "Refactor entire codebase from Tensorflow to PyTorch"\n59: \n60: * 🖼️ **Editing images** and other non-text assets\n...\n</snippet>\n<snippet source="docs/pages/usage/tutorial.mdx:0-15">\n0: # Tutorial for Getting Started with Sweep\n1: \n2: We recommend using an existing **real project** for Sweep, but if you must start from scratch, we recommend **using a template**. In particular, we recommend Vercel templates and Vercel auto-deploy, since Vercel\'s auto-generated previews make it **easy to review Sweep\'s PRs**\n3: \n4: We\'ll use [Docusaurus](https://vercel.com/templates/next.js/docusaurus-2) since it\'s is the easiest to set up (no backend). To see other templates see https://vercel.com/templates.\n5: \n6: 1. Go to https://vercel.com/templates/next.js/docusaurus-2 (or another template) and click "Deploy".\n7: \n8: ![Deploy](/tutorial/deployment.png)\n9: \n10: 2. Vercel will prompt you to select a GitHub account and click "Clone" after. This will trigger a build and deploy which will take a few minutes. Once the build is done, you will be greeted with a congratulations message.\n11: \n12: ![Congratulations](/tutorial/congratulations.png)\n13: \n14: 3. Go to the [Sweep Installation](https://github.com/apps/sweep-ai) page and click the grey "Configure" button or the green "Install" button. Ensure that that the Vercel template (i.e. Docusaurus) is configured to use Sweep.\n...\n</snippet>\n</relevant_snippets_in_repo>\ndocs/\n installation.md\n docs/pages/\n docs/pages/usage/\n _meta.json\n advanced.mdx\n config.mdx\n extra-self-host.mdx\n sandbox.mdx\n tutorial.mdx',
name=None,
function_call=None,
key=None,
)
]
print(
new_planning(
request,
"/tmp/sweep_archive.zip",
chat_logger=ChatLogger(
{"username": "kevinlu1248", "title": "Unit test for planning"}
),
ticket_progress=TicketProgress(tracking_id="ed47605a38"),
)

import datetime
import difflib
import hashlib
import json
import os
import re
import shutil
import subprocess
import tempfile
import time
import traceback
from dataclasses import dataclass
from functools import cached_property
from typing import Any
import git
import requests
from github import Github, PullRequest, Repository, InputGitTreeElement
from jwt import encode
from loguru import logger
from sweepai.config.client import SweepConfig
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM, GITHUB_BOT_USERNAME
from sweepai.utils.tree_utils import DirectoryTree, remove_all_not_included
MAX_FILE_COUNT = 50
def make_valid_string(string: str):
pattern = r"[^\w./-]+"
return re.sub(pattern, "_", string)
def get_jwt():
signing_key = GITHUB_APP_PEM
app_id = GITHUB_APP_ID
payload = {"iat": int(time.time()), "exp": int(time.time()) + 600, "iss": app_id}
return encode(payload, signing_key, algorithm="RS256")
def get_token(installation_id: int):
if int(installation_id) < 0:
return os.environ["GITHUB_PAT"]
for timeout in [5.5, 5.5, 10.5]:
try:
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.post(
f"https://api.github.com/app/installations/{int(installation_id)}/access_tokens",
headers=headers,
)
obj = response.json()
if "token" not in obj:
logger.error(obj)
raise Exception("Could not get token")
return obj["token"]
except SystemExit:
raise SystemExit
except Exception:
time.sleep(timeout)
raise Exception(
"Could not get token, please double check your PRIVATE_KEY and GITHUB_APP_ID in the .env file. Make sure to restart uvicorn after."
)
def get_app():
jwt = get_jwt()
headers = {
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get("https://api.github.com/app", headers=headers)
return response.json()
def get_github_client(installation_id: int) -> tuple[str, Github]:
if not installation_id:
return os.environ["GITHUB_PAT"], Github(os.environ["GITHUB_PAT"])
token: str = get_token(installation_id)
return token, Github(token)
# fetch installation object
def get_installation(username: str):
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation, probably not installed")
def get_installation_id(username: str) -> str:
jwt = get_jwt()
try:
# Try user
response = requests.get(
f"https://api.github.com/users/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
obj = response.json()
return obj["id"]
except Exception:
# Try org
response = requests.get(
f"https://api.github.com/orgs/{username}/installation",
headers={
"Accept": "application/vnd.github+json",
"Authorization": "Bearer " + jwt,
"X-GitHub-Api-Version": "2022-11-28",
},
)
try:
obj = response.json()
return obj["id"]
except Exception as e:
logger.error(e)
logger.error(response.text)
raise Exception("Could not get installation id, probably not installed")
# commits multiple files in a single commit, returns the commit object
def commit_multi_file_changes(repo: Repository, file_changes: dict[str, str], commit_message: str, branch: str):
blobs_to_commit = []
# convert to blob
for path, content in file_changes.items():
blob = repo.create_git_blob(content, "utf-8")
blobs_to_commit.append(InputGitTreeElement(path=path, mode="100644", type="blob", sha=blob.sha))
latest_commit = repo.get_branch(branch).commit
base_tree = latest_commit.commit.tree
# create new git tree
new_tree = repo.create_git_tree(blobs_to_commit, base_tree=base_tree)
# commit the changes
parent = repo.get_git_commit(latest_commit.sha)
commit = repo.create_git_commit(
commit_message,
new_tree,
[parent],
)
# update ref of branch
ref = f"heads/{branch}"
repo.get_git_ref(ref).edit(sha=commit.sha)
return commit
REPO_CACHE_BASE_DIR = "/tmp/cache/repos"
@dataclass
class ClonedRepo:
repo_full_name: str
installation_id: str
branch: str | None = None
token: str | None = None
repo: Any | None = None
git_repo: git.Repo | None = None
class Config:
arbitrary_types_allowed = True
@cached_property
def cached_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
"base",
parse_collection_name(self.branch),
)
@cached_property
def zip_path(self):
logger.info("Zipping repository...")
shutil.make_archive(self.repo_dir, "zip", self.repo_dir)
logger.info("Done zipping")
return f"{self.repo_dir}.zip"
@cached_property
def repo_dir(self):
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.branch = self.branch or SweepConfig.get_branch(self.repo)
curr_time_str = str(time.time()).encode("utf-8")
hash_obj = hashlib.sha256(curr_time_str)
hash_hex = hash_obj.hexdigest()
if self.branch:
return os.path.join(
REPO_CACHE_BASE_DIR,
self.repo_full_name,
hash_hex,
parse_collection_name(self.branch),
)
else:
return os.path.join("/tmp/cache/repos", self.repo_full_name, hash_hex)
@property
def clone_url(self):
return (
f"https://x-access-token:{self.token}@github.com/{self.repo_full_name}.git"
)
def clone(self):
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
repo = git.Repo.clone_from(
self.clone_url, self.cached_dir, branch=self.branch
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Done cloning")
else:
try:
repo = git.Repo(self.cached_dir)
repo.remotes.origin.pull(
kill_after_timeout=60, progress=git.RemoteProgress()
)
except Exception:
logger.error("Could not pull repo")
shutil.rmtree(self.cached_dir, ignore_errors=True)
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
logger.info("Repo already cached, copying")
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
)
logger.info("Done copying")
repo = git.Repo(self.repo_dir)
return repo
def __post_init__(self):
subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.token = self.token or get_token(self.installation_id)
self.repo = (
Github(self.token).get_repo(self.repo_full_name)
if not self.repo
else self.repo
)
self.commit_hash = self.repo.get_commits()[0].sha
self.git_repo = self.clone()
self.branch = self.branch or SweepConfig.get_branch(self.repo)
def __del__(self):
try:
shutil.rmtree(self.repo_dir)
os.remove(self.zip_path)
return True
except Exception:
return False
def list_directory_tree(
self,
included_directories=None,
excluded_directories: list[str] = None,
included_files=None,
):
"""Display the directory tree.
Arguments:
root_directory -- String path of the root directory to display.
included_directories -- List of directory paths (relative to the root) to include in the tree. Default to None.
excluded_directories -- List of directory names to exclude from the tree. Default to None.
"""
root_directory = self.repo_dir
sweep_config: SweepConfig = SweepConfig()
# Default values if parameters are not provided
if included_directories is None:
included_directories = [] # gets all directories
if excluded_directories is None:
excluded_directories = sweep_config.exclude_dirs
def list_directory_contents(
current_directory: str,
excluded_directories: list[str],
indentation="",
):
"""Recursively list the contents of directories."""
file_and_folder_names = os.listdir(current_directory)
file_and_folder_names.sort()
directory_tree_string = ""
for name in file_and_folder_names[:MAX_FILE_COUNT]:
relative_path = os.path.join(current_directory, name)[
len(root_directory) + 1 :
]
if name in excluded_directories:
continue
complete_path = os.path.join(current_directory, name)
if os.path.isdir(complete_path):
directory_tree_string += f"{indentation}{relative_path}/\n"
directory_tree_string += list_directory_contents(
complete_path,
excluded_directories,
indentation + " ",
)
else:
directory_tree_string += f"{indentation}{name}\n"
# if os.path.isfile(complete_path) and relative_path in included_files:
# # Todo, use these to fetch neighbors
# ctags_str, names = get_ctags_for_file(ctags, complete_path)
# ctags_str = "\n".join([indentation + line for line in ctags_str.splitlines()])
# if ctags_str.strip():
# directory_tree_string += f"{ctags_str}\n"
return directory_tree_string
dir_obj = DirectoryTree()
directory_tree = list_directory_contents(root_directory, excluded_directories)
dir_obj.parse(directory_tree)
if included_directories:
dir_obj = remove_all_not_included(dir_obj, included_directories)
return directory_tree, dir_obj
def get_file_list(self) -> str:
root_directory = self.repo_dir
files = []
sweep_config: SweepConfig = SweepConfig()
def dfs_helper(directory):
nonlocal files
for item in os.listdir(directory):
if item == ".git":
continue
if item in sweep_config.exclude_dirs: # this saves a lot of time
continue
item_path = os.path.join(directory, item)
if os.path.isfile(item_path):
# make sure the item_path is not in one of the banned directories
if not sweep_config.is_file_excluded(item_path):
files.append(item_path) # Add the file to the list
elif os.path.isdir(item_path):
dfs_helper(item_path) # Recursive call to explore subdirectory
dfs_helper(root_directory)
files = [file[len(root_directory) + 1 :] for file in files]
return files
def get_file_contents(self, file_path, ref=None):
local_path = (
f"{self.repo_dir}{file_path}"
if file_path.startswith("/")
else f"{self.repo_dir}/{file_path}"
)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
def get_num_files_from_repo(self):
# subprocess.run(["git", "config", "--global", "http.postBuffer", "524288000"])
self.git_repo.git.checkout(self.branch)
file_list = self.get_file_list()
return len(file_list)
def get_commit_history(
self, username: str = "", limit: int = 200, time_limited: bool = True
):
commit_history = []
try:
if username != "":
commit_list = list(self.git_repo.iter_commits(author=username))
else:
commit_list = list(self.git_repo.iter_commits())
line_count = 0
cut_off_date = datetime.datetime.now() - datetime.timedelta(days=7)
for commit in commit_list:
# must be within a week
if time_limited and commit.authored_datetime.replace(
tzinfo=None
) <= cut_off_date.replace(tzinfo=None):
logger.info("Exceeded cut off date, stopping...")
break
repo = get_github_client(self.installation_id)[1].get_repo(
self.repo_full_name
)
branch = SweepConfig.get_branch(repo)
if branch not in self.git_repo.git.branch():
branch = f"origin/{branch}"
diff = self.git_repo.git.diff(commit, branch, unified=1)
lines = diff.count("\n")
# total diff lines must not exceed 200
if lines + line_count > limit:
logger.info(f"Exceeded {limit} lines of diff, stopping...")
break
commit_history.append(
f"<commit>\nAuthor: {commit.author.name}\nMessage: {commit.message}\n{diff}\n</commit>"
)
line_count += lines
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return commit_history
def get_similar_file_paths(self, file_path: str, limit: int = 10):
from rapidfuzz.fuzz import ratio
# Fuzzy search over file names
file_name = os.path.basename(file_path)
all_file_paths = self.get_file_list()
# filter for matching extensions if both have extensions
if "." in file_name:
all_file_paths = [
file
for file in all_file_paths
if "." in file and file.split(".")[-1] == file_name.split(".")[-1]
]
files_with_matching_name = []
files_without_matching_name = []
for file_path in all_file_paths:
if file_name in file_path:
files_with_matching_name.append(file_path)
else:
files_without_matching_name.append(file_path)
file_path_to_ratio = {file: ratio(file_name, file) for file in all_file_paths}
files_with_matching_name = sorted(
files_with_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
files_without_matching_name = sorted(
files_without_matching_name,
key=lambda file_path: file_path_to_ratio[file_path],
reverse=True,
)
# this allows 'config.py' to return 'sweepai/config/server.py', 'sweepai/config/client.py', 'sweepai/config/__init__.py' and no more
filtered_files_without_matching_name = list(filter(lambda file_path: file_path_to_ratio[file_path] > 50, files_without_matching_name))
all_files = files_with_matching_name + filtered_files_without_matching_name
return all_files[:limit]
# updates a file with new_contents, returns True if successful
def update_file(root_dir: str, file_path: str, new_contents: str):
local_path = os.path.join(root_dir, file_path)
try:
with open(local_path, "w") as f:
f.write(new_contents)
return True
except Exception as e:
logger.error(f"Failed to update file: {e}")
return False
@dataclass
class MockClonedRepo(ClonedRepo):
_repo_dir: str = ""
git_repo: git.Repo | None = None
def __init__(
self,
_repo_dir: str,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def from_dir(cls, repo_dir: str, **kwargs):
return cls(_repo_dir=repo_dir, **kwargs)
@property
def cached_dir(self):
return self._repo_dir
@property
def repo_dir(self):
return self._repo_dir
@property
def git_repo(self):
return git.Repo(self.repo_dir)
def clone(self):
return git.Repo(self.repo_dir)
def __post_init__(self):
return self
def __del__(self):
return True
@dataclass
class TemporarilyCopiedClonedRepo(MockClonedRepo):
tmp_dir: tempfile.TemporaryDirectory | None = None
def __init__(
self,
_repo_dir: str,
tmp_dir: tempfile.TemporaryDirectory,
repo_full_name: str,
installation_id: str = "",
branch: str | None = None,
token: str | None = None,
repo: Any | None = None,
git_repo: git.Repo | None = None,
):
self._repo_dir = _repo_dir
self.tmp_dir = tmp_dir
self.repo_full_name = repo_full_name
self.installation_id = installation_id
self.branch = branch
self.token = token
self.repo = repo
@classmethod
def copy_from_cloned_repo(cls, cloned_repo: ClonedRepo, **kwargs):
temp_dir = tempfile.TemporaryDirectory()
new_dir = temp_dir.name + "/" + cloned_repo.repo_full_name.split("/")[1]
print("Copying...")
shutil.copytree(cloned_repo.repo_dir, new_dir)
print("Done copying.")
return cls(
_repo_dir=new_dir,
tmp_dir=temp_dir,
repo_full_name=cloned_repo.repo_full_name,
installation_id=cloned_repo.installation_id,
branch=cloned_repo.branch,
token=cloned_repo.token,
repo=cloned_repo.repo,
**kwargs,
)
def __del__(self):
print(f"Dropping {self.tmp_dir.name}...")
shutil.rmtree(self._repo_dir, ignore_errors=True)
self.tmp_dir.cleanup()
print("Done.")
return True
def get_file_names_from_query(query: str) -> list[str]:
query_file_names = re.findall(r"\b[\w\-\.\/]*\w+\.\w{1,6}\b", query)
return [
query_file_name
for query_file_name in query_file_names
if len(query_file_name) > 3
]
def get_hunks(a: str, b: str, context=10):
differ = difflib.Differ()
diff = [
line
for line in differ.compare(a.splitlines(), b.splitlines())
if line[0] in ("+", "-", " ")
]
show = set()
hunks = []
for i, line in enumerate(diff):
if line.startswith(("+", "-")):
show.update(range(max(0, i - context), min(len(diff), i + context + 1)))
for i in range(len(diff)):
if i in show:
hunks.append(diff[i])
elif i - 1 in show:
hunks.append("...")
if len(hunks) > 0 and hunks[0] == "...":
hunks = hunks[1:]
if len(hunks) > 0 and hunks[-1] == "...":
hunks = hunks[:-1]
return "\n".join(hunks)
def parse_collection_name(name: str) -> str:
# Replace any non-alphanumeric characters with hyphens
name = re.sub(r"[^\w-]", "--", name)
# Ensure the name is between 3 and 63 characters and starts/ends with alphanumeric
name = re.sub(r"^(-*\w{0,61}\w)-*$", r"\1", name[:63].ljust(3, "x"))
return name
# set whether or not a pr is a draft, there is no way to do this using pygithub
def convert_pr_draft_field(pr: PullRequest, is_draft: bool = False):
pr_id = pr.raw_data['node_id']
# GraphQL mutation for marking a PR as ready for review
mutation = """
mutation MarkPRReady {
markPullRequestReadyForReview(input: {pullRequestId: {pull_request_id}}) {
pullRequest {
id
}
}
}
""".replace("{pull_request_id}", "\""+pr_id+"\"")
# GraphQL API URL
url = 'https://api.github.com/graphql'
# Headers
headers={
"Accept": "application/vnd.github+json",
"X-Github-Api-Version": "2022-11-28",
"Authorization": "Bearer " + os.environ["GITHUB_PAT"],
}
# Prepare the JSON payload
json_data = {
'query': mutation,
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(json_data))
if response.status_code != 200:
logger.error(f"Failed to convert PR to {'draft' if is_draft else 'open'}")
return False
return True
try:
g = Github(os.environ.get("GITHUB_PAT"))
CURRENT_USERNAME = g.get_user().login
except Exception:
try:
slug = get_app()["slug"]
CURRENT_USERNAME = f"{slug}[bot]"
except Exception:
CURRENT_USERNAME = GITHUB_BOT_USERNAME
if __name__ == "__main__":
try:
organization_name = "sweepai"
sweep_config = SweepConfig()
installation_id = get_installation_id(organization_name)
user_token, g = get_github_client(installation_id)
cloned_repo = ClonedRepo("sweepai/sweep", installation_id, "main")
dir_ojb = cloned_repo.list_directory_tree()
commit_history = cloned_repo.get_commit_history()
similar_file_paths = cloned_repo.get_similar_file_paths("config.py")
# ensure no similar file_paths are sweep excluded
assert(not any([file for file in similar_file_paths if sweep_config.is_file_excluded(file)]))
print(f"similar_file_paths: {similar_file_paths}")
str1 = "a\nline1\nline2\nline3\nline4\nline5\nline6\ntest\n"
str2 = "a\nline1\nlineTwo\nline3\nline4\nline5\nlineSix\ntset\n"
print(get_hunks(str1, str2, 1))
mocked_repo = MockClonedRepo.from_dir(
cloned_repo.repo_dir,
repo_full_name="sweepai/sweep",
)
temp_repo = TemporarilyCopiedClonedRepo.copy_from_cloned_repo(mocked_repo)
print(f"mocked repo: {mocked_repo}")
except Exception as e:

sweep/sweepai/api.py

Lines 1 to 1178 in 0277fad

from __future__ import annotations
import ctypes
import json
import threading
import time
from typing import Any, Optional
import requests
from fastapi import (
Body,
FastAPI,
Header,
HTTPException,
Path,
Request,
Security,
status,
)
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.templating import Jinja2Templates
from github.Commit import Commit
from sweepai.config.client import (
DEFAULT_RULES,
RESTART_SWEEP_BUTTON,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
SWEEP_BAD_FEEDBACK,
SWEEP_GOOD_FEEDBACK,
SweepConfig,
get_gha_enabled,
get_rules,
)
from sweepai.config.server import (
BLACKLISTED_USERS,
DISABLED_REPOS,
DISCORD_FEEDBACK_WEBHOOK_URL,
ENV,
GHA_AUTOFIX_ENABLED,
GITHUB_BOT_USERNAME,
GITHUB_LABEL_COLOR,
GITHUB_LABEL_DESCRIPTION,
GITHUB_LABEL_NAME,
IS_SELF_HOSTED,
MERGE_CONFLICT_ENABLED,
)
from sweepai.core.entities import PRChangeRequest
from sweepai.global_threads import global_threads
from sweepai.handlers.create_pr import ( # type: ignore
add_config_to_top_repos,
create_gha_pr,
)
from sweepai.handlers.on_button_click import handle_button_click
from sweepai.handlers.on_check_suite import ( # type: ignore
clean_gh_logs,
download_logs,
on_check_suite,
)
from sweepai.handlers.on_comment import on_comment
from sweepai.handlers.on_jira_ticket import handle_jira_ticket
from sweepai.handlers.on_merge import on_merge
from sweepai.handlers.on_merge_conflict import on_merge_conflict
from sweepai.handlers.on_ticket import on_ticket
from sweepai.handlers.pr_utils import make_pr
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import (
Button,
ButtonList,
check_button_activated,
check_button_title_match,
)
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import logger, posthog
from sweepai.utils.github_utils import CURRENT_USERNAME, get_github_client
from sweepai.utils.progress import TicketProgress
from sweepai.utils.safe_pqueue import SafePriorityQueue
from sweepai.utils.str_utils import BOT_SUFFIX, get_hash
from sweepai.web.events import (
CheckRunCompleted,
CommentCreatedRequest,
InstallationCreatedRequest,
IssueCommentRequest,
IssueRequest,
PREdited,
PRRequest,
ReposAddedRequest,
)
from sweepai.web.health import health_check
app = FastAPI()
events = {}
on_ticket_events = {}
security = HTTPBearer()
templates = Jinja2Templates(directory="sweepai/web")
# version_command = r"""git config --global --add safe.directory /app
# timestamp=$(git log -1 --format="%at")
# date -d "@$timestamp" +%y.%m.%d.%H 2>/dev/null || date -r "$timestamp" +%y.%m.%d.%H"""
# try:
# version = subprocess.check_output(version_command, shell=True, text=True).strip()
# except Exception:
version = time.strftime("%y.%m.%d.%H")
logger.bind(application="webhook")
def auth_metrics(credentials: HTTPAuthorizationCredentials = Security(security)):
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication scheme.",
)
if credentials.credentials != "example_token": # grafana requires authentication
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
return True
def run_on_ticket(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="ticket_" + kwargs["username"],
tracking_id=tracking_id,
):
return on_ticket(*args, **kwargs, tracking_id=tracking_id)
def run_on_comment(*args, **kwargs):
tracking_id = get_hash()
with logger.contextualize(
**kwargs,
name="comment_" + kwargs["username"],
tracking_id=tracking_id,
):
on_comment(*args, **kwargs, tracking_id=tracking_id)
def run_on_button_click(*args, **kwargs):
thread = threading.Thread(target=handle_button_click, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def run_on_check_suite(*args, **kwargs):
request = kwargs["request"]
pr_change_request = on_check_suite(request)
if pr_change_request:
call_on_comment(**pr_change_request.params, comment_type="github_action")
logger.info("Done with on_check_suite")
else:
logger.info("Skipping on_check_suite as no pr_change_request was returned")
def terminate_thread(thread):
"""Terminate a python threading.Thread."""
try:
if not thread.is_alive():
return
exc = ctypes.py_object(SystemExit)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), exc
)
if res == 0:
raise ValueError("Invalid thread ID")
elif res != 1:
# Call with exception set to 0 is needed to cleanup properly.
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to terminate thread: {e}")
# def delayed_kill(thread: threading.Thread, delay: int = 60 * 60):
# time.sleep(delay)
# terminate_thread(thread)
def call_on_ticket(*args, **kwargs):
global on_ticket_events
key = f"{kwargs['repo_full_name']}-{kwargs['issue_number']}" # Full name, issue number as key
# Use multithreading
# Check if a previous process exists for the same key, cancel it
e = on_ticket_events.get(key, None)
if e:
logger.info(f"Found previous thread for key {key} and cancelling it")
terminate_thread(e)
thread = threading.Thread(target=run_on_ticket, args=args, kwargs=kwargs)
on_ticket_events[key] = thread
thread.start()
global_threads.append(thread)
def call_on_check_suite(*args, **kwargs):
kwargs["request"].repository.full_name
kwargs["request"].check_run.pull_requests[0].number
thread = threading.Thread(target=run_on_check_suite, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
def call_on_comment(
*args, **kwargs
): # TODO: if its a GHA delete all previous GHA and append to the end
def worker():
while not events[key].empty():
task_args, task_kwargs = events[key].get()
run_on_comment(*task_args, **task_kwargs)
global events
repo_full_name = kwargs["repo_full_name"]
pr_id = kwargs["pr_number"]
key = f"{repo_full_name}-{pr_id}" # Full name, comment number as key
comment_type = kwargs["comment_type"]
logger.info(f"Received comment type: {comment_type}")
if key not in events:
events[key] = SafePriorityQueue()
events[key].put(0, (args, kwargs))
# If a thread isn't running, start one
if not any(
thread.name == key and thread.is_alive() for thread in threading.enumerate()
):
thread = threading.Thread(target=worker, name=key)
thread.start()
global_threads.append(thread)
def call_on_merge(*args, **kwargs):
thread = threading.Thread(target=on_merge, args=args, kwargs=kwargs)
thread.start()
global_threads.append(thread)
@app.get("/health")
def redirect_to_health():
return health_check()
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
name="index.html", context={"version": version, "request": request}
)
@app.get("/ticket_progress/{tracking_id}")
def progress(tracking_id: str = Path(...)):
ticket_progress = TicketProgress.load(tracking_id)
return ticket_progress.dict()
def init_hatchet() -> Any | None:
try:
from hatchet_sdk import Context, Hatchet
hatchet = Hatchet(debug=True)
worker = hatchet.worker("github-worker")
@hatchet.workflow(on_events=["github:webhook"])
class OnGithubEvent:
"""Workflow for handling GitHub events."""
@hatchet.step()
def run(self, context: Context):
event_payload = context.workflow_input()
request_dict = event_payload.get("request")
event = event_payload.get("event")
handle_event(request_dict, event)
workflow = OnGithubEvent()
worker.register_workflow(workflow)
# start worker in the background
thread = threading.Thread(target=worker.start)
thread.start()
global_threads.append(thread)
return hatchet
except Exception as e:
print(f"Failed to initialize Hatchet: {e}, continuing with local mode")
return None
# hatchet = init_hatchet()
def handle_github_webhook(event_payload):
# if hatchet:
# hatchet.client.event.push("github:webhook", event_payload)
# else:
handle_event(event_payload.get("request"), event_payload.get("event"))
def handle_request(request_dict, event=None):
"""So it can be exported to the listen endpoint."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action")
try:
# Send the event to Hatchet
handle_github_webhook(
{
"request": request_dict,
"event": event,
}
)
except Exception as e:
logger.exception(f"Failed to send event to Hatchet: {e}")
# try:
# worker()
# except Exception as e:
# discord_log_error(str(e), priority=1)
logger.info(f"Done handling {event}, {action}")
return {"success": True}
@app.post("/")
def webhook(
request_dict: dict = Body(...),
x_github_event: Optional[str] = Header(None, alias="X-GitHub-Event"),
):
"""Handle a webhook request from GitHub."""
with logger.contextualize(tracking_id="main", env=ENV):
action = request_dict.get("action", None)
logger.info(f"Received event: {x_github_event}, {action}")
return handle_request(request_dict, event=x_github_event)
@app.post("/jira")
def jira_webhook(
request_dict: dict = Body(...),
) -> None:
def call_jira_ticket(*args, **kwargs):
thread = threading.Thread(target=handle_jira_ticket, args=args, kwargs=kwargs)
thread.start()
call_jira_ticket(event=request_dict)
# Set up cronjob for this
@app.get("/update_sweep_prs_v2")
def update_sweep_prs_v2(repo_full_name: str, installation_id: int):
# Get a Github client
_, g = get_github_client(installation_id)
# Get the repository
repo = g.get_repo(repo_full_name)
config = SweepConfig.get_config(repo)
try:
branch_ttl = int(config.get("branch_ttl", 7))
except Exception:
branch_ttl = 7
branch_ttl = max(branch_ttl, 1)
# Get all open pull requests created by Sweep
pulls = repo.get_pulls(
state="open", head="sweep", sort="updated", direction="desc"
)[:5]
# For each pull request, attempt to merge the changes from the default branch into the pull request branch
try:
for pr in pulls:
try:
# make sure it's a sweep ticket
feature_branch = pr.head.ref
if not feature_branch.startswith(
"sweep/"
) and not feature_branch.startswith("sweep_"):
continue
if "Resolve merge conflicts" in pr.title:
continue
if (
pr.mergeable_state != "clean"
and (time.time() - pr.created_at.timestamp()) > 60 * 60 * 24
and pr.title.startswith("[Sweep Rules]")
):
pr.edit(state="closed")
continue
repo.merge(
feature_branch,
pr.base.ref,
f"Merge main into {feature_branch}",
)
# Check if the merged PR is the config PR
if pr.title == "Configure Sweep" and pr.merged:
# Create a new PR to add "gha_enabled: True" to sweep.yaml
create_gha_pr(g, repo)
except Exception as e:
logger.warning(
f"Failed to merge changes from default branch into PR #{pr.number}: {e}"
)
except Exception:
logger.warning("Failed to update sweep PRs")
def handle_event(request_dict, event):
action = request_dict.get("action")
if repo_full_name := request_dict.get("repository", {}).get("full_name"):
if repo_full_name in DISABLED_REPOS:
logger.warning(f"Repo {repo_full_name} is disabled")
return {"success": False, "error_message": "Repo is disabled"}
with logger.contextualize(tracking_id="main", env=ENV):
match event, action:
case "check_run", "completed":
request = CheckRunCompleted(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pull_requests = request.check_run.pull_requests
if pull_requests:
logger.info(pull_requests[0].number)
pr = repo.get_pull(pull_requests[0].number)
if (time.time() - pr.created_at.timestamp()) > 60 * 60 and (
pr.title.startswith("[Sweep Rules]")
or pr.title.startswith("[Sweep GHA Fix]")
):
after_sha = pr.head.sha
commit = repo.get_commit(after_sha)
check_suites = commit.get_check_suites()
for check_suite in check_suites:
if check_suite.conclusion == "failure":
pr.edit(state="closed")
break
if (
not (time.time() - pr.created_at.timestamp()) > 60 * 15
and request.check_run.conclusion == "failure"
and pr.state == "open"
and get_gha_enabled(repo)
and len(
[
comment
for comment in pr.get_issue_comments()
if "Fixing PR" in comment.body
]
)
< 2
and GHA_AUTOFIX_ENABLED
):
# check if the base branch is passing
commits = repo.get_commits(sha=pr.base.ref)
latest_commit: Commit = commits[0]
if all(
status != "failure"
for status in [
status.state for status in latest_commit.get_statuses()
]
): # base branch is passing
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
tracking_id = get_hash()
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
stack_pr(
request=f"[Sweep GHA Fix] The GitHub Actions run failed on {request.check_run.head_sha[:7]} ({repo.default_branch}) with the following error logs:\n\n```\n\n{logs}\n\n```",
pr_number=pr.number,
username=attributor,
repo_full_name=repo.full_name,
installation_id=request.installation.id,
tracking_id=tracking_id,
commit_hash=pr.head.sha,
)
elif (
request.check_run.check_suite.head_branch == repo.default_branch
and get_gha_enabled(repo)
and GHA_AUTOFIX_ENABLED
):
if request.check_run.conclusion == "failure":
commit = repo.get_commit(request.check_run.head_sha)
attributor = request.sender.login
if attributor.endswith("[bot]"):
attributor = commit.author.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id,
)
logs, user_message = clean_gh_logs(logs)
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
make_pr(
title=f"[Sweep GHA Fix] Fix the failing GitHub Actions on {request.check_run.head_sha[:7]} ({repo.default_branch})",
repo_description=repo.description,
summary=f"The GitHub Actions run failed with the following error logs:\n\n```\n{logs}\n```",
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
user_token=None,
use_faster_model=chat_logger.use_faster_model(),
username=attributor,
chat_logger=chat_logger,
)
case "pull_request", "opened":
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
pr = repo.get_pull(request_dict["pull_request"]["number"])
# if the pr already has a comment from sweep bot do nothing
time.sleep(10)
if any(
comment.user.login == GITHUB_BOT_USERNAME
for comment in pr.get_issue_comments()
) or pr.title.startswith("Sweep:"):
return {
"success": True,
"reason": "PR already has a comment from sweep bot",
}
rule_buttons = []
repo_rules = get_rules(repo) or []
if repo_rules != [""] and repo_rules != []:
for rule in repo_rules or []:
if rule:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if len(repo_rules) == 0:
for rule in DEFAULT_RULES:
rule_buttons.append(Button(label=f"{RULES_LABEL} {rule}"))
if rule_buttons:
rules_buttons_list = ButtonList(
buttons=rule_buttons, title=RULES_TITLE
)
pr.create_issue_comment(rules_buttons_list.serialize() + BOT_SUFFIX)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if chat_logger.use_faster_model() and not IS_SELF_HOSTED:
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=attributor,
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "issues", "opened":
request = IssueRequest(**request_dict)
issue_title_lower = request.issue.title.lower()
if (
issue_title_lower.startswith("sweep")
or "sweep:" in issue_title_lower
):
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
labels = repo.get_labels()
label_names = [label.name for label in labels]
if GITHUB_LABEL_NAME not in label_names:
repo.create_label(
name=GITHUB_LABEL_NAME,
color=GITHUB_LABEL_COLOR,
description=GITHUB_LABEL_DESCRIPTION,
)
current_issue = repo.get_issue(number=request.issue.number)
current_issue.add_to_labels(GITHUB_LABEL_NAME)
case "issue_comment", "edited":
request = IssueCommentRequest(**request_dict)
sweep_labeled_issue = GITHUB_LABEL_NAME in [
label.name.lower() for label in request.issue.labels
]
button_title_match = check_button_title_match(
REVERT_CHANGED_FILES_TITLE,
request.comment.body,
request.changes,
) or check_button_title_match(
RULES_TITLE,
request.comment.body,
request.changes,
)
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and button_title_match
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
run_on_button_click(request_dict)
restart_sweep = False
if (
request.comment.user.type == "Bot"
and GITHUB_BOT_USERNAME in request.comment.user.login
and request.changes.body_from is not None
and check_button_activated(
RESTART_SWEEP_BUTTON,
request.comment.body,
request.changes,
)
and sweep_labeled_issue
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
):
# Restart Sweep on this issue
restart_sweep = True
if (
request.issue is not None
and sweep_labeled_issue
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.comment.user.login.startswith("sweep")
and not (
request.issue.pull_request and request.issue.pull_request.url
)
or restart_sweep
):
logger.info("New issue comment edited")
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
and not restart_sweep
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id if not restart_sweep else None,
edited=True,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
): # TODO(sweep): set a limit
logger.info(f"Handling comment on PR: {request.issue.pull_request}")
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
) and BOT_SUFFIX not in comment:
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "issues", "edited":
request = IssueRequest(**request_dict)
if (
GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.sender.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not request.sender.login.startswith("sweep")
):
logger.info("New issue edited")
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
else:
logger.info("Issue edited, but not a sweep issue")
case "issues", "labeled":
request = IssueRequest(**request_dict)
if (
any(
label.name.lower() == GITHUB_LABEL_NAME
for label in request.issue.labels
)
and not request.issue.pull_request
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
)
case "issue_comment", "created":
request = IssueCommentRequest(**request_dict)
if (
request.issue is not None
and GITHUB_LABEL_NAME
in [label.name.lower() for label in request.issue.labels]
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and not (
request.issue.pull_request and request.issue.pull_request.url
)
and BOT_SUFFIX not in request.comment.body
):
request.issue.body = request.issue.body or ""
request.repository.description = (
request.repository.description or ""
)
if (
not request.comment.body.strip()
.lower()
.startswith(GITHUB_LABEL_NAME)
):
logger.info("Comment does not start with 'Sweep', passing")
return {
"success": True,
"reason": "Comment does not start with 'Sweep', passing",
}
call_on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.issue.user.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=request.comment.id,
)
elif (
request.issue.pull_request
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in request.comment.body
): # TODO(sweep): set a limit
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.issue.number)
labels = pr.get_labels()
comment = request.comment.body
if (
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": None,
"pr_line_position": None,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.issue.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "created":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "pull_request_review_comment", "edited":
request = CommentCreatedRequest(**request_dict)
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
labels = pr.get_labels()
comment = request.comment.body
if (
(
comment.lower().startswith("sweep:")
or any(label.name.lower() == "sweep" for label in labels)
)
and request.comment.user.type == "User"
and request.comment.user.login not in BLACKLISTED_USERS
and BOT_SUFFIX not in comment
):
pr_change_request = PRChangeRequest(
params={
"comment_type": "comment",
"repo_full_name": request.repository.full_name,
"repo_description": request.repository.description,
"comment": request.comment.body,
"pr_path": request.comment.path,
"pr_line_position": request.comment.original_line,
"username": request.comment.user.login,
"installation_id": request.installation.id,
"pr_number": request.pull_request.number,
"comment_id": request.comment.id,
},
)
call_on_comment(**pr_change_request.params)
case "installation_repositories", "added":
repos_added_request = ReposAddedRequest(**request_dict)
metadata = {
"installation_id": repos_added_request.installation.id,
"repositories": [
repo.full_name
for repo in repos_added_request.repositories_added
],
}
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories_added,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
posthog.capture(
"installation_repositories",
"started",
properties={**metadata},
)
for repo in repos_added_request.repositories_added:
organization, repo_name = repo.full_name.split("/")
posthog.capture(
organization,
"installed_repository",
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": repo.full_name,
},
)
case "installation", "created":
repos_added_request = InstallationCreatedRequest(**request_dict)
try:
add_config_to_top_repos(
repos_added_request.installation.id,
repos_added_request.installation.account.login,
repos_added_request.repositories,
)
except Exception as e:
logger.exception(f"Failed to add config to top repos: {e}")
case "pull_request", "edited":
request = PREdited(**request_dict)
if (
request.pull_request.user.login == GITHUB_BOT_USERNAME
and not request.sender.login.endswith("[bot]")
and DISCORD_FEEDBACK_WEBHOOK_URL is not None
):
good_button = check_button_activated(
SWEEP_GOOD_FEEDBACK,
request.pull_request.body,
request.changes,
)
bad_button = check_button_activated(
SWEEP_BAD_FEEDBACK,
request.pull_request.body,
request.changes,
)
if good_button or bad_button:
emoji = "😕"
if good_button:
emoji = "👍"
elif bad_button:
emoji = "👎"
data = {
"content": f"{emoji} {request.pull_request.html_url} ({request.sender.login})\n{request.pull_request.commits} commits, {request.pull_request.changed_files} files: +{request.pull_request.additions}, -{request.pull_request.deletions}"
}
headers = {"Content-Type": "application/json"}
requests.post(
DISCORD_FEEDBACK_WEBHOOK_URL,
data=json.dumps(data),
headers=headers,
)
# Send feedback to PostHog
posthog.capture(
request.sender.login,
"feedback",
properties={
"repo_name": request.repository.full_name,
"pr_url": request.pull_request.html_url,
"pr_commits": request.pull_request.commits,
"pr_additions": request.pull_request.additions,
"pr_deletions": request.pull_request.deletions,
"pr_changed_files": request.pull_request.changed_files,
"username": request.sender.login,
"good_button": good_button,
"bad_button": bad_button,
},
)
def remove_buttons_from_description(body):
"""
Replace:
### PR Feedback...
...
# (until it hits the next #)
with
### PR Feedback: {emoji}
#
"""
lines = body.split("\n")
if not lines[0].startswith("### PR Feedback"):
return None
# Find when the second # occurs
i = 0
for i, line in enumerate(lines):
if line.startswith("#") and i > 0:
break
return "\n".join(
[
f"### PR Feedback: {emoji}",
*lines[i:],
]
)
# Update PR description to remove buttons
try:
_, g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
pr = repo.get_pull(request.pull_request.number)
new_body = remove_buttons_from_description(
request.pull_request.body
)
if new_body is not None:
pr.edit(body=new_body)
except SystemExit:
raise SystemExit
except Exception as e:
logger.exception(f"Failed to edit PR description: {e}")
case "pull_request", "closed":
pr_request = PRRequest(**request_dict)
(
organization,
repo_name,
) = pr_request.repository.full_name.split("/")
commit_author = pr_request.pull_request.user.login
merged_by = (
pr_request.pull_request.merged_by.login
if pr_request.pull_request.merged_by
else None
)
if CURRENT_USERNAME == commit_author and merged_by is not None:
event_name = "merged_sweep_pr"
if pr_request.pull_request.title.startswith("[config]"):
event_name = "config_pr_merged"
elif pr_request.pull_request.title.startswith("[Sweep Rules]"):
event_name = "sweep_rules_pr_merged"
edited_by_developers = False
_token, g = get_github_client(pr_request.installation.id)
pr = g.get_repo(pr_request.repository.full_name).get_pull(
pr_request.number
)
total_lines_in_commit = 0
total_lines_edited_by_developer = 0
edited_by_developers = False
for commit in pr.get_commits():
lines_modified = commit.stats.additions + commit.stats.deletions
total_lines_in_commit += lines_modified
if commit.author.login != CURRENT_USERNAME:
total_lines_edited_by_developer += lines_modified
# this was edited by a developer if at least 25% of the lines were edited by a developer
edited_by_developers = total_lines_in_commit > 0 and (total_lines_edited_by_developer / total_lines_in_commit) >= 0.25
posthog.capture(
merged_by,
event_name,
properties={
"repo_name": repo_name,
"organization": organization,
"repo_full_name": pr_request.repository.full_name,
"username": merged_by,
"additions": pr_request.pull_request.additions,
"deletions": pr_request.pull_request.deletions,
"total_changes": pr_request.pull_request.additions
+ pr_request.pull_request.deletions,
"edited_by_developers": edited_by_developers,
"total_lines_in_commit": total_lines_in_commit,
"total_lines_edited_by_developer": total_lines_edited_by_developer,
},
)
chat_logger = ChatLogger({"username": merged_by})
case "push", None:
if event != "pull_request" or request_dict["base"]["merged"] is True:
chat_logger = ChatLogger(
{"username": request_dict["pusher"]["name"]}
)
# on merge
call_on_merge(request_dict, chat_logger)
ref = request_dict["ref"] if "ref" in request_dict else ""
if ref.startswith("refs/heads") and not ref.startswith(
"ref/heads/sweep"
):
_, g = get_github_client(request_dict["installation"]["id"])
repo = g.get_repo(request_dict["repository"]["full_name"])
if ref[len("refs/heads/") :] == SweepConfig.get_branch(repo):
update_sweep_prs_v2(
request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
)
if ref.startswith("refs/heads"):
branch_name = ref[len("refs/heads/") :]
# Check if the branch has an associated PR
org_name, repo_name = request_dict["repository"][
"full_name"
].split("/")
pulls = repo.get_pulls(
state="open",
sort="created",
head=org_name + ":" + branch_name,
)
for pr in pulls:
logger.info(
f"PR associated with branch {branch_name}: #{pr.number} - {pr.title}"
)
if pr.mergeable is False and MERGE_CONFLICT_ENABLED:
attributor = pr.user.login
if attributor.endswith("[bot]"):
attributor = pr.assignee.login
if attributor.endswith("[bot]"):
return {
"success": False,
"error_message": "The PR was created by a bot, so I won't attempt to fix it.",
}
chat_logger = ChatLogger(
data={
"username": attributor,
"title": "[Sweep GHA Fix] Fix the failing GitHub Actions",
}
)
if (
chat_logger.use_faster_model()
and not IS_SELF_HOSTED
):
return {
"success": False,
"error_message": "Disabled for free users",
}
on_merge_conflict(
pr_number=pr.number,
username=pr.user.login,
repo_full_name=request_dict["repository"][
"full_name"
],
installation_id=request_dict["installation"]["id"],
tracking_id=get_hash(),
)
case "ping", None:
return {"message": "pong"}
case _:

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
os.environ["TRANSFORMERS_CACHE"] = os.environ.get(
"TRANSFORMERS_CACHE", "/tmp/cache/model"
) # vector_db.py
os.environ["TIKTOKEN_CACHE_DIR"] = os.environ.get(
"TIKTOKEN_CACHE_DIR", "/tmp/cache/tiktoken"
) # utils.py
SENTENCE_TRANSFORMERS_MODEL = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2", # "all-mpnet-base-v2"
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
# ENV = os.environ.get("MODAL_ENVIRONMENT", "dev")
# ENV = PREFIX
# ENVIRONMENT = PREFIX
DB_MODAL_INST_NAME = "db"
DOCS_MODAL_INST_NAME = "docs"
API_MODAL_INST_NAME = "api"
UTILS_MODAL_INST_NAME = "utils"
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
E2B_API_KEY = os.environ.get("E2B_API_KEY")
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ACTIVELOOP_TOKEN = os.environ.get("ACTIVELOOP_TOKEN", None)
VECTOR_EMBEDDING_SOURCE = os.environ.get(
"VECTOR_EMBEDDING_SOURCE", "openai"
) # Alternate option is openai or huggingface and set the corresponding env vars
BASERUN_API_KEY = os.environ.get("BASERUN_API_KEY", None)
# Huggingface settings, only checked if VECTOR_EMBEDDING_SOURCE == "huggingface"
HUGGINGFACE_URL = os.environ.get("HUGGINGFACE_URL", None)
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Replicate settings, only checked if VECTOR_EMBEDDING_SOURCE == "replicate"
REPLICATE_API_KEY = os.environ.get("REPLICATE_API_KEY", None)
REPLICATE_URL = os.environ.get("REPLICATE_URL", None)
REPLICATE_DEPLOYMENT_URL = os.environ.get("REPLICATE_DEPLOYMENT_URL", None)
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_API_KEY = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_KEY", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
OPENAI_API_ENGINE_GPT4_32K = os.environ.get("OPENAI_API_ENGINE_GPT4_32K", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_32K_MODEL = os.environ.get("DEFAULT_GPT4_32K_MODEL", "gpt-4-0125-preview")
DEFAULT_GPT35_MODEL = os.environ.get("DEFAULT_GPT35_MODEL", "gpt-3.5-turbo-1106")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
USE_ASSISTANT = os.environ.get("USE_ASSISTANT", "true").lower() == "true"
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 32 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)

import re
from dataclasses import dataclass
from functools import lru_cache
from rapidfuzz import fuzz
from tqdm import tqdm
from sweepai.logn import file_cache
from loguru import logger
@lru_cache()
def score_line(str1: str, str2: str) -> float:
if str1 == str2:
return 100
if str1.lstrip() == str2.lstrip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 90 - whitespace_ratio * 10
return max(score, 0)
if str1.strip() == str2.strip():
whitespace_ratio = abs(len(str1) - len(str2)) / (len(str1) + len(str2))
score = 80 - whitespace_ratio * 10
return max(score, 0)
levenshtein_ratio = fuzz.ratio(str1, str2)
score = 85 * (levenshtein_ratio / 100)
return max(score, 0)
def match_without_whitespace(str1: str, str2: str) -> bool:
return str1.strip() == str2.strip()
def line_cost(line: str) -> float:
if line.strip() == "":
return 50
if line.strip().startswith("#") or line.strip().startswith("//"):
return 50 + len(line) / (len(line) + 1) * 30
return len(line) / (len(line) + 1) * 100
def score_multiline(query: list[str], target: list[str]) -> float:
# TODO: add weighting on first and last lines
q, t = 0, 0 # indices for query and target
scores: list[tuple[float, float]] = []
skipped_comments = 0
def get_weight(q: int) -> float:
# Prefers lines at beginning and end of query
# Sequence: 1, 2/3, 1/2, 2/5...
index = min(q, len(query) - q)
return 100 / (index / 2 + 1)
while q < len(query) and t < len(target):
q_line = query[q]
t_line = target[t]
weight = get_weight(q)
if match_without_whitespace(q_line, t_line):
# Case 1: lines match
scores.append((score_line(q_line, t_line), weight))
q += 1
t += 1
elif q_line.strip().startswith("...") or q_line.strip().endswith("..."):
# Case 3: ellipsis wildcard
t += 1
if q + 1 == len(query):
scores.append((100 - (len(target) - t), weight))
q += 1
t = len(target)
break
max_score = 0
# Radix optimization
indices = [
t + i
for i, line in enumerate(target[t:])
if match_without_whitespace(line, query[q + 1])
]
if not indices:
# logger.warning(f"Could not find whitespace match, using brute force")
indices = range(t, len(target))
for i in indices:
score, weight = score_multiline(query[q + 1 :], target[i:]), (
100 - (i - t) / len(target) * 10
)
new_scores = scores + [(score, weight)]
total_score = sum(
[value * weight for value, weight in new_scores]
) / sum([weight for _, weight in new_scores])
max_score = max(max_score, total_score)
return max_score
elif (
t_line.strip() == ""
or t_line.strip().startswith("#")
or t_line.strip().startswith("//")
or t_line.strip().startswith("print")
or t_line.strip().startswith("logger")
or t_line.strip().startswith("console.")
):
# Case 2: skipped comment
skipped_comments += 1
t += 1
scores.append((90, weight))
else:
break
if q < len(query):
scores.extend(
(100 - line_cost(line), get_weight(index))
for index, line in enumerate(query[q:])
)
if t < len(target):
scores.extend(
(100 - line_cost(line), 100) for index, line in enumerate(target[t:])
)
final_score = (
sum([value * weight for value, weight in scores])
/ sum([weight for _, weight in scores])
if scores
else 0
)
final_score *= 1 - 0.05 * skipped_comments
return final_score
@dataclass
class Match:
start: int
end: int
score: float
indent: str = ""
def __gt__(self, other):
return self.score > other.score
def get_indent_type(content: str):
two_spaces = len(re.findall(r"\n {2}[^ ]", content))
four_spaces = len(re.findall(r"\n {4}[^ ]", content))
return " " if two_spaces > four_spaces else " "
def get_max_indent(content: str, indent_type: str):
return max(len(line) - len(line.lstrip()) for line in content.split("\n")) // len(
indent_type
)
@file_cache()
def find_best_match(query: str, code_file: str):
best_match = Match(-1, -1, 0)
code_file_lines = code_file.split("\n")
query_lines = query.split("\n")
if len(query_lines) > 0 and query_lines[-1].strip() == "...":
query_lines = query_lines[:-1]
if len(query_lines) > 0 and query_lines[0].strip() == "...":
query_lines = query_lines[1:]
indent = get_indent_type(code_file)
max_indents = get_max_indent(code_file, indent)
top_matches = []
if len(query_lines) == 1:
for i, line in enumerate(code_file_lines):
score = score_line(line, query_lines[0])
if score > best_match.score:
best_match = Match(i, i + 1, score)
return best_match
truncate = min(40, len(code_file_lines) // 5)
if truncate < 1:
truncate = len(code_file_lines)
indent_array = [i for i in range(0, max(min(max_indents + 1, 20), 1))]
if max_indents > 3:
indent_array = [3, 2, 4, 0, 1] + list(range(5, max_indents + 1))
for num_indents in indent_array:
indented_query_lines = [indent * num_indents + line for line in query_lines]
start_pairs = [
(i, score_line(line, indented_query_lines[0]))
for i, line in enumerate(code_file_lines)
]
start_pairs.sort(key=lambda x: x[1], reverse=True)
start_pairs = start_pairs[:truncate]
start_indices = [i for i, _ in start_pairs]
for i in tqdm(
start_indices,
position=0,
desc=f"Indent {num_indents}/{max_indents}",
leave=False,
):
end_pairs = [
(j, score_line(line, indented_query_lines[-1]))
for j, line in enumerate(code_file_lines[i:], start=i)
]
end_pairs.sort(key=lambda x: x[1], reverse=True)
end_pairs = end_pairs[:truncate]
end_indices = [j for j, _ in end_pairs]
for j in tqdm(
end_indices, position=1, leave=False, desc=f"Starting line {i}"
):
candidate = code_file_lines[i : j + 1]
raw_score = score_multiline(indented_query_lines, candidate)
score = raw_score * (1 - num_indents * 0.01)
current_match = Match(i, j + 1, score, indent * num_indents)
if raw_score >= 99.99: # early exit, 99.99 for floating point error
logger.info(f"Exact match found! Returning: {current_match}")
return current_match
top_matches.append(current_match)
if score > best_match.score:
best_match = current_match
unique_top_matches: list[Match] = []
unique_spans = set()
for top_match in sorted(top_matches, reverse=True):
if (top_match.start, top_match.end) not in unique_spans:
unique_top_matches.append(top_match)
unique_spans.add((top_match.start, top_match.end))
for top_match in unique_top_matches[:5]:
logger.print(top_match)
# Todo: on_comment file comments able to modify multiple files
return unique_top_matches[0] if unique_top_matches else Match(-1, -1, 0)
def split_ellipses(query: str) -> list[str]:
queries = []
current_query = ""
for line in query.split("\n"):
if line.strip() == "...":
queries.append(current_query.strip("\n"))
current_query = ""
else:
current_query += line + "\n"
queries.append(current_query.strip("\n"))
return queries
def match_indent(generated: str, original: str) -> str:
indent_type = "\t" if "\t" in original[:5] else " "
generated_indents = len(generated) - len(generated.lstrip())
target_indents = len(original) - len(original.lstrip())
diff_indents = target_indents - generated_indents
if diff_indents > 0:
generated = indent_type * diff_indents + generated.replace(
"\n", "\n" + indent_type * diff_indents
)
return generated
old_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
new_code = """
\"\"\"
on_ticket is the main function that is called when a new issue is created.
It is only called by the webhook handler in sweepai/api.py.
\"\"\"
# TODO: Add file validation
import math
import re
import traceback
from time import time
import hashlib
import openai
import requests
from github import BadCredentialsException
from logtail import LogtailHandler
from loguru import logger
from requests.exceptions import Timeout
from tabulate import tabulate
from tqdm import tqdm"""
# print(match_indent(new_code, old_code))
test_code = """\
def naive_euclidean_profile(X, q, mask):
r\"\"\"
Compute a euclidean distance profile in a brute force way.
A distance profile between a (univariate) time series :math:`X_i = {x_1, ..., x_m}`
and a query :math:`Q = {q_1, ..., q_m}` is defined as a vector of size :math:`m-(
l-1)`, such as :math:`P(X_i, Q) = {d(C_1, Q), ..., d(C_m-(l-1), Q)}` with d the
Euclidean distance, and :math:`C_j = {x_j, ..., x_{j+(l-1)}}` the j-th candidate
subsequence of size :math:`l` in :math:`X_i`.
\"\"\"
return _naive_euclidean_profile(X, q, mask)
"""
if __name__ == "__main__":
# for section in split_ellipses(test_code):
# print(section)
code_file = r"""
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
if check_button_title_match(REVERT_CHANGED_FILES_TITLE, request.comment.body, request.changes):
revert_files = []
for button_text in selected_buttons:
revert_files.append(button_text.split(f"{RESET_FILE} ")[-1].strip())
handle_revert(revert_files, request_dict["issue"]["number"], repo)
comment.edit(
body=ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title = REVERT_CHANGED_FILES_TITLE,
).serialize()
)
"""
# Sample target snippet
target = """
from loguru import logger
from github.Repository import Repository
from sweepai.config.client import RESET_FILE, REVERT_CHANGED_FILES_TITLE, RULES_LABEL, RULES_TITLE, get_rules
from sweepai.utils.event_logger import posthog
from sweepai.core.post_merge import PostMerge
from sweepai.core.sweep_bot import SweepBot
from sweepai.events import IssueCommentRequest
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.pr_utils import make_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.github_utils import get_github_client
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(request_dict["repository"]["full_name"]) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
...
""".strip(
"\n"
)
# Find the best match
# best_span = find_best_match(target, code_file)
best_span = find_best_match("a\nb", "a\nb")

import hashlib
import time
from github.Repository import Repository
from loguru import logger
from sweepai.config.client import (
RESET_FILE,
REVERT_CHANGED_FILES_TITLE,
RULES_LABEL,
RULES_TITLE,
get_blocked_dirs,
)
from sweepai.config.server import MONGODB_URI
from sweepai.core.post_merge import PostMerge
from sweepai.handlers.on_merge import comparison_to_diff
from sweepai.handlers.stack_pr import stack_pr
from sweepai.utils.buttons import ButtonList, check_button_title_match
from sweepai.utils.chat_logger import ChatLogger
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import BOT_SUFFIX
from sweepai.web.events import IssueCommentRequest
def handle_button_click(request_dict):
request = IssueCommentRequest(**request_dict)
user_token, gh_client = get_github_client(request_dict["installation"]["id"])
button_list = ButtonList.deserialize(request_dict["comment"]["body"])
selected_buttons = [button.label for button in button_list.get_clicked_buttons()]
repo = gh_client.get_repo(
request_dict["repository"]["full_name"]
) # do this after checking ref
comment_id = request.comment.id
pr = repo.get_pull(request_dict["issue"]["number"])
comment = pr.get_issue_comment(comment_id)
if check_button_title_match(
REVERT_CHANGED_FILES_TITLE, request.comment.body, request.changes
):
revert_files = []
for button_text in selected_buttons:
revert_files.append(button_text.split(f"{RESET_FILE} ")[-1].strip())
handle_revert(
file_paths=revert_files,
pr_number=request_dict["issue"]["number"],
repo=repo,
)
comment.edit(
ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title=REVERT_CHANGED_FILES_TITLE,
).serialize()
)
if check_button_title_match(RULES_TITLE, request.comment.body, request.changes):
rules = []
for button_text in selected_buttons:
rules.append(button_text.split(f"{RULES_LABEL} ")[-1].strip())
handle_rules(
request_dict=request_dict,
rules=rules,
user_token=user_token,
repo=repo,
gh_client=gh_client,
)
comment.edit(
ButtonList(
buttons=[
button
for button in button_list.buttons
if button.label not in selected_buttons
],
title=RULES_TITLE,
).serialize()
+ BOT_SUFFIX
)
if not rules:
try:
comment.delete()
except Exception as e:
logger.error(f"Error deleting comment: {e}")
def handle_revert(file_paths, pr_number, repo: Repository):
pr = repo.get_pull(pr_number)
branch_name = pr.head.ref if pr_number else pr.pr_head
def get_contents_with_fallback(
repo: Repository, file_path: str, branch: str = None
):
try:
if branch:
return repo.get_contents(file_path, ref=branch)
return repo.get_contents(file_path)
except Exception:
return None
old_file_contents = [
get_contents_with_fallback(repo, file_path) for file_path in file_paths
]
for file_path, old_file_content in zip(file_paths, old_file_contents):
try:
current_content = repo.get_contents(file_path, ref=branch_name)
if old_file_content:
repo.update_file(
file_path,
f"Revert {file_path}",
old_file_content.decoded_content,
sha=current_content.sha,
branch=branch_name,
)
else:
repo.delete_file(
file_path,
f"Delete {file_path}",
sha=current_content.sha,
branch=branch_name,
)
except Exception:
pass # file may not exist and this is expected
def handle_rules(request_dict, rules, user_token, repo: Repository, gh_client):
pr = repo.get_pull(request_dict["issue"]["number"])
chat_logger = (
ChatLogger(
{"username": request_dict["sender"]["login"]},
)
if MONGODB_URI
else None
)
blocked_dirs = get_blocked_dirs(repo)
comparison = repo.compare(pr.base.sha, pr.head.sha) # head is the most recent
commits_diff = comparison_to_diff(comparison, blocked_dirs)
for rule in rules:
changes_required, issue_title, issue_description = PostMerge(
chat_logger=chat_logger
).check_for_issues(rule=rule, diff=commits_diff)
tracking_id = hashlib.sha256(str(time.time()).encode()).hexdigest()[:10]
if changes_required:
stack_pr(
request=issue_description
+ "\n\nThis issue was created to address the following rule:\n"
+ rule,
pr_number=request_dict["issue"]["number"],
username=request_dict["sender"]["login"],
repo_full_name=request_dict["repository"]["full_name"],
installation_id=request_dict["installation"]["id"],
tracking_id=tracking_id,
)

import re
import traceback
from typing import TypeVar
from sweepai.config.server import DEFAULT_GPT4_32K_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message, RegexMatchableBaseModel
from loguru import logger
system_prompt = """You are a brilliant and meticulous engineer assigned to review the following commit diffs and make sure the file conforms to the user's rules.
If the diffs do not conform to the rules, we should create a GitHub issue telling the user what changes should be made.
Provide your response in the following format:
<rule_analysis>
- Analysis of each file_diff and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
user_message = """Review the following diffs and make sure they conform to the rules:
{diff}
The rule is: {rule}
Provide your response in the following format:
<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
- Analysis of code diff 2 and whether it breaks the rule
...
</rule_analysis>
<changes_required>
Output "True" if the rule is broken, "False" otherwise
</changes_required>
<issue_title>
Write an issue title describing what file and rule to fix.
</issue_title>
<issue_description>
GitHub issue description for what we want to solve. Give general instructions on how to solve it. Mention files to take a look at and other code pointers.
</issue_description>"""
Self = TypeVar("Self", bound="RegexMatchableBaseModel")
class IssueTitleAndDescription(RegexMatchableBaseModel):
changes_required: bool = False
issue_title: str
issue_description: str
@classmethod
def from_string(cls: type["IssueTitleAndDescription"], string: str, **kwargs) -> "IssueTitleAndDescription":
changes_required_pattern = (
r"""<changes_required>(\n)?(?P<changes_required>.*)</changes_required>"""
)
changes_required_match = re.search(changes_required_pattern, string, re.DOTALL)
changes_required = (
changes_required_match.groupdict()["changes_required"].strip()
if changes_required_match
else None
)
if changes_required and "true" in changes_required.lower():
changes_required = True
else:
changes_required = False
issue_title_pattern = r"""<issue_title>(\n)?(?P<issue_title>.*)</issue_title>"""
issue_title_match = re.search(issue_title_pattern, string, re.DOTALL)
issue_title = (
issue_title_match.groupdict()["issue_title"].strip()
if issue_title_match
else ""
)
issue_description_pattern = (
r"""<issue_description>(\n)?(?P<issue_description>.*)</issue_description>"""
)
issue_description_match = re.search(
issue_description_pattern, string, re.DOTALL
)
issue_description = (
issue_description_match.groupdict()["issue_description"].strip()
if issue_description_match
else ""
)
return cls(
changes_required=changes_required,
issue_title=issue_title,
issue_description=issue_description,
)
class PostMerge(ChatGPT):
def check_for_issues(self, rule, diff) -> tuple[bool, str, str]:
try:
self.messages = [
Message(
role="system",
content=system_prompt.format(rule=rule),
key="system",
)
]
if self.chat_logger and not self.chat_logger.is_paying_user():
raise ValueError("User is not a paying user")
self.model = DEFAULT_GPT4_32K_MODEL
response = self.chat(
user_message.format(
rule=rule,
diff=diff,
)
)
issue_title_and_description = IssueTitleAndDescription.from_string(response)
return (
issue_title_and_description.changes_required,
issue_title_and_description.issue_title,
issue_title_and_description.issue_description,
)
except SystemExit:
raise SystemExit
except Exception:
logger.error(f"An error occurred: {traceback.print_exc()}")
return False, "", ""
if __name__ == "__main__":
changes_required_response = """<rule_analysis>
- Analysis of code diff 1 and whether it breaks the rule
The code diff 1 does not break the rule. There are no docstrings or comments that need to be updated.
- Analysis of code diff 2 and whether it breaks the rule
The code diff 2 breaks the rule. There is a commented out code block that should be removed.
</rule_analysis>
<changes_required>
True if the rule is broken, False otherwise
True
</changes_required>
<issue_title>
Outdated Commented Code Block in plan-list.blade.php
</issue_title>
<issue_description>
There is an outdated commented out code block in the file `resources/views/livewire/plan-list.blade.php` that should be removed. The code block starts at line 104 and ends at line 110. Please remove this code block as it is no longer needed.
Please refer to the file `resources/views/livewire/plan-list.blade.php` and remove the commented out code block starting at line 104 and ending at line 110.
</issue_description>"""

from typing import Any, Dict, Literal
from pydantic import BaseModel
class Changes(BaseModel):
body: Dict[str, str]
@property
def body_from(self):
return self.body.get("from")
class Account(BaseModel):
id: int
login: str
type: str
class Installation(BaseModel):
id: Any | None = None
account: Account | None = None
class PREdited(BaseModel):
class Repository(BaseModel):
full_name: str
class PullRequest(BaseModel):
class User(BaseModel):
login: str
html_url: str
title: str
body: str
number: int
user: User
commits: int = 0
additions: int = 0
deletions: int = 0
changed_files: int = 0
class Sender(BaseModel):
login: str
changes: Changes
pull_request: PullRequest
sender: Sender
repository: Repository
installation: Installation
class InstallationCreatedRequest(BaseModel):
class Repository(BaseModel):
full_name: str
repositories: list[Repository]
installation: Installation
class ReposAddedRequest(BaseModel):
class Repository(BaseModel):
full_name: str
repositories_added: list[Repository]
installation: Installation
class CommentCreatedRequest(BaseModel):
class Comment(BaseModel):
class User(BaseModel):
login: str
type: str
body: str | None
original_line: int
path: str
diff_hunk: str
user: User
id: int
class PullRequest(BaseModel):
class Head(BaseModel):
ref: str
number: int
body: str | None
state: str # "closed" or "open"
head: Head
title: str
class Repository(BaseModel):
full_name: str
description: str | None
class Sender(BaseModel):
pass
action: str
comment: Comment
pull_request: PullRequest
repository: Repository
sender: Sender
installation: Installation
class IssueRequest(BaseModel):
class Issue(BaseModel):
class User(BaseModel):
login: str
type: str
class Assignee(BaseModel):
login: str
class Repository(BaseModel):
# TODO(sweep): Move this out
full_name: str
description: str | None
class Label(BaseModel):
name: str
class PullRequest(BaseModel):
url: str | None
title: str
number: int
html_url: str
user: User
body: str | None
labels: list[Label]
assignees: list[Assignee] | None = None
pull_request: PullRequest | None = None
action: str
issue: Issue
repository: Issue.Repository
assignee: Issue.Assignee | None = None
installation: Installation | None = None
sender: Issue.User
class IssueCommentRequest(IssueRequest):
class Comment(BaseModel):
class User(BaseModel):
login: str
type: Literal["User", "Bot"]
user: User
id: int
body: str
comment: Comment
sender: Comment.User
changes: Changes | None = None
class PRRequest(BaseModel):
class PullRequest(BaseModel):
class User(BaseModel):
login: str
title: str
class MergedBy(BaseModel):
login: str
user: User
merged_by: MergedBy | None
additions: int = 0
deletions: int = 0
class Repository(BaseModel):
full_name: str
pull_request: PullRequest
repository: Repository
number: int
installation: Installation
class CheckRunCompleted(BaseModel):
class CheckRun(BaseModel):
class PullRequest(BaseModel):
number: int
class CheckSuite(BaseModel):
head_branch: str | None
conclusion: str
html_url: str
pull_requests: list[PullRequest]
completed_at: str
check_suite: CheckSuite
head_sha: str
@property
def run_id(self):
# format is like https://github.com/ORG/REPO_NAME/actions/runs/RUN_ID/jobs/JOB_ID
return self.html_url.split("/")[-3]
class Repository(BaseModel):
full_name: str
description: str | None
class Sender(BaseModel):
login: str
check_run: CheckRun
installation: Installation
repository: Repository
sender: Sender
class GithubRequest(IssueRequest):
class Sender(BaseModel):
login: str

sweep/probot/app.yml

Lines 1 to 146 in 0277fad

# This is a GitHub App Manifest. These settings will be used by default when
# initially configuring your GitHub App.
#
# NOTE: changing this file will not update your GitHub App settings.
# You must visit github.com/settings/apps/your-app-name to edit them.
#
# Read more about configuring your GitHub App:
# https://probot.github.io/docs/development/#configuring-a-github-app
#
# Read more about GitHub App Manifests:
# https://developer.github.com/apps/building-github-apps/creating-github-apps-from-a-manifest/
# The list of events the GitHub App subscribes to.
# Uncomment the event names below to enable them.
default_events:
- check_run
- check_suite
- commit_comment
# - delete
# - deployment
# - deployment_status
# - fork
# - gollum
- issue_comment
- issues
- label
# - milestone
# - member
# - membership
# - org_block
# - organization
# - page_build
# - project
# - project_card
# - project_column
# - public
- pull_request
- pull_request_review
- pull_request_review_comment
- pull_request_review_thread
- push
# - release
# - repository
# - repository_import
- status
# - team
# - team_add
# - watch
- workflow_job
- workflow_run
# The set of permissions needed by the GitHub App. The format of the object uses
# the permission name for the key (for example, issues) and the access type for
# the value (for example, write).
# Valid values are `read`, `write`, and `none`
default_permissions:
# Repository creation, deletion, settings, teams, and collaborators.
# https://developer.github.com/v3/apps/permissions/#permission-on-administration
administration: read
actions: read
# Checks on code.
# https://developer.github.com/v3/apps/permissions/#permission-on-checks
checks: read
# Repository contents, commits, branches, downloads, releases, and merges.
# https://developer.github.com/v3/apps/permissions/#permission-on-contents
contents: write
# Deployments and deployment statuses.
# https://developer.github.com/v3/apps/permissions/#permission-on-deployments
# deployments: read
# Issues and related comments, assignees, labels, and milestones.
# https://developer.github.com/v3/apps/permissions/#permission-on-issues
issues: write
# Search repositories, list collaborators, and access repository metadata.
# https://developer.github.com/v3/apps/permissions/#metadata-permissions
metadata: read
# Retrieve Pages statuses, configuration, and builds, as well as create new builds.
# https://developer.github.com/v3/apps/permissions/#permission-on-pages
# pages: read
# Pull requests and related comments, assignees, labels, milestones, and merges.
# https://developer.github.com/v3/apps/permissions/#permission-on-pull-requests
pull_requests: write
# Manage the post-receive hooks for a repository.
# https://developer.github.com/v3/apps/permissions/#permission-on-repository-hooks
# repository_hooks: read
# Manage repository projects, columns, and cards.
# https://developer.github.com/v3/apps/permissions/#permission-on-repository-projects
# repository_projects: read
# Retrieve security vulnerability alerts.
# https://developer.github.com/v4/object/repositoryvulnerabilityalert/
# vulnerability_alerts: read
# Commit statuses.
# https://developer.github.com/v3/apps/permissions/#permission-on-statuses
statuses: read
# Organization members and teams.
# https://developer.github.com/v3/apps/permissions/#permission-on-members
# members: read
# View and manage users blocked by the organization.
# https://developer.github.com/v3/apps/permissions/#permission-on-organization-user-blocking
# organization_user_blocking: read
# Manage organization projects, columns, and cards.
# https://developer.github.com/v3/apps/permissions/#permission-on-organization-projects
# organization_projects: read
# Manage team discussions and related comments.
# https://developer.github.com/v3/apps/permissions/#permission-on-team-discussions
# team_discussions: read
# Manage the post-receive hooks for an organization.
# https://developer.github.com/v3/apps/permissions/#permission-on-organization-hooks
# organization_hooks: read
# Get notified of, and update, content references.
# https://developer.github.com/v3/apps/permissions/
# organization_administration: read
workflows: write
# The name of the GitHub App. Defaults to the name specified in package.json
name: sweep-example-name
# The homepage of your GitHub App.
url: https://docs.sweep.dev/usage/deployment
# A description of the GitHub App.
description: Self-hosted Sweep, an AI-powered junior developer
# Set to true when your GitHub App is available to the public or false when it is only accessible to the owner of the app.
# Default: true
public: true

## Contributing
[fork]: /fork
[pr]: /compare
[code-of-conduct]: CODE_OF_CONDUCT.md
Hi there! We're thrilled that you'd like to contribute to this project. Your help is essential for keeping it great.
Please note that this project is released with a [Contributor Code of Conduct][code-of-conduct]. By participating in this project you agree to abide by its terms.
## Issues and PRs
If you have suggestions for how this project could be improved, or want to report a bug, open an issue! We'd love all and any contributions. If you have questions, too, we'd love to hear them.
We'd also love PRs. If you're thinking of a large PR, we advise opening up an issue first to talk about it, though! Look at the links below if you're not sure how to open a PR.
## Submitting a pull request
1. [Fork][fork] and clone the repository.
1. Configure and install the dependencies: `npm install`.
1. Make sure the tests pass on your machine: `npm test`, note: these tests also apply the linter, so there's no need to lint separately.
1. Create a new branch: `git checkout -b my-branch-name`.
1. Make your change, add tests, and make sure the tests still pass.
1. Push to your fork and [submit a pull request][pr].
1. Pat your self on the back and wait for your pull request to be reviewed and merged.
Here are a few things you can do that will increase the likelihood of your pull request being accepted:
- Write and update tests.
- Keep your changes as focused as possible. If there are multiple changes you would like to make that are not dependent upon each other, consider submitting them as separate pull requests.
- Write a [good commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html).
Work in Progress pull requests are also welcome to get feedback early on, or if there is something blocked you.
## Resources
- [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/)
- [Using Pull Requests](https://help.github.com/articles/about-pull-requests/)

import { FileChangeRequest, Snippet } from "../lib/types";
const fcrEqual = (a: FileChangeRequest, b: FileChangeRequest) => {
return (
a.snippet.file === b.snippet.file &&
a.snippet.start === b.snippet.start &&
a.snippet.end === b.snippet.end
);
};
const undefinedCheck = (variable: any) => {
if (typeof variable === "undefined") {
throw new Error("Variable is undefined");
}
};
export const setIsLoading = (
newIsLoading: boolean,
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
isLoading: newIsLoading,
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setIsLoading: ", error);
}
};
export const setStatusForFCR = (
newStatus: "queued" | "in-progress" | "done" | "error" | "idle",
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
status: newStatus,
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setStatus: ", error);
}
};
export const setStatusForAll = (
newStatus: "queued" | "in-progress" | "done" | "error" | "idle",
setFCRs: any,
) => {
setFCRs((prev: FileChangeRequest[]) => {
return prev.map((fileChangeRequest) => {
return {
...fileChangeRequest,
status: newStatus,
};
});
});
};
export const setFileForFCR = (
newFile: string,
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
newContents: newFile,
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setFileForFCR: ", error);
}
};
export const setOldFileForFCR = (
newOldFile: string,
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
snippet: {
...prev[fcrIndex].snippet,
entireFile: newOldFile,
},
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setOldFileForFCR: ", error);
}
};
export const removeFileChangeRequest = (
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [...prev.slice(0, fcrIndex), ...prev.slice(fcrIndex! + 1)];
});
} catch (error) {
console.error("Error in removeFileChangeRequest: ", error);
}
};
export const setHideMergeAll = (newHideMerge: boolean, setFCRs: any) => {
setFCRs((prev: FileChangeRequest[]) => {
return prev.map((fileChangeRequest) => {
return {
...fileChangeRequest,
hideMerge: newHideMerge,
};
});
});
};
// updates readOnlySnippets for a certain fcr then updates entire fileChangeRequests array
export const setReadOnlySnippetForFCR = (
fcr: FileChangeRequest,
readOnlySnippet: Snippet,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
const updatedFcr = {
...prev[fcrIndex],
readOnlySnippets: {
...prev[fcrIndex].readOnlySnippets,
[readOnlySnippet.file]: readOnlySnippet,
},
};
return [
...prev.slice(0, fcrIndex),
updatedFcr,
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setReadOnlySnippetForFCR: ", error);
}
};
export const removeReadOnlySnippetForFCR = (
fcr: FileChangeRequest,
snippetFile: string,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
const { [snippetFile]: _, ...restOfSnippets } =
prev[fcrIndex].readOnlySnippets;
const updatedFCR = {
...prev[fcrIndex],
readOnlySnippets: restOfSnippets,
};
return [
...prev.slice(0, fcrIndex),
updatedFCR,
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in removeReadOnlySnippetForFCR: ", error);
}
};
export const setDiffForFCR = (
newDiff: string,
fcr: FileChangeRequest,
fcrs: FileChangeRequest[],
setFCRs: any,
) => {
try {
const fcrIndex = fcrs.findIndex((fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFCRs((prev: FileChangeRequest[]) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
diff: newDiff,
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setDiffForFCR: ", error);
}

"use client";
import React, { memo, useCallback } from "react";
import { vscodeDark } from "@uiw/codemirror-theme-vscode";
import { javascript } from "@codemirror/lang-javascript";
import { java } from "@codemirror/lang-java";
import { python } from "@codemirror/lang-python";
import { html } from "@codemirror/lang-html";
import { elm } from "@codemirror/legacy-modes/mode/elm";
import CodeMirror, {
EditorState,
EditorView,
keymap,
} from "@uiw/react-codemirror";
import CodeMirrorMerge from "react-codemirror-merge";
import { indentWithTab } from "@codemirror/commands";
import { indentUnit, LanguageSupport, StreamLanguage } from "@codemirror/language";
import { FaArrowLeft } from "react-icons/fa6";
const getLanguage = (ext: string) => {
const languageMap: { [key: string]: any } = {
js: javascript(),
jsx: javascript({ jsx: true }),
ts: javascript({ typescript: true }),
tsx: javascript({ typescript: true, jsx: true }),
html: html(),
ejs: html(),
erb: html(),
py: python(),
kt: java(),
elm: new LanguageSupport(StreamLanguage.define(elm)),
};
return languageMap[ext] || javascript();
};
const Original = CodeMirrorMerge.Original;
const Modified = CodeMirrorMerge.Modified;
const FileSelector = memo(function FileSelector({
filePath,
file,
setFile,
hideMerge,
oldFile,
setOldFile,
}: {
filePath: string;
file: string;
setFile: (newFile: string) => void;
hideMerge: boolean;
oldFile: string;
setOldFile: (newOldFile: string) => void;
}) {
const placeholderText =
"Your code will be displayed here once you select a Repository and add a file to modify.";
const onChange = useCallback(
(val: any, viewUpdate: any) => {
setFile(val);
},
[setFile],
);
const onOldChange = setOldFile;
const ext = filePath?.split(".").pop() || "js";
const languageExtension = getLanguage(ext);
const extensions = [
languageExtension,
EditorView.lineWrapping,
keymap.of([indentWithTab]),
indentUnit.of(" "),
];
return (
<>
<div className="flex flex-row mb-2">
<span className="border rounded grow p-2 mr-2 font-mono">
{filePath === "" || filePath === undefined
? "No files selected"
: filePath}
</span>
</div>
{hideMerge || hideMerge === undefined ? (
<CodeMirror
value={file}
extensions={extensions}
onChange={onChange}
theme={vscodeDark}
style={{ overflow: "auto" }}
placeholder={placeholderText}
className="ph-no-capture"
/>
) : (
<CodeMirrorMerge
theme={vscodeDark}
style={{ overflow: "auto" }}
className="ph-no-capture"
revertControls="b-to-a"
collapseUnchanged={{
margin: 3,
minSize: 6,
}}
>
<Original
value={oldFile}
extensions={[...extensions, EditorState.readOnly.of(true)]}
onChange={onOldChange}
placeholder={placeholderText}
/>
<Modified
value={file}
extensions={extensions}
onChange={onChange}
placeholder={placeholderText}
/>
</CodeMirrorMerge>
)}
</>
);
});
export default FileSelector;

"use client";
import {
ResizableHandle,
ResizablePanel,
ResizablePanelGroup,
} from "../ui/resizable";
import { Textarea } from "../ui/textarea";
import React, { useCallback, useEffect, useState } from "react";
import FileSelector from "./FileSelector";
import DashboardActions from "./DashboardActions";
import { useLocalStorage } from "usehooks-ts";
import { Label } from "../ui/label";
import { Button } from "../ui/button";
import { FileChangeRequest, fcrEqual } from "../../lib/types";
import getFiles, { getFile, writeFile } from "../../lib/api.service";
import { usePostHog } from "posthog-js/react";
import { posthogMetadataScript } from "../../lib/posthog";
import { FaArrowsRotate, FaCheck } from "react-icons/fa6";
import { toast } from "sonner";
import { FileChangeRequestsState } from "../../state/fcrAtoms";
import { useRecoilState } from "recoil";
import {
setStatusForFCR,
setFileForFCR,
setOldFileForFCR,
removeFileChangeRequest,
setStatusForAll,
setHideMergeAll,
} from "../../state/fcrStateHelpers";
const blockedPaths = [
".git",
"node_modules",
"venv",
"__pycache__",
".next",
"cache",
"logs",
"sweep",
"install_assistant.sh",
];
const versionScript = `timestamp=$(git log -1 --format="%at")
[[ "$OSTYPE" == "linux-gnu"* ]] && date -d @$timestamp +%y.%m.%d.%H || date -r $timestamp +%y.%m.%d.%H
`;
const DashboardDisplay = () => {
const [streamData, setStreamData] = useState("");
const [outputToggle, setOutputToggle] = useState("script");
const [scriptOutput = "" as string, setScriptOutput] = useLocalStorage(
"scriptOutput",
"",
);
const [repoName, setRepoName] = useLocalStorage("repoName", "");
const [fileLimit, setFileLimit] = useLocalStorage<number>("fileLimit", 10000);
const [blockedGlobs, setBlockedGlobs] = useLocalStorage(
"blockedGlobs",
blockedPaths.join(", "),
);
const [fileChangeRequests, setFileChangeRequests] = useRecoilState(
FileChangeRequestsState,
);
const [currentFileChangeRequestIndex, setCurrentFileChangeRequestIndex] =
useLocalStorage("currentFileChangeRequestIndex", 0);
const [versionNumber, setVersionNumber] = useState("");
const [files = [], setFiles] = useLocalStorage<
{ label: string; name: string }[]
>("files", []);
const [directories = [], setDirectories] = useLocalStorage<
{ label: string; name: string }[]
>("directories", []);
const [loadingMessage = "", setLoadingMessage] = useState("" as string);
const filePath =
fileChangeRequests[currentFileChangeRequestIndex]?.snippet.file;
const oldFile =
fileChangeRequests[currentFileChangeRequestIndex]?.snippet.entireFile;
const file = fileChangeRequests[currentFileChangeRequestIndex]?.newContents;
const hideMerge =
fileChangeRequests[currentFileChangeRequestIndex]?.hideMerge;
const posthog = usePostHog();
const undefinedCheck = (variable: any) => {
if (typeof variable === "undefined") {
throw new Error("Variable is undefined");
}
};
const setHideMerge = useCallback(
(newHideMerge: boolean, fcr: FileChangeRequest) => {
try {
const fcrIndex = fileChangeRequests.findIndex(
(fileChangeRequest: FileChangeRequest) =>
fcrEqual(fileChangeRequest, fcr),
);
undefinedCheck(fcrIndex);
setFileChangeRequests((prev) => {
return [
...prev.slice(0, fcrIndex),
{
...prev[fcrIndex],
hideMerge: newHideMerge,
},
...prev.slice(fcrIndex + 1),
];
});
} catch (error) {
console.error("Error in setHideMerge: ", error);
}
},
[fileChangeRequests],
);
const setOldFile = useCallback((newOldFile: string) => {
setCurrentFileChangeRequestIndex((index) => {
setFileChangeRequests((newFileChangeRequests) => {
return [
...newFileChangeRequests.slice(0, index),
{
...newFileChangeRequests[index],
snippet: {
...newFileChangeRequests[index].snippet,
entireFile: newOldFile,
},
},
...newFileChangeRequests.slice(index + 1),
];
});
return index;
});
}, []);
const setFile = useCallback((newFile: string) => {
setCurrentFileChangeRequestIndex((index) => {
setFileChangeRequests((newFileChangeRequests) => {
return [
...newFileChangeRequests.slice(0, index),
{
...newFileChangeRequests[index],
newContents: newFile,
},
...newFileChangeRequests.slice(index + 1),
];
});
return index;
});
}, []);
useEffect(() => {
(async () => {
const filesAndDirectories = await getFiles(
repoName,
blockedGlobs,
fileLimit,
);
let newFiles = filesAndDirectories.sortedFiles;
let directories = filesAndDirectories.directories;
newFiles = newFiles.map((file: string) => {
return { value: file, label: file };
});
directories = directories.map((directory: string) => {
return { value: directory + "/", label: directory + "/" };
});
setFiles(newFiles);
setDirectories(directories);
})();
}, [repoName, blockedGlobs, fileLimit]);
useEffect(() => {
let textarea = document.getElementById("llm-output") as HTMLTextAreaElement;
const delta = 50; // Define a delta for the inequality check
if (
Math.abs(
textarea.scrollHeight - textarea.scrollTop - textarea.clientHeight,
) < delta
) {
textarea.scrollTop = textarea.scrollHeight;
}
}, [streamData]);
useEffect(() => {
(async () => {
const body = {
repo: repoName,
filePath,
script: versionScript,
};
const result = await fetch("/api/run?", {
method: "POST",
body: JSON.stringify(body),
});
const object = await result.json();
const versionNumberString = object.stdout;
setVersionNumber("v" + versionNumberString);
})();
}, []);
useEffect(() => {
(async () => {
const body = { repo: repoName, filePath, script: posthogMetadataScript };
const result = await fetch("/api/run?", {
method: "POST",
body: JSON.stringify(body),
});
const object = await result.json();
const metadata = JSON.parse(object.stdout);
posthog?.identify(
metadata.email === "N/A"
? metadata.email
: `${metadata.whoami}@${metadata.hostname}`,
metadata,
);
})();
}, [posthog]);
return (
<>
{loadingMessage && (
<div
className="p-2 fixed bottom-12 right-12 text-center z-10 flex flex-col items-center"
style={{
borderRadius: "50%",
background:
"radial-gradient(circle, rgb(40, 40, 40) 0%, rgba(0, 0, 0, 0) 75%)",
}}
>
<img
className="rounded-full border-zinc-800 border"
src="https://raw.githubusercontent.com/sweepai/sweep/main/.assets/sweeping.gif"
alt="Sweeping"
height={75}
width={75}
/>
<p className="mt-2">{loadingMessage}</p>
</div>
)}
<h1 className="font-bold text-xl">Sweep Assistant</h1>
<h3 className="text-zinc-400">{versionNumber}</h3>
<ResizablePanelGroup className="min-h-[80vh] pt-0" direction="horizontal">
<DashboardActions
filePath={filePath}
setScriptOutput={setScriptOutput}
file={file}
fileLimit={fileLimit}
setFileLimit={setFileLimit}
blockedGlobs={blockedGlobs}
setBlockedGlobs={setBlockedGlobs}
hideMerge={hideMerge}
setHideMerge={setHideMerge}
repoName={repoName}
setRepoName={setRepoName}
setStreamData={setStreamData}
files={files}
directories={directories}
currentFileChangeRequestIndex={currentFileChangeRequestIndex}
setCurrentFileChangeRequestIndex={setCurrentFileChangeRequestIndex}
setOutputToggle={setOutputToggle}
setLoadingMessage={setLoadingMessage}
/>
<ResizableHandle withHandle />
<ResizablePanel defaultSize={75}>
<ResizablePanelGroup direction="vertical">
<ResizablePanel defaultSize={75} className="flex flex-col mb-4">
<FileSelector
filePath={filePath}
file={file}
setFile={setFile}
hideMerge={hideMerge}
oldFile={oldFile}
setOldFile={setOldFile}
></FileSelector>
</ResizablePanel>
<ResizableHandle withHandle />
<ResizablePanel className="mt-2" defaultSize={25}>
<div className="flex flex-row items-center">
<Label className="mr-2">Toggle outputs:</Label>
<Button
className={`mr-2 ${outputToggle === "script" ? "bg-blue-800 hover:bg-blue-900 text-white" : ""}`}
size="sm"
variant="secondary"
onClick={() => {
setOutputToggle("script");
}}
>
Validation Output
</Button>
<Button
className={`${outputToggle === "llm" ? "bg-blue-800 hover:bg-blue-900 text-white" : ""}`}
size="sm"
variant="secondary"
onClick={() => {
setOutputToggle("llm");
}}
>
Debug Logs
</Button>
<div className="grow"></div>
<Button
className="mr-2"
size="sm"
variant="secondary"
onClick={async () => {
const fcr =
fileChangeRequests[currentFileChangeRequestIndex];
const response = await getFile(repoName, fcr.snippet.file);
setFileForFCR(
response.contents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setOldFileForFCR(
response.contents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
toast.success("File synced from storage!", {
action: { label: "Dismiss", onClick: () => {} },
});
setCurrentFileChangeRequestIndex(
currentFileChangeRequestIndex,
);
setHideMerge(true, fcr);
setStatusForFCR(
"idle",
fcr,
fileChangeRequests,
setFileChangeRequests,
);
}}
disabled={
fileChangeRequests.length === 0 ||
fileChangeRequests[currentFileChangeRequestIndex]?.isLoading
}
>
<FaArrowsRotate />
</Button>
<Button
size="sm"
className="mr-2 bg-green-600 hover:bg-green-700"
onClick={async () => {
const fcr =
fileChangeRequests[currentFileChangeRequestIndex];
setOldFileForFCR(
fcr.newContents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setHideMerge(true, fcr);
await writeFile(
repoName,
fcr.snippet.file,
fcr.newContents,
);
toast.success("Succesfully saved file!", {
action: { label: "Dismiss", onClick: () => {} },
});
}}
disabled={
fileChangeRequests.length === 0 ||
fileChangeRequests[currentFileChangeRequestIndex]
?.isLoading ||
fileChangeRequests[currentFileChangeRequestIndex]?.hideMerge
}
>
<FaCheck />
</Button>
</div>
<Textarea
className={`mt-4 grow font-mono h-4/5 ${scriptOutput.trim().startsWith("Error") ? "text-red-600" : "text-green-600"}`}
value={scriptOutput}
id="script-output"
placeholder="Your script output will be displayed here"
readOnly
hidden={outputToggle !== "script"}
></Textarea>
<Textarea
className={`mt-4 grow font-mono h-4/5`}
id="llm-output"
value={streamData}
placeholder="ChatGPT's output will be displayed here."
readOnly
hidden={outputToggle !== "llm"}
></Textarea>
</ResizablePanel>
</ResizablePanelGroup>
</ResizablePanel>
</ResizablePanelGroup>
</>
);
};

import { useLocalStorage } from "usehooks-ts";
import { Label } from "../ui/label";
import { ReactNode, useEffect, useRef, useState } from "react";
import CodeMirror, {
EditorView,
keymap,
lineNumbers,
} from "@uiw/react-codemirror";
import { FileChangeRequest, Snippet } from "../../lib/types";
import { Button } from "../ui/button";
import { indentWithTab } from "@codemirror/commands";
import { indentUnit } from "@codemirror/language";
import { xml } from "@codemirror/lang-xml";
import { vscodeDark } from "@uiw/codemirror-theme-vscode";
import { Switch } from "../ui/switch";
import { getFile } from "../../lib/api.service";
import Markdown from "react-markdown";
import { Mention, MentionsInput, SuggestionDataItem } from "react-mentions";
import { Badge } from "../ui/badge";
import { FaTimes } from "react-icons/fa";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import { vscDarkPlus } from "react-syntax-highlighter/dist/esm/styles/prism";
import { toast } from "sonner";
import { useRecoilState } from "recoil";
import { FileChangeRequestsState } from "../../state/fcrAtoms";
const codeStyle = {
...vscDarkPlus,
'code[class*="language-"]': {
...vscDarkPlus['code[class*="language-"]'],
whiteSpace: "pre-wrap",
},
};
const systemMessagePrompt = `You are a brilliant and meticulous engineer assigned to plan code changes for the following user's concerns. Take into account the current repository's language, frameworks, and dependencies.`;
const userMessagePromptOld = `Here are relevant read-only files:
<read_only_files>
{readOnlyFiles}
</read_only_files>
Here is the user's request:
<user_request>
{userRequest}
</user_request>
# Task:
Analyze the snippets, repo, and user request to break down the requested change and propose a plan to addresses the user's request. Mention all changes required to solve the request.
Provide a plan to solve the issue, following these rules:
* You may only create new files and modify existing files but may not necessarily need both.
* Include the full path (e.g. src/main.py and not just main.py), using the snippets for reference.
* Use natural language instructions on what to modify regarding business logic.
* Be concrete with instructions and do not write "identify x" or "ensure y is done". Instead write "add x" or "change y to z".
* Refer to the user as "you".
You MUST follow the following format with XML tags:
# Contextual Request Analysis:
<contextual_request_analysis>
Concisely outline the minimal plan that solves the user request by referencing the snippets, and names of entities. and any other necessary files/directories.
</contextual_request_analysis>
# Plan:
<plan>
<create file="file_path_1" relevant_files="space-separated list of ALL files relevant for creating file_path_1">
* Concise natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create>
...
<modify file="file_path_2" start_line="i" end_line="j" relevant_files="space-separated list of ALL files relevant for modifying file_path_2">
* Concise natural language instructions for the modifications needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</modify>
...
</plan>`;
const userMessagePrompt = `Here are relevant read-only files:
<read_only_files>
{readOnlyFiles}
</read_only_files>
Here is the user's request:
<user_request>
{userRequest}
</user_request>
# Task:
Analyze the snippets, repo, and user request to break down the requested change and propose a plan to addresses the user's request. Mention all changes required to solve the request.
Provide a plan to solve the issue, following these rules:
* You may only create new files and modify existing files but may not necessarily need both.
* Include the full path (e.g. src/main.py and not just main.py), using the snippets for reference.
* Use natural language instructions on what to modify regarding business logic.
* Be concrete with instructions and do not write "identify x" or "ensure y is done". Instead write "add x" or "change y to z".
* Refer to the user as "you".
# Plan:
<plan>
<create file="file_path_1" relevant_files="space-separated list of ALL files relevant for creating file_path_1">
* Concise natural language instructions for creating the new file needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</create>
...
<modify file="file_path_2" start_line="i" end_line="j" relevant_files="space-separated list of ALL files relevant for modifying file_path_2">
* Concise natural language instructions for the modifications needed to solve the issue.
* Reference necessary files, imports and entity names.
...
</modify>
...
</plan>`;
const readOnlyFileFormat = `<read_only_file file="{file}" start_line="{start_line}" end_line="{end_line}">
{contents}
</read_only_file>`;
const fileChangeRequestPattern =
/<create file="(?<cFile>.*?)" relevant_files="(?<relevant_files>.*?)">(?<cInstructions>[\s\S]*?)($|<\/create>)|<modify file="(?<mFile>.*?)" start_line="(?<startLine>.*?)" end_line="(?<endLine>.*?)" relevant_files="(.*?)">(?<mInstructions>[\s\S]*?)($|<\/modify>)/gs;
const capitalize = (s: string) => {
return s.charAt(0).toUpperCase() + s.slice(1);
};
const DashboardPlanning = ({
repoName,
files,
setLoadingMessage,
setCurrentTab,
}: {
repoName: string;
files: { label: string; name: string }[];
setLoadingMessage: React.Dispatch<React.SetStateAction<string>>;
setCurrentTab: React.Dispatch<React.SetStateAction<"planning" | "coding">>;
}) => {
const [instructions = "", setInstructions] = useLocalStorage(
"globalInstructions",
"" as string,
);
const [snippets = {}, setSnippets] = useLocalStorage(
"globalSnippets",
{} as { [key: string]: Snippet },
);
const [rawResponse = "", setRawResponse] = useLocalStorage(
"planningRawResponse",
"" as string,
);
const [currentFileChangeRequests = [], setCurrentFileChangeRequests] =
useLocalStorage("globalFileChangeRequests", [] as FileChangeRequest[]);
const [debugLogToggle = false, setDebugLogToggle] = useState<boolean>(false);
const [isLoading = false, setIsLoading] = useState<boolean>(false);
const [fileChangeRequests, setFileChangeRequests] = useRecoilState(
FileChangeRequestsState,
);
const instructionsRef = useRef<HTMLTextAreaElement>(null);
const thoughtsRef = useRef<HTMLDivElement>(null);
const planRef = useRef<HTMLDivElement>(null);
const extensions = [
xml(),
EditorView.lineWrapping,
keymap.of([indentWithTab]),
indentUnit.of(" "),
];
useEffect(() => {
if (instructionsRef.current) {
instructionsRef.current.focus();
}
}, []);
const generatePlan = async () => {
setIsLoading(true);
setLoadingMessage("Queued...");
try {
setCurrentFileChangeRequests([]);
const response = await fetch("/api/openai/edit", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
userMessage: userMessagePrompt
.replace("{userRequest}", instructions)
.replace(
"{readOnlyFiles}",
Object.keys(snippets)
.map((filePath) =>
readOnlyFileFormat
.replace("{file}", snippets[filePath].file)
.replace(
"{start_line}",
snippets[filePath].start.toString(),
)
.replace("{end_line}", snippets[filePath].end.toString())
.replace("{contents}", snippets[filePath].content),
)
.join("\n"),
),
systemMessagePrompt,
}),
});
setLoadingMessage("Planning...");
const reader = response.body?.getReader();
const decoder = new TextDecoder("utf-8");
let rawText = "";
while (true) {
const { done, value } = await reader?.read()!;
if (done) {
break;
}
const text = decoder.decode(value);
rawText += text;
setRawResponse(rawText);
if (thoughtsRef.current) {
thoughtsRef.current.scrollTop = thoughtsRef.current.scrollHeight || 0;
}
const fileChangeRequestMatches = rawText.matchAll(
fileChangeRequestPattern,
);
var fileChangeRequests = [];
for (const match of fileChangeRequestMatches) {
const file: string = match.groups?.cFile || match.groups?.mFile || "";
const relevantFiles: string = match.groups?.relevant_files || "";
const instructions: string =
match.groups?.cInstructions || match.groups?.mInstructions || "";
const changeType: "create" | "modify" = match.groups?.cInstructions
? "create"
: "modify";
const contents: string =
(await getFile(repoName, file)).contents || "";
const startLine: string | undefined = match.groups?.startLine;
const start: number =
startLine === undefined ? 0 : parseInt(startLine);
const endLine: string | undefined = match.groups?.endLine;
const end: number = Math.max(
endLine === undefined
? contents.split("\n").length
: parseInt(endLine),
start + 10,
);
fileChangeRequests.push({
snippet: {
start,
end,
file: file,
entireFile: contents,
content: contents.split("\n").slice(start, end).join("\n"),
},
newContents: contents,
changeType,
hideMerge: true,
instructions: instructions.trim(),
isLoading: false,
readOnlySnippets: {},
status: "idle",
} as FileChangeRequest);
}
setCurrentFileChangeRequests(fileChangeRequests);
if (planRef.current) {
const delta = 50; // Define a delta for the inequality check
if (
Math.abs(
planRef.current.scrollHeight -
planRef.current.scrollTop -
planRef.current.clientHeight,
) < delta
) {
planRef.current.scrollTop = planRef.current.scrollHeight || 0;
}
}
}
} catch (e) {
console.error(e);
toast.error("An error occurred while generating the plan.");
} finally {
setIsLoading(false);
setLoadingMessage("");
}
};
const setUserSuggestion = (
suggestion: SuggestionDataItem,
search: string,
highlightedDisplay: ReactNode,
index: number,
focused: boolean,
) => {
const maxLength = 50;
const suggestedFileName =
suggestion.display!.length < maxLength
? suggestion.display
: "..." +
suggestion.display!.slice(
suggestion.display!.length - maxLength,
suggestion.display!.length,
);
if (index > 10) {
return null;
}
return (
<div
className={`user ${focused ? "bg-zinc-800" : "bg-zinc-900"} p-2 text-sm hover:text-white`}
>
{suggestedFileName}
</div>
);
};
return (
<div className="flex flex-col h-full">
<div className="flex flex-row justify-between items-center mb-2">
<Label className="mr-2">Instructions</Label>
{/* <Button variant="secondary">
Search
</Button> */}
</div>
<MentionsInput
className="min-h-[100px] w-full rounded-md border border-input bg-background MentionsInput mb-2"
placeholder="Describe the changes you want to make here."
value={instructions}
onKeyDown={(e: any) => {
if (e.key === "Enter" && e.ctrlKey) {
e.preventDefault();
generatePlan();
}
}}
onChange={(e: any) => setInstructions(e.target.value)}
onBlur={(e: any) => setInstructions(e.target.value)}
inputRef={instructionsRef}
autoFocus
>
<Mention
trigger="@"
data={files.map((file) => ({ id: file.label, display: file.label }))}
renderSuggestion={setUserSuggestion}
onAdd={async (currentValue) => {
const contents = (await getFile(repoName, currentValue.toString()))
.contents;
const newSnippet = {
file: currentValue,
start: 0,
end: contents.split("\n").length,
entireFile: contents,
content: contents,
} as Snippet;
setSnippets((newSnippets) => {
return {
...newSnippets,
[currentValue]: newSnippet,
};
});
}}
appendSpaceOnAdd={true}
/>
</MentionsInput>
<div hidden={Object.keys(snippets).length === 0} className="mb-4">
{Object.keys(snippets).map((snippetFile: string, index: number) => (
<Badge
variant="secondary"
key={index}
className="bg-zinc-800 text-zinc-300 mr-1"
>
{snippetFile.split("/")[snippetFile.split("/").length - 1]}
<FaTimes
key={String(index) + "-remove"}
className="bg-zinc-800 cursor-pointer ml-1"
onClick={() => {
setSnippets((snippets: { [key: string]: Snippet }) => {
const { [snippetFile]: _, ...newSnippets } = snippets;
return newSnippets;
});
}}
/>
</Badge>
))}
</div>
{Object.keys(snippets).length === 0 && (
<div className="text-xs px-2 text-zinc-400">
No files added yet. Type @ to add a file.
</div>
)}
<div className="text-right mb-2">
<Button
className="mb-2 mt-2"
variant="secondary"
onClick={generatePlan}
disabled={
isLoading ||
instructions === "" ||
Object.keys(snippets).length === 0
}
>
Generate Plan
</Button>
</div>
<div className="flex flex-row mb-2 items-center">
<Label className="mb-0">Sweep&apos;s Plan</Label>
<div className="grow"></div>
<Label className="text-zinc-400 mb-0">Debug mode</Label>
<Switch
className="ml-2"
checked={debugLogToggle}
onClick={() => setDebugLogToggle((debugLogToggle) => !debugLogToggle)}
>
Debug mode
</Switch>
</div>
<div className="overflow-y-auto" ref={planRef}>
{debugLogToggle ? (
<CodeMirror
value={rawResponse}
extensions={extensions}
// onChange={onChange}
theme={vscodeDark}
style={{ overflow: "auto" }}
placeholder="Empty file"
className="ph-no-capture"
/>
) : (
<>
{currentFileChangeRequests.map((fileChangeRequest, index) => {
const filePath = fileChangeRequest.snippet.file;
var path = filePath.split("/");
const fileName = path.pop();
if (path.length > 2) {
path = path.slice(0, 1).concat(["..."]);
}
return (
<div className="rounded border p-3 mb-2" key={index}>
<div className="flex flex-row justify-between mb-2 p-2">
{fileChangeRequest.changeType === "create" ? (
<div className="font-mono">
<span className="text-zinc-400">{path.join("/")}/</span>
<span>{fileName}</span>
</div>
) : (
<div className="font-mono">
<span className="text-zinc-400">{path.join("/")}/</span>
<span>{fileName}</span>
<span className="text-zinc-400">
:{fileChangeRequest.snippet.start}-
{fileChangeRequest.snippet.end}
</span>
</div>
)}
<span className="font-mono text-zinc-400">
{capitalize(fileChangeRequest.changeType)}
</span>
</div>
<Markdown
className="react-markdown mb-2"
components={{
code(props) {
const { children, className, node, ...rest } = props;
const match = /language-(\w+)/.exec(className || "");
return match ? (
// @ts-ignore
<SyntaxHighlighter
{...rest}
PreTag="div"
language={match[1]}
style={codeStyle}
>
{String(children).replace(/\n$/, "")}
</SyntaxHighlighter>
) : (
<code {...rest} className={className}>
{children}
</code>
);
},
}}
>
{fileChangeRequest.instructions}
</Markdown>
{fileChangeRequest.changeType === "modify" && (
<>
<Label>Snippet Preview</Label>
<CodeMirror
value={fileChangeRequest.snippet.content}
extensions={[
...extensions,
lineNumbers({
formatNumber: (num: number) => {
return (
num + fileChangeRequest.snippet.start
).toString();
},
}),
]}
theme={vscodeDark}
style={{ overflow: "auto" }}
placeholder={"No plan generated yet."}
className="ph-no-capture"
maxHeight="150px"
/>
</>
)}
</div>
);
})}
{currentFileChangeRequests.length === 0 && (
<div className="text-zinc-500">No plan generated yet.</div>
)}
</>
)}
</div>
<div className="grow"></div>
<div className="text-right">
<Button
variant="secondary"
className="bg-blue-800 hover:bg-blue-900 mt-4"
onClick={async (e) => {
setFileChangeRequests((prev: FileChangeRequest[]) => {
return [...prev, ...currentFileChangeRequests];
});
setCurrentTab("coding");
}}
disabled={isLoading}
>
Accept Plan
</Button>
</div>
</div>
);
};

"use client";
import { Input } from "../ui/input";
import { ResizablePanel } from "../ui/resizable";
import { Textarea } from "../ui/textarea";
import React, { useEffect, useRef, useState } from "react";
import { Button } from "../ui/button";
import getFiles, { getFile, runScript, writeFile } from "../../lib/api.service";
import { toast } from "sonner";
import { FaPlay } from "react-icons/fa6";
import { useLocalStorage } from "usehooks-ts";
import { Label } from "../ui/label";
import { CaretSortIcon } from "@radix-ui/react-icons";
import { Collapsible, CollapsibleContent } from "../ui/collapsible";
import { Snippet } from "../../lib/search";
import DashboardInstructions from "./DashboardInstructions";
import { FileChangeRequest, Message, fcrEqual } from "../../lib/types";
import {
AlertDialog,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "../ui/alert-dialog";
import { FaCog, FaQuestion } from "react-icons/fa";
import { Switch } from "../ui/switch";
import { usePostHog } from "posthog-js/react";
import { Dialog, DialogContent } from "../ui/dialog";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "../ui/tabs";
import DashboardPlanning from "./DashboardPlanning";
import { useRecoilState } from "recoil";
import { FileChangeRequestsState } from "../../state/fcrAtoms";
import {
parseRegexFromOpenAICreate,
parseRegexFromOpenAIModify,
} from "../../lib/patchUtils";
import {
setIsLoading,
setFileForFCR,
setOldFileForFCR,
setStatusForFCR,
setDiffForFCR,
} from "../../state/fcrStateHelpers";
const Diff = require("diff");
const systemMessagePromptCreate = `You are creating a file of code in order to solve a user's request. You will follow the request under "# Request" and respond based on the format under "# Format".
# Request
file_name: "{filename}"
{instructions}`;
const changesMadePrompt = `The following changes have already been made as part of this task in unified diff format:
<changes_made>
{changesMade}
</changes_made>`;
const userMessagePromptCreate = `Here are relevant read-only files:
<read_only_files>
{readOnlyFiles}
</read_only_files>
Your job is to create a new code file in order to complete the user's request:
<user_request>
{prompt}
</user_request>`;
const userMessagePrompt = `Here are relevant read-only files:
<read_only_files>
{readOnlyFiles}
</read_only_files>
Here are the file's current contents:
<file_to_modify>
{fileContents}
</file_to_modify>
Your job is to modify the current code file in order to complete the user's request:
<user_request>
{prompt}
</user_request>`;
const readOnlyFileFormat = `<read_only_file file="{file}" start_line="{start_line}" end_line="{end_line}">
{contents}
</read_only_file>`;
const retryChangesMadePrompt = `The following error occurred while editing the code. The following changes have been made:
<changes_made>
{changes_made}
</changes_made>
However, the following error occurred while editing the code:
<error_message>
{error_message}
</error_message>
Please identify the error and how to correct the error. Then rewrite the diff hunks with the corrections to continue to modify the code.`;
const retryPrompt = `The following error occurred while generating the code:
<error_message>
{error_message}
</error_message>
Please identify the error and how to correct the error. Then rewrite the diff hunks with the corrections to continue to modify the code.`;
const createPatch = (filePath: string, oldFile: string, newFile: string) => {
if (oldFile === newFile) {
return "";
}
return Diff.createPatch(filePath, oldFile, newFile);
};
const formatUserMessage = (
request: string,
fileContents: string,
snippets: Snippet[],
patches: string,
changeType: string,
) => {
const patchesSection =
patches.trim().length > 0
? changesMadePrompt.replace("{changesMade}", patches.trimEnd()) + "\n\n"
: "";
let basePrompt = userMessagePrompt;
if (changeType == "create") {
basePrompt = userMessagePromptCreate;
}
const userMessage =
patchesSection +
basePrompt
.replace("{prompt}", request)
.replace("{fileContents}", fileContents)
.replace(
"{readOnlyFiles}",
snippets
.map((snippet) =>
readOnlyFileFormat
.replace("{file}", snippet.file)
.replace("{start_line}", snippet.start.toString())
.replace("{end_line}", snippet.end.toString())
.replace("{contents}", snippet.content),
)
.join("\n"),
);
return userMessage;
};
const DashboardActions = ({
filePath,
setScriptOutput,
file,
fileLimit,
setFileLimit,
blockedGlobs,
setBlockedGlobs,
hideMerge,
setHideMerge,
repoName,
setRepoName,
setStreamData,
files,
directories,
currentFileChangeRequestIndex,
setCurrentFileChangeRequestIndex,
setOutputToggle,
setLoadingMessage,
}: {
filePath: string;
setScriptOutput: React.Dispatch<React.SetStateAction<string>>;
file: string;
fileLimit: number;
setFileLimit: React.Dispatch<React.SetStateAction<number>>;
blockedGlobs: string;
setBlockedGlobs: React.Dispatch<React.SetStateAction<string>>;
hideMerge: boolean;
setHideMerge: (newHideMerge: boolean, fcr: FileChangeRequest) => void;
repoName: string;
setRepoName: React.Dispatch<React.SetStateAction<string>>;
setStreamData: React.Dispatch<React.SetStateAction<string>>;
files: { label: string; name: string }[];
directories: { label: string; name: string }[];
currentFileChangeRequestIndex: number;
setCurrentFileChangeRequestIndex: React.Dispatch<
React.SetStateAction<number>
>;
setOutputToggle: (newOutputToggle: string) => void;
setLoadingMessage: React.Dispatch<React.SetStateAction<string>>;
}) => {
const [fileChangeRequests, setFileChangeRequests] = useRecoilState(
FileChangeRequestsState,
);
const posthog = usePostHog();
const validationScriptPlaceholder = `Example: python3 -m py_compile $FILE_PATH\npython3 -m pylint $FILE_PATH --error-only`;
const testScriptPlaceholder = `Example: python3 -m pytest $FILE_PATH`;
const [validationScript, setValidationScript] = useLocalStorage(
"validationScript",
"",
);
const [testScript, setTestScript] = useLocalStorage("testScript", "");
const [currentRepoName, setCurrentRepoName] = useState(repoName);
const [currentBlockedGlobs, setCurrentBlockedGlobs] = useState(blockedGlobs);
const [repoNameCollapsibleOpen, setRepoNameCollapsibleOpen] = useLocalStorage(
"repoNameCollapsibleOpen",
false,
);
const [validationScriptCollapsibleOpen, setValidationScriptCollapsibleOpen] =
useLocalStorage("validationScriptCollapsibleOpen", false);
const [doValidate, setDoValidate] = useLocalStorage("doValidation", true);
const isRunningRef = useRef(false);
const [alertDialogOpen, setAlertDialogOpen] = useState(false);
const [currentTab = "coding", setCurrentTab] = useLocalStorage(
"currentTab",
"planning" as "planning" | "coding",
);
const refreshFiles = async () => {
try {
let { directories, sortedFiles } = await getFiles(
currentRepoName,
blockedGlobs,
fileLimit,
);
if (sortedFiles.length === 0) {
throw new Error("No files found in the repository");
}
toast.success(
`Successfully fetched ${sortedFiles.length} files from the repository!`,
{ action: { label: "Dismiss", onClick: () => {} } },
);
setCurrentRepoName((currentRepoName: string) => {
setRepoName(currentRepoName);
return currentRepoName;
});
} catch (e) {
console.error(e);
toast.error("An Error Occured", {
description: "Please enter a valid repository name.",
action: { label: "Dismiss", onClick: () => {} },
});
}
};
useEffect(() => {
setRepoNameCollapsibleOpen(repoName === "");
}, [repoName]);
useEffect(() => {
if (repoName === "") {
setRepoNameCollapsibleOpen(true);
}
}, [repoName]);
const runScriptWrapper = async (newFile: string) => {
const response = await runScript(
repoName,
filePath,
validationScript + "\n" + testScript,
newFile,
);
const { code } = response;
let scriptOutput = response.stdout + "\n" + response.stderr;
if (code != 0) {
scriptOutput = `Error (exit code ${code}):\n` + scriptOutput;
}
if (response.code != 0) {
toast.error("An Error Occured", {
description: [
<div key="stdout">{(response.stdout || "").slice(0, 800)}</div>,
<div className="text-red-500" key="stderr">
{(response.stderr || "").slice(0, 800)}
</div>,
],
action: { label: "Dismiss", onClick: () => {} },
});
} else {
toast.success("The script ran successfully", {
description: [
<div key="stdout">{(response.stdout || "").slice(0, 800)}</div>,
<div key="stderr">{(response.stderr || "").slice(0, 800)}</div>,
],
action: { label: "Dismiss", onClick: () => {} },
});
}
setScriptOutput(scriptOutput);
};
const checkCode = async (sourceCode: string, filePath: string) => {
const response = await fetch(
"/api/files/check?" +
new URLSearchParams({ filePath, sourceCode }).toString(),
);
return await response.text();
};
const checkForErrors = async (
filePath: string,
oldFile: string,
newFile: string,
) => {
setLoadingMessage("Validating...");
if (!doValidate) {
return "";
}
const parsingErrorMessageOld = checkCode(oldFile, filePath);
const parsingErrorMessage = checkCode(newFile, filePath);
if (!parsingErrorMessageOld && parsingErrorMessage) {
return parsingErrorMessage;
}
var { stdout, stderr, code } = await runScript(
repoName,
filePath,
validationScript,
oldFile,
);
var { stdout, stderr, code } = await runScript(
repoName,
filePath,
validationScript,
newFile,
);
// TODO: add diff
setScriptOutput(stdout + "\n" + stderr);
return code !== 0 ? stdout + "\n" + stderr : "";
};
// modify an existing file or create a new file
const getFileChanges = async (fcr: FileChangeRequest, index: number) => {
console.log("getting file changes")
var validationOutput = "";
const patches = fileChangeRequests
.slice(0, index)
.map((fcr: FileChangeRequest) => {
return fcr.diff;
})
.join("\n\n");
setIsLoading(true, fcr, fileChangeRequests, setFileChangeRequests);
setStatusForFCR(
"in-progress",
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setOutputToggle("llm");
setLoadingMessage("Queued...");
const changeType = fcr.changeType;
// by default we modify file
let url = "/api/openai/edit";
let prompt = fcr.instructions;
if (changeType === "create") {
url = "/api/openai/create";
prompt = systemMessagePromptCreate
.replace("{instructions}", fcr.instructions)
.replace("{filename}", fcr.snippet.file);
}
const body = {
prompt: prompt,
snippets: Object.values(fcr.readOnlySnippets),
};
const additionalMessages: Message[] = [];
var currentIterationContents = (fcr.snippet.entireFile || "").replace(
/\\n/g,
"\\n",
);
let errorMessage = "";
let userMessage = formatUserMessage(
fcr.instructions,
currentIterationContents,
Object.values(fcr.readOnlySnippets),
patches,
changeType,
);
if (changeType === "create") {
userMessage =
systemMessagePromptCreate
.replace("{instructions}", fcr.instructions)
.replace("{filename}", fcr.snippet.file) + userMessage;
}
isRunningRef.current = true;
setScriptOutput(validationOutput);
setStreamData("");
if (!hideMerge) {
setFileChangeRequests((prev: FileChangeRequest[]) => {
setHideMerge(true, fcr);
setFileForFCR(
prev[index].snippet.entireFile,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
return prev;
});
}
const maxIterations = 3;
for (let i = 0; i <= maxIterations; i++) {
if (!isRunningRef.current) {
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
return;
}
if (i !== 0) {
var retryMessage = "";
if (fcr.snippet.entireFile === currentIterationContents) {
retryMessage = retryChangesMadePrompt.replace(
"{changes_made}",
createPatch(
fcr.snippet.file,
fcr.snippet.entireFile,
currentIterationContents,
),
);
} else {
retryMessage = retryPrompt;
}
retryMessage = retryMessage.replace(
"{error_message}",
errorMessage.trim(),
);
userMessage = retryMessage;
}
setLoadingMessage("Queued...");
const response = await fetch(url, {
method: "POST",
body: JSON.stringify({
...body,
fileContents: currentIterationContents,
additionalMessages,
userMessage,
}),
});
setLoadingMessage("Generating code...");
additionalMessages.push({ role: "user", content: userMessage });
errorMessage = "";
var currentContents = currentIterationContents;
const updateIfChanged = (newContents: string) => {
if (newContents !== currentIterationContents) {
setFileForFCR(
newContents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
currentContents = newContents;
}
};
try {
const reader = response.body!.getReader();
const decoder = new TextDecoder("utf-8");
let rawText = String.raw``;
setHideMerge(false, fcr);
var j = 0;
let globalUpdatedFile = ""; // this is really jank and bad but a quick fix because of the async nature of setters in react
while (isRunningRef.current) {
var { done, value } = await reader?.read();
if (done) {
let updatedFile = "";
let patchingErrors = "";
if (changeType == "modify") {
let [newUpdatedFile, newPatchingErrors] =
parseRegexFromOpenAIModify(
rawText || "",
currentIterationContents,
);
updatedFile = newUpdatedFile;
patchingErrors = newPatchingErrors;
} else if (changeType == "create") {
let [newUpdatedFile, newPatchingErrors] =
parseRegexFromOpenAICreate(
rawText || "",
currentIterationContents,
);
updatedFile = newUpdatedFile;
patchingErrors = newPatchingErrors;
}
if (patchingErrors) {
errorMessage += patchingErrors;
} else {
errorMessage = await checkForErrors(
fcr.snippet.file,
fcr.snippet.entireFile,
updatedFile,
);
}
additionalMessages.push({ role: "assistant", content: rawText });
updateIfChanged(updatedFile);
globalUpdatedFile = updatedFile;
rawText += "\n\n";
setStreamData((prev) => prev + "\n\n");
break;
}
const text = decoder.decode(value);
rawText += text;
setStreamData((prev: string) => prev + text);
try {
let updatedFile = "";
let _ = "";
if (changeType == "modify") {
[updatedFile, _] = parseRegexFromOpenAIModify(
rawText,
currentIterationContents,
);
} else if (changeType == "create") {
[updatedFile, _] = parseRegexFromOpenAICreate(
rawText,
currentIterationContents,
);
}
if (j % 3 == 0) {
updateIfChanged(updatedFile);
}
j += 1;
} catch (e) {
console.error(e);
}
}
if (!isRunningRef.current) {
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
setLoadingMessage("");
setStatusForFCR(
"idle",
fcr,
fileChangeRequests,
setFileChangeRequests,
);
return;
}
setHideMerge(false, fcr);
const changeLineCount = Math.abs(
fcr.snippet.entireFile.split("\n").length -
globalUpdatedFile.split("\n").length,
);
const changeCharCount = Math.abs(
fcr.snippet.entireFile.length - globalUpdatedFile.length,
);
if (errorMessage.length > 0) {
console.error("errorMessage in loop", errorMessage);
toast.error(
"An error occured while generating your code." +
(i < 3 ? " Retrying..." : " Retried 4 times so I will give up."),
{
description: errorMessage.slice(0, 800),
action: { label: "Dismiss", onClick: () => {} },
},
);
validationOutput += "\n\n" + errorMessage;
setScriptOutput(validationOutput);
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
setStatusForFCR(
"in-progress",
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setLoadingMessage("Retrying...");
} else {
toast.success(`Successfully modified file!`, {
description: [
<div key="stdout">{`There were ${changeLineCount} line and ${changeCharCount} character changes made.`}</div>,
],
action: { label: "Dismiss", onClick: () => {} },
});
const newDiff = Diff.createPatch(
filePath,
fcr.snippet.entireFile,
fcr.newContents,
);
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
setDiffForFCR(
newDiff,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
isRunningRef.current = false;
break;
}
} catch (e: any) {
console.error("errorMessage in except block", errorMessage);
toast.error("An error occured while generating your code.", {
description: e,
action: { label: "Dismiss", onClick: () => {} },
});
if (i === maxIterations) {
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
setStatusForFCR(
"error",
fcr,
fileChangeRequests,
setFileChangeRequests,
);
isRunningRef.current = false;
setLoadingMessage("");
return;
}
}
}
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
setStatusForFCR("done", fcr, fileChangeRequests, setFileChangeRequests);
isRunningRef.current = false;
setLoadingMessage("");
return;
};
// syncronously modify/create all files
const getAllFileChanges = async (fcrs: FileChangeRequest[]) => {
for (let index = 0; index < fcrs.length; index++) {
const fcr = fcrs[index];
await getFileChanges(fcr, index);
}
};
const saveAllFiles = async (fcrs: FileChangeRequest[]) => {
for await (const [index, fcr] of fcrs.entries()) {
setOldFileForFCR(
fcr.newContents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setHideMerge(true, fcr);
await writeFile(repoName, fcr.snippet.file, fcr.newContents);
}
toast.success(`Succesfully saved ${fcrs.length} files!`, {
action: { label: "Dismiss", onClick: () => {} },
});
};
const syncAllFiles = async () => {
fileChangeRequests.forEach(
async (fcr: FileChangeRequest, index: number) => {
const response = await getFile(repoName, fcr.snippet.file);
setFileForFCR(
response.contents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setOldFileForFCR(
response.contents,
fcr,
fileChangeRequests,
setFileChangeRequests,
);
setIsLoading(false, fcr, fileChangeRequests, setFileChangeRequests);
},
);
};
return (
<ResizablePanel defaultSize={35} className="p-6 h-[90vh]">
<Tabs
defaultValue="coding"
className="h-full w-full"
value={currentTab}
onValueChange={(value) => setCurrentTab(value as "planning" | "coding")}
>
<div className="flex flex-row justify-between">
<div className="flex flex-row">
<TabsList>
<TabsTrigger value="planning">Planning</TabsTrigger>
<TabsTrigger value="coding">Coding</TabsTrigger>
</TabsList>
</div>
<div>
<Dialog
defaultOpen={repoName === ""}
open={repoNameCollapsibleOpen}
onOpenChange={(open) => setRepoNameCollapsibleOpen(open)}
>
<Button
variant="secondary"
className={`${repoName === "" ? "bg-blue-800 hover:bg-blue-900" : ""} h-full`}
size="sm"
onClick={() => setRepoNameCollapsibleOpen((open) => !open)}
>
<FaCog />
&nbsp;&nbsp;Repository Settings
<span className="sr-only">Toggle</span>
</Button>
<DialogContent className="CollapsibleContent">
<div>
<Label className="mb-2">Repository Path</Label>
<Input
id="name"
placeholder="/Users/sweep/path/to/repo"
value={currentRepoName}
className="col-span-4 w-full"
onChange={(e) => setCurrentRepoName(e.target.value)}
onBlur={refreshFiles}
/>
<p className="text-sm text-muted-foreground mb-4">
Absolute path to your repository.
</p>
<Label className="mb-2">Blocked Keywords</Label>
<Input
className="mb-4"
value={currentBlockedGlobs}
onChange={(e) => {
setCurrentBlockedGlobs(e.target.value);
}}
onBlur={() => {
setBlockedGlobs(currentBlockedGlobs);
}}
placeholder="node_modules, .log, build"
/>
<Label className="mb-2">File Limit</Label>
<Input
value={fileLimit}
onChange={(e) => {
setFileLimit(parseInt(e.target.value));
}}
placeholder="10000"
type="number"
/>
</div>
</DialogContent>
</Dialog>
</div>
</div>
<TabsContent
value="planning"
className="rounded-xl border h-full p-4 h-[95%]"
>
<DashboardPlanning
repoName={repoName}
files={files}
setLoadingMessage={setLoadingMessage}
setCurrentTab={setCurrentTab}
/>
</TabsContent>
<TabsContent value="coding" className="h-full">
<div className="flex flex-col h-[95%]">
<DashboardInstructions
filePath={filePath}
repoName={repoName}
files={files}
directories={directories}
currentFileChangeRequestIndex={currentFileChangeRequestIndex}
setCurrentFileChangeRequestIndex={
setCurrentFileChangeRequestIndex
}
getFileChanges={getFileChanges}
isRunningRef={isRunningRef}
syncAllFiles={syncAllFiles}
getAllFileChanges={() => getAllFileChanges(fileChangeRequests)}
setCurrentTab={setCurrentTab}
/>
<Collapsible
open={validationScriptCollapsibleOpen}
className="border-2 rounded p-4"
>
<div className="flex flex-row justify-between items-center">
<Label className="mb-0 flex flex-row items-center">
Checks&nbsp;
<AlertDialog open={alertDialogOpen}>
<Button
variant="secondary"
size="sm"
className="rounded-lg ml-1 mr-2"
onClick={() => setAlertDialogOpen(true)}
>
<FaQuestion style={{ fontSize: 12 }} />
</Button>
<Switch
checked={doValidate}
onClick={() => setDoValidate(!doValidate)}
disabled={fileChangeRequests.some(
(fcr: FileChangeRequest) => fcr.isLoading,
)}
/>
<AlertDialogContent className="p-12">
<AlertDialogHeader>
<AlertDialogTitle className="text-5xl mb-2">
Test and Validation Scripts
</AlertDialogTitle>
<AlertDialogDescription className="text-md pt-4">
<p>
We highly recommend setting up the validation script
to allow Sweep to iterate against static analysis
tools to ensure valid code is generated. You can
this off by clicking the switch.
</p>
<h2 className="text-2xl mt-4 mb-2 text-zinc-100">
Validation Script
</h2>
<p>
Sweep runs validation after every edit, and will try
to auto-fix any errors.
<br />
<br />
We recommend a syntax checker (a formatter suffices)
and a linter. We also recommend using your local
environment, to ensure we use your dependencies.
<br />
<br />
For example, for Python you can use:
<pre className="py-4">
<code>
python -m py_compile $FILE_PATH
<br />
pylint $FILE_PATH --error-only
</code>
</pre>
And for JavaScript you can use:
<pre className="py-4">
<code>
prettier $FILE_PATH
<br />
eslint $FILE_PATH
</code>
</pre>
</p>
<h2 className="text-2xl mt-4 mb-2 text-zinc-100">
Test Script
</h2>
<p>
You can run tests after all the files have been
edited by Sweep.
<br />
<br />
E.g. For example, for Python you can use:
<pre className="py-4">pytest $FILE_PATH</pre>
And for JavaScript you can use:
<pre className="py-4">jest $FILE_PATH</pre>
</p>
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel
onClick={() => setAlertDialogOpen(false)}
>
Close
</AlertDialogCancel>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</Label>
<div className="grow"></div>
<Button
variant="secondary"
onClick={async () => {
posthog.capture("run_tests", {
name: "Run Tests",
repoName: repoName,
filePath: filePath,
validationScript: validationScript,
testScript: testScript,
});
await runScriptWrapper(file);
}}
disabled={
fileChangeRequests.some(
(fcr: FileChangeRequest) => fcr.isLoading,
) || !doValidate
}
size="sm"
className="mr-2"
>
<FaPlay />
&nbsp;&nbsp;Run Tests
</Button>
<Button
variant="secondary"
size="sm"
onClick={() =>
setValidationScriptCollapsibleOpen((open: boolean) => !open)
}
>
{!validationScriptCollapsibleOpen ? "Expand" : "Collapse"}
&nbsp;&nbsp;
<CaretSortIcon className="h-4 w-4" />
<span className="sr-only">Toggle</span>
</Button>
</div>
<CollapsibleContent className="pt-2 CollapsibleContent">
<Label className="mb-0">Validation Script&nbsp;</Label>
<Textarea
id="script-input"
placeholder={validationScriptPlaceholder}
className="col-span-4 w-full font-mono height-fit-content"
value={validationScript}
onChange={(e) => setValidationScript(e.target.value)}
disabled={
fileChangeRequests.some(
(fcr: FileChangeRequest) => fcr.isLoading,
) || !doValidate
}
></Textarea>
<Label className="mb-0">Test Script</Label>
<Textarea
id="script-input"
placeholder={testScriptPlaceholder}
className="col-span-4 w-full font-mono height-fit-content"
value={testScript}
onChange={(e) => setTestScript(e.target.value)}
disabled={
fileChangeRequests.some(
(fcr: FileChangeRequest) => fcr.isLoading,
) || !doValidate
}
></Textarea>
<p className="text-sm text-muted-foreground mb-4">
Use $FILE_PATH to refer to the file you selected. E.g. `python
$FILE_PATH`.
</p>
</CollapsibleContent>
</Collapsible>
</div>
</TabsContent>
</Tabs>
</ResizablePanel>
);
};

import { CheckIcon } from "@radix-ui/react-icons";
import { Popover, PopoverTrigger, PopoverContent } from "../../ui/popover";
import {
Command,
CommandInput,
CommandEmpty,
CommandGroup,
CommandItem,
} from "../../ui/command";
import React, { useState } from "react";
import { getFile } from "../../../lib/api.service";
import { Snippet } from "../../../lib/search";
import { cn } from "../../../lib/utils";
import { Button } from "../../ui/button";
import { FileChangeRequest } from "../../../lib/types";
import { FaPlus } from "react-icons/fa6";
import { FileChangeRequestsState } from "../../../state/fcrAtoms";
import { useRecoilState } from "recoil";
const CreationPanel = ({
filePath,
repoName,
files,
directories,
setCurrentTab,
}: {
filePath: string;
repoName: string;
files: { label: string; name: string }[];
directories: { label: string; name: string }[];
setCurrentTab: React.Dispatch<React.SetStateAction<"planning" | "coding">>;
}) => {
const [hidePanel, setHidePanel] = useState(true);
const [openModify, setOpenModify] = useState(false);
const [openCreate, setOpenCreate] = useState(false);
const [fileChangeRequests, setFileChangeRequests] = useRecoilState(
FileChangeRequestsState,
);
return (
<div
id="creation-panel-wrapper"
onMouseEnter={() => setHidePanel(false)}
onMouseLeave={() => setHidePanel(true)}
>
<div id="creation-panel-plus-sign-wraper" hidden={!hidePanel}>
<div className="flex flex-row w-full h-[80px] overflow-auto items-center mb-4">
<Button
variant="outline"
className="w-full h-full justify-center overflow-hidden bg-zinc-800 hover:bg-zinc-900 items-center"
>
<FaPlus className="mr-2" />
</Button>
</div>
</div>
<div id="creation-panel-actions-panel-wrapper" hidden={hidePanel}>
<div
id="creation-panel-actions-panel"
className="flex flex-row w-full h-[80px] mb-4 border rounded items-center"
>
<Popover open={openModify} onOpenChange={setOpenModify}>
<div className="w-full h-full overflow-auto">
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openModify}
className="border rounded-none w-full h-full bg-zinc-800 hover:bg-zinc-900 text-lg"
>
Modify file
</Button>
</PopoverTrigger>
</div>
<PopoverContent className="w-full p-0 text-left">
<Command>
<CommandInput
placeholder="Search for a file to modify..."
className="h-9"
/>
<CommandEmpty>No file found.</CommandEmpty>
<CommandGroup>
{files.map((file: any) => (
<CommandItem
key={file.value}
value={file.value}
onSelect={async (currentValue) => {
// ensure file is not already included
if (
fileChangeRequests.some(
(fcr: FileChangeRequest) =>
fcr.snippet.file === file.value,
)
) {
return;
}
const contents =
(await getFile(repoName, file.value)).contents || "";
setFileChangeRequests((prev: FileChangeRequest[]) => {
let snippet = {
file: file.value,
start: 0,
end: contents.split("\n").length,
entireFile: contents,
content: contents, // this is the slice based on start and end, remeber to change this
} as Snippet;
return [
...prev,
{
snippet,
changeType: "modify",
newContents: contents,
hideMerge: true,
instructions: "",
isLoading: false,
readOnlySnippets: {},
diff: "",
status: "idle",
} as FileChangeRequest,
];
});
setOpenModify(false);
}}
>
{file.label}
<CheckIcon
className={cn(
"ml-auto h-4 w-4",
filePath === file.value ? "opacity-100" : "opacity-0",
)}
/>
</CommandItem>
))}
</CommandGroup>
</Command>
</PopoverContent>
</Popover>
<Popover open={openCreate} onOpenChange={setOpenCreate}>
<div className="w-full h-full overflow-auto">
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openCreate}
className="border-2 border-black-900 rounded-none w-full h-full bg-zinc-800 hover:bg-zinc-900 text-lg"
>
Create file
</Button>
</PopoverTrigger>
</div>
<PopoverContent className="w-full p-0 text-left">
<Command>
<CommandInput
placeholder="Search for a directory..."
className="h-9"
/>
<CommandEmpty>No directory found.</CommandEmpty>
<CommandGroup>
{directories.map((dir: any) => (
<CommandItem
key={dir.value}
value={dir.value}
onSelect={async (currentValue) => {
setFileChangeRequests((prev: FileChangeRequest[]) => {
let snippet = {
file: dir.value,
start: 0,
end: 0,
entireFile: "",
content: "",
} as Snippet;
return [
...prev,
{
snippet,
changeType: "create",
newContents: "",
hideMerge: true,
instructions: "",
isLoading: false,
readOnlySnippets: {},
diff: "",
status: "idle",
} as FileChangeRequest,
];
});
setOpenCreate(false);
}}
>
{dir.label}
<CheckIcon
className={cn(
"ml-auto h-4 w-4",
filePath === dir.value ? "opacity-100" : "opacity-0",
)}
/>
</CommandItem>
))}
</CommandGroup>
</Command>
</PopoverContent>
</Popover>
<div className="w-full h-full overflow-auto">
<Button
variant="outline"
role="combobox"
className="border rounded-none w-full h-full bg-zinc-800 hover:bg-zinc-900 text-lg"
onClick={() => {
setCurrentTab("planning");
}}
>
Create plan
</Button>
</div>
</div>
</div>
</div>
);
};

import React, { useState } from "react";
import { Button } from "../../ui/button";
import { FaArrowRotateLeft } from "react-icons/fa6";
import { setStatusForAll } from "../../../state/fcrStateHelpers";
import { FileChangeRequestsState } from "../../../state/fcrAtoms";
import { useRecoilState } from "recoil";
const ModifyOrCreate = ({
filePath,
repoName,
files,
directories,
syncAllFiles,
}: {
filePath: string;
repoName: string;
files: { label: string; name: string }[];
directories: { label: string; name: string }[];
syncAllFiles: () => Promise<void>;
}) => {
const [openModify, setOpenModify] = useState(false);
const [openCreate, setOpenCreate] = useState(false);
const [fileChangeRequests, setFileChangeRequests] = useRecoilState(
FileChangeRequestsState,
);
return (
<div className="flex flex-row mb-4">
{/* <Popover open={openModify} onOpenChange={setOpenModify}>
<div className="flex flex-row mb-4 overflow-auto">
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openModify}
className="w-full justify-between overflow-hidden mr-2 bg-blue-800 hover:bg-blue-900"
>
Modify file
<CaretSortIcon className="ml-2 h-4 w-4 shrink-0 opacity-50" />
</Button>
</PopoverTrigger>
</div>
<PopoverContent className="w-full p-0 text-left">
<Command>
<CommandInput placeholder="Search for a file to modify..." className="h-9" />
<CommandEmpty>No file found.</CommandEmpty>
<CommandGroup>
{files.map((file: any) => (
<CommandItem
key={file.value}
value={file.value}
onSelect={async (currentValue) => {
// ensure file is not already included
if (
fileChangeRequests.some(
(fcr: FileChangeRequest) =>
fcr.snippet.file === file.value,
)
) {
return;
}
const contents = (await getFile(repoName, file.value))
.contents || "";
setFileChangeRequests((prev: FileChangeRequest[]) => {
let snippet = {
file: file.value,
start: 0,
end: contents.split("\n").length,
entireFile: contents,
content: contents, // this is the slice based on start and end, remeber to change this
} as Snippet;
return [
...prev,
{
snippet,
changeType: "modify",
newContents: contents,
hideMerge: true,
instructions: "",
isLoading: false,
readOnlySnippets: {},
status: "idle"
} as FileChangeRequest,
];
});
setOpenModify(false);
}}
>
{file.label}
<CheckIcon
className={cn(
"ml-auto h-4 w-4",
filePath === file.value ? "opacity-100" : "opacity-0",
)}
/>
</CommandItem>
))}
</CommandGroup>
</Command>
</PopoverContent>
</Popover>
<Popover open={openCreate} onOpenChange={setOpenCreate}>
<div className="flex flex-row mb-4 overflow-auto">
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openCreate}
className="w-full justify-between overflow-hidden mr-2 bg-blue-800 hover:bg-blue-900"
>
Create file
<CaretSortIcon className="ml-2 h-4 w-4 shrink-0 opacity-50" />
</Button>
</PopoverTrigger>
</div>
<PopoverContent className="w-full p-0 text-left">
<Command>
<CommandInput placeholder="Search for a directory..." className="h-9" />
<CommandEmpty>No directory found.</CommandEmpty>
<CommandGroup>
{directories.map((dir: any) => (
<CommandItem
key={dir.value}
value={dir.value}
onSelect={async (currentValue) => {
setFileChangeRequests((prev: FileChangeRequest[]) => {
let snippet = {
file: dir.value,
start: 0,
end: 0,
entireFile: "",
content: "",
} as Snippet;
return [
...prev,
{
snippet,
changeType: "create",
newContents: "",
hideMerge: true,
instructions: "",
isLoading: false,
readOnlySnippets: {},
status: "idle"
} as FileChangeRequest,
];
});
setOpenCreate(false);
}}
>
{dir.label}
<CheckIcon
className={cn(
"ml-auto h-4 w-4",
filePath === dir.value ? "opacity-100" : "opacity-0",
)}
/>
</CommandItem>
))}
</CommandGroup>
</Command>
</PopoverContent>
</Popover> */}
<div className="grow"></div>
<Button
onClick={() => {
syncAllFiles();
setStatusForAll("idle", setFileChangeRequests);
}}
variant="secondary"
>
<FaArrowRotateLeft style={{ marginTop: -3, fontSize: 12 }} />
&nbsp;&nbsp;Refresh files
</Button>
</div>
);
};

interface File {
name: string;
path: string;
isDirectory: boolean;
content?: string;
snippets?: Snippet[];
}
interface Snippet {
file: string;
start: number;
end: number;
entireFile: string;
content: string;
}
interface FileChangeRequest {
snippet: Snippet;
instructions: string;
newContents: string;
changeType: "create" | "modify";
hideMerge: boolean;
isLoading: boolean;
readOnlySnippets: { [key: string]: Snippet };
diff: string;
status: "queued" | "in-progress" | "done" | "error" | "idle";
}
const fcrEqual = (a: FileChangeRequest, b: FileChangeRequest) => {
return (
a.snippet.file === b.snippet.file &&
a.snippet.start === b.snippet.start &&
a.snippet.end === b.snippet.end
);
};
const snippetKey = (snippet: Snippet) => {
return `${snippet.file}:${snippet.start || 0}-${snippet.end || 0}`;
};
interface Message {
role: "user" | "system" | "assistant";
content: string;
}
export { fcrEqual, snippetKey };

import { type ClassValue, clsx } from "clsx";
import { twMerge } from "tailwind-merge";
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));

@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
:root {
--background: 0 0% 100%;
--foreground: 240 10% 3.9%;
--card: 0 0% 100%;
--card-foreground: 240 10% 3.9%;
--popover: 0 0% 100%;
--popover-foreground: 240 10% 3.9%;
--primary: 240 5.9% 10%;
--primary-foreground: 0 0% 98%;
--secondary: 240 4.8% 95.9%;
--secondary-foreground: 240 5.9% 10%;
--muted: 240 4.8% 95.9%;
--muted-foreground: 240 3.8% 46.1%;
--accent: 240 4.8% 95.9%;
--accent-foreground: 240 5.9% 10%;
--destructive: 0 84.2% 60.2%;
--destructive-foreground: 0 0% 98%;
--border: 240 5.9% 90%;
--input: 240 5.9% 90%;
--ring: 240 10% 3.9%;
--radius: 0.5rem;
}
.dark {
--background: 240 10% 3.9%;
--foreground: 0 0% 98%;
--card: 240 10% 3.9%;
--card-foreground: 0 0% 98%;
--popover: 240 10% 3.9%;
--popover-foreground: 0 0% 98%;
--primary: 0 0% 98%;
--primary-foreground: 240 5.9% 10%;
--secondary: 240 3.7% 15.9%;
--secondary-foreground: 0 0% 98%;
--muted: 240 3.7% 15.9%;
--muted-foreground: 240 5% 64.9%;
--accent: 240 3.7% 15.9%;
--accent-foreground: 0 0% 98%;
--destructive: 0 62.8% 30.6%;
--destructive-foreground: 0 0% 98%;
--border: 240 3.7% 15.9%;
--input: 240 3.7% 15.9%;
--ring: 240 4.9% 83.9%;
}
}
@layer base {
* {
@apply border-border;
font-size: 14px;
}
body {
@apply bg-background text-foreground;
font-size: 14px;
}
div[cmdk-group-items] > div:nth-child(n + 21) {
display: none;
}
}
.CollapsibleContent {
overflow: hidden;
}
.CollapsibleContent[data-state="open"] {
animation: slideDown 300ms ease-out;
}
.CollapsibleContent[data-state="closed"] {
animation: slideUp 300ms ease-out;
}
@keyframes slideDown {
from {
height: 0;
}
to {
height: var(--radix-collapsible-content-height);
}
}
@keyframes slideUp {
from {
height: var(--radix-collapsible-content-height);
}
to {
height: 0;
}
}
.cm-mergeViewEditor::-webkit-scrollbar {
display: none;
}
.cm-mergeViewEditor {
-ms-overflow-style: none;
scrollbar-width: none;
}
.cm-editor::-webkit-scrollbar {
display: none;
}
.cm-editor {
-ms-overflow-style: none;
scrollbar-width: none;
}
.MentionsInput > div > textarea {
padding-top: 0.5rem;
padding-bottom: 0.5rem;
padding-left: 0.75rem;
padding-right: 0.75rem;
font-size: 0.875rem !important;
}
.MentionsInput > div:has(ul) {
margin-top: 25px !important;
border-radius: 5px;
}
.react-markdown {
font-size: 0.75rem;
}
.react-markdown li {
list-style-type: disc;
margin-left: 2rem;
}
.react-markdown code {
color: #e83e8c;
}
.react-markdown pre {
margin-top: 1rem;
margin-bottom: 1rem;
background-color: #1e1e1e;
color: #afafaf;
border-radius: 0.5rem;
overflow-x: auto;
padding: 1rem;
white-space: pre-wrap;
}
.react-markdown pre code {
color: #f4f4f4;
}
.SyntaxHighlighter {
white-space: pre-wrap !important;
}
.SyntaxHighlighter code {
white-space: pre-wrap !important;

import { useState, useEffect } from "react";
import parse from "parse-diff";
import { ShowMore } from "./ShowMore";
import { BiGitMerge } from "react-icons/bi";
import { FiCornerDownRight } from "react-icons/fi";
export function PRPreview({ repoName, prId }) {
const [prData, setPrData] = useState(null)
const [issueData, setIssueData] = useState(null)
const [diffData, setDiffData] = useState(null)
const key = `prData-${repoName}-${prId}-v0`;
const herokuAnywhere = "https://sweep-examples-cors-143adb2b6ffb.herokuapp.com/"
const headers = {}
useEffect(() => {
const fetchPRData = async () => {
try {
const url = `https://api.github.com/repos/${repoName}/pulls/${prId}`;
console.log(url)
const response = await fetch(url, {headers});
console.log(response)
const data = await response.json();
console.log("pr data", data)
setPrData(data);
if (!data.body) {
return;
}
const content = data.body
const issueId = data.body.match(/Fixes #(\d+)/)[1]
if (!issueId) {
return;
}
const issuesUrl = `https://api.github.com/repos/${repoName}/issues/${issueId}`
const issueResponse = await fetch(issuesUrl, {headers});
const issueData = await issueResponse.json();
setIssueData(issueData);
console.log("issueData", issueData)
// const diffResponse = await fetch(herokuAnywhere + data.diff_url); // need better cors solution
const diffResponse = await fetch(`${herokuAnywhere}${data.diff_url}`); // need better cors solution
const diffText = await diffResponse.text();
setDiffData(diffText);
if (!data.diff_url) {
return;
}
} catch (error) {
console.error("Error fetching PR data:", error);
}
};
console.log(localStorage);
if (localStorage) {
try {
const cacheHit = localStorage.getItem(key)
if (cacheHit) {
const { prData, diffData, issueData, timestamp } = JSON.parse(cacheHit)
if (prData && diffData && issueData && timestamp && new Date() - new Date(timestamp) < 1000 * 60 * 60 * 24) {
console.log("cache hit")
setPrData(prData)
setDiffData(diffData)
setIssueData(issueData)
return
}
}
} catch (error) {
console.error("Error parsing cache hit:", error);
}
}
console.log("cache miss")
fetchPRData();
}, [repoName, prId]);
useEffect(() => {
console.log(localStorage);
if (localStorage && prData && diffData && issueData) {
const data = {
prData,
diffData,
issueData,
timestamp: new Date(),
}
localStorage.setItem(key, JSON.stringify(data))
}
}, [prData, diffData, issueData]);
if (!prData || !prData.user) {
return <div>{`https://github.com/${repoName}/pulls/${prId}`}. Loading...</div>;
}
const numberDaysAgoMerged = Math.max(Math.round((new Date() - new Date(prData.merged_at)) / (1000 * 60 * 60 * 24)), 71)
const parsedDiff = parse(diffData)
var issueTitle = issueData != null ? issueData.title.replace("Sweep: ", "") : ""
issueTitle = issueTitle.charAt(0).toUpperCase() + issueTitle.slice(1);
console.log("parsedDiff", parsedDiff)
return (
<>
<style>
{`
.hoverEffect:hover {
// background-color: #222;
}
h5 ::after {
display: none;
}
.clickable {
cursor: pointer;
}
.clickable:hover {
text-decoration: underline;
}
`}
</style>
<div
className="hoverEffect"
style={{
border: "1px solid darkgrey",
borderRadius: 5,
marginTop: 32,
padding: 10,
}}
>
<div style={{display: "flex"}}>
<h5
className="clickable" style={{marginTop: 0, fontWeight: "bold", fontSize: 18}}
onClick={() => window.open(prData.html_url, "_blank")}
>
{prData.title}
</h5>
<span style={{color: "#815b9e", marginTop: 2, display: "flex"}}>
&nbsp;&nbsp;
<BiGitMerge style={{marginTop: 3}}/>
&nbsp;Merged
</span>
</div>
{
prData && (
<div style={{display: "flex", color: "#666"}}>
#{prId} •&nbsp;<span className="clickable" onClick={() => window.open("https://github.com/apps/sweep-ai")}>{prData.user && prData.user.login}</span>&nbsp;•&nbsp;<BiGitMerge style={{marginTop: 3}}/>&nbsp;Merged {numberDaysAgoMerged} days ago by&nbsp;<span className="clickable" onClick={() => window.open(`https://github.com/${prData.mergedBy && prData.merged_by.login || "wwzeng1"}`, "_blank")}>{prData.mergedBy && prData.merged_by.login || "wwzeng1"}</span>
</div>
)
}
<div style={{display: "flex", marginTop: 15, color: "darkgrey"}}>
<FiCornerDownRight style={{marginTop: 3 }} />&nbsp;{issueData && <p className="clickable">Fixes #{issueData.number}{issueTitle}</p>}
</div>
{diffData && (
<>
<hr style={{borderColor: "darkgrey", margin: 20}}/>
<ShowMore>
<div
className="codeBlocks"
style={{
borderRadius: 5,
padding: 10,
transition: "background-color 0.2s linear",
}}
>
{parsedDiff.map(({chunks, from, oldStart}) => (
from !== "/dev/null" && from !== "sweep.yaml" &&
<>
<p style={{
marginTop: 0,
marginBottom: 0,
fontFamily: "ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,monospace"
}}>{from}</p>
{chunks.map(({changes}) =>
<pre style={{
backgroundColor: "#161717",
borderRadius: 10,
whiteSpace: "pre-wrap",
}}>
{changes.map(({content, type}) =>
<>
{type === "add" && content && <div style={{backgroundColor: "#12261e", width: "100%", padding: 4}}>{content}</div>}
{type === "del" && content && <div style={{backgroundColor: "#25171c", width: "100%", padding: 4}}>{content}</div>}
{type === "normal" && content && <div style={{padding: 4}}>{content}</div>}
</>
)}
</pre>
)}
</>
))}
</div>
</ShowMore>
</>
)}
</div>
</>
)

# 🔒 Privacy Policy
Last updated: July 19th
This Privacy Policy describes how Sweep AI ("we", "us", "our") collects, uses, and discloses your Personal Information when you use our app Sweep AI ("the App").
By using the App, you agree to the collection and use of information in accordance with this policy.
## Information Collection and Use
Data Collection
While using our App, we may collect information related to your codebase, commits, GitHub tickets, pull requests (PRs), and PR comments ("Data"). This Data is used to understand how you interact with our App and to improve our services.
- The logs from Sweep(which contain snippets of code) are logged for debugging purposes. These will only be stored for 30 days. We send this data to OpenAI to generate code.
- OpenAI has an agreement stating they will not train on this data and will persist it for 30 days to monitor trust and safety.
- We store the codebase as embeddings, which are not readable as plaintext.
- We use posthog to log telemetry in order to understand how Sweep is being used. This includes the number of tickets created, comments created, and PRs merged.
- No data being used will be sold to third parties.
## Use of Data
We use the collected data for various purposes:
- To automatically generate and modify PRs, which are integral to the service we provide
- To provide analysis or valuable information so that we can improve the App
- To monitor the usage of the App
- To detect, prevent and address technical issues
## Transfer of Data
Your information, including the Data, may be transferred to — and maintained on — computers located outside of your state, province, country, or other governmental jurisdiction where the data protection laws may differ from those of your jurisdiction.
## Disclosure of Data
We may disclose your Data in the good faith belief that such action is necessary to:
- To comply with a legal obligation
- To protect and defend the rights or property of Sweep AI
- To prevent or investigate possible wrongdoing in connection with the App
- To protect the personal safety of users of the App or the public
- To protect against legal liability
- No data being used will be sold to third parties.
## Security of Data
The security of your data is important to us but remember that no method of transmission over the Internet or method of electronic storage is 100% secure. While we strive to use commercially acceptable means to protect your Data, we cannot guarantee its absolute security.
## Changes to This Privacy Policy
We may update our Privacy Policy from time to time. We will notify you of any changes by posting the new Privacy Policy on this page.
You are advised to review this Privacy Policy periodically for any changes. Changes to this Privacy Policy are effective when they are posted on this page.
## Contact Us
If you have any questions about this Privacy Policy, please contact us:
By email: [[email protected]](mailto:[email protected])
By visiting this page on our website: [https://sweep.dev](https://sweep.dev/)
This Privacy Policy does not apply to the practices of third parties that we do not own or control, including GitHub or any third party services you access through GitHub. You can reference the GitHub Privacy Policy to learn more about its privacy practices.

# Advanced Features: becoming a Power User 🧠
## Usage 📖
### Mention important files
To ensure that Sweep scans a file, mention the file name in your ticket. Sweep searches for relevant files at runtime, but specifying the file helps avoid missing important details.
### Giving Sweep feedback
If Sweep's plan isn't accurate, you can respond to Sweep in three places:
1. **Issue**: Sweep will create a new pull request and close the old one. Alternatively, you can edit the issue description to recreate the pull request.
2. **Pull request**: Sweep will update the PR based on your PR comments
3. **Code**: Sweep will only update the file that the comment is on
Whenever you make a message that Sweep is taking a look at, you will see an 👀 emoji. If you don't see this, make sure the PR/issue is open and you prefixed the message with "sweep:".
Further, on failed Github Action runs, Sweep will update the PR based on the error message.
### Switch branch
To get Sweep to use a different base branch for one issue, add the following to the issue description.
> branch: BRANCH_NAME
## Configuration 🛠️
### Use GitHub Actions
We highly recommend linters, as well as Netlify/Vercel preview builds. Sweep auto-corrects based on linter and build errors, and Netlify and Vercel helps with iteration cycles by providing previews of static sites using Netlify.
### Set up `sweep.yaml`
You can set up `sweep.yaml` to
* Provide up to date docs by setting up `docs` (https://docs.sweep.dev/usage/config#docs)
* Set up automated formatting and linting by setting up `sandbox` (https://docs.sweep.dev/usage/config#sandbox). Never have Sweep commit a failing `npm lint` again.
* Give Sweep a high level description of where to find files in your repo by editing the `repo_description` field.
For more on configs, check out https://docs.sweep.dev/usage/config.
## Prompting 🗣️
The amount of prompting you need to give Sweep directly scales with the complexity of the problem.
For harder problems, try to provide the same information a human would need, and for simpler problems, providing a single line and a file name should suffice.
### Prompting formats
A good issue should include **where to look** (file name or entity name), **what to do** ("change the logic to do this"), and **additional context** (there's a bug/we need this feature/there's this dependency). Examples:

sweep/sweepai/cli.py

Lines 1 to 363 in 0277fad

import datetime
import json
import os
import pickle
import threading
import time
import uuid
from itertools import chain, islice
import typer
from github import Github
from github.Event import Event
from github.IssueEvent import IssueEvent
from github.Repository import Repository
from loguru import logger
from rich.console import Console
from rich.prompt import Prompt
from sweepai.api import handle_request
from sweepai.handlers.on_ticket import on_ticket
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import get_github_client
from sweepai.utils.str_utils import get_hash
from sweepai.web.events import Account, Installation, IssueRequest
app = typer.Typer(
name="sweepai", context_settings={"help_option_names": ["-h", "--help"]}
)
app_dir = typer.get_app_dir("sweepai")
config_path = os.path.join(app_dir, "config.json")
console = Console()
cprint = console.print
def posthog_capture(event_name, properties, *args, **kwargs):
POSTHOG_DISTINCT_ID = os.environ.get("POSTHOG_DISTINCT_ID")
if POSTHOG_DISTINCT_ID:
posthog.capture(POSTHOG_DISTINCT_ID, event_name, properties, *args, **kwargs)
def load_config():
if os.path.exists(config_path):
cprint(f"\nLoading configuration from {config_path}", style="yellow")
with open(config_path, "r") as f:
config = json.load(f)
os.environ["GITHUB_PAT"] = config.get("GITHUB_PAT", "")
os.environ["OPENAI_API_KEY"] = config.get("OPENAI_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = config.get("ANTHROPIC_API_KEY", "")
os.environ["VOYAGE_API_KEY"] = config.get("VOYAGE_API_KEY", "")
os.environ["POSTHOG_DISTINCT_ID"] = str(config.get("POSTHOG_DISTINCT_ID", ""))
def fetch_issue_request(issue_url: str, __version__: str = "0"):
(
protocol_name,
_,
_base_url,
org_name,
repo_name,
_issues,
issue_number,
) = issue_url.split("/")
cprint("Fetching installation ID...")
installation_id = -1
cprint("Fetching access token...")
_token, g = get_github_client(installation_id)
g: Github = g
cprint("Fetching repo...")
issue = g.get_repo(f"{org_name}/{repo_name}").get_issue(int(issue_number))
issue_request = IssueRequest(
action="labeled",
issue=IssueRequest.Issue(
title=issue.title,
number=int(issue_number),
html_url=issue_url,
user=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
body=issue.body,
labels=[
IssueRequest.Issue.Label(
name="sweep",
),
],
assignees=None,
pull_request=None,
),
repository=IssueRequest.Issue.Repository(
full_name=issue.repository.full_name,
description=issue.repository.description,
),
assignee=IssueRequest.Issue.Assignee(login=issue.user.login),
installation=Installation(
id=installation_id,
account=Account(
id=issue.user.id,
login=issue.user.login,
type="User",
),
),
sender=IssueRequest.Issue.User(
login=issue.user.login,
type="User",
),
)
return issue_request
def pascal_to_snake(name):
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
def get_event_type(event: Event | IssueEvent):
if isinstance(event, IssueEvent):
return "issues"
else:
return pascal_to_snake(event.type)[: -len("_event")]
@app.command()
def test():
cprint("Sweep AI is installed correctly and ready to go!", style="yellow")
@app.command()
def watch(
repo_name: str,
debug: bool = False,
record_events: bool = False,
max_events: int = 30,
):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
posthog_capture(
"sweep_watch_started",
{
"repo": repo_name,
"debug": debug,
"record_events": record_events,
"max_events": max_events,
},
)
GITHUB_PAT = os.environ.get("GITHUB_PAT", None)
if GITHUB_PAT is None:
raise ValueError("GITHUB_PAT environment variable must be set")
g = Github(os.environ["GITHUB_PAT"])
repo = g.get_repo(repo_name)
if debug:
logger.debug("Debug mode enabled")
def stream_events(repo: Repository, timeout: int = 2, offset: int = 2 * 60):
processed_event_ids = set()
current_time = time.time() - offset
current_time = datetime.datetime.fromtimestamp(current_time)
local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
while True:
events_iterator = chain(
islice(repo.get_events(), max_events),
islice(repo.get_issues_events(), max_events),
)
for i, event in enumerate(events_iterator):
if event.id not in processed_event_ids:
local_time = event.created_at.replace(
tzinfo=datetime.timezone.utc
).astimezone(local_tz)
if local_time.timestamp() > current_time.timestamp():
yield event
else:
if debug:
logger.debug(
f"Skipping event {event.id} because it is in the past (local_time={local_time}, current_time={current_time}, i={i})"
)
if debug:
logger.debug(
f"Skipping event {event.id} because it is already handled"
)
processed_event_ids.add(event.id)
time.sleep(timeout)
def handle_event(event: Event | IssueEvent, do_async: bool = True):
if isinstance(event, IssueEvent):
payload = event.raw_data
payload["action"] = payload["event"]
else:
payload = {**event.raw_data, **event.payload}
payload["sender"] = payload.get("sender", payload["actor"])
payload["sender"]["type"] = "User"
payload["pusher"] = payload.get("pusher", payload["actor"])
payload["pusher"]["name"] = payload["pusher"]["login"]
payload["pusher"]["type"] = "User"
payload["after"] = payload.get("after", payload.get("head"))
payload["repository"] = repo.raw_data
payload["installation"] = {"id": -1}
logger.info(str(event) + " " + str(event.created_at))
if record_events:
_type = get_event_type(event) if isinstance(event, Event) else "issue"
pickle.dump(
event,
open(
"tests/events/"
+ f"{_type}_{payload.get('action')}_{str(event.id)}.pkl",
"wb",
),
)
if do_async:
thread = threading.Thread(
target=handle_request, args=(payload, get_event_type(event))
)
thread.start()
return thread
else:
return handle_request(payload, get_event_type(event))
def main():
cprint(
f"\n[bold black on white] Starting server, listening to events from {repo_name}... [/bold black on white]\n",
)
cprint(
f"To create a PR, please create an issue at https://github.com/{repo_name}/issues with a title prefixed with 'Sweep:' or label an existing issue with 'sweep'. The events will be logged here, but there may be a brief delay.\n"
)
for event in stream_events(repo):
handle_event(event)
if __name__ == "__main__":
main()
@app.command()
def init(override: bool = False):
# TODO: Fix telemetry
if not override:
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
if "OPENAI_API_KEY" in config and "ANTHROPIC_API_KEY" in config and "GITHUB_PAT" in config:
override = typer.confirm(
f"\nConfiguration already exists at {config_path}. Override?",
default=False,
abort=True,
)
cprint(
"\n[bold black on white] Initializing Sweep CLI... [/bold black on white]\n",
)
cprint(
"\nFirstly, let's store your OpenAI API Key. You can get it here: https://platform.openai.com/api-keys\n",
style="yellow",
)
openai_api_key = Prompt.ask("OpenAI API Key", password=True)
assert len(openai_api_key) > 30, "OpenAI API Key must be of length at least 30."
assert openai_api_key.startswith("sk-"), "OpenAI API Key must start with 'sk-'."
cprint(
"\nNext, let's store your Anthropic API key. You can get it here: https://console.anthropic.com/settings/keys.",
style="yellow",
)
anthropic_api_key = Prompt.ask("Anthropic API Key", password=True)
assert len(anthropic_api_key) > 30, "Anthropic API Key must be of length at least 30."
assert anthropic_api_key.startswith("sk-ant-api03-"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nGreat! Next, we'll need just your GitHub PAT. Here's a link with all the permissions pre-filled:\nhttps://github.com/settings/tokens/new?description=Sweep%20Self-hosted&scopes=repo,workflow\n",
style="yellow",
)
github_pat = Prompt.ask("GitHub PAT", password=True)
assert len(github_pat) > 30, "GitHub PAT must be of length at least 30."
assert github_pat.startswith("ghp_"), "GitHub PAT must start with 'ghp_'."
cprint(
"\nAwesome! Lastly, let's get your Voyage AI API key from https://dash.voyageai.com/api-keys. This is optional, but improves code search by about [cyan]5%[/cyan]. You can always return to this later by re-running 'sweep init'.",
style="yellow",
)
voyage_api_key = Prompt.ask("Voyage AI API key", password=True)
if voyage_api_key:
assert len(voyage_api_key) > 30, "Voyage AI API key must be of length at least 30."
assert voyage_api_key.startswith("pa-"), "Voyage API key must start with 'pa-'."
POSTHOG_DISTINCT_ID = None
enable_telemetry = typer.confirm(
"\nEnable usage statistics? This will help us improve the product.",
default=True,
)
if enable_telemetry:
cprint(
"\nThank you for enabling telemetry. We'll collect anonymous usage statistics to improve the product. You can disable this at any time by rerunning 'sweep init'.",
style="yellow",
)
POSTHOG_DISTINCT_ID = uuid.getnode()
posthog.capture(POSTHOG_DISTINCT_ID, "sweep_init", {})
config = {
"GITHUB_PAT": github_pat,
"OPENAI_API_KEY": openai_api_key,
"ANTHROPIC_API_KEY": anthropic_api_key,
"VOYAGE_API_KEY": voyage_api_key,
}
if POSTHOG_DISTINCT_ID:
config["POSTHOG_DISTINCT_ID"] = POSTHOG_DISTINCT_ID
os.makedirs(app_dir, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config, f)
cprint(f"\nConfiguration saved to {config_path}\n", style="yellow")
cprint(
"Installation complete! You can now run [green]'sweep run <issue-url>'[/green][yellow] to run Sweep on an issue. or [/yellow][green]'sweep watch <org-name>/<repo-name>'[/green] to have Sweep listen for and fix newly created GitHub issues.",
style="yellow",
)
@app.command()
def run(issue_url: str):
if not os.path.exists(config_path):
cprint(
f"\nConfiguration not found at {config_path}. Please run [green]'sweep init'[/green] to initialize the CLI.\n",
style="yellow",
)
raise ValueError(
"Configuration not found, please run 'sweep init' to initialize the CLI."
)
cprint(f"\n Running Sweep on issue: {issue_url} \n", style="bold black on white")
posthog_capture("sweep_run_started", {"issue_url": issue_url})
request = fetch_issue_request(issue_url)
try:
cprint(f'\nRunning Sweep to solve "{request.issue.title}"!\n')
on_ticket(
title=request.issue.title,
summary=request.issue.body,
issue_number=request.issue.number,
issue_url=request.issue.html_url,
username=request.sender.login,
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
installation_id=request.installation.id,
comment_id=None,
edited=False,
tracking_id=get_hash(),
)
except Exception as e:
posthog_capture("sweep_run_fail", {"issue_url": issue_url, "error": str(e)})
else:
posthog_capture("sweep_run_success", {"issue_url": issue_url})
def main():
cprint(
"By using the Sweep CLI, you agree to the Sweep AI Terms of Service at https://sweep.dev/tos.pdf",
style="cyan",
)
load_config()
app()
if __name__ == "__main__":

# Frequently Asked Questions
<details id="does-sweep-write-tests">
<summary>Does Sweep write tests?</summary>
Yep! The easiest way to have Sweep write tests is by modifying the `description` parameter in your `sweep.yaml`. You can add something like:
“In [your repository], the tests are written in [your format]. If you modify business logic, modify the tests as well using this format.” You can add anything you’d like to the description parameter, including formatting rules (like PEP8), code style, etc!
</details>
<details id="can-we-trust-code-written-by-sweep">
<summary>Can we trust the code written by Sweep?</summary>
You should always review the PR. However, we also perform testing to make sure the PR works using your existing GitHub actions.
To get the best performance, add GitHub actions that lint, test, and validate your code.
</details>
<details id="work-off-another-branch">
<summary>Can I have Sweep work off of another branch besides main?</summary>
Yes! In the `sweep.yaml`, you can set the `branch` parameter to something besides your default branch, and Sweep will use that as a reference.
</details>
<details id="retry-issue-with-sweep">
<summary>How do I retry an issue with Sweep?</summary>
To retry an issue, prefix your issue reply with 'Sweep: '. This will trigger Sweep to retry the issue.
</details>
<details id="give-documentation-to-sweep">
<summary>Can I give documentation to Sweep?</summary>
Yes! In the `sweep.yaml`, you can specify docs. Be sure to pick the prefix of the site, which will allow us to only fetch the docs you need.
Check out the example here: https://github.com/sweepai/sweep/blob/main/sweep.yaml.
</details>
<details id="comment-on-sweeps-prs">
<summary>Can I comment on Sweep’s PRs?</summary>
Yep! You have three options depending on the degree of the change:
1. You can comment on the issue, and Sweep will rewrite the entire pull request. This will use one of your GPT4 credits.
2. You can comment on the pull request (not a file) and Sweep can make substantial changes to the pull request. Sweep will search the codebase, and is able to modify and create files.
3. You can comment on the file directly, and Sweep will only modify that file. Use this for small single file changes.
</details>

Once Sweep has the reference implementation, Sweep generates the corresponding test as commits in a [GitHub PR](https://github.com/sweepai/sweep/pull/2378):
```python
def get_file_contents(self, file_path, ref=None):
local_path = os.path.join(self.cache_dir, file_path)
if os.path.exists(local_path):
with open(local_path, "r", encoding="utf-8", errors="replace") as f:
contents = f.read()
return contents
else:
raise FileNotFoundError(f"{local_path} does not exist.")
```
We have Sweep generated mocks for `os.path.join` and `open`. <br></br>
This code looks great!
```python
@patch("os.path.join")
@patch("open")
def test_get_file_contents(self, mock_open, mock_join):
mock_join.return_value = "/tmp/cache/repos/sweepai/sweep/main/file1"
mock_open.return_value.__enter__.return_value.read.return_value = "file content"
content = self.cloned_repo.get_file_contents("file1")
self.assertEqual(content, "file content")
```
We generated mocks for `os.path.join` and `open`, which should return the correct path and file contents. <br></br>
Ok we're done here right? Can we just write these tests and leave the rest to the developer?
## 3. **Run the tests.**
Most other AI tools stop here, but it’s not enough. <br></br>
If you just committed these tests it would be great, but you’d end up with a frustrating bug. Here it is:
```bash
File "/usr/lib/python3.10/unittest/mock.py", line 1616, in _get_target
raise TypeError(
TypeError: Need a valid target to patch. You supplied: 'open'
```
Did we really save time for the developer here? It’s frustrating that most other tools don’t fix these issues.
*Unlike every other tool, Sweep actually runs these tests.*
Sweep ran the code, found the issue, and identified the solution: <br></br>
**”Change the target of the patch in the 'test_get_file_contents' method from 'open' to 'builtins.open'. This will correctly patch the built-in 'open' function during the test.”**
Sweep added [this commit](https://github.com/sweepai/sweep/pull/2378/commits/0ded79eab77ca3e511257ff0bf3874893b038e9e):
```python


Step 2: ⌨️ Coding

Modify sweepai/handlers/on_merge.py with contents: In the `on_merge` function:
• Locate the code block that performs the merge operation when the target branch is updated.
• Replace the merge logic with a rebase operation: - Use `git.repo.git.rebase()` to rebase the Sweep-created branch onto the updated target branch. - Handle any rebase conflicts that may occur. - Force push the rebased branch to update the remote branch.
• Update any error handling and logging statements to reflect the rebase operation.
• Import the necessary `git` module at the top of the file.

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/allow_for_rebase_1a484.


🎉 Latest improvements to Sweep:
  • New dashboard launched for real-time tracking of Sweep issues, covering all stages from search to coding.
  • Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.
  • Use the GitHub issues extension for creating Sweep issues directly from your editor.

💡 To recreate the pull request edit the issue title or description.
Something wrong? Let us know.

This is an automated message generated by Sweep AI.

@sweep-nightly sweep-nightly bot linked a pull request Apr 8, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment