-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
176 lines (131 loc) · 6.41 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import numpy as np
import math
import openai
import torch
import os
import sys
import argparse
import traceback
import multiprocessing
import logging
import functools
import models
import config
from lang_sam import LangSAM
from multiprocessing import Process, Pipe
from io import StringIO
from contextlib import redirect_stdout
from api import API
from env import run_simulation_environment
from prompts.main_prompt import MAIN_PROMPT
from prompts.error_correction_prompt import ERROR_CORRECTION_PROMPT
from prompts.print_output_prompt import PRINT_OUTPUT_PROMPT
from prompts.task_failure_prompt import TASK_FAILURE_PROMPT
from prompts.task_summary_prompt import TASK_SUMMARY_PROMPT
from config import OK, PROGRESS, FAIL, ENDC
sys.path.append("./XMem/")
print = functools.partial(print, flush=True)
from XMem.model.network import XMem
if __name__ == "__main__":
openai.api_key = os.getenv("OPENAI_API_KEY")
# Parse args
parser = argparse.ArgumentParser(description="Main Program.")
parser.add_argument("-lm", "--language_model", choices=["gpt-4", "gpt-4-32k", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"], default="gpt-4", help="select language model")
parser.add_argument("-r", "--robot", choices=["sawyer", "franka"], default="sawyer", help="select robot")
parser.add_argument("-m", "--mode", choices=["default", "debug"], default="default", help="select mode to run")
args = parser.parse_args()
# Logging
logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.INFO)
# Device
if torch.cuda.is_available():
logger.info("Using GPU.")
device = torch.device("cuda")
else:
logger.info("CUDA not available. Please connect to a GPU instance if possible.")
device = torch.device("cpu")
torch.set_grad_enabled(False)
# Load models
langsam_model = LangSAM()
xmem_model = XMem(config.xmem_config, "./XMem/saves/XMem.pth", device).eval().to(device)
# API set-up
main_connection, env_connection = Pipe()
api = API(args, main_connection, logger, langsam_model, xmem_model, device)
detect_object = api.detect_object
execute_trajectory = api.execute_trajectory
open_gripper = api.open_gripper
close_gripper = api.close_gripper
task_completed = api.task_completed
# Start process
env_process = Process(target=run_simulation_environment, name="EnvProcess", args=[args, env_connection, logger])
env_process.start()
[env_connection_message] = main_connection.recv()
logger.info(env_connection_message)
# User input
command = input("Enter a command: ")
api.command = command
# ChatGPT
logger.info(PROGRESS + "STARTING TASK..." + ENDC)
messages = []
error = False
new_prompt = MAIN_PROMPT.replace("[INSERT EE POSITION]", str(config.ee_start_position)).replace("[INSERT TASK]", command)
logger.info(PROGRESS + "Generating ChatGPT output..." + ENDC)
messages = models.get_chatgpt_output(args.language_model, new_prompt, messages, "system")
logger.info(OK + "Finished generating ChatGPT output!" + ENDC)
while True:
while not api.completed_task:
new_prompt = ""
if len(messages[-1]["content"].split("```python")) > 1:
code_block = messages[-1]["content"].split("```python")
block_number = 0
for block in code_block:
if len(block.split("```")) > 1:
code = block.split("```")[0]
block_number += 1
try:
f = StringIO()
with redirect_stdout(f):
exec(code)
except Exception:
error_message = traceback.format_exc()
new_prompt += ERROR_CORRECTION_PROMPT.replace("[INSERT BLOCK NUMBER]", str(block_number)).replace("[INSERT ERROR MESSAGE]", error_message)
new_prompt += "\n"
error = True
else:
s = f.getvalue()
error = False
if s != "" and len(s) < 2000:
new_prompt += PRINT_OUTPUT_PROMPT.replace("[INSERT PRINT STATEMENT OUTPUT]", s)
new_prompt += "\n"
error = True
if error:
api.completed_task = False
api.failed_task = False
if not api.completed_task:
if api.failed_task:
logger.info(FAIL + "FAILED TASK! Generating summary of the task execution attempt..." + ENDC)
new_prompt += TASK_SUMMARY_PROMPT
new_prompt += "\n"
logger.info(PROGRESS + "Generating ChatGPT output..." + ENDC)
messages = models.get_chatgpt_output(args.language_model, new_prompt, messages, "user")
logger.info(OK + "Finished generating ChatGPT output!" + ENDC)
logger.info(PROGRESS + "RETRYING TASK..." + ENDC)
new_prompt = MAIN_PROMPT.replace("[INSERT EE POSITION]", str(config.ee_start_position)).replace("[INSERT TASK]", command)
new_prompt += "\n"
new_prompt += TASK_FAILURE_PROMPT.replace("[INSERT TASK SUMMARY]", messages[-1]["content"])
messages = []
error = False
logger.info(PROGRESS + "Generating ChatGPT output..." + ENDC)
messages = models.get_chatgpt_output(args.language_model, new_prompt, messages, "system")
logger.info(OK + "Finished generating ChatGPT output!" + ENDC)
api.failed_task = False
else:
logger.info(PROGRESS + "Generating ChatGPT output..." + ENDC)
messages = models.get_chatgpt_output(args.language_model, new_prompt, messages, "user")
logger.info(OK + "Finished generating ChatGPT output!" + ENDC)
logger.info(OK + "FINISHED TASK!" + ENDC)
new_prompt = input("Enter a command: ")
logger.info(PROGRESS + "Generating ChatGPT output..." + ENDC)
messages = models.get_chatgpt_output(args.language_model, new_prompt, messages, "user")
logger.info(OK + "Finished generating ChatGPT output!" + ENDC)
api.completed_task = False