-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from shankarpandala/feature/dev
Feature/dev
- Loading branch information
Showing
7 changed files
with
78 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,43 @@ | ||
import json | ||
from langchain.prompts import ChatPromptTemplate | ||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser | ||
import re | ||
|
||
from langchain.chains import ConversationalRetrievalChain | ||
from langchain.agents import Tool | ||
from langchain.tools import DuckDuckGoSearchRun | ||
from langchain.agents import initialize_agent | ||
from lazygitgpt.llms import chat_model | ||
from lazygitgpt.datasources.repos import read_repository_contents | ||
from lazygitgpt.git.operations import update_files | ||
from lazygitgpt.retrievers.retrievalqa import retriever | ||
from lazygitgpt.memory.memory import memory | ||
|
||
search = DuckDuckGoSearchRun() | ||
|
||
def generate_response(prompt): | ||
inputs = {'chat_history': '', 'question': prompt} | ||
qa = ConversationalRetrievalChain.from_llm(chat_model, retriever=retriever, memory=memory) | ||
result = qa(inputs) | ||
return result["answer"] | ||
|
||
# tools = [ | ||
# Tool( | ||
# name='DuckDuckGo Search', | ||
# func= search.run, | ||
# description="Useful for when you need to do a search on the internet to find information that another tool can't find. be specific with your input." | ||
# ), | ||
# Tool( | ||
# name='Conversational Retrieval', | ||
# func=generate_response, | ||
# description="This is Conversational Retrieval chain which has content of the entire repository." | ||
# ) | ||
# ] | ||
|
||
output_schema = ResponseSchema(name='filename', description='contents', type='string') | ||
output_parser = StructuredOutputParser(response_schemas=[output_schema]) | ||
format_instructions = output_parser.get_format_instructions() | ||
template_string = """You are an expert programmer. | ||
You are reviewing a code repository. | ||
Read the code and make changes to the code as per the user requirements. | ||
user requirements: {user_requirements} | ||
code repository: {code_repository} | ||
Output the contents of the file that you changed as per the format instructions : {format_instructions} | ||
""" | ||
# zero_shot_agent = initialize_agent( | ||
# agent="zero-shot-react-description", | ||
# tools=tools, | ||
# llm=chat_model, | ||
# verbose=True, | ||
# max_iterations=30, | ||
# retriever=retriever | ||
# ) | ||
|
||
def generate_response(prompt, sources=read_repository_contents()): | ||
sources_str = json.dumps(sources, indent=4) | ||
prompt_template = ChatPromptTemplate.from_template(template_string) | ||
messages = prompt_template.format_messages(user_requirements = prompt, | ||
code_repository = sources_str, | ||
format_instructions=format_instructions) | ||
response = chat_model(messages) | ||
response_json = response.to_json() | ||
data = response_json['kwargs']['content'] | ||
return data | ||
# def run(prompt): | ||
# reponse = zero_shot_agent.run(prompt) | ||
# return reponse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,30 @@ | ||
import os | ||
import glob | ||
import json | ||
from git import Repo | ||
from langchain.document_loaders import GitLoader | ||
from langchain.document_loaders.generic import GenericLoader | ||
from langchain.document_loaders.parsers import LanguageParser | ||
from langchain.text_splitter import Language | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
|
||
def read_repository_contents(directory_path=os.getcwd(), file_pattern="*"): | ||
""" | ||
Reads all files in the specified directory matching the file pattern, | ||
and creates a JSON object with file names and their contents. | ||
def read_repository_contents(): | ||
repo_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
Repo.git_dir=repo_path | ||
repo = Repo(Repo.git_dir) | ||
branch = repo.head.reference | ||
|
||
Args: | ||
directory_path (str): Path to the directory containing the files. | ||
file_pattern (str): Pattern to match files. Defaults to '*' (all files). | ||
Returns: | ||
str: A JSON string containing the file names and their contents. | ||
""" | ||
data = {} | ||
for file_path in glob.glob(f"{directory_path}/{file_pattern}"): | ||
if os.path.isfile(file_path): | ||
try: | ||
with open(file_path, 'r', encoding='utf-8') as file: | ||
data[file_path] = file.read() | ||
except Exception as e: | ||
print(f"Error reading file: {file_path} - {e}") | ||
|
||
return json.dumps(data, indent=4) | ||
loader = GitLoader(repo_path, branch=branch) | ||
docs = loader.load() | ||
loader = GenericLoader.from_filesystem( | ||
repo_path, | ||
glob="**/*", | ||
suffixes=[".py"], | ||
parser=LanguageParser(language=Language.PYTHON, parser_threshold=500), | ||
) | ||
documents = loader.load() | ||
python_splitter = RecursiveCharacterTextSplitter.from_language( | ||
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200 | ||
) | ||
texts = python_splitter.split_documents(documents) | ||
return texts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .memory import memory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from langchain.memory import ConversationBufferMemory | ||
|
||
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.vectorstores import Chroma | ||
from lazygitgpt.datasources.repos import read_repository_contents | ||
|
||
db = Chroma.from_documents(read_repository_contents(), OpenAIEmbeddings(disallowed_special=())) | ||
retriever = db.as_retriever( | ||
search_type="mmr", # Also test "similarity" | ||
search_kwargs={"k": 1000}, | ||
) |
Empty file.