From 1b3d1424ab7a6c9735a1b4851fc009a5cd6e3b03 Mon Sep 17 00:00:00 2001 From: rongzhang Date: Fri, 18 Oct 2024 01:44:00 +0000 Subject: [PATCH] feat: github integration backend --- querybook/server/clients/github_client.py | 102 ++++++++++++++++ querybook/server/datasources/github.py | 92 +++++++++++++- .../lib/github_integration/serializers.py | 41 +++++++ querybook/server/logic/datadoc.py | 40 +++++++ querybook/server/logic/github.py | 44 +++++++ querybook/server/models/__init__.py | 1 + querybook/server/models/github.py | 41 +++++++ .../test_github_client.py | 113 ++++++++++++++++++ .../test_serializers.py | 68 +++++++++++ requirements/extra.txt | 3 + requirements/github_integration/github.txt | 1 + 11 files changed, 544 insertions(+), 2 deletions(-) create mode 100644 querybook/server/clients/github_client.py create mode 100644 querybook/server/lib/github_integration/serializers.py create mode 100644 querybook/server/logic/github.py create mode 100644 querybook/server/models/github.py create mode 100644 querybook/tests/test_lib/test_github_integration/test_github_client.py create mode 100644 querybook/tests/test_lib/test_github_integration/test_serializers.py create mode 100644 requirements/github_integration/github.txt diff --git a/querybook/server/clients/github_client.py b/querybook/server/clients/github_client.py new file mode 100644 index 000000000..558ed89f8 --- /dev/null +++ b/querybook/server/clients/github_client.py @@ -0,0 +1,102 @@ +from flask import session as flask_session +from github import Github, GithubException, Auth +from typing import List, Dict + +from lib.github_integration.serializers import ( + deserialize_datadoc, + serialize_datadoc, +) +from lib.logger import get_logger +from models.datadoc import DataDoc +from models.github import GitHubLink + +LOG = get_logger(__file__) + + +class GitHubClient: + def __init__(self, github_link: GitHubLink): + """ + Initialize the GitHub client with an access token from the session. + Raises an exception if the token is not found. + """ + access_token = flask_session.get("github_access_token") + if not access_token: + raise Exception("GitHub OAuth token not found in session") + auth = Auth.Token(access_token) + self.client = Github(auth=auth) + self.user = self.client.get_user() + self.github_link = github_link + self.repo = self.client.get_repo(github_link.repo_url) + + def commit_datadoc(self, datadoc: DataDoc): + """ + Commit a DataDoc to the repository. + Args: + datadoc (DataDoc): The DataDoc object to commit. + Raises: + Exception: If committing the DataDoc fails. + """ + file_path = self.github_link.file_path + content = serialize_datadoc(datadoc) + commit_message = f"Update DataDoc {datadoc.id}: {datadoc.title}" + + try: + contents = self.repo.get_contents(file_path, ref=self.github_link.branch) + # Update file + self.repo.update_file( + path=contents.path, + message=commit_message, + content=content, + sha=contents.sha, + branch=self.github_link.branch, + ) + LOG.info(f"Updated file {file_path} in repository.") + except GithubException as e: + if e.status == 404: + # Create new file + self.repo.create_file( + path=file_path, + message=commit_message, + content=content, + branch=self.github_link.branch, + ) + LOG.info(f"Created file {file_path} in repository.") + else: + LOG.error(f"GitHubException: {e}") + raise Exception(f"Failed to commit DataDoc: {e}") + + def get_datadoc_versions(self, datadoc: DataDoc) -> List[Dict]: + """ + Get the versions of a DataDoc. + Args: + datadoc (DataDoc): The DataDoc object. + Returns: + List[Dict]: A list of commit dictionaries. + """ + file_path = self.github_link.file_path + try: + commits = self.repo.get_commits(path=file_path, ref=self.github_link.branch) + return [commit.raw_data for commit in commits] + except GithubException as e: + LOG.error(f"GitHubException: {e}") + return [] + + def get_datadoc_at_commit(self, datadoc_id: int, commit_sha: str) -> DataDoc: + """ + Get a DataDoc at a specific commit. + Args: + datadoc_id (int): The DataDoc ID. + commit_sha (str): The commit SHA. + Returns: + DataDoc: The DataDoc object at the specified commit. + Raises: + Exception: If getting the DataDoc at the commit fails. + """ + file_path = self.github_link.file_path + try: + file_contents = self.repo.get_contents(path=file_path, ref=commit_sha) + json_content = file_contents.decoded_content.decode("utf-8") + return deserialize_datadoc(json_content) + except GithubException as e: + LOG.error(f"GitHubException: {e}") + raise Exception(f"Failed to get DataDoc at commit {commit_sha}: {e}") diff --git a/querybook/server/datasources/github.py b/querybook/server/datasources/github.py index 82420604c..90e6dc854 100644 --- a/querybook/server/datasources/github.py +++ b/querybook/server/datasources/github.py @@ -1,6 +1,26 @@ -from app.datasource import register +from app.datasource import register, api_assert +from app.db import DBSession from lib.github_integration.github_integration import get_github_manager -from typing import Dict +from clients.github_client import GitHubClient +from functools import wraps +from typing import List, Dict +from logic import datadoc as datadoc_logic +from logic import github as logic +from const.datasources import RESOURCE_NOT_FOUND_STATUS_CODE +from logic.datadoc_permission import assert_can_read, assert_can_write +from app.auth.permission import verify_data_doc_permission +from flask_login import current_user + + +def with_github_client(f): + @wraps(f) + def decorated_function(*args, **kwargs): + datadoc_id = kwargs.get("datadoc_id") + github_link = logic.get_repo_link(datadoc_id) + github_client = GitHubClient(github_link) + return f(github_client, *args, **kwargs) + + return decorated_function @register("/github/auth/", methods=["GET"]) @@ -14,3 +34,71 @@ def is_github_authenticated() -> str: github_manager = get_github_manager() is_authenticated = github_manager.get_github_token() is not None return {"is_authenticated": is_authenticated} + + +@register("/github/datadocs//link/", methods=["POST"]) +def link_datadoc_to_github( + datadoc_id: int, + repo_url: str, + branch: str, + file_path: str, +) -> Dict: + return logic.create_repo_link( + datadoc_id=datadoc_id, + user_id=current_user.id, + repo_url=repo_url, + branch=branch, + file_path=file_path, + ) + + +@register("/github/datadocs//commit/", methods=["POST"]) +@with_github_client +def commit_datadoc( + github_client: GitHubClient, + datadoc_id: int, +) -> Dict: + with DBSession() as session: + datadoc = datadoc_logic.get_data_doc_by_id(datadoc_id, session=session) + api_assert( + datadoc is not None, + "DataDoc not found", + status_code=RESOURCE_NOT_FOUND_STATUS_CODE, + ) + assert_can_write(datadoc_id, session=session) + verify_data_doc_permission(datadoc_id, session=session) + github_client.commit_datadoc(datadoc) + return {"message": "DataDoc committed successfully"} + + +@register("/github/datadocs//versions/", methods=["GET"]) +@with_github_client +def get_datadoc_versions(github_client: GitHubClient, datadoc_id: int) -> List[Dict]: + datadoc = datadoc_logic.get_data_doc_by_id(datadoc_id) + api_assert( + datadoc is not None, + "DataDoc not found", + status_code=RESOURCE_NOT_FOUND_STATUS_CODE, + ) + assert_can_read(datadoc_id) + verify_data_doc_permission(datadoc_id) + versions = github_client.get_datadoc_versions(datadoc) + return versions + + +@register("/github/datadocs//restore/", methods=["POST"]) +@with_github_client +def restore_datadoc_version( + github_client: GitHubClient, datadoc_id: int, commit_sha: str +) -> Dict: + datadoc = datadoc_logic.get_data_doc_by_id(datadoc_id) + api_assert( + datadoc is not None, + "DataDoc not found", + status_code=RESOURCE_NOT_FOUND_STATUS_CODE, + ) + assert_can_write(datadoc_id) + verify_data_doc_permission(datadoc_id) + restored_datadoc = github_client.get_datadoc_at_commit(datadoc.id, commit_sha) + saved_datadoc = datadoc_logic.restore_data_doc(restored_datadoc) + return saved_datadoc.to_dict(with_cells=True) diff --git a/querybook/server/lib/github_integration/serializers.py b/querybook/server/lib/github_integration/serializers.py new file mode 100644 index 000000000..93ab0f9f9 --- /dev/null +++ b/querybook/server/lib/github_integration/serializers.py @@ -0,0 +1,41 @@ +import json +from models.datadoc import DataDoc, DataCell +from const.data_doc import DataCellType + + +def serialize_datadoc(datadoc: DataDoc) -> str: + datadoc_dict = datadoc.to_dict(with_cells=True) + return json.dumps(datadoc_dict, indent=4, default=str) + + +def deserialize_datadoc(json_content: str) -> DataDoc: + datadoc_dict = json.loads(json_content) + datadoc = DataDoc( + id=datadoc_dict.get("id"), + environment_id=datadoc_dict.get("environment_id"), + public=datadoc_dict.get("public", True), + archived=datadoc_dict.get("archived", False), + owner_uid=datadoc_dict.get("owner_uid"), + created_at=datadoc_dict.get("created_at"), + updated_at=datadoc_dict.get("updated_at"), + title=datadoc_dict.get("title", ""), + ) + + # Need to set the meta attribute directly + datadoc.meta = datadoc_dict.get("meta") + + # Deserialize cells + cells_data = datadoc_dict.get("cells", []) + cells = [] + for cell_dict in cells_data: + cell = DataCell( + id=cell_dict.get("id"), + cell_type=DataCellType[cell_dict.get("cell_type")], + context=cell_dict.get("context"), + meta=cell_dict.get("meta"), + created_at=cell_dict.get("created_at"), + updated_at=cell_dict.get("updated_at"), + ) + cells.append(cell) + datadoc.cells = cells + return datadoc diff --git a/querybook/server/logic/datadoc.py b/querybook/server/logic/datadoc.py index 37ff2ddbf..baca40740 100644 --- a/querybook/server/logic/datadoc.py +++ b/querybook/server/logic/datadoc.py @@ -248,6 +248,46 @@ def clone_data_doc(id, owner_uid, commit=True, session=None): return new_data_doc +@with_session +def restore_data_doc(restored_datadoc: DataDoc, commit=True, session=None) -> DataDoc: + # Update the DataDoc fields + updated_datadoc = update_data_doc( + id=restored_datadoc.id, + commit=False, + session=session, + **{ + "public": restored_datadoc.public, + "archived": restored_datadoc.archived, + "owner_uid": restored_datadoc.owner_uid, + "title": restored_datadoc.title, + "meta": restored_datadoc.meta, + }, + ) + + # Update each DataCell + for restored_cell in restored_datadoc.cells: + update_data_cell( + id=restored_cell.id, + commit=False, + session=session, + **{ + "context": restored_cell.context, + "meta": restored_cell.meta, + }, + ) + + if commit: + session.commit() + update_es_data_doc_by_id(updated_datadoc.id) + update_es_queries_by_datadoc_id(updated_datadoc.id) + else: + session.flush() + + session.refresh(updated_datadoc) + + return updated_datadoc + + """ ---------------------------------------------------------------------------------------------------------- DATA CELL diff --git a/querybook/server/logic/github.py b/querybook/server/logic/github.py new file mode 100644 index 000000000..4fcacaf3a --- /dev/null +++ b/querybook/server/logic/github.py @@ -0,0 +1,44 @@ +from app.db import with_session +from models.github import GitHubLink +from models.datadoc import DataDoc + + +@with_session +def create_repo_link( + datadoc_id: int, + user_id: int, + repo_url: str, + branch: str, + file_path: str, + commit=True, + session=None, +): + datadoc = DataDoc.get(id=datadoc_id, session=session) + assert datadoc is not None, f"DataDoc with id {datadoc_id} not found" + + github_link = GitHubLink.get(datadoc_id=datadoc_id, session=session) + assert ( + github_link is None + ), f"GitHub link for DataDoc with id {datadoc_id} already exists" + + github_link = GitHubLink.create( + { + "datadoc_id": datadoc_id, + "user_id": user_id, + "repo_url": repo_url, + "branch": branch, + "file_path": file_path, + }, + commit=commit, + session=session, + ) + return github_link + + +@with_session +def get_repo_link(datadoc_id: int, session=None): + github_link = GitHubLink.get(datadoc_id=datadoc_id, session=session) + assert ( + github_link is not None + ), f"GitHub link for DataDoc with id {datadoc_id} not found" + return github_link diff --git a/querybook/server/models/__init__.py b/querybook/server/models/__init__.py index cf3dce9f2..6550df625 100644 --- a/querybook/server/models/__init__.py +++ b/querybook/server/models/__init__.py @@ -15,3 +15,4 @@ from .data_element import * from .comment import * from .survey import * +from .github import * diff --git a/querybook/server/models/github.py b/querybook/server/models/github.py new file mode 100644 index 000000000..bc8157ed6 --- /dev/null +++ b/querybook/server/models/github.py @@ -0,0 +1,41 @@ +import sqlalchemy as sql +from sqlalchemy.sql import func +from lib.sqlalchemy import CRUDMixin +from sqlalchemy.orm import backref, relationship +from app import db + +Base = db.Base + + +class GitHubLink(Base, CRUDMixin): + __tablename__ = "github_link" + id = sql.Column(sql.Integer, primary_key=True, autoincrement=True) + datadoc_id = sql.Column( + sql.Integer, sql.ForeignKey("data_doc.id"), nullable=False, unique=True + ) + user_id = sql.Column(sql.Integer, sql.ForeignKey("user.id"), nullable=False) + repo_url = sql.Column(sql.String(255), nullable=False) + branch = sql.Column(sql.String(255), nullable=False) + file_path = sql.Column(sql.String(255), nullable=False) + created_at = sql.Column(sql.DateTime, server_default=func.now(), nullable=False) + updated_at = sql.Column( + sql.DateTime, server_default=func.now(), onupdate=func.now(), nullable=False + ) + + datadoc = relationship( + "DataDoc", + backref=backref("github_link", uselist=False, cascade="all, delete-orphan"), + ) + user = relationship("User", backref=backref("github_link", uselist=False)) + + def to_dict(self): + return { + "id": self.id, + "datadoc_id": self.datadoc_id, + "user_id": self.user_id, + "repo_url": self.repo_url, + "branch": self.branch, + "file_path": self.file_path, + "created_at": self.created_at, + "updated_at": self.updated_at, + } diff --git a/querybook/tests/test_lib/test_github_integration/test_github_client.py b/querybook/tests/test_lib/test_github_integration/test_github_client.py new file mode 100644 index 000000000..dda0c1a08 --- /dev/null +++ b/querybook/tests/test_lib/test_github_integration/test_github_client.py @@ -0,0 +1,113 @@ +import pytest +from unittest.mock import MagicMock +from clients.github_client import GitHubClient +from models.datadoc import DataDoc +from models.github import GitHubLink +from github import GithubException + + +@pytest.fixture +def mock_flask_session(monkeypatch): + session = {} + monkeypatch.setattr("clients.github_client.flask_session", session) + return session + + +@pytest.fixture +def mock_github(monkeypatch): + mock_github = MagicMock() + monkeypatch.setattr("clients.github_client.Github", mock_github) + return mock_github + + +@pytest.fixture +def mock_github_link(): + return GitHubLink( + datadoc_id=1, + user_id=1, + repo_url="user/repo", + branch="main", + file_path="path/to/datadoc.json", + ) + + +@pytest.fixture +def mock_repo(): + return MagicMock() + + +def test_initialization(mock_flask_session, mock_github, mock_github_link, mock_repo): + mock_flask_session["github_access_token"] = "fake_token" + mock_github_instance = mock_github.return_value + mock_github_instance.get_repo.return_value = mock_repo + + client = GitHubClient(mock_github_link) + assert client.client is not None + assert client.user is not None + assert client.repo is not None + + +def test_initialization_no_token(mock_flask_session, mock_github_link): + with pytest.raises(Exception) as excinfo: + GitHubClient(mock_github_link) + assert "GitHub OAuth token not found in session" in str(excinfo.value) + + +def test_commit_datadoc_update( + mock_flask_session, mock_github, mock_github_link, mock_repo +): + mock_flask_session["github_access_token"] = "fake_token" + mock_github_instance = mock_github.return_value + mock_github_instance.get_repo.return_value = mock_repo + mock_repo.get_contents.return_value = MagicMock(sha="fake_sha") + + client = GitHubClient(mock_github_link) + datadoc = DataDoc(id=1, title="Test Doc") + client.commit_datadoc(datadoc) + mock_repo.update_file.assert_called_once() + + +def test_commit_datadoc_create( + mock_flask_session, mock_github, mock_github_link, mock_repo +): + mock_flask_session["github_access_token"] = "fake_token" + mock_github_instance = mock_github.return_value + mock_github_instance.get_repo.return_value = mock_repo + mock_repo.get_contents.side_effect = GithubException(404, "Not Found", None) + + client = GitHubClient(mock_github_link) + datadoc = DataDoc(id=1, title="Test Doc") + client.commit_datadoc(datadoc) + mock_repo.create_file.assert_called_once() + + +def test_get_datadoc_versions( + mock_flask_session, mock_github, mock_github_link, mock_repo +): + mock_flask_session["github_access_token"] = "fake_token" + mock_github_instance = mock_github.return_value + mock_github_instance.get_repo.return_value = mock_repo + mock_commit = MagicMock() + mock_commit.raw_data = {"sha": "123"} + mock_repo.get_commits.return_value = [mock_commit] + + client = GitHubClient(mock_github_link) + datadoc = DataDoc(id=1, title="Test Doc") + versions = client.get_datadoc_versions(datadoc) + assert len(versions) == 1 + assert versions[0]["sha"] == "123" + + +def test_get_datadoc_at_commit( + mock_flask_session, mock_github, mock_github_link, mock_repo +): + mock_flask_session["github_access_token"] = "fake_token" + mock_github_instance = mock_github.return_value + mock_github_instance.get_repo.return_value = mock_repo + mock_contents = mock_repo.get_contents.return_value + mock_contents.decoded_content = b'{"id": 1, "title": "Test Doc", "meta": {}}' + + client = GitHubClient(mock_github_link) + datadoc = client.get_datadoc_at_commit(1, "commit_sha") + assert datadoc.id == 1 + assert datadoc.title == "Test Doc" diff --git a/querybook/tests/test_lib/test_github_integration/test_serializers.py b/querybook/tests/test_lib/test_github_integration/test_serializers.py new file mode 100644 index 000000000..ac415ec6f --- /dev/null +++ b/querybook/tests/test_lib/test_github_integration/test_serializers.py @@ -0,0 +1,68 @@ +import pytest +import json +from const.data_doc import DataCellType +from lib.github_integration.serializers import serialize_datadoc, deserialize_datadoc +from models.datadoc import DataCell, DataDoc + + +@pytest.fixture +def mock_datadoc(): + cells = [ + DataCell( + id=1, + cell_type=DataCellType.query, + context="SELECT * FROM table;", + created_at="2023-01-01T00:00:00Z", + updated_at="2023-01-01T00:00:00Z", + ), + DataCell( + id=2, + cell_type=DataCellType.text, + context="This is a text cell.", + created_at="2023-01-01T00:00:00Z", + updated_at="2023-01-01T00:00:00Z", + ), + ] + datadoc = DataDoc( + id=1, + environment_id=1, + public=True, + archived=False, + owner_uid="user1", + created_at="2023-01-01T00:00:00Z", + updated_at="2023-01-01T00:00:00Z", + title="Test DataDoc", + cells=cells, + ) + return datadoc + + +def test_serialize_datadoc(mock_datadoc): + serialized = serialize_datadoc(mock_datadoc) + datadoc_dict = mock_datadoc.to_dict(with_cells=True) + expected_serialized = json.dumps(datadoc_dict, indent=4, default=str) + assert serialized == expected_serialized + + +def test_deserialize_datadoc(mock_datadoc): + datadoc_dict = mock_datadoc.to_dict(with_cells=True) + json_content = json.dumps(datadoc_dict, indent=4, default=str) + deserialized = deserialize_datadoc(json_content) + + assert deserialized.id == mock_datadoc.id + assert deserialized.environment_id == mock_datadoc.environment_id + assert deserialized.public == mock_datadoc.public + assert deserialized.archived == mock_datadoc.archived + assert deserialized.owner_uid == mock_datadoc.owner_uid + assert deserialized.created_at == mock_datadoc.created_at + assert deserialized.updated_at == mock_datadoc.updated_at + assert deserialized.title == mock_datadoc.title + assert deserialized.meta == mock_datadoc.meta + assert len(deserialized.cells) == len(mock_datadoc.cells) + for d_cell, m_cell in zip(deserialized.cells, mock_datadoc.cells): + assert d_cell.id == m_cell.id + assert d_cell.cell_type == m_cell.cell_type + assert d_cell.context == m_cell.context + assert d_cell.meta == m_cell.meta + assert d_cell.created_at == m_cell.created_at + assert d_cell.updated_at == m_cell.updated_at diff --git a/requirements/extra.txt b/requirements/extra.txt index db5edd23e..04b77c8af 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -28,3 +28,6 @@ # AI Assistant -r ai/langchain.txt + +# Github +-r github_integration/github.txt diff --git a/requirements/github_integration/github.txt b/requirements/github_integration/github.txt new file mode 100644 index 000000000..9879283d5 --- /dev/null +++ b/requirements/github_integration/github.txt @@ -0,0 +1 @@ +pygithub==2.4.0