Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massive PR on many parts #54

Merged
merged 31 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7630e04
Merge remote-tracking branch 'origin/bug-in-current-viewport-gitlab' …
shuyanzhou Sep 18, 2023
1fef526
add huggingface model support
shuyanzhou Sep 22, 2023
507659a
better rendering of typing action
shuyanzhou Sep 22, 2023
2b15f20
multi threading auto login; auto login per example
shuyanzhou Sep 22, 2023
e84910d
better error message for env config
shuyanzhou Sep 22, 2023
493294b
fix statictext bounding box bug
shuyanzhou Sep 22, 2023
57d2067
add support to evaluate by trace
shuyanzhou Sep 22, 2023
16f2592
support generation retry when the parsing of the action failed
shuyanzhou Sep 22, 2023
9f3e4ac
ignore cache
shuyanzhou Sep 22, 2023
0e7bcda
add prompts
shuyanzhou Sep 23, 2023
741292e
fix force_prefix missing bug
shuyanzhou Sep 23, 2023
1ee1ea4
fix typo
shuyanzhou Sep 23, 2023
c1ae73c
add script to check inference failures
shuyanzhou Sep 23, 2023
6fdbd92
add parallel running script
shuyanzhou Sep 23, 2023
cd7d593
fix annotation errors based on human trajectories
shuyanzhou Sep 26, 2023
b6e0b22
change reddit vote related posts to absolute urls
shuyanzhou Sep 26, 2023
f8d636a
update URL matching, fix typos
shuyanzhou Sep 26, 2023
b4c917d
update must_include tokenization condition; upate url match
shuyanzhou Sep 26, 2023
a7c475b
remove unused evaluators
shuyanzhou Sep 26, 2023
50e2c43
remove exact from evalutor names
shuyanzhou Sep 26, 2023
db063c7
update test example due to html escape
shuyanzhou Sep 26, 2023
6ab7fd2
update fuzzy match prompt
shuyanzhou Sep 27, 2023
58061ee
reduce coordinate precision; fix template 67 annotations
shuyanzhou Sep 27, 2023
4b86d43
fix locator for product; add prep action; fix url for promo rules
shuyanzhou Oct 20, 2023
df87757
add options to renew cookie for selected sites
shuyanzhou Oct 20, 2023
3d3d837
print unfinished examples
shuyanzhou Oct 20, 2023
7730a85
reduce openai max retry
shuyanzhou Oct 20, 2023
f91eb5b
minor
shuyanzhou Oct 20, 2023
4cec5ac
minor
shuyanzhou Oct 21, 2023
7a1f8d6
Merge remote-tracking branch 'origin/main' into new_eval
shuyanzhou Oct 21, 2023
9f0900f
fix type errors
shuyanzhou Oct 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
mypy --version
# Run this mypy instance against our main package.
mypy --install-types --non-interactive .
mypy --strict .
mypy --strict . --exclude scripts
- name: Enviroment prepare
run: |
bash prepare.sh
Expand Down
23 changes: 12 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,19 @@ run.sh

# trajectory visualization
render_cache/*
cache/*

# TMP IGNORE
agent/prompts/jsons/*
log_files/
config_files/*0.json
config_files/*1.json
config_files/*2.json
config_files/*3.json
config_files/*4.json
config_files/*5.json
config_files/*6.json
config_files/*7.json
config_files/*8.json
config_files/*9.json
config_files/test.json
config_files*/*0.json
config_files*/*1.json
config_files*/*2.json
config_files*/*3.json
config_files*/*4.json
config_files*/*5.json
config_files*/*6.json
config_files*/*7.json
config_files*/*8.json
config_files*/*9.json
config_files*/test.json
95 changes: 33 additions & 62 deletions agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
create_playwright_action,
)
from browser_env.utils import Observation, StateInfo
from llms import lm_config
from llms.providers.openai_utils import (
from llms import (
call_llm,
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
lm_config,
)
from llms.tokenizers import Tokenizer


class Agent:
Expand Down Expand Up @@ -120,82 +123,50 @@ def next_action(
trajectory, intent, meta_data
)
lm_config = self.lm_config
if lm_config.provider == "openai":
if lm_config.mode == "chat":
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
n = 0
while True:
response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
response = f"{force_prefix}{response}"
n += 1
try:
parsed_response = self.prompt_constructor.extract_action(
response
)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)

try:
parsed_response = self.prompt_constructor.extract_action(response)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
else:
raise ValueError(f"Unknown action type {self.action_set_tag}")

action["raw_prediction"] = response

except ActionParsingError as e:
action = create_none_action()
action["raw_prediction"] = response
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
action["raw_prediction"] = response
break
except ActionParsingError as e:
if n >= lm_config.gen_config["max_retry"]:
action = create_none_action()
action["raw_prediction"] = response
break

return action

def reset(self, test_config_file: str) -> None:
pass


def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config


def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
llm_config = lm_config.construct_llm_config(args)

agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
tokenizer = Tokenizer(args.provider, args.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
Expand Down
58 changes: 45 additions & 13 deletions agent/prompts/prompt_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from pathlib import Path
from typing import Any, TypedDict

import tiktoken

from browser_env import Action, ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo
from llms import lm_config

APIInput = str | list[Any] | dict[str, Any]
from llms.tokenizers import Tokenizer
from llms.utils import APIInput


class Instruction(TypedDict):
Expand All @@ -27,12 +25,12 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
self.instrction_path = Path(instruction_path)
self.instruction_path = Path(instruction_path)
self.obs_modality = "text"
self.lm_config = lm_config
instruction = json.load(open(self.instrction_path))
instruction = json.load(open(self.instruction_path))
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
self.instruction: Instruction = instruction
self.tokenizer = tokenizer
Expand Down Expand Up @@ -77,6 +75,37 @@ def get_lm_api_input(
raise ValueError(
f"OpenAI models do not support mode {self.lm_config.mode}"
)
elif "huggingface" in self.lm_config.provider:
# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
# https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320
if "Llama-2" in self.lm_config.model:
if self.lm_config.mode == "chat":
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
BOS, EOS = "<s>", "</s>"
# adding the system message to be the starting of the first example
examples = [
(
B_SYS + intro + E_SYS + examples[0][0],
examples[0][1],
)
] + examples[1:]
message = "".join(
[
f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}"
for (x, y) in examples
]
)
# add the current observation
message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}"

return message
else:
raise ValueError("Only chat mode is supported for Llama-2")
else:
raise ValueError(
f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}"
)
else:
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
Expand All @@ -102,6 +131,9 @@ def map_url_to_local(self, url: str) -> str:
for i, j in URL_MAPPINGS.items():
if j in url:
url = url.replace(j, i)
# https
if j.replace("http", "https") in url:
url = url.replace(j.replace("http", "https"), i)
return url

def _extract_action(self, response: str) -> str:
Expand All @@ -120,7 +152,7 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)

Expand Down Expand Up @@ -161,10 +193,10 @@ def construct(

def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
return match.group(1).strip()
else:
raise ActionParsingError(
f"Cannot parse action from response {response}"
Expand All @@ -178,7 +210,7 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
Expand Down Expand Up @@ -218,10 +250,10 @@ def construct(
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
return match.group(1).strip()
else:
raise ActionParsingError(
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
Expand Down
82 changes: 82 additions & 0 deletions agent/prompts/raw/p_cot_id_actree_2s_no_na.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
prompt = {
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.

Here's the information you'll have:
The user's objective: This is the task you're trying to complete.
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
The current web page's URL: This is the page you're currently navigating.
The open tabs: These are the tabs you have open.
The previous action: This is the action you just performed. It may be helpful to track your progress.

The actions you can perform fall into several categories:

Page Operation Actions:
`click [id]`: This action clicks on an element with a specific id on the webpage.
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
`hover [id]`: Hover over an element with id.
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
`scroll [direction=down|up]`: Scroll the page up or down.

Tab Management Actions:
`new_tab`: Open a new, empty browser tab.
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
`close_tab`: Close the currently active tab.

URL Navigation Actions:
`goto [url]`: Navigate to a specific URL.
`go_back`: Navigate to the previously viewed page.
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).

Completion Action:
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket.

Homepage:
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.

To be successful, it is very important to follow the following rules:
1. You should only issue an action that is valid given the current observation
2. You should only issue one action at a time.
3. You should follow the examples to reason step by step and then issue the next action.
4. Generate the action in the correct format. Start with a "In summary, the next action I will perform is" phrase, followed by action inside ``````. For example, "In summary, the next action I will perform is ```click [1234]```".
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
"examples": [
(
"""OBSERVATION:
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
[1749] StaticText '$279.49'
[1757] button 'Add to Cart'
[1760] button 'Add to Wish List'
[1761] button 'Add to Compare'
URL: http://onestopmarket.com/office-products/office-electronics.html
OBJECTIVE: What is the price of HP Inkjet Fax Machine
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```",
),
(
"""OBSERVATION:
[164] textbox 'Search' focused: True required: False
[171] button 'Go'
[174] link 'Find directions between two points'
[212] heading 'Search Results'
[216] button 'Close'
URL: http://openstreetmap.org
OBJECTIVE: Show me the restaurants near CMU
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page has a search box whose ID is [164]. According to the nominatim rule of openstreetmap, I can search for the restaurants near a location by \"restaurants near\". I can submit my typing by pressing the Enter afterwards. In summary, the next action I will perform is ```type [164] [restaurants near CMU] [1]```",
),
],
"template": """OBSERVATION:
{observation}
URL: {url}
OBJECTIVE: {objective}
PREVIOUS ACTION: {previous_action}""",
"meta_data": {
"observation": "accessibility_tree",
"action_type": "id_accessibility_tree",
"keywords": ["url", "objective", "observation", "previous_action"],
"prompt_constructor": "CoTPromptConstructor",
"answer_phrase": "In summary, the next action I will perform is",
"action_splitter": "```"
},
}
Loading