From 1fef526178f585bbf770d1a6db412ac331b9b44b Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:22:17 -0400 Subject: [PATCH 01/29] add huggingface model support --- agent/agent.py | 22 ++++++++++++- agent/prompts/prompt_constructor.py | 51 ++++++++++++++++++++++++----- llms/providers/hf_utils.py | 21 ++++++++++++ llms/providers/openai_utils.py | 5 +++ llms/tokenizers.py | 17 ++++++++-- run.py | 16 +++++++-- 6 files changed, 117 insertions(+), 15 deletions(-) create mode 100644 llms/providers/hf_utils.py diff --git a/agent/agent.py b/agent/agent.py index 240ce0b..5229101 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -16,10 +16,12 @@ ) from browser_env.utils import Observation, StateInfo from llms import lm_config +from llms.providers.hf_utils import generate_from_huggingface_completion from llms.providers.openai_utils import ( generate_from_openai_chat_completion, generate_from_openai_completion, ) +from llms.tokenizers import Tokenizer class Agent: @@ -144,6 +146,15 @@ def next_action( raise ValueError( f"OpenAI models do not support mode {lm_config.mode}" ) + elif lm_config.provider == "huggingface": + response = generate_from_huggingface_completion( + prompt=prompt, + model_endpoint=lm_config.gen_config["model_endpoint"], + temperature=lm_config.gen_config["temperature"], + top_p=lm_config.gen_config["top_p"], + stop_sequences=lm_config.gen_config["stop_sequences"], + max_new_tokens=lm_config.gen_config["max_new_tokens"], + ) else: raise NotImplementedError( f"Provider {lm_config.provider} not implemented" @@ -181,6 +192,15 @@ def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig: llm_config.gen_config["max_tokens"] = args.max_tokens llm_config.gen_config["stop_token"] = args.stop_token llm_config.gen_config["max_obs_length"] = args.max_obs_length + elif args.provider == "huggingface": + llm_config.gen_config["temperature"] = args.temperature + llm_config.gen_config["top_p"] = args.top_p + llm_config.gen_config["max_new_tokens"] = args.max_tokens + llm_config.gen_config["stop_sequences"] = ( + [args.stop_token] if args.stop_token else None + ) + llm_config.gen_config["max_obs_length"] = args.max_obs_length + llm_config.gen_config["model_endpoint"] = args.model_endpoint else: raise NotImplementedError(f"provider {args.provider} not implemented") return llm_config @@ -195,7 +215,7 @@ def construct_agent(args: argparse.Namespace) -> Agent: elif args.agent_type == "prompt": with open(args.instruction_path) as f: constructor_type = json.load(f)["meta_data"]["prompt_constructor"] - tokenizer = tiktoken.encoding_for_model(llm_config.model) + tokenizer = Tokenizer(args.provider, args.model) prompt_constructor = eval(constructor_type)( args.instruction_path, lm_config=llm_config, tokenizer=tokenizer ) diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 6e2d3cb..575236e 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -3,12 +3,11 @@ from pathlib import Path from typing import Any, TypedDict -import tiktoken - from browser_env import Action, ActionParsingError, Trajectory from browser_env.env_config import URL_MAPPINGS from browser_env.utils import StateInfo from llms import lm_config +from llms.tokenizers import Tokenizer APIInput = str | list[Any] | dict[str, Any] @@ -27,7 +26,7 @@ def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, - tokenizer: tiktoken.core.Encoding, + tokenizer: Tokenizer, ): self.instrction_path = Path(instruction_path) self.obs_modality = "text" @@ -77,6 +76,37 @@ def get_lm_api_input( raise ValueError( f"OpenAI models do not support mode {self.lm_config.mode}" ) + elif "huggingface" in self.lm_config.provider: + # https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320 + if "Llama-2" in self.lm_config.model: + if self.lm_config.mode == "chat": + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + BOS, EOS = "", "" + # adding the system message to be the starting of the first example + examples = [ + ( + B_SYS + intro + E_SYS + examples[0][0], + examples[0][1], + ) + ] + examples[1:] + message = "".join( + [ + f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}" + for (x, y) in examples + ] + ) + # add the current observation + message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}" + + return message + else: + raise ValueError("Only chat mode is supported for Llama-2") + else: + raise ValueError( + f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}" + ) else: raise NotImplementedError( f"Provider {self.lm_config.provider} not implemented" @@ -102,6 +132,9 @@ def map_url_to_local(self, url: str) -> str: for i, j in URL_MAPPINGS.items(): if j in url: url = url.replace(j, i) + # https + if j.replace("http", "https") in url: + url = url.replace(j.replace("http", "https"), i) return url def _extract_action(self, response: str) -> str: @@ -120,7 +153,7 @@ def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, - tokenizer: tiktoken.core.Encoding, + tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) @@ -161,10 +194,10 @@ def construct( def _extract_action(self, response: str) -> str: action_splitter = self.instruction["meta_data"]["action_splitter"] - pattern = rf"{action_splitter}(.*?){action_splitter}" + pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: - return match.group(1) + return match.group(1).strip() else: raise ActionParsingError( f"Cannot parse action from response {response}" @@ -178,7 +211,7 @@ def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, - tokenizer: tiktoken.core.Encoding, + tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) self.answer_phrase = self.instruction["meta_data"]["answer_phrase"] @@ -218,10 +251,10 @@ def construct( def _extract_action(self, response: str) -> str: # find the first occurence of action action_splitter = self.instruction["meta_data"]["action_splitter"] - pattern = rf"{action_splitter}(.*?){action_splitter}" + pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: - return match.group(1) + return match.group(1).strip() else: raise ActionParsingError( f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"' diff --git a/llms/providers/hf_utils.py b/llms/providers/hf_utils.py new file mode 100644 index 0000000..c5a3f11 --- /dev/null +++ b/llms/providers/hf_utils.py @@ -0,0 +1,21 @@ +from text_generation import Client + + +def generate_from_huggingface_completion( + prompt: str, + model_endpoint: str, + temperature: float, + top_p: float, + max_new_tokens: int, + stop_sequences: list[str] | None = None, +) -> str: + client = Client(model_endpoint, timeout=60) + generation = client.generate( + prompt=prompt, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + stop_sequences=stop_sequences, + ).generated_text + + return generation diff --git a/llms/providers/openai_utils.py b/llms/providers/openai_utils.py index 75d03ee..05887f4 100644 --- a/llms/providers/openai_utils.py +++ b/llms/providers/openai_utils.py @@ -115,6 +115,7 @@ async def agenerate_from_openai_completion( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) openai.api_key = os.environ["OPENAI_API_KEY"] + openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") limiter = aiolimiter.AsyncLimiter(requests_per_minute) async_responses = [ @@ -147,6 +148,7 @@ def generate_from_openai_completion( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) openai.api_key = os.environ["OPENAI_API_KEY"] + openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") response = openai.Completion.create( # type: ignore prompt=prompt, engine=engine, @@ -218,6 +220,7 @@ async def agenerate_from_openai_chat_completion( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) openai.api_key = os.environ["OPENAI_API_KEY"] + openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") limiter = aiolimiter.AsyncLimiter(requests_per_minute) async_responses = [ @@ -250,6 +253,7 @@ def generate_from_openai_chat_completion( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) openai.api_key = os.environ["OPENAI_API_KEY"] + openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") response = openai.ChatCompletion.create( # type: ignore model=model, @@ -279,5 +283,6 @@ def fake_generate_from_openai_chat_completion( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) openai.api_key = os.environ["OPENAI_API_KEY"] + openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"." return answer diff --git a/llms/tokenizers.py b/llms/tokenizers.py index 24763a6..67aa231 100644 --- a/llms/tokenizers.py +++ b/llms/tokenizers.py @@ -1,14 +1,27 @@ from typing import Any import tiktoken +from transformers import LlamaTokenizer class Tokenizer(object): - def __init__(self, model_name: str) -> None: - if model_name in ["gpt-4", "gpt-turbo-3.5"]: + def __init__(self, provider: str, model_name: str) -> None: + if provider == "openai": self.tokenizer = tiktoken.encoding_for_model(model_name) + elif provider == "huggingface": + self.tokenizer = LlamaTokenizer.from_pretrained(model_name) + # turn off adding special tokens automatically + self.tokenizer.add_special_tokens = False + self.tokenizer.add_bos_token = False + self.tokenizer.add_eos_token = False else: raise NotImplementedError + def encode(self, text: str) -> list[int]: + return self.tokenizer.encode(text) + + def decode(self, ids: list[int]) -> str: + return self.tokenizer.decode(ids) + def __call__(self, text: str) -> list[int]: return self.tokenizer.encode(text) diff --git a/run.py b/run.py index c4781c2..7d3d648 100644 --- a/run.py +++ b/run.py @@ -5,6 +5,8 @@ import logging import os import random +import subprocess +import tempfile import time from pathlib import Path @@ -26,6 +28,7 @@ create_stop_action, ) from browser_env.actions import is_equivalent +from browser_env.auto_login import get_site_comb_from_filepath from browser_env.helper_functions import ( RenderHelper, get_action_description, @@ -122,6 +125,12 @@ def config() -> argparse.Namespace: help="when not zero, will truncate the observation to this length before feeding to the model", default=1920, ) + parser.add_argument( + "--model_endpoint", + help="huggingface model endpoint", + type=str, + default="", + ) # example config parser.add_argument("--test_start_idx", type=int, default=0) @@ -376,7 +385,7 @@ def dump_config(args: argparse.Namespace) -> None: if __name__ == "__main__": args = config() - args.sleep_after_execution = 2.5 + args.sleep_after_execution = 2.0 prepare(args) test_file_list = [] @@ -384,9 +393,10 @@ def dump_config(args: argparse.Namespace) -> None: ed_idx = args.test_end_idx for i in range(st_idx, ed_idx): test_file_list.append(f"config_files/{i}.json") - test_file_list = get_unfinished(test_file_list, args.result_dir) + if "debug" not in args.result_dir: + test_file_list = get_unfinished(test_file_list, args.result_dir) print(f"Total {len(test_file_list)} tasks left") - args.render = True + args.render = False args.render_screenshot = True args.save_trace_enabled = True From 507659a1820026321c49348a61bd9ac8f0b37a4a Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:25:09 -0400 Subject: [PATCH 02/29] better rendering of typing action --- browser_env/actions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/browser_env/actions.py b/browser_env/actions.py index 950eeb1..eb3a859 100644 --- a/browser_env/actions.py +++ b/browser_env/actions.py @@ -125,6 +125,7 @@ def action2str( action_str = f"click [{element_id}] where [{element_id}] is {semantic_element}" case ActionTypes.TYPE: text = "".join([_id2key[i] for i in action["text"]]) + text = text.replace("\n", " ") action_str = f"type [{element_id}] [{text}] where [{element_id}] is {semantic_element}" case ActionTypes.HOVER: action_str = f"hover [{element_id}] where [{element_id}] is {semantic_element}" From 2b15f206d25363c0b49391cc2f89f0f859c71128 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:26:52 -0400 Subject: [PATCH 03/29] multi threading auto login; auto login per example --- browser_env/auto_login.py | 96 +++++++++++++++++++++++++-------------- run.py | 22 +++++++++ 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/browser_env/auto_login.py b/browser_env/auto_login.py index d466603..7602deb 100644 --- a/browser_env/auto_login.py +++ b/browser_env/auto_login.py @@ -1,5 +1,8 @@ """Script to automatically login each website""" +import argparse import glob +import os +from concurrent.futures import ThreadPoolExecutor from itertools import combinations from pathlib import Path @@ -17,6 +20,17 @@ SLOW_MO = 0 +SITES = ["gitlab", "shopping", "shopping_admin", "reddit"] +URLS = [ + f"{GITLAB}/-/profile", + f"{SHOPPING}/wishlist/", + f"{SHOPPING_ADMIN}/dashboard", + f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account", +] +EXACT_MATCH = [True, True, True, True] +KEYWORDS = ["", "", "Dashboard", "Delete"] + + def is_expired( storage_state: Path, url: str, keyword: str, url_exact: bool = True ) -> bool: @@ -42,7 +56,7 @@ def is_expired( return url not in d_url -def renew_comb(comb: list[str]) -> None: +def renew_comb(comb: list[str], auth_folder: str = "./.auth") -> None: context_manager = sync_playwright() playwright = context_manager.__enter__() browser = playwright.chromium.launch(headless=HEADLESS) @@ -83,42 +97,58 @@ def renew_comb(comb: list[str]) -> None: page.get_by_test_id("password-field").fill(password) page.get_by_test_id("sign-in-button").click() - context.storage_state(path=f"./.auth/{'.'.join(comb)}_state.json") + context.storage_state(path=f"{auth_folder}/{'.'.join(comb)}_state.json") context_manager.__exit__() -def main() -> None: - sites = ["gitlab", "shopping", "shopping_admin", "reddit"] - urls = [ - f"{GITLAB}/-/profile", - f"{SHOPPING}/wishlist/", - f"{SHOPPING_ADMIN}/dashboard", - f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account", - ] - exact_match = [True, True, True, True] - keywords = ["", "", "Dashboard", "Delete"] - - pairs = list(combinations(sites, 2)) - for pair in pairs: - # TODO[shuyanzh] auth don't work on these two sites - if "reddit" in pair and ( - "shopping" in pair or "shopping_admin" in pair - ): - continue - renew_comb(list(sorted(pair))) - - for site in sites: - renew_comb([site]) - - for c_file in glob.glob("./.auth/*.json"): - comb = c_file.split("/")[-1].rsplit("_", 1)[0].split(".") - for cur_site in comb: - url = urls[sites.index(cur_site)] - keyword = keywords[sites.index(cur_site)] - match = exact_match[sites.index(cur_site)] - assert not is_expired(Path(c_file), url, keyword, match) +def get_site_comb_from_filepath(file_path: str) -> list[str]: + comb = os.path.basename(file_path).rsplit("_", 1)[0].split(".") + return comb + + +def main(auth_folder: str = "./.auth") -> None: + pairs = list(combinations(SITES, 2)) + + max_workers = 8 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for pair in pairs: + # TODO[shuyanzh] auth don't work on these two sites + if "reddit" in pair and ( + "shopping" in pair or "shopping_admin" in pair + ): + continue + executor.submit( + renew_comb, list(sorted(pair)), auth_folder=auth_folder + ) + + for site in SITES: + executor.submit(renew_comb, [site], auth_folder=auth_folder) + + futures = [] + cookie_files = list(glob.glob(f"{auth_folder}/*.json")) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for c_file in cookie_files: + comb = get_site_comb_from_filepath(c_file) + for cur_site in comb: + url = URLS[SITES.index(cur_site)] + keyword = KEYWORDS[SITES.index(cur_site)] + match = EXACT_MATCH[SITES.index(cur_site)] + future = executor.submit( + is_expired, Path(c_file), url, keyword, match + ) + futures.append(future) + + for i, future in enumerate(futures): + assert not future.result(), f"Cookie {cookie_files[i]} expired." if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--site_list", nargs="+", default=[]) + parser.add_argument("--auth_folder", type=str, default="./.auth") + args = parser.parse_args() + if not args.site_list: + main() + else: + renew_comb(args.site_list, auth_folder=args.auth_folder) diff --git a/run.py b/run.py index 7d3d648..d4766cd 100644 --- a/run.py +++ b/run.py @@ -245,6 +245,28 @@ def test( _c = json.load(f) intent = _c["intent"] task_id = _c["task_id"] + # automatically login + if _c["storage_state"]: + cookie_file_name = os.path.basename(_c["storage_state"]) + comb = get_site_comb_from_filepath(cookie_file_name) + temp_dir = tempfile.mkdtemp() + # subprocess to renew the cookie + subprocess.run( + [ + "python", + "browser_env/auto_login.py", + "--auth_folder", + temp_dir, + "--site_list", + *comb, + ] + ) + _c["storage_state"] = f"{temp_dir}/{cookie_file_name}" + assert os.path.exists(_c["storage_state"]) + # update the config file + config_file = f"{temp_dir}/{os.path.basename(config_file)}" + with open(config_file, "w") as f: + json.dump(_c, f) logger.info(f"[Config file]: {config_file}") logger.info(f"[Intent]: {intent}") From e84910dd36bd726745aa5596921cadfa406a7b8d Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:27:31 -0400 Subject: [PATCH 04/29] better error message for env config --- browser_env/env_config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/browser_env/env_config.py b/browser_env/env_config.py index e3eac6a..81cf52d 100644 --- a/browser_env/env_config.py +++ b/browser_env/env_config.py @@ -18,14 +18,14 @@ and MAP and HOMEPAGE ), ( - f"Please setup the URLs to each site. Current: " - + f"Reddit: {REDDIT}" - + f"Shopping: {SHOPPING}" - + f"Shopping Admin: {SHOPPING_ADMIN}" - + f"Gitlab: {GITLAB}" - + f"Wikipedia: {WIKIPEDIA}" - + f"Map: {MAP}" - + f"Homepage: {HOMEPAGE}" + f"Please setup the URLs to each site. Current: \n" + + f"Reddit: {REDDIT}\n" + + f"Shopping: {SHOPPING}\n" + + f"Shopping Admin: {SHOPPING_ADMIN}\n" + + f"Gitlab: {GITLAB}\n" + + f"Wikipedia: {WIKIPEDIA}\n" + + f"Map: {MAP}\n" + + f"Homepage: {HOMEPAGE}\n" ) From 493294bde9ca07734170911e3a169e3a9b2b56a6 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:29:04 -0400 Subject: [PATCH 05/29] fix statictext bounding box bug --- browser_env/processors.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/browser_env/processors.py b/browser_env/processors.py index d4de787..0ab9c64 100644 --- a/browser_env/processors.py +++ b/browser_env/processors.py @@ -121,7 +121,15 @@ def get_bounding_client_rect( "objectId": remote_object_id, "functionDeclaration": """ function() { - return this.getBoundingClientRect().toJSON(); + if (this.nodeType == 3) { + var range = document.createRange(); + range.selectNode(this); + var rect = range.getBoundingClientRect().toJSON(); + range.detach(); + return rect; + } else { + return this.getBoundingClientRect().toJSON(); + } } """, "returnByValue": True, @@ -231,8 +239,6 @@ def fetch_page_html( # get the bound if cur_node["parentId"] == "-1": cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0] - elif cur_node["nodeName"] == "#text": - todo_nodes[node_idx] = int(cur_node["parentId"]) else: response = self.get_bounding_client_rect( client, cur_node["backendNodeId"] @@ -392,8 +398,6 @@ def fetch_page_accessibility_tree( if node["role"]["value"] == "RootWebArea": # always inside the viewport node["union_bound"] = [0.0, 0.0, 10.0, 10.0] - elif node["role"]["value"] == "StaticText": - todo_nodes[cursor] = node["parentId"] else: response = self.get_bounding_client_rect( client, backend_node_id From 57d206748c31454ecae6bbcd3b8f5f64c2b55a1e Mon Sep 17 00:00:00 2001 From: alexisxy Date: Thu, 21 Sep 2023 23:31:05 -0400 Subject: [PATCH 06/29] add support to evaluate by trace --- evaluation_harness/evaluate_by_trace.py | 66 +++++++++++++++++++++++++ evaluation_harness/evaluators.py | 21 ++++---- evaluation_harness/helper_functions.py | 13 +++++ 3 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 evaluation_harness/evaluate_by_trace.py diff --git a/evaluation_harness/evaluate_by_trace.py b/evaluation_harness/evaluate_by_trace.py new file mode 100644 index 0000000..3820789 --- /dev/null +++ b/evaluation_harness/evaluate_by_trace.py @@ -0,0 +1,66 @@ +"""Evaluate by using the traces.zip files saved""" +import argparse +import json +import os +import sys +import tempfile +import zipfile + +from playwright.sync_api import Page, sync_playwright + +from evaluation_harness import evaluator_router +from evaluation_harness.helper_functions import PseudoPage + + +def eval_trace(trace_path: str, task_id: int, config_file_folder: str): + # load the config file + config_file = f"{config_file_folder}/{task_id}.json" + with open(config_file, "r") as f: + config = json.load(f) + + if "string_match" in config["eval"]["eval_types"]: + raise ValueError( + "string_match is not supported in this evaluation script" + ) + + # extract the last url from the trace file + temp_dir = tempfile.TemporaryDirectory() + with zipfile.ZipFile(trace_path, "r") as zip_ref: + zip_ref.extractall(temp_dir.name) + with open(f"{temp_dir.name}/trace.trace", "r") as f: + trace = [] + for line in f: + trace.append(json.loads(line)) + last_url = "" + for step in trace[::-1]: + if step.get("type", None) == "frame-snapshot": + last_url = step["snapshot"]["frameUrl"] + break + if not last_url: + raise ValueError("Cannot find the last url in the trace file") + + # start the playwright + context_manager = sync_playwright() + playwright = context_manager.__enter__() + browser = playwright.chromium.launch(headless=True) + context = browser.new_context() + page = context.new_page() + page.goto("https://trace.playwright.dev/") + with page.expect_file_chooser() as fc_info: + page.get_by_role("button", name="Select file(s)").click() + file_chooser = fc_info.value + file_chooser.set_files(trace_path) + with page.expect_popup() as page1_info: + page.get_by_role("button", name="").click() + page1 = page1_info.value + + pseudo_page = PseudoPage(page1, last_url) + evaluator = evaluator_router(config_file) + + score = evaluator( + trajectory=[], + config_file=config_file, + page=pseudo_page, + client=pseudo_page.context.new_cdp_session(pseudo_page), + ) + print(score) diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 6a4eb5a..30c3a5c 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -16,6 +16,7 @@ from browser_env.actions import Action from browser_env.utils import StateInfo from evaluation_harness.helper_functions import ( + PseudoPage, gitlab_get_project_memeber_role, llm_fuzzy_match, reddit_get_post_url, @@ -36,7 +37,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession, ) -> float: raise NotImplementedError @@ -112,7 +113,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | None = None, + page: Page | PseudoPage | None = None, client: CDPSession | None = None, ) -> float: with open(config_file, "r") as f: @@ -148,7 +149,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | None = None, + page: Page | PseudoPage | None = None, client: CDPSession | None = None, ) -> float: with open(config_file, "r") as f: @@ -171,7 +172,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession | None = None, ) -> float: with open(config_file, "r") as f: @@ -209,7 +210,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession | None = None, ) -> float: with open(config_file, "r") as f: @@ -236,7 +237,9 @@ def __call__( if not locator.strip(): selected_element = page.content() # use JS to select the element - elif locator.startswith("document."): + elif locator.startswith("document.") or locator.startswith( + "[...document." + ): try: selected_element = page.evaluate(f"() => {locator}") if not selected_element: @@ -295,7 +298,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession, ) -> float: raise NotImplementedError @@ -308,7 +311,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession, ) -> float: with open(config_file, "r") as f: @@ -355,7 +358,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page, + page: Page | PseudoPage, client: CDPSession, ) -> float: diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index 915ef1f..6df22e4 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -170,3 +170,16 @@ def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: return 1.0 else: return 0.0 + + +class PseudoPage: + def __init__(self, original_page: Page, url: str): + self.url = url + self.original_page = original_page + + def __getattr__(self, attr: str) -> any: + # Delegate attribute access to the original page object + if attr not in ["url"]: + return getattr(self.original_page, attr) + else: + return getattr(self, attr) From 16f25921a2fc934e616705d271f9898b1cc44e90 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 22 Sep 2023 17:28:12 -0400 Subject: [PATCH 07/29] support generation retry when the parsing of the action failed --- agent/agent.py | 107 +++++++--------------------- agent/prompts/prompt_constructor.py | 3 +- llms/__init__.py | 13 ++++ llms/lm_config.py | 28 ++++++++ llms/utils.py | 56 +++++++++++++++ run.py | 6 ++ 6 files changed, 131 insertions(+), 82 deletions(-) create mode 100644 llms/utils.py diff --git a/agent/agent.py b/agent/agent.py index 5229101..90e3692 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -15,11 +15,12 @@ create_playwright_action, ) from browser_env.utils import Observation, StateInfo -from llms import lm_config -from llms.providers.hf_utils import generate_from_huggingface_completion -from llms.providers.openai_utils import ( +from llms import ( + call_llm, + generate_from_huggingface_completion, generate_from_openai_chat_completion, generate_from_openai_completion, + lm_config, ) from llms.tokenizers import Tokenizer @@ -122,58 +123,29 @@ def next_action( trajectory, intent, meta_data ) lm_config = self.lm_config - if lm_config.provider == "openai": - if lm_config.mode == "chat": - response = generate_from_openai_chat_completion( - messages=prompt, - model=lm_config.model, - temperature=lm_config.gen_config["temperature"], - top_p=lm_config.gen_config["top_p"], - context_length=lm_config.gen_config["context_length"], - max_tokens=lm_config.gen_config["max_tokens"], - stop_token=None, - ) - elif lm_config.mode == "completion": - response = generate_from_openai_completion( - prompt=prompt, - engine=lm_config.model, - temperature=lm_config.gen_config["temperature"], - max_tokens=lm_config.gen_config["max_tokens"], - top_p=lm_config.gen_config["top_p"], - stop_token=lm_config.gen_config["stop_token"], - ) - else: - raise ValueError( - f"OpenAI models do not support mode {lm_config.mode}" + n = 0 + while True: + response = call_llm(lm_config, prompt) + n += 1 + try: + parsed_response = self.prompt_constructor.extract_action( + response ) - elif lm_config.provider == "huggingface": - response = generate_from_huggingface_completion( - prompt=prompt, - model_endpoint=lm_config.gen_config["model_endpoint"], - temperature=lm_config.gen_config["temperature"], - top_p=lm_config.gen_config["top_p"], - stop_sequences=lm_config.gen_config["stop_sequences"], - max_new_tokens=lm_config.gen_config["max_new_tokens"], - ) - else: - raise NotImplementedError( - f"Provider {lm_config.provider} not implemented" - ) - - try: - parsed_response = self.prompt_constructor.extract_action(response) - if self.action_set_tag == "id_accessibility_tree": - action = create_id_based_action(parsed_response) - elif self.action_set_tag == "playwright": - action = create_playwright_action(parsed_response) - else: - raise ValueError(f"Unknown action type {self.action_set_tag}") - - action["raw_prediction"] = response - - except ActionParsingError as e: - action = create_none_action() - action["raw_prediction"] = response + if self.action_set_tag == "id_accessibility_tree": + action = create_id_based_action(parsed_response) + elif self.action_set_tag == "playwright": + action = create_playwright_action(parsed_response) + else: + raise ValueError( + f"Unknown action type {self.action_set_tag}" + ) + action["raw_prediction"] = response + break + except ActionParsingError as e: + if n >= lm_config.gen_config["max_retry"]: + action = create_none_action() + action["raw_prediction"] = response + break return action @@ -181,33 +153,8 @@ def reset(self, test_config_file: str) -> None: pass -def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig: - llm_config = lm_config.LMConfig( - provider=args.provider, model=args.model, mode=args.mode - ) - if args.provider == "openai": - llm_config.gen_config["temperature"] = args.temperature - llm_config.gen_config["top_p"] = args.top_p - llm_config.gen_config["context_length"] = args.context_length - llm_config.gen_config["max_tokens"] = args.max_tokens - llm_config.gen_config["stop_token"] = args.stop_token - llm_config.gen_config["max_obs_length"] = args.max_obs_length - elif args.provider == "huggingface": - llm_config.gen_config["temperature"] = args.temperature - llm_config.gen_config["top_p"] = args.top_p - llm_config.gen_config["max_new_tokens"] = args.max_tokens - llm_config.gen_config["stop_sequences"] = ( - [args.stop_token] if args.stop_token else None - ) - llm_config.gen_config["max_obs_length"] = args.max_obs_length - llm_config.gen_config["model_endpoint"] = args.model_endpoint - else: - raise NotImplementedError(f"provider {args.provider} not implemented") - return llm_config - - def construct_agent(args: argparse.Namespace) -> Agent: - llm_config = construct_llm_config(args) + llm_config = lm_config.construct_llm_config(args) agent: Agent if args.agent_type == "teacher_forcing": diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 575236e..9f991e5 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -8,8 +8,7 @@ from browser_env.utils import StateInfo from llms import lm_config from llms.tokenizers import Tokenizer - -APIInput = str | list[Any] | dict[str, Any] +from llms.utils import APIInput class Instruction(TypedDict): diff --git a/llms/__init__.py b/llms/__init__.py index 8dd1547..7a8c942 100644 --- a/llms/__init__.py +++ b/llms/__init__.py @@ -1 +1,14 @@ """This module is adapt from https://github.com/zeno-ml/zeno-build""" +from .providers.hf_utils import generate_from_huggingface_completion +from .providers.openai_utils import ( + generate_from_openai_chat_completion, + generate_from_openai_completion, +) +from .utils import call_llm + +__all__ = [ + "generate_from_openai_completion", + "generate_from_openai_chat_completion", + "generate_from_huggingface_completion", + "call_llm", +] diff --git a/llms/lm_config.py b/llms/lm_config.py index 6d67579..2156ef9 100644 --- a/llms/lm_config.py +++ b/llms/lm_config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import argparse import dataclasses from dataclasses import dataclass from typing import Any @@ -27,3 +28,30 @@ class LMConfig: tokenizer_cls: type | None = None mode: str | None = None gen_config: dict[str, Any] = dataclasses.field(default_factory=dict) + + +def construct_llm_config(args: argparse.Namespace) -> LMConfig: + llm_config = LMConfig( + provider=args.provider, model=args.model, mode=args.mode + ) + if args.provider == "openai": + llm_config.gen_config["temperature"] = args.temperature + llm_config.gen_config["top_p"] = args.top_p + llm_config.gen_config["context_length"] = args.context_length + llm_config.gen_config["max_tokens"] = args.max_tokens + llm_config.gen_config["stop_token"] = args.stop_token + llm_config.gen_config["max_obs_length"] = args.max_obs_length + llm_config.gen_config["max_retry"] = args.max_retry + elif args.provider == "huggingface": + llm_config.gen_config["temperature"] = args.temperature + llm_config.gen_config["top_p"] = args.top_p + llm_config.gen_config["max_new_tokens"] = args.max_tokens + llm_config.gen_config["stop_sequences"] = ( + [args.stop_token] if args.stop_token else None + ) + llm_config.gen_config["max_obs_length"] = args.max_obs_length + llm_config.gen_config["model_endpoint"] = args.model_endpoint + llm_config.gen_config["max_retry"] = args.max_retry + else: + raise NotImplementedError(f"provider {args.provider} not implemented") + return llm_config diff --git a/llms/utils.py b/llms/utils.py new file mode 100644 index 0000000..54b57e0 --- /dev/null +++ b/llms/utils.py @@ -0,0 +1,56 @@ +import argparse +from typing import Any + +from llms import ( + generate_from_huggingface_completion, + generate_from_openai_chat_completion, + generate_from_openai_completion, + lm_config, +) + +APIInput = str | list[Any] | dict[str, Any] + + +def call_llm( + lm_config: lm_config.LMConfig, + prompt: list[Any] | str, +) -> APIInput: + if lm_config.provider == "openai": + if lm_config.mode == "chat": + response = generate_from_openai_chat_completion( + messages=prompt, + model=lm_config.model, + temperature=lm_config.gen_config["temperature"], + top_p=lm_config.gen_config["top_p"], + context_length=lm_config.gen_config["context_length"], + max_tokens=lm_config.gen_config["max_tokens"], + stop_token=None, + ) + elif lm_config.mode == "completion": + response = generate_from_openai_completion( + prompt=prompt, + engine=lm_config.model, + temperature=lm_config.gen_config["temperature"], + max_tokens=lm_config.gen_config["max_tokens"], + top_p=lm_config.gen_config["top_p"], + stop_token=lm_config.gen_config["stop_token"], + ) + else: + raise ValueError( + f"OpenAI models do not support mode {lm_config.mode}" + ) + elif lm_config.provider == "huggingface": + response = generate_from_huggingface_completion( + prompt=prompt, + model_endpoint=lm_config.gen_config["model_endpoint"], + temperature=lm_config.gen_config["temperature"], + top_p=lm_config.gen_config["top_p"], + stop_sequences=lm_config.gen_config["stop_sequences"], + max_new_tokens=lm_config.gen_config["max_new_tokens"], + ) + else: + raise NotImplementedError( + f"Provider {lm_config.provider} not implemented" + ) + + return response diff --git a/run.py b/run.py index d4766cd..010bc54 100644 --- a/run.py +++ b/run.py @@ -119,6 +119,12 @@ def config() -> argparse.Namespace: parser.add_argument("--context_length", type=int, default=0) parser.add_argument("--max_tokens", type=int, default=384) parser.add_argument("--stop_token", type=str, default=None) + parser.add_argument( + "--max_retry", + type=int, + help="max retry times to perform generations when parsing fails", + default=1, + ) parser.add_argument( "--max_obs_length", type=int, From 9f3e4ac4cce7f487fec53f14cfdc5791236dada8 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 22 Sep 2023 17:30:02 -0400 Subject: [PATCH 08/29] ignore cache --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1da6709..1cc64e2 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,7 @@ run.sh # trajectory visualization render_cache/* +cache/* # TMP IGNORE agent/prompts/jsons/* From 0e7bcda0bad31b8f7f5b3eb5e2b0bc67ffbb572c Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 23 Sep 2023 00:15:22 -0400 Subject: [PATCH 09/29] add prompts --- agent/prompts/raw/p_cot_id_actree_2s_no_na.py | 82 ++++++++++++++++++ .../raw/p_direct_id_actree_2s_no_na.py | 81 ++++++++++++++++++ .../raw/p_direct_id_actree_3s_llama.py | 83 +++++++++++++++++++ 3 files changed, 246 insertions(+) create mode 100644 agent/prompts/raw/p_cot_id_actree_2s_no_na.py create mode 100644 agent/prompts/raw/p_direct_id_actree_2s_no_na.py create mode 100644 agent/prompts/raw/p_direct_id_actree_3s_llama.py diff --git a/agent/prompts/raw/p_cot_id_actree_2s_no_na.py b/agent/prompts/raw/p_cot_id_actree_2s_no_na.py new file mode 100644 index 0000000..945cd95 --- /dev/null +++ b/agent/prompts/raw/p_cot_id_actree_2s_no_na.py @@ -0,0 +1,82 @@ +prompt = { + "intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue. + +Here's the information you'll have: +The user's objective: This is the task you're trying to complete. +The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information. +The current web page's URL: This is the page you're currently navigating. +The open tabs: These are the tabs you have open. +The previous action: This is the action you just performed. It may be helpful to track your progress. + +The actions you can perform fall into several categories: + +Page Operation Actions: +`click [id]`: This action clicks on an element with a specific id on the webpage. +`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0. +`hover [id]`: Hover over an element with id. +`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v). +`scroll [direction=down|up]`: Scroll the page up or down. + +Tab Management Actions: +`new_tab`: Open a new, empty browser tab. +`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index. +`close_tab`: Close the currently active tab. + +URL Navigation Actions: +`goto [url]`: Navigate to a specific URL. +`go_back`: Navigate to the previously viewed page. +`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed). + +Completion Action: +`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. + +Homepage: +If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit. +http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites. + +To be successful, it is very important to follow the following rules: +1. You should only issue an action that is valid given the current observation +2. You should only issue one action at a time. +3. You should follow the examples to reason step by step and then issue the next action. +4. Generate the action in the correct format. Start with a "In summary, the next action I will perform is" phrase, followed by action inside ``````. For example, "In summary, the next action I will perform is ```click [1234]```". +5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""", + "examples": [ + ( + """OBSERVATION: +[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)' + [1749] StaticText '$279.49' + [1757] button 'Add to Cart' + [1760] button 'Add to Wish List' + [1761] button 'Add to Compare' +URL: http://onestopmarket.com/office-products/office-electronics.html +OBJECTIVE: What is the price of HP Inkjet Fax Machine +PREVIOUS ACTION: None""", + "Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```", + ), + ( + """OBSERVATION: +[164] textbox 'Search' focused: True required: False +[171] button 'Go' +[174] link 'Find directions between two points' +[212] heading 'Search Results' +[216] button 'Close' +URL: http://openstreetmap.org +OBJECTIVE: Show me the restaurants near CMU +PREVIOUS ACTION: None""", + "Let's think step-by-step. This page has a search box whose ID is [164]. According to the nominatim rule of openstreetmap, I can search for the restaurants near a location by \"restaurants near\". I can submit my typing by pressing the Enter afterwards. In summary, the next action I will perform is ```type [164] [restaurants near CMU] [1]```", + ), + ], + "template": """OBSERVATION: +{observation} +URL: {url} +OBJECTIVE: {objective} +PREVIOUS ACTION: {previous_action}""", + "meta_data": { + "observation": "accessibility_tree", + "action_type": "id_accessibility_tree", + "keywords": ["url", "objective", "observation", "previous_action"], + "prompt_constructor": "CoTPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```" + }, +} diff --git a/agent/prompts/raw/p_direct_id_actree_2s_no_na.py b/agent/prompts/raw/p_direct_id_actree_2s_no_na.py new file mode 100644 index 0000000..c399454 --- /dev/null +++ b/agent/prompts/raw/p_direct_id_actree_2s_no_na.py @@ -0,0 +1,81 @@ +prompt = { + "intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue. + +Here's the information you'll have: +The user's objective: This is the task you're trying to complete. +The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information. +The current web page's URL: This is the page you're currently navigating. +The open tabs: These are the tabs you have open. +The previous action: This is the action you just performed. It may be helpful to track your progress. + +The actions you can perform fall into several categories: + +Page Operation Actions: +`click [id]`: This action clicks on an element with a specific id on the webpage. +`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0. +`hover [id]`: Hover over an element with id. +`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v). +`scroll [direction=down|up]`: Scroll the page up or down. + +Tab Management Actions: +`new_tab`: Open a new, empty browser tab. +`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index. +`close_tab`: Close the currently active tab. + +URL Navigation Actions: +`goto [url]`: Navigate to a specific URL. +`go_back`: Navigate to the previously viewed page. +`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed). + +Completion Action: +`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. + +Homepage: +If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit. +http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites. + +To be successful, it is very important to follow the following rules: +1. You should only issue an action that is valid given the current observation +2. You should only issue one action at a time. +4. Generate the action in the correct format, wrap the action inside ``````. For example, ```click [1234]```". +5. Issue stop action when you think you have achieved the objective.""", + "examples": [ + ( + """OBSERVATION: +[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)' + [1749] StaticText '$279.49' + [1757] button 'Add to Cart' + [1760] button 'Add to Wish List' + [1761] button 'Add to Compare' +URL: http://onestopmarket.com/office-products/office-electronics.html +OBJECTIVE: What is the price of HP Inkjet Fax Machine +PREVIOUS ACTION: None""", + "```stop [$279.49]```", + ), + ( + """OBSERVATION: +[164] textbox 'Search' focused: True required: False +[171] button 'Go' +[174] link 'Find directions between two points' +[212] heading 'Search Results' +[216] button 'Close' +URL: http://openstreetmap.org +OBJECTIVE: Show me the restaurants near CMU +PREVIOUS ACTION: None""", + "```type [164] [restaurants near CMU] [1]```", + ), + ], + "template": """OBSERVATION: +{observation} +URL: {url} +OBJECTIVE: {objective} +PREVIOUS ACTION: {previous_action}""", + "meta_data": { + "observation": "accessibility_tree", + "action_type": "id_accessibility_tree", + "keywords": ["url", "objective", "observation", "previous_action"], + "prompt_constructor": "CoTPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```" + }, +} diff --git a/agent/prompts/raw/p_direct_id_actree_3s_llama.py b/agent/prompts/raw/p_direct_id_actree_3s_llama.py new file mode 100644 index 0000000..6278d2b --- /dev/null +++ b/agent/prompts/raw/p_direct_id_actree_3s_llama.py @@ -0,0 +1,83 @@ +prompt = { + "intro": """You are an autonomous intelligent agent tasked with navigating a web browser. The actions you can perform fall into several categories: + +Page Operation Actions: +`click [id]`: This action clicks on an element with a specific id on the webpage. +`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0. +`hover [id]`: Hover over an element with id. +`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v). +`scroll [direction=down|up]`: Scroll the page up or down. + +Tab Management Actions: +`new_tab`: Open a new, empty browser tab. +`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index. +`close_tab`: Close the currently active tab. + +URL Navigation Actions: +`goto [url]`: Navigate to a specific URL. +`go_back`: Navigate to the previously viewed page. +`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed). + +Completion Action: +`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. + +Homepage: +If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit. + +You can only issue one action at a time""", + + "examples": [ + ( + """Observation: +[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)' + [1749] StaticText '$279.49' + [1757] button 'Add to Cart' + [1760] button 'Add to Wish List' + [1761] button 'Add to Compare' +URL: http://onestopmarket.com/office-products/office-electronics.html +Objective: What is the price of HP Inkjet Fax Machine +Previous action: None""", + "```stop [$279.49]```", + ), + ( + """Observation: +[164] textbox 'Search' focused: True required: False +[171] button 'Go' +[174] link 'Find directions between two points' +[212] heading 'Search Results' +[216] button 'Close' +URL: http://openstreetmap.org +Objective: Show me the restaurants near CMU +Previous action: None""", + "```type [164] [restaurants near CMU] [1]```", + ), + ( + """Observation: +[2036] button 'Sort by: New' hasPopup: menu expanded: False + [587] link 'US Marine’s adoption of Afghan war orphan voided' + [989] time 'March 30, 2023 at 15:03:48 AM UTC' + [602] link 'York student uses AI chatbot to get parking fine revoked' + [1025] time 'March 15, 2023 at 7:48:34 AM UTC' + [617] link 'Loveland parents furious after teachers leave, communication lagged during school threat investigation' + [1025] time 'March 2, 2023 at 3:46:01 AM UTC' +URL: http://reddit.com/f/news/new +Objective: Open the most recent post that was published prior to March 1st. +Previous action: None""", + "```scroll [down]```", + ) + ], + "template": """Observation: +{observation} +URL: {url} +Objective: {objective} +Previous action: {previous_action}""", + "meta_data": { + "observation": "accessibility_tree", + "action_type": "id_accessibility_tree", + "keywords": ["url", "objective", "observation", "previous_action"], + "prompt_constructor": "DirectPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```", + "force_prefix": "```" + }, +} From 741292e1d26ef3ede5316e947bb4ebdf4af510ea Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 23 Sep 2023 00:16:18 -0400 Subject: [PATCH 10/29] fix force_prefix missing bug --- agent/agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/agent/agent.py b/agent/agent.py index 90e3692..923ebce 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -126,6 +126,10 @@ def next_action( n = 0 while True: response = call_llm(lm_config, prompt) + force_prefix = self.prompt_constructor.instruction[ + "meta_data" + ].get("force_prefix", "") + response = f"{force_prefix}{response}" n += 1 try: parsed_response = self.prompt_constructor.extract_action( From 1ee1ea48007f0d6d4c5121086757beebd119eebb Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 23 Sep 2023 00:16:36 -0400 Subject: [PATCH 11/29] fix typo --- agent/prompts/prompt_constructor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 9f991e5..a0ca408 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -27,10 +27,10 @@ def __init__( lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): - self.instrction_path = Path(instruction_path) + self.instruction_path = Path(instruction_path) self.obs_modality = "text" self.lm_config = lm_config - instruction = json.load(open(self.instrction_path)) + instruction = json.load(open(self.instruction_path)) instruction["examples"] = [tuple(e) for e in instruction["examples"]] self.instruction: Instruction = instruction self.tokenizer = tokenizer From c1ae73c960dddc375b0f52b06421092502693355 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 23 Sep 2023 00:17:19 -0400 Subject: [PATCH 12/29] add script to check inference failures --- scripts/check_error_runs.py | 144 ++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 scripts/check_error_runs.py diff --git a/scripts/check_error_runs.py b/scripts/check_error_runs.py new file mode 100644 index 0000000..2fb4247 --- /dev/null +++ b/scripts/check_error_runs.py @@ -0,0 +1,144 @@ +"""Some executions may failed. +This script checks the recordings, print the task ids. +It deletes the recordings if needed.""" +import argparse +import glob +import os +import shutil +import sys + + +def merge_logs(result_folder: str, args: argparse.Namespace) -> str: + if not os.path.exists(f"{result_folder}/log_files.txt"): + sys.exit(1) + + with open(f"{result_folder}/log_files.txt", "r") as f: + log_files = f.readlines() + + merged_results = {} + for file in log_files: + with open(file.strip(), "r") as f: + lines = f.readlines() + + cur_log = [] + index = None + for line in lines: + if "[Config file]" in line: + if ( + cur_log + and index + and os.path.exists(f"{result_folder}/render_{index}.html") + ): + merged_results[index] = cur_log + # update index and log + index = line.split("/")[-1].split(".")[0] + cur_log = [line] + else: + cur_log.append(line) + + if os.path.exists(f"{result_folder}/render_{index}.html"): + merged_results[index] = cur_log + + # sort by the key + merged_results = dict( + sorted(merged_results.items(), key=lambda x: int(x[0])) + ) + + merged_log_path = f"{result_folder}/tmp_merged_log.txt" + with open(merged_log_path, "w") as f: + for k, v in merged_results.items(): + for line in v: + f.write(line) + print(f"Number of examples: {len(merged_results)}") + + unlog_examples = [] + for i in range(812): + if ( + os.path.exists(f"{result_folder}/render_{i}.html") + and str(i) not in merged_results + ): + unlog_examples.append(i) + + print(f"Number of unlogged examples: {len(unlog_examples)}") + print(unlog_examples) + if ( + args.delete_errors + or input("Do you want to delete these examples? (y/n)") == "y" + ): + for idx in unlog_examples: + os.remove(f"{args.result_folder}/render_{idx}.html") + + return merged_log_path + + +def check_unhandled_errors(args: argparse.Namespace) -> int: + log_path = merge_logs(args.result_folder, args) + with open(log_path, "r") as f: + logs = f.read() + + error_examples = [] + for line in logs.split("\n"): + if "[Config file]" in line: + example_idx = line.split("/")[-1].split(".")[0] + if "[Unhandled Error]" in line or "[OpenAI Error]" in line: + error_examples.append(int(example_idx)) + + num_errors = len(error_examples) + print(f"Number of unhandled errors: {len(error_examples)}") + print(error_examples) + if ( + args.delete_errors + or input("Do you want to delete these examples? (y/n)") == "y" + ): + for idx in error_examples: + if os.path.exists(f"{args.result_folder}/render_{idx}.html"): + os.remove(f"{args.result_folder}/render_{idx}.html") + return num_errors + + +def check_unexpected_logout(args: argparse.Namespace) -> int: + target_strings = set( + [ + "Creating an account has many benefits: check out faster", + "Welcome, please sign in", + "Username or email", + "Keep me logged in", + ] + ) + + error_examples = [] + for render_file in glob.glob(f"{args.result_folder}/render_*.html"): + with open(render_file, "r") as f: + contents = f.read() + if any([s in contents for s in target_strings]): + task_id = int( + render_file.split("/")[-1].split(".")[0].split("_")[-1] + ) + error_examples.append(task_id) + print(f"Number of unexpected logout: {len(error_examples)}") + print(error_examples) + num_errors = len(error_examples) + if ( + args.delete_errors + or input("Do you want to delete these examples? (y/n)") == "y" + ): + for idx in error_examples: + if os.path.exists(f"{args.result_folder}/render_{idx}.html"): + os.remove(f"{args.result_folder}/render_{idx}.html") + + return num_errors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("result_folder", type=str) + parser.add_argument("--delete_errors", action="store_true") + parser.add_argument("--tolerance", type=int, default=0) + + args = parser.parse_args() + n1 = check_unhandled_errors(args) + n2 = check_unexpected_logout(args) + if n1 + n2 > args.tolerance: + sys.exit(1) + else: + sys.exit(0) From 6fdbd92bd57de34c6042e2fea61d2cf346ad3324 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 23 Sep 2023 00:17:38 -0400 Subject: [PATCH 13/29] add parallel running script --- parallel_run.sh | 73 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 parallel_run.sh diff --git a/parallel_run.sh b/parallel_run.sh new file mode 100644 index 0000000..fb56cc3 --- /dev/null +++ b/parallel_run.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +result_dir="cache/919_gpt35_16k_cot_na" +model="gpt-3.5-turbo-16k-0613" +instruction_path="agent/prompts/jsons/p_cot_id_actree_2s.json" + +SERVER="" +OPENAI_API_KEY="" +OPENAI_ORGANIZATION="" +CONDA_ENV_NAME="webarena" +ENV_VARIABLES="export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://miniserver1875.asuscomm.com:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}" + +# get the number of tmux panes +num_panes=$(tmux list-panes | wc -l) + +# calculate how many panes need to be created +let "panes_to_create = 5 - num_panes" + +# array of tmux commands to create each pane +tmux_commands=( + 'tmux split-window -h' + 'tmux split-window -v' + 'tmux select-pane -t 0; tmux split-window -v' + 'tmux split-window -v' + 'tmux select-pane -t 3; tmux split-window -v' +) + +# create panes up to 5 +for ((i=0; i<$panes_to_create; i++)); do + eval ${tmux_commands[$i]} +done + +#!/bin/bash + +# Function to run a job +run_job() { + tmux select-pane -t $1 + tmux send-keys "conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until python run.py --test_start_idx $2 --test_end_idx $3 --model ${model} --instruction_path ${instruction_path} --result_dir ${result_dir}; do echo 'crashed' >&2; sleep 1; done" C-m + sleep 3 +} + +TOLERANCE=2 +run_batch() { + args=("$@") # save all arguments in an array + num_jobs=${#args[@]} # get number of arguments + + for ((i=1; i<$num_jobs; i++)); do + run_job $i ${args[i-1]} ${args[i]} + done + + # Wait for all jobs to finish + while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do + sleep 100 # wait for 10 seconds before checking again + done + + # Run checker + while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do + echo "Check failed, rerunning jobs..." + for ((i=1; i<$num_jobs; i++)); do + run_job $i ${args[i-1]} ${args[i]} + done + + # Wait for all jobs to finish + while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do + sleep 100 # wait for 10 seconds before checking again + done + done + +} + +run_batch 0 100 200 300 380 +run_batch 380 480 580 680 770 +run_batch 770 812 From cd7d5934a0d8d0bb0f6546ae9bdab5e6b2433522 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Mon, 25 Sep 2023 23:59:08 -0400 Subject: [PATCH 14/29] fix annotation errors based on human trajectories --- config_files/test.raw.json | 200 ++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 102 deletions(-) diff --git a/config_files/test.raw.json b/config_files/test.raw.json index d4806fb..641d8bf 100644 --- a/config_files/test.raw.json +++ b/config_files/test.raw.json @@ -11649,7 +11649,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/admin/system_design_theme/edit/id/1/key/", + "reference_url": "__SHOPPING_ADMIN__/admin/system_design_theme/edit/id/1", "program_html": [], "url_note": "GOLD in PRED" }, @@ -12482,7 +12482,7 @@ "program_html": [ { "url": "__REDDIT__/user/MarvelsGrantMan136", - "locator": "document.querySelector(\".profile-user-bio\").outerText", + "locator": "document.querySelector(\".user-bio__biography\").outerText", "required_contents": { "exact_match": "I am a robot" } @@ -12515,7 +12515,7 @@ "program_html": [ { "url": "__REDDIT__/user/MarvelsGrantMan136", - "locator": "document.querySelector(\".profile-user-bio\").outerText", + "locator": "document.querySelector(\".user-bio__biography\").outerText", "required_contents": { "exact_match": "Pro Python Developer with 20 years of Experience" } @@ -12548,7 +12548,7 @@ "program_html": [ { "url": "__REDDIT__/user/MarvelsGrantMan136", - "locator": "document.querySelector(\".profile-user-bio\").outerText", + "locator": "document.querySelector(\".user-bio__biography\").outerText", "required_contents": { "exact_match": "Seeking SDE positions" } @@ -12581,7 +12581,7 @@ "program_html": [ { "url": "__REDDIT__/user/MarvelsGrantMan136", - "locator": "document.querySelector(\".profile-user-bio\").outerText", + "locator": "document.querySelector(\".user-bio__biography\").outerText", "required_contents": { "exact_match": "Freelance Web Developer" } @@ -12614,7 +12614,7 @@ "program_html": [ { "url": "__REDDIT__/user/MarvelsGrantMan136", - "locator": "document.querySelector(\".profile-user-bio\").outerText", + "locator": "document.querySelector(\".user-bio__biography\").outerText", "required_contents": { "exact_match": "Awesome Prompt Artist" } @@ -12751,7 +12751,7 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/books/deeplearning", + "url": "__REDDIT__/f/deeplearning", "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ @@ -12786,7 +12786,7 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/books/explainlikeimfive", + "url": "__REDDIT__/f/explainlikeimfive", "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ @@ -13135,7 +13135,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.profile-user-bio').outerText", + "locator": "document.querySelector('.cover-status').outerText", "required_contents": { "exact_match": "Busy" } @@ -13168,7 +13168,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.profile-user-bio').outerText", + "locator": "document.querySelector('.cover-status').outerText", "required_contents": { "exact_match": "Enjoying life" } @@ -13201,7 +13201,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.profile-user-bio').outerText", + "locator": "document.querySelector('.cover-status').outerText", "required_contents": { "exact_match": "Playing Badminton" } @@ -13234,7 +13234,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.profile-user-bio').outerText", + "locator": "document.querySelector('.cover-status').outerText", "required_contents": { "exact_match": "Resting due to leg injury" } @@ -13267,7 +13267,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.profile-user-bio').outerText", + "locator": "document.querySelector('.cover-status').outerText", "required_contents": { "exact_match": "Out of Office" } @@ -14560,7 +14560,7 @@ "task_id": 460, "require_login": true, "storage_state": "./.auth/shopping_admin_state.json", - "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/418/", "geolocation": null, "intent_template": "{{action}} the price of this product by {{amount}}", "instantiation_dict": { @@ -14577,10 +14577,10 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/418/", "locator": "document.querySelector('[name=\"product[price]\"').value", "required_contents": { - "exact_match": "58.65" + "exact_match": "38.25" } } ] @@ -14594,7 +14594,7 @@ "task_id": 461, "require_login": true, "storage_state": "./.auth/shopping_admin_state.json", - "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/1481/", + "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/721/", "geolocation": null, "intent_template": "{{action}} the price of this product by {{amount}}", "instantiation_dict": { @@ -14611,10 +14611,10 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/1481/", + "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/721/", "locator": "document.querySelector('[name=\"product[price]\"').value", "required_contents": { - "exact_match": "43.50" + "exact_match": "29.50" } } ] @@ -14628,7 +14628,7 @@ "task_id": 462, "require_login": true, "storage_state": "./.auth/shopping_admin_state.json", - "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/710/", "geolocation": null, "intent_template": "{{action}} the price of this product by {{amount}}", "instantiation_dict": { @@ -14645,10 +14645,10 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/710/", "locator": "document.querySelector('[name=\"product[price]\"').value", "required_contents": { - "exact_match": "75.90" + "exact_match": "19.80" } } ] @@ -14662,7 +14662,7 @@ "task_id": 463, "require_login": true, "storage_state": "./.auth/shopping_admin_state.json", - "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "start_url": "__SHOPPING_ADMIN__/catalog/product/edit/id/996/", "geolocation": null, "intent_template": "{{action}} the price of this product by {{amount}}", "instantiation_dict": { @@ -14679,10 +14679,10 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/237/", + "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/996/", "locator": "document.querySelector('[name=\"product[price]\"').value", "required_contents": { - "exact_match": "79.35" + "exact_match": "36.80" } } ] @@ -14713,10 +14713,11 @@ "program_html": [ { "url": "__SHOPPING_ADMIN__/antonia-racer-tank.html", - "locator": "document.querySelector('.product.info.detailed').outerText", + "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ - "This is in regular rotation at the gym. Its colorful and looks kinda cute under my exercise tanks." + "This is in regular rotation at the gym", + "Its colorful and looks kinda cute under my exercise tanks" ] } } @@ -15640,17 +15641,10 @@ "eval_types": [ "program_html" ], - "reference_answers": null, + "reference_answers": {"exact_match": "N/A"}, "reference_url": "", - "program_html": [ - { - "url": "__SHOPPING_ADMIN__/sales/order/view/order_id/299/", - "locator": "document.querySelector(\"#order_history_block\").querySelector(\".note-list\").firstElementChild.querySelector(\".note-list-comment\").outerText", - "required_contents": { - "exact_match": "the order is ready to be shipped soon!" - } - } - ] + "program_html": [], + "string_note": "System message: We cannot add order history." }, "intent_template_id": 280 }, @@ -16802,8 +16796,7 @@ "AndroidAsync", "Pytorch-GAN", "administrate", - "keycloak", - "openapi-generator" + "keycloak" ] } } @@ -17544,11 +17537,13 @@ "program_html": [ { "url": "__SHOPPING_ADMIN__/bella-tank.html", - "locator": "document.querySelector('.product.info.detailed').outerText", + "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ - "Good choice for working out and stylin' enough to wear when I'm hanging with friends on hot days. Also washes really well!", - "Always a sweet n sporty look for the gym! Keeps me cool and the seams don't rub up against me like some of my other tanks." + "Good choice for working out and stylin' enough to wear when I'm hanging with friends on hot days", + "Also washes really well", + "Always a sweet n sporty look for the gym", + "Keeps me cool and the seams don't rub up against me like some of my other tanks" ] } } @@ -19105,7 +19100,7 @@ "program_html": [ { "url": "__REDDIT__/f/sci_fi/edit", - "locator": "document.querySelector(\"#forum_description\").outerText", + "locator": "document.querySelector(\"#forum_description\").value", "required_contents": { "must_include": [ "A wild place for sci-fi enthusiasts" @@ -19114,7 +19109,7 @@ }, { "url": "__REDDIT__/f/sci_fi/edit", - "locator": "document.querySelector(\"#forum_sidebar\").outerText", + "locator": "document.querySelector(\"#forum_sidebar\").value", "required_contents": { "must_include": [ "New", @@ -19159,7 +19154,7 @@ "program_html": [ { "url": "__REDDIT__/f/cmu_lti/edit", - "locator": "document.querySelector(\"#forum_description\").outerText", + "locator": "document.querySelector(\"#forum_description\").value", "required_contents": { "must_include": [ "Language Technologies Institute at Carnegie Mellon University" @@ -19168,7 +19163,7 @@ }, { "url": "__REDDIT__/f/cmu_lti/edit", - "locator": "document.querySelector(\"#forum_sidebar\").outerText", + "locator": "document.querySelector(\"#forum_sidebar\").value", "required_contents": { "must_include": [ "announcement", @@ -19212,7 +19207,7 @@ "program_html": [ { "url": "__REDDIT__/f/Cyberpunk/edit", - "locator": "document.querySelector(\"#forum_description\").outerText", + "locator": "document.querySelector(\"#forum_description\").value", "required_contents": { "must_include": [ "Welcome to the future" @@ -19221,7 +19216,7 @@ }, { "url": "__REDDIT__/f/Cyberpunk/edit", - "locator": "document.querySelector(\"#forum_sidebar\").outerText", + "locator": "document.querySelector(\"#forum_sidebar\").value", "required_contents": { "must_include": [ "Games", @@ -19266,7 +19261,7 @@ "program_html": [ { "url": "__REDDIT__/f/PlantsForCatParents/edit", - "locator": "document.querySelector(\"#forum_description\").outerText", + "locator": "document.querySelector(\"#forum_description\").value", "required_contents": { "must_include": [ "Cat parents & plan lovers" @@ -19275,7 +19270,7 @@ }, { "url": "__REDDIT__/f/PlantsForCatParents/edit", - "locator": "document.querySelector(\"#forum_sidebar\").outerText", + "locator": "document.querySelector(\"#forum_sidebar\").value", "required_contents": { "must_include": [ "Cat friendly", @@ -19318,7 +19313,7 @@ "program_html": [ { "url": "__REDDIT__/f/Karaoke", - "locator": "document.querySelector(\"#forum_description\").outerText", + "locator": "document.querySelector(\"#forum_description\").value", "required_contents": { "must_include": [ "Place for Karaoke lovers" @@ -19327,7 +19322,7 @@ }, { "url": "__REDDIT__/f/Karaoke", - "locator": "document.querySelector(\"#forum_sidebar\").outerText", + "locator": "document.querySelector(\"#forum_sidebar\").value", "required_contents": { "must_include": [ "devices", @@ -20030,7 +20025,7 @@ "locator": "document.querySelector(\"#sidebar > section\").outerText", "required_contents": { "must_include": [ - "machine learning" + "MachineLearning" ] } } @@ -21679,11 +21674,11 @@ "geolocation": null, "intent_template": "Post a notice on a virtual meetup for {{interest}} enthusiasts on {{date}} in the {{subreddit}} subreddit", "instantiation_dict": { - "interest": "Harry Poter", + "interest": "Harry Potter", "date": "July 8th", "subreddit": "books" }, - "intent": "Post a notice on a virtual meetup for Harry Poter enthusiasts on July 8th in the books subreddit", + "intent": "Post a notice on a virtual meetup for Harry Potter enthusiasts on July 8th in the books subreddit", "require_reset": false, "eval": { "eval_types": [ @@ -21698,7 +21693,7 @@ "locator": "document.querySelector('.submission__inner').outerText", "required_contents": { "must_include": [ - "Harry Poter", + "Harry Potter", "July 8th", "virtual meetup" ] @@ -22152,7 +22147,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -22167,7 +22162,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 153 }, @@ -22193,7 +22188,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -22208,7 +22203,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 153 }, @@ -22234,7 +22229,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -22249,7 +22244,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 153 }, @@ -22275,7 +22270,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -22290,7 +22285,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 153 }, @@ -22316,7 +22311,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -22331,7 +22326,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 153 }, @@ -23066,9 +23061,10 @@ "required_contents": { "must_include": [ "Unable to set neutral steering", - "Doesn\u2019t work with PC.", - "Crazy problems in automatic mode; then pedals stopped working", - "Only works with certain games." + "Doesn\u2019t work with PC", + "Crazy problems in automatic mode", + "pedals stopped working", + "Only works with certain games" ] } } @@ -23696,7 +23692,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -23709,7 +23705,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 163 }, @@ -23734,7 +23730,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -23747,7 +23743,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 163 }, @@ -23772,7 +23768,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -23785,7 +23781,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 163 }, @@ -23810,7 +23806,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -23823,7 +23819,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 163 }, @@ -23848,7 +23844,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING__/contact/", + "reference_url": "__SHOPPING__/contact", "program_html": [ { "url": "last", @@ -23861,7 +23857,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 163 }, @@ -24644,14 +24640,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "02/1/2023" + "exact_match": "2/1/2023" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "02/28/2023" + "exact_match": "2/28/2023" } } ], @@ -24687,14 +24683,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "01/29/2023" + "exact_match": "1/29/2023" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "03/15/2023" + "exact_match": "3/15/2023" } } ], @@ -24724,20 +24720,20 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/reports/report_sales/refunded/", + "reference_url": "__SHOPPING_ADMIN__/reports/report_sales/refunded", "program_html": [ { "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "01/1/2023" + "exact_match": "1/1/2023" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "03/31/2023" + "exact_match": "3/31/2023" } } ], @@ -24773,7 +24769,7 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "01/1/2022" + "exact_match": "1/1/2022" } }, { @@ -24816,7 +24812,7 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "01/1/2023" + "exact_match": "1/1/2023" } }, { @@ -24860,14 +24856,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "05/1/2021" + "exact_match": "5/1/2021" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "03/31/2022" + "exact_match": "3/31/2022" } } ], @@ -24904,14 +24900,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "08/5/2022" + "exact_match": "8/5/2022" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "03/1/2023" + "exact_match": "3/1/2023" } } ], @@ -24948,14 +24944,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "07/5/2021" + "exact_match": "7/5/2021" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "05/31/2023" + "exact_match": "5/31/2023" } } ], @@ -24992,14 +24988,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "05/1/2021" + "exact_match": "5/1/2021" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "05/15/2023" + "exact_match": "5/15/2023" } } ], @@ -25036,14 +25032,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "05/1/2022" + "exact_match": "5/1/2022" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "05/31/2023" + "exact_match": "5/31/2023" } } ], @@ -25865,7 +25861,7 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/Hrekires/submissions", + "url": "__REDDIT__/user/AdamCannon/submissions", "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ @@ -26479,7 +26475,7 @@ "locator": "document.querySelector('[name=\"route_to\"').value", "required_contents": { "must_include": [ - "150, Causeway Street", + "TD Garden", "Boston", "Massachusetts" ] @@ -27538,7 +27534,7 @@ "required_contents": { "must_include": [ "Carnegie Hall", - "West 56th Street", + "West 57th Street", "Manhattan", "New York" ] @@ -28273,7 +28269,7 @@ "url": "__SHOPPING_ADMIN__/catalog/product/edit/id/123/", "locator": "document.querySelector('[name=\"product[price]\"').value", "required_contents": { - "exact_match": "47" + "exact_match": "47.00" } } ] From b6e0b22ade85c27b6b98dfea63e0a0f331e6123a Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 00:07:44 -0400 Subject: [PATCH 15/29] change reddit vote related posts to absolute urls --- config_files/test.raw.json | 236 +++++++++++++++++++------------------ 1 file changed, 119 insertions(+), 117 deletions(-) diff --git a/config_files/test.raw.json b/config_files/test.raw.json index 641d8bf..710c798 100644 --- a/config_files/test.raw.json +++ b/config_files/test.raw.json @@ -12646,8 +12646,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/books/new", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/books/124260/adults-reading-to-each-other-out-loud", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -12681,8 +12681,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/diy/new", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/DIY/119019/how-can-i-bring-an-hdmi-cable-from-my-pc-downstairs-to-my-tv", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -12716,8 +12716,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/futurology/new", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/Futurology/119517/openai-ceo-it-s-not-funny-that-i-m-afraid-of-the-ai-we-re", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -12751,8 +12751,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/deeplearning", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/deeplearning/124993/meta-s-llama-weights-leaked-on-torrent-and-the-best-thing", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -12786,8 +12786,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/explainlikeimfive", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/explainlikeimfive/39244/eli5-how-does-pinching-a-ribbon-and-sliding-your-finger", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -15641,7 +15641,9 @@ "eval_types": [ "program_html" ], - "reference_answers": {"exact_match": "N/A"}, + "reference_answers": { + "exact_match": "N/A" + }, "reference_url": "", "program_html": [], "string_note": "System message: We cannot add order history." @@ -25071,8 +25073,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/gadgets/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/gadgets/19459/a-custom-gaming-pc-built-inside-a-vintage-1940s-motorola", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25107,8 +25109,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/history/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/history/84338/the-scientist-who-discovered-sperm-was-so-grossed-out-he", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25116,8 +25118,8 @@ } }, { - "url": "__REDDIT__/f/history/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/history/105990/4-500-year-old-sumerian-temple-dedicated-to-mighty-thunder", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25152,8 +25154,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/books/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/books/81371/the-letters-of-t-s-eliot-to-emily-hale-that-were-kept-sealed", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25161,8 +25163,8 @@ } }, { - "url": "__REDDIT__/f/books/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/books/59421/friendly-reminder-bookshop-org-exists", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25170,8 +25172,8 @@ } }, { - "url": "__REDDIT__/f/books/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/books/59447/appalachian-prison-book-project-seeks-notebook-donations-the", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25206,8 +25208,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/movies/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/movies/86174/who-will-win-the-oscar-for-actress-in-a-supporting-role", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25215,8 +25217,8 @@ } }, { - "url": "__REDDIT__/f/movies/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/movies/86029/who-will-win-the-oscar-for-film-editing", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25224,8 +25226,8 @@ } }, { - "url": "__REDDIT__/f/movies/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/movies/86055/cindy-williams-dies-laverne-amp-shirley-star-who-appeared-in", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25233,8 +25235,8 @@ } }, { - "url": "__REDDIT__/f/movies/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/movies/42682/michelle-yeoh-to-receive-palm-springs-film-festival-s", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25269,8 +25271,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/f/technology/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/technology/48670/brain-cancer-vaccine-succeeds-at-prolonging-survival-in", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25278,8 +25280,8 @@ } }, { - "url": "__REDDIT__/f/technology/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/technology/134696/india-cuts-internet-for-27-million-people-amid-search-for", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25287,8 +25289,8 @@ } }, { - "url": "__REDDIT__/f/technology/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/technology/48785/us-judge-orders-amazon-to-cease-and-desist-anti-union", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25296,8 +25298,8 @@ } }, { - "url": "__REDDIT__/f/technology/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/technology/70354/activision-s-boston-studio-workers-announce-unionization", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25305,8 +25307,8 @@ } }, { - "url": "__REDDIT__/f/technology/top?t=all", - "locator": "document.querySelectorAll('div.submission__vote')[4].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/technology/70233/social-media-influencers-are-charged-with-feeding-followers", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25341,8 +25343,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/ThetaGang_wsb/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/wallstreetbets/29478/how-will-airbnb-close-following-their-earnings-report-on", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25350,8 +25352,8 @@ } }, { - "url": "__REDDIT__/user/ThetaGang_wsb/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/wallstreetbets/29458/how-much-will-the-federal-reserve-raise-interest-rates-in", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25386,8 +25388,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/CameronKelsey/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/EarthPorn/98332/my-favorite-place-on-the-planet-henry-s-fork-of-the-snake", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25395,8 +25397,8 @@ } }, { - "url": "__REDDIT__/user/CameronKelsey/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/EarthPorn/98297/2-years-later-this-is-still-one-of-the-most-incredible", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25404,8 +25406,8 @@ } }, { - "url": "__REDDIT__/user/CameronKelsey/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/EarthPorn/98256/i-can-t-wait-for-all-this-green-to-start-coming-back-little", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25440,8 +25442,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/119742/hi-i-m-vienne-a-doctoral-student-at-the-university-of-bath-i", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25449,8 +25451,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/119719/hello-reddit-i-m-nazia-mehrban-a-lecturer-in-biotechnology", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25458,8 +25460,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/119714/i-m-ellie-jarvis-she-her-a-2nd-year-phd-student-in-the", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25467,8 +25469,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/55155/hi-i-m-dr-lucy-maddox-from-bath-university-uk-i-m-a-clinical", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25476,8 +25478,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[4].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/55142/we-re-sadeka-nujhat-hannah-leese-and-sandhya-moise-from-the", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25485,8 +25487,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[5].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/34032/we-re-sandhya-moise-david-phillips-and-chan-lee-from-the", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25494,8 +25496,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[6].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/13175/hi-i-m-kit-yates-i-m-a-mathematical-biologist-at-the", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25503,8 +25505,8 @@ } }, { - "url": "__REDDIT__/user/UniversityofBath/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/IAmA/13170/hello-i-m-dr-sara-fontani-from-the-university-of", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25539,8 +25541,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/Don_Gato1/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/nyc/44650/fox-news-hosts-cast-new-york-as-crime-ridden-and-chaotic", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25606,8 +25608,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129816/gov-whitmer-signs-bills-to-repeal-right-to-work-restore", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25615,8 +25617,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129808/disney-world-deal-with-union-will-raise-minimum-wage-to-18", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25624,8 +25626,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129794/judge-halts-wyoming-abortion-ban-days-after-it-took-effect", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25633,8 +25635,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129783/don-t-say-gay-lawmaker-pleads-guilty-to-covid-relief-fraud", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25642,8 +25644,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[4].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129594/arizona-gov-katie-hobbs-refuses-to-proceed-with-execution", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25651,8 +25653,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[5].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129508/tennessee-governor-oks-bill-to-cut-nashville-council-in-half", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25660,8 +25662,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43839/philadelphia-da-larry-krasner-impeached-by-pa-house", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25669,8 +25671,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[8].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43781/crypto-giant-ftx-to-file-for-bankruptcy-ceo-sam-bankman", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25678,8 +25680,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[9].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43572/sec-doj-investigating-crypto-platform-ftx", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25687,8 +25689,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[10].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43558/kansas-gov-laura-kelly-wins-re-election-defeating-gop", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-upvoted" @@ -25723,8 +25725,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/RickyDontLoseThat/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/massachusetts/84954/the-last-of-lincoln", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25789,8 +25791,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/PatientBuilder499/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/videos/115139/hundreds-of-civilian-turkish-volunteers-waiting-to-be-sent", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25825,8 +25827,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/sirbarani/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/sports/48303/iran-football-legend-daei-will-not-attend-world-cup-amid", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25861,8 +25863,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/AdamCannon/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/UpliftingNews/16087/same-sex-marriage-is-now-legal-in-all-of-mexico-s-states", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25897,8 +25899,8 @@ "reference_url": "", "program_html": [ { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[0].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129816/gov-whitmer-signs-bills-to-repeal-right-to-work-restore", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25906,8 +25908,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[1].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129808/disney-world-deal-with-union-will-raise-minimum-wage-to-18", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25915,8 +25917,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[2].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129794/judge-halts-wyoming-abortion-ban-days-after-it-took-effect", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25924,8 +25926,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[3].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129783/don-t-say-gay-lawmaker-pleads-guilty-to-covid-relief-fraud", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25933,8 +25935,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[4].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129594/arizona-gov-katie-hobbs-refuses-to-proceed-with-execution", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25942,8 +25944,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[5].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/129508/tennessee-governor-oks-bill-to-cut-nashville-council-in-half", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25951,8 +25953,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[7].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43839/philadelphia-da-larry-krasner-impeached-by-pa-house", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25960,8 +25962,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[8].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43781/crypto-giant-ftx-to-file-for-bankruptcy-ceo-sam-bankman", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25969,8 +25971,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[9].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43572/sec-doj-investigating-crypto-platform-ftx", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" @@ -25978,8 +25980,8 @@ } }, { - "url": "__REDDIT__/user/Hrekires/submissions", - "locator": "document.querySelectorAll('div.submission__vote')[10].querySelector('form').getAttribute('class')", + "url": "__REDDIT__/f/news/43558/kansas-gov-laura-kelly-wins-re-election-defeating-gop", + "locator": "document.querySelector('div.submission__vote').querySelector('form').getAttribute('class')", "required_contents": { "must_include": [ "vote vote--user-downvoted" From f8d636aec3969c9113a8205227d8c6eb103a81b9 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 15:21:04 -0400 Subject: [PATCH 16/29] update URL matching, fix typos --- config_files/test.raw.json | 264 +++++++++++++++++++------------------ 1 file changed, 133 insertions(+), 131 deletions(-) diff --git a/config_files/test.raw.json b/config_files/test.raw.json index 710c798..735919a 100644 --- a/config_files/test.raw.json +++ b/config_files/test.raw.json @@ -1077,7 +1077,7 @@ "reference_answers": { "must_include": [ "DoubleTree by Hilton Hotel Pittsburgh Airport", - "2.0km" + "1.4km" ] }, "reference_url": "", @@ -1395,7 +1395,7 @@ "must_include": [ "hollister", "Joust Bag", - "Antonia Race Tank" + "Antonia Racer Tank" ] }, "reference_url": "", @@ -1425,7 +1425,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/dashboard/todos", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 303 }, @@ -1449,7 +1449,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/a11yproject/a11yproject.com/-/issues/?sort=created_asc&state=opened", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 300 }, @@ -1473,7 +1473,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/primer/design/-/issues/?sort=created_date&state=opened", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 300 }, @@ -3288,7 +3288,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/byteblaze/a11y-syntax-highlighting/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=help%20wanted", + "reference_url": "__GITLAB__/byteblaze/a11y-syntax-highlighting/-/issues/?label_name%5B%5D=help%20wanted", "program_html": [], "url_note": "GOLD in PRED" }, @@ -3315,7 +3315,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/kkroening/ffmpeg-python/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=question", + "reference_url": "__GITLAB__/kkroening/ffmpeg-python/-/issues/?label_name%5B%5D=question", "program_html": [], "url_note": "GOLD in PRED" }, @@ -3342,7 +3342,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/keycloak/keycloak/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=flaky-test", + "reference_url": "__GITLAB__/keycloak/keycloak/-/issues/?label_name%5B%5D=flaky-test", "program_html": [], "url_note": "GOLD in PRED" }, @@ -3369,7 +3369,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/OpenAPITools/openapi-generator/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=OpenAPI%20Generator%20CLI", + "reference_url": "__GITLAB__/OpenAPITools/openapi-generator/-/issues/?label_name%5B%5D=OpenAPI%20Generator%20CLI", "program_html": [], "url_note": "GOLD in PRED" }, @@ -3396,7 +3396,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/umano/AndroidSlidingUpPanel/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=BUG", + "reference_url": "__GITLAB__/umano/AndroidSlidingUpPanel/-/issues/?label_name%5B%5D=BUG", "program_html": [], "url_note": "GOLD in PRED" }, @@ -3711,12 +3711,12 @@ "string_match" ], "reference_answers": { - "exact_match": "Teofila" + "exact_match": "N/A" }, "reference_url": "", "program_html": [], - "string_note": "", - "reference_answer_raw_annotation": "Teofila" + "string_note": "There is no negative review for Chloe tank", + "reference_answer_raw_annotation": "" }, "intent_template_id": 245 }, @@ -5026,7 +5026,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/dashboard/merge_requests?assignee_username=byteblaze", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 290 }, @@ -5076,7 +5076,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/heiying-game-card-case-for-nintendo-switch-switch-oled-game-card-or-micro-sd-memory-cards-portable-switch-game-memory-card-storage-with-24-game-card-slots-and-24-micro-sd-card-slots-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 171 }, @@ -5102,7 +5102,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/game-card-holder-storage-case-for-nintendo-switch-games-or-ps-vita-game-case-or-sd-memory-cards-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 171 }, @@ -5128,7 +5128,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/heiying-game-card-case-for-nintendo-switch-switch-oled-game-card-or-micro-sd-memory-cards-portable-switch-game-memory-card-storage-with-24-game-card-slots-and-24-micro-sd-card-slots-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 171 }, @@ -5154,7 +5154,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/heiying-game-card-case-for-nintendo-switch-switch-oled-game-card-or-micro-sd-memory-cards-portable-switch-game-memory-card-storage-with-24-game-card-slots-and-24-micro-sd-card-slots-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 171 }, @@ -5180,7 +5180,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/game-card-holder-storage-case-for-nintendo-switch-games-or-ps-vita-game-case-or-sd-memory-cards-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 171 }, @@ -5542,7 +5542,8 @@ "reference_url": "__GITLAB__/byteblaze/empathy-prompts/-/issues/8", "program_html": [], "reference_answer_raw_annotation": "Not closed", - "string_note": "" + "string_note": "", + "url_note": "GOLD in PRED" }, "intent_template_id": 310 }, @@ -7463,23 +7464,21 @@ "geolocation": null, "intent_template": "Get the order number of my most recent {{status}} order ", "instantiation_dict": { - "status": "" + "status": "under delivery" }, - "intent": "Get the order number of my most recent order ", + "intent": "Get the order number of my most recent under delivery order ", "require_reset": false, "eval": { "eval_types": [ "string_match" ], "reference_answers": { - "must_include": [ - "136" - ] + "exact_match": "N/A" }, "reference_url": "", "program_html": [], "string_note": "", - "reference_answer_raw_annotation": "000000136" + "reference_answer_raw_annotation": "There is no under delivery order" }, "intent_template_id": 213 }, @@ -7578,7 +7577,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/astro-gaming-a50-wireless-headset-base-station-gen-4-compatible-with-ps5-ps4-pc-mac-black-silver.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 138 }, @@ -7604,7 +7603,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/kellogg-s-special-k-protein-meal-bars-chocolate-caramel-12-7oz-6-count.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 138 }, @@ -7630,7 +7629,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/women-cross-flower-beachwear-tankini-bandeau-bandage-bikini-set-push-up-swimwear-bathing-suit-two-pieces-swimsuits.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 138 }, @@ -7656,7 +7655,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/professional-medi-spa-scar-stretch-mark-reduction-system.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 138 }, @@ -7682,7 +7681,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/lynx-battery-12v-200ah-lithium-iron-phosphate-lifepo4-prismatic-deep-cell-battery-set-of-4-3-2v-cells-with-3-bus-bars-and-8-lug-nuts-for-rv-solar-marine-off-grid-applications.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 138 }, @@ -8168,7 +8167,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/explore", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 325 }, @@ -8221,7 +8220,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/video-games.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 211 }, @@ -8247,7 +8246,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/electronics/headphones.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 211 }, @@ -8273,7 +8272,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/clothing-shoes-jewelry/men/shoes.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 211 }, @@ -8299,7 +8298,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/clothing-shoes-jewelry/women/clothing.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 211 }, @@ -8325,7 +8324,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/office-products/office-furniture-lighting/cabinets-racks-shelves.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 211 }, @@ -8485,7 +8484,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/clothing-shoes-jewelry/women/shoes.html?price=0-25", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 139 }, @@ -8512,7 +8511,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/clothing-shoes-jewelry/men/shoes.html?price=0-30", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 139 }, @@ -8539,7 +8538,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/beauty-personal-care/makeup/makeup-remover.html?price=0-46.99", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 139 }, @@ -8566,7 +8565,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/beauty-personal-care/oral-care/children-s-dental-care.html?price=0-78", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 139 }, @@ -8593,7 +8592,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/home-kitchen/furniture/accent-furniture.html?price=0-199", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 139 }, @@ -8619,7 +8618,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=usb+wifi", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 212 }, @@ -8645,7 +8644,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=xbox", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 212 }, @@ -8671,7 +8670,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=switch+accessories", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 212 }, @@ -8697,7 +8696,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=iphone+13", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 212 }, @@ -8723,7 +8722,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=green+tea+bag+for+weight+loss", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 212 }, @@ -8902,7 +8901,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/microsoft-xbox-controller-carbon-black-for-series-x-series-s-xbox-one-windows-10-android-ios-bundled-with-dual-port-charging-dock-xbox-controller-skin-voucher-premgear-cloth.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 210 }, @@ -8929,7 +8928,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/onlyeasy-over-the-door-shoe-storage-organizer-hanging-shoe-rack-holder-with-24-large-fabric-pockets-22-1-x-61-4-herringbone-grey-mxrodsb1p.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 207 }, @@ -8956,7 +8955,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/game-card-holder-storage-case-for-nintendo-switch-games-or-ps-vita-game-case-or-sd-memory-cards-black.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 207 }, @@ -8983,7 +8982,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/external-hard-drive-2tb-ultra-thin-external-hard-drive-2000gb-ultra-high-speed-portable-3-1-type-c-storage-drive-compatible-with-pc-laptop-and-mac-2tb-a1.html", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 207 }, @@ -9341,7 +9340,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/sales/order/view/order_id/180/", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 180 }, @@ -9367,7 +9366,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/sales/order/view/order_id/170/", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 180 }, @@ -9393,7 +9392,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/sales/order/view/order_id/189/", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 180 }, @@ -9414,12 +9413,12 @@ "require_reset": false, "eval": { "eval_types": [ - "url_match" + "string_match" ], - "reference_answers": null, - "reference_url": "NA", + "reference_answers": {"exact_match": "N/A"}, + "reference_url": "", "program_html": [], - "url_note": "EXACT" + "string_note": "there is no order in processing" }, "intent_template_id": 180 }, @@ -9440,12 +9439,12 @@ "require_reset": false, "eval": { "eval_types": [ - "url_match" + "string_match" ], - "reference_answers": null, - "reference_url": "NA", + "reference_answers": {"exact_match": "N/A"}, + "reference_url": "", "program_html": [], - "url_note": "EXACT" + "string_note": "there is no order in processing" }, "intent_template_id": 180 }, @@ -10129,7 +10128,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/index/?product_list_order=price&q=chairs&product_list_dir=asc", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 208 }, @@ -10156,7 +10155,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/index/?q=mouth%20night%20guard%20&product_list_order=price", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 208 }, @@ -10183,7 +10182,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/?q=Canon+photo+printer", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 208 }, @@ -10210,7 +10209,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/index/?q=%20iphone%2012%20phone%20case&product_list_order=name", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 208 }, @@ -10237,7 +10236,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/catalogsearch/result/index/?product_list_order=price&q=%20iphone%2012%20phone%20case", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 208 }, @@ -10575,7 +10574,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/a11yproject/a11yproject.com/-/issues/?label_name%5B%5D=bug", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 299 }, @@ -10601,7 +10600,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/primer/design/-/issues/?label_name%5B%5D=type%3A%20bug%20%F0%9F%90%9E", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 299 }, @@ -10627,7 +10626,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/root/metaseq/-/issues/?label_name%5B%5D=enhancement", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 299 }, @@ -10651,9 +10650,9 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/root/metaseq/-/issues/?search=OPT&sort=priority_desc&state=opened&label_name%5B%5D=question&first_page_size=20", + "reference_url": "__GITLAB__/root/metaseq/-/issues/?search=OPT&label_name%5B%5D=question", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 299 }, @@ -10677,9 +10676,9 @@ "url_match" ], "reference_answers": null, - "reference_url": "__GITLAB__/root/metaseq/-/issues/?sort=priority_desc&state=opened&label_name%5B%5D=None&first_page_size=20", + "reference_url": "__GITLAB__/root/metaseq/-/issues/?label_name%5B%5D=None", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 299 }, @@ -10921,7 +10920,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/video-games/playstation-4/accessories.html?product_list_order=price", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 137 }, @@ -10948,7 +10947,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/health-household/diet-sports-nutrition/nutrition-bars-drinks.html?product_list_order=price", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 137 }, @@ -10975,7 +10974,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/clothing-shoes-jewelry/sport-specific-clothing/competitive-swimwear.html?product_list_order=price", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 137 }, @@ -11002,7 +11001,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/home-kitchen/furniture/living-room-furniture.html?product_list_order=price&product_list_dir=desc", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 137 }, @@ -11029,7 +11028,7 @@ "reference_answers": null, "reference_url": "__SHOPPING__/home-kitchen/bedding/kids-bedding.html?product_list_dir=desc", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 137 }, @@ -11104,7 +11103,7 @@ "reference_answers": null, "reference_url": "__GITLAB__/dashboard/merge_requests?reviewer_username=byteblaze", "program_html": [], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 291 }, @@ -11731,7 +11730,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__MAP__search?query=restaurants%20near%20CMU%20ArtPark%20Lab", + "reference_url": "__MAP__/search?query=restaurants%20near%20CMU%20ArtPark%20Lab", "program_html": [], "url_note": "GOLD in PRED" }, @@ -11758,7 +11757,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__MAP__search?query=parking%20near%20carnegie%20mellon%20university", + "reference_url": "__MAP__/search?query=parking%20near%20carnegie%20mellon%20university", "program_html": [], "url_note": "GOLD in PRED" }, @@ -11785,7 +11784,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__MAP__search?query=hotels%20near%20carnegie%20mellon%20university", + "reference_url": "__MAP__/search?query=hotels%20near%20carnegie%20mellon%20university", "program_html": [], "url_note": "GOLD in PRED" }, @@ -11812,7 +11811,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__MAP__search?query=bars%20near%20Carnegie%20Music%20Hall", + "reference_url": "__MAP__/search?query=bars%20near%20Carnegie%20Music%20Hall", "program_html": [], "url_note": "GOLD in PRED" }, @@ -11839,7 +11838,7 @@ "url_match" ], "reference_answers": null, - "reference_url": "__MAP__search?query=hotels%20near%20Carnegie%20Music%20Hall", + "reference_url": "__MAP__/search?query=hotels%20near%20Carnegie%20Music%20Hall", "program_html": [], "url_note": "GOLD in PRED" }, @@ -11992,7 +11991,7 @@ ], "reference_answers": { "must_include": [ - "3" + "65 |OR| 3" ] }, "reference_url": "", @@ -13135,7 +13134,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.cover-status').outerText", + "locator": "document.querySelector('.cover-status').lastChild.textContent", "required_contents": { "exact_match": "Busy" } @@ -13168,7 +13167,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.cover-status').outerText", + "locator": "document.querySelector('.cover-status').lastChild.textContent", "required_contents": { "exact_match": "Enjoying life" } @@ -13201,7 +13200,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.cover-status').outerText", + "locator": "document.querySelector('.cover-status').lastChild.textContent", "required_contents": { "exact_match": "Playing Badminton" } @@ -13234,7 +13233,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.cover-status').outerText", + "locator": "document.querySelector('.cover-status').lastChild.textContent", "required_contents": { "exact_match": "Resting due to leg injury" } @@ -13267,7 +13266,7 @@ "program_html": [ { "url": "__GITLAB__/byteblaze", - "locator": "document.querySelector('.cover-status').outerText", + "locator": "document.querySelector('.cover-status').lastChild.textContent", "required_contents": { "exact_match": "Out of Office" } @@ -14712,12 +14711,13 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/antonia-racer-tank.html", + "url": "__SHOPPING_ADMIN__/../antonia-racer-tank.html", "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ "This is in regular rotation at the gym", - "Its colorful and looks kinda cute under my exercise tanks" + "Its colorful and looks kinda cute under my exercise tanks", + "it's very stylish for yoga or something else low impact" ] } } @@ -15639,7 +15639,7 @@ "require_reset": false, "eval": { "eval_types": [ - "program_html" + "string_match" ], "reference_answers": { "exact_match": "N/A" @@ -17538,7 +17538,7 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/bella-tank.html", + "url": "__SHOPPING_ADMIN__/../bella-tank.html", "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ @@ -17576,14 +17576,15 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/selene-yoga-hoodie.html", - "locator": "document.querySelector('.product.info.detailed').outerText", + "url": "__SHOPPING_ADMIN__/../selene-yoga-hoodie.html", + "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ "I was super cold and it did the job.", "The sleeves are definitely thicker than you realize, which is a good thing", "really quite substantial", - "m planning on buying another one of these in another color. the best hoodie ive ever owned." + "planning on buying another one of these in another color", + "the best hoodie ive ever owned" ] } } @@ -17614,15 +17615,16 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/radiant-tee.html", - "locator": "document.querySelector('.product.info.detailed').outerText", + "url": "__SHOPPING_ADMIN__/../radiant-tee.html", + "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ "What I rally love here is that it does the job of keeping me cool and dry", - "I'm a big guy and sweat A LOT! Even after a day of gulf, I'm still dry and comfortable", - "What a versatile shirt!", - "Not only does it feel very soft compared to my old worn out polos, but it also does the job promised.", - "I like going out after my game for drinks so I look good then too and don't need to change into something fresh." + "I'm a big guy and sweat A LOT", + "Even after a day of gulf, I'm still dry and comfortable", + "What a versatile shirt", + "Not only does it feel very soft compared to my old worn out polos, but it also does the job promised", + "I like going out after my game for drinks so I look good then too and don't need to change into something fresh" ] } } @@ -17653,11 +17655,11 @@ "reference_url": "", "program_html": [ { - "url": "__SHOPPING_ADMIN__/affirm-water-bottle.html", - "locator": "document.querySelector('.product.info.detailed').outerText", + "url": "__SHOPPING_ADMIN__/../affirm-water-bottle.html", + "locator": "document.querySelector('.data.item.content').outerText + (document.querySelector('.product.attribute.overview [itemprop=\"description\"]')?.outerText || '')", "required_contents": { "must_include": [ - "Wide mouth opening makes it easy to clean!" + "Wide mouth opening makes it easy to clean" ] } } @@ -24306,7 +24308,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote/new/", + "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote", "program_html": [ { "url": "last", @@ -24346,7 +24348,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 258 }, @@ -24372,7 +24374,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote/new/", + "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote", "program_html": [ { "url": "last", @@ -24412,7 +24414,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 258 }, @@ -24438,7 +24440,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote/new/", + "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote", "program_html": [ { "url": "last", @@ -24478,7 +24480,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 258 }, @@ -24504,7 +24506,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote/new/", + "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote", "program_html": [ { "url": "last", @@ -24544,7 +24546,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 258 }, @@ -24570,7 +24572,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote/new/", + "reference_url": "__SHOPPING_ADMIN__/sales_rule/promo_quote", "program_html": [ { "url": "last", @@ -24610,7 +24612,7 @@ } } ], - "url_note": "EXACT" + "url_note": "GOLD in PRED" }, "intent_template_id": 258 }, @@ -24642,14 +24644,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "2/1/2023" + "exact_match": "2/1/23" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "2/28/2023" + "exact_match": "2/28/23" } } ], @@ -24685,14 +24687,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "1/29/2023" + "exact_match": "1/29/23" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "3/15/2023" + "exact_match": "3/15/23" } } ], @@ -24728,14 +24730,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "1/1/2023" + "exact_match": "1/1/23" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "3/31/2023" + "exact_match": "3/31/23" } } ], @@ -24902,14 +24904,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "8/5/2022" + "exact_match": "8/5/22" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "3/1/2023" + "exact_match": "3/1/23" } } ], @@ -24946,14 +24948,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "7/5/2021" + "exact_match": "7/5/21" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "5/31/2023" + "exact_match": "5/31/23" } } ], @@ -24990,14 +24992,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "5/1/2021" + "exact_match": "5/1/21" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "5/15/2023" + "exact_match": "5/15/23" } } ], @@ -25034,14 +25036,14 @@ "url": "last", "locator": "document.querySelector('[id=\"sales_report_from\"').value", "required_contents": { - "exact_match": "5/1/2022" + "exact_match": "5/1/22" } }, { "url": "last", "locator": "document.querySelector('[id=\"sales_report_to\"').value", "required_contents": { - "exact_match": "5/31/2023" + "exact_match": "5/31/23" } } ], From b4c917dd45825ff33019dd345d95ee5e1bb624c7 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 15:27:07 -0400 Subject: [PATCH 17/29] update must_include tokenization condition; upate url match --- evaluation_harness/evaluators.py | 99 +++++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 30c3a5c..793e8c7 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -1,5 +1,7 @@ """base class for evaluation""" # answer string match +import collections +import html import importlib import json import time @@ -77,6 +79,7 @@ class StringEvaluator(Evaluator): @staticmethod @beartype def clean_answer(answer: str) -> str: + answer = answer.strip() if answer.startswith("'") and answer.endswith("'"): answer = answer[1:-1] elif answer.startswith('"') and answer.endswith('"'): @@ -93,12 +96,16 @@ def exact_match(ref: str, pred: str) -> float: @staticmethod @beartype - def must_include(ref: str, pred: str) -> float: + def must_include(ref: str, pred: str, tokenize=False) -> float: clean_ref = StringEvaluator.clean_answer(ref) clean_pred = StringEvaluator.clean_answer(pred) # tokenize the answer if the ref is a single word # prevent false positive (e.g, 0) - if len(word_tokenize(clean_ref)) == 1: + if ( + tokenize + and len(clean_ref) == 1 + and len(word_tokenize(clean_ref)) == 1 + ): tok_pred = word_tokenize(clean_pred) return float(clean_ref in tok_pred) else: @@ -130,7 +137,11 @@ def __call__( case "must_include": assert isinstance(value, list) for must_value in value: - score *= self.must_include(ref=must_value, pred=pred) + score *= self.must_include( + ref=must_value, + pred=pred, + tokenize=(len(value) == 1), + ) case "fuzzy_match": intent = configs["intent"] assert isinstance(value, list) @@ -165,7 +176,7 @@ def __call__( class URLExactEvaluator(Evaluator): - """Check whether the URL is exactly the same as of the reference URLs""" + """Check URL matching""" @beartype def __call__( @@ -180,27 +191,60 @@ def __call__( def clean_url(url: str) -> str: url = str(url) - if url.endswith("/"): - url = url[:-1] + url = url.rstrip("/") return url + def parse_url(url: str) -> tuple[str, dict[str, list[str]]]: + """Parse a URL into its base, path, and query components.""" + parsed_url = urllib.parse.urlparse(url) + base_path = parsed_url.netloc + parsed_url.path + query = urllib.parse.parse_qs(parsed_url.query) + return base_path, query + + def parse_urls( + urls: list[str], + ) -> tuple[list[str], list[str], dict[str, set[str]]]: + """Parse a list of URLs.""" + base_paths = [] + queries = collections.defaultdict(set) + for url in urls: + base_path, query = parse_url(url) + base_paths.append(base_path) + for k, v in query.items(): + queries[k].update(v) + return base_paths, queries + pred = clean_url(page.url) ref_urls = configs["eval"]["reference_url"].split(" |OR| ") ref_urls = [clean_url(url) for url in ref_urls] - matching_rule = configs["eval"].get("url_note", "EXACT") - if matching_rule == "EXACT": - if pred in ref_urls: - return 1.0 - else: - return 0.0 - elif matching_rule == "GOLD in PRED": - if any([ref in pred for ref in ref_urls]): - return 1.0 - else: - return 0.0 + matching_rule = configs["eval"].get("url_note", "GOLD in PRED") + if matching_rule == "GOLD in PRED": + ref_base_paths, ref_queries = parse_urls(ref_urls) + pred_base_paths, pred_query = parse_url(pred) + + base_score = float( + any( + [ + ref_base_path in pred_base_paths + for ref_base_path in ref_base_paths + ] + ) + ) + query_score = 1.0 + for k, possible_values in ref_queries.items(): + query_score *= float( + any( + possible_ref_value in pred_query.get(k, []) + for possible_ref_value in possible_values + ) + ) + score = base_score * query_score + else: raise ValueError(f"Unknown matching rule: {matching_rule}") + return score + class HTMLContentExactEvaluator(Evaluator): """Check whether the contents appear in the page""" @@ -241,10 +285,9 @@ def __call__( "[...document." ): try: - selected_element = page.evaluate(f"() => {locator}") + selected_element = str(page.evaluate(f"() => {locator}")) if not selected_element: selected_element = "" - selected_element = str(selected_element) except Exception: # the page is wrong, return empty selected_element = "" @@ -256,29 +299,34 @@ def __call__( else: raise ValueError(f"Unknown locator: {locator}") + selected_element = html.unescape(selected_element) + if "exact_match" in target["required_contents"]: required_contents = target["required_contents"]["exact_match"] - score *= StringEvaluator.exact_match( + cur_score = StringEvaluator.exact_match( ref=required_contents, pred=selected_element ) + score *= float(cur_score) elif "must_include" in target["required_contents"]: required_contents = target["required_contents"]["must_include"] assert isinstance(required_contents, list) for content in required_contents: content_or = content.split(" |OR| ") - score *= any( + cur_score = any( [ StringEvaluator.must_include( - ref=content, pred=selected_element + ref=content, + pred=selected_element, + tokenize=False, ) for content in content_or ] ) + score *= float(cur_score) else: raise ValueError( f"Unknown required_contents: {target['required_contents'].keys()}" ) - return score @@ -358,15 +406,14 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession, + page: Page | PseudoPage | None = None, + client: CDPSession | None = None, ) -> float: score = 1.0 for evaluator in self.evaluators: cur_score = evaluator(trajectory, config_file, page, client) score *= cur_score - return score From a7c475b575041eba314a54208ecac77b57fa5e5d Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 15:29:16 -0400 Subject: [PATCH 18/29] remove unused evaluators --- evaluation_harness/evaluators.py | 90 -------------------------------- 1 file changed, 90 deletions(-) diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 793e8c7..5c80238 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -152,29 +152,6 @@ def __call__( return score -class StringSoftEvaluator(Evaluator): - """Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer""" - - @beartype - def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage | None = None, - client: CDPSession | None = None, - ) -> float: - with open(config_file, "r") as f: - configs = json.load(f) - - last_action = self.get_last_action(trajectory) - pred = last_action["answer"] - ref = configs["eval"]["reference_answers"] - # rouge - m = evaluate.load("rouge") - rouge = m.compute(predictions=[pred], references=[ref]) - return float(rouge["rouge1"]) - - class URLExactEvaluator(Evaluator): """Check URL matching""" @@ -330,73 +307,6 @@ def __call__( return score -###### -# soft matches. -# mainly for partial scores -# !!under development!! -# TODO[shuyanzh] -###### - - -class EvaluatorPartial(Evaluator): - def __init__(self) -> None: - raise NotImplementedError - - def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession, - ) -> float: - raise NotImplementedError - - -class URLSoftEvaluator(EvaluatorPartial): - """Parse the URL and compare the domain and parameters""" - - def __call__( - self, - trajectory: Trajectory, - config_file: Path | str, - page: Page | PseudoPage, - client: CDPSession, - ) -> float: - with open(config_file, "r") as f: - configs = json.load(f) - - last_state = self.get_last_state(trajectory) - pred = last_state["info"]["page"].url - ref = configs["eval"]["reference_url"] - - # parse url to get domain, parameters, etc. - parsed_pred = urllib.parse.urlparse(pred) - parsed_ref = urllib.parse.urlparse(ref) - - # check domain - domain_match = int(parsed_pred.netloc == parsed_ref.netloc) - - def get_param_set(query: dict[str, list[str]]) -> set[str]: - param_set = set() - for k, v in query.items(): - for vv in v: - param_set.add(f"{k}={vv}") - return param_set - - # calculate parameter f1 - param_set_ref = get_param_set(urllib.parse.parse_qs(parsed_ref.query)) - param_set_pred = get_param_set( - urllib.parse.parse_qs(parsed_pred.query) - ) - r = len(param_set_ref & param_set_pred) / len(param_set_ref) - p = len(param_set_ref & param_set_pred) / len(param_set_pred) - f1 = 2 * r * p / (r + p) if r + p > 0 else 1.0 - - score = domain_match * f1 # domain match is a must - - return score - - class EvaluatorComb: def __init__(self, evaluators: list[Evaluator]) -> None: self.evaluators = evaluators From 50e2c430b46e0a0fffdf12027e7ad684e9e4d248 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 15:42:29 -0400 Subject: [PATCH 19/29] remove exact from evalutor names --- evaluation_harness/evaluators.py | 10 +++---- ...exact_evaluators.py => test_evaluators.py} | 28 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) rename tests/test_evaluation_harness/{test_exact_evaluators.py => test_evaluators.py} (94%) diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 5c80238..e210fa8 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -152,7 +152,7 @@ def __call__( return score -class URLExactEvaluator(Evaluator): +class URLEvaluator(Evaluator): """Check URL matching""" @beartype @@ -223,7 +223,7 @@ def parse_urls( return score -class HTMLContentExactEvaluator(Evaluator): +class HTMLContentEvaluator(Evaluator): """Check whether the contents appear in the page""" @beartype @@ -334,15 +334,15 @@ def evaluator_router(config_file: Path | str) -> EvaluatorComb: configs = json.load(f) eval_types = configs["eval"]["eval_types"] - evaluators: list[Evaluator | EvaluatorPartial] = [] + evaluators: list[Evaluator] = [] for eval_type in eval_types: match eval_type: case "string_match": evaluators.append(StringEvaluator()) case "url_match": - evaluators.append(URLExactEvaluator()) + evaluators.append(URLEvaluator()) case "program_html": - evaluators.append(HTMLContentExactEvaluator()) + evaluators.append(HTMLContentEvaluator()) case _: raise ValueError(f"eval_type {eval_type} is not supported") diff --git a/tests/test_evaluation_harness/test_exact_evaluators.py b/tests/test_evaluation_harness/test_evaluators.py similarity index 94% rename from tests/test_evaluation_harness/test_exact_evaluators.py rename to tests/test_evaluation_harness/test_evaluators.py index 9715ccf..bef0db6 100644 --- a/tests/test_evaluation_harness/test_exact_evaluators.py +++ b/tests/test_evaluation_harness/test_evaluators.py @@ -12,9 +12,9 @@ from browser_env import ActionTypes, ScriptBrowserEnv from browser_env.env_config import * from evaluation_harness import ( - HTMLContentExactEvaluator, + HTMLContentEvaluator, StringEvaluator, - URLExactEvaluator, + URLEvaluator, ) from evaluation_harness.evaluators import EvaluatorComb @@ -99,7 +99,7 @@ def test_url_exact_match_success(script_browser_env: ScriptBrowserEnv) -> None: trajectory = tf_roll_out(agent, env, config_file) - evalutor = URLExactEvaluator() + evalutor = URLEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -119,7 +119,7 @@ def test_url_exact_match_fail(script_browser_env: ScriptBrowserEnv) -> None: trajectory = tf_roll_out(agent, env, config_file) - evalutor = URLExactEvaluator() + evalutor = URLEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -143,7 +143,7 @@ def test_html_content_match_success( trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -164,7 +164,7 @@ def test_html_content_match_fail(script_browser_env: ScriptBrowserEnv) -> None: trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -189,7 +189,7 @@ def test_html_content_element_match_success( trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -214,7 +214,7 @@ def test_html_content_element_match_fail( trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -239,9 +239,7 @@ def test_html_content_url_comb_success( trajectory = tf_roll_out(agent, env, config_file) - evaluators = EvaluatorComb( - [URLExactEvaluator(), HTMLContentExactEvaluator()] - ) + evaluators = EvaluatorComb([URLEvaluator(), HTMLContentEvaluator()]) score = evaluators( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -264,7 +262,7 @@ def test_func_success( env = script_browser_env trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -287,7 +285,7 @@ def test_func_fail( env = script_browser_env trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -308,7 +306,7 @@ def test_func_url_func_last_success( env = script_browser_env trajectory = tf_roll_out(agent, env, config_file) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, config_file, env.page, env.get_page_client(env.page) ) @@ -341,7 +339,7 @@ def test_func_url_func_page_success( env = script_browser_env trajectory = tf_roll_out(agent, env, tmp_config) - evalutor = HTMLContentExactEvaluator() + evalutor = HTMLContentEvaluator() score = evalutor( trajectory, tmp_config, env.page, env.get_page_client(env.page) ) From db063c77425cb703aacd4654d49d5d3b94bcab7a Mon Sep 17 00:00:00 2001 From: alexisxy Date: Tue, 26 Sep 2023 15:42:54 -0400 Subject: [PATCH 20/29] update test example due to html escape --- tests/test_evaluation_harness/configs/func_url_func_1.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_evaluation_harness/configs/func_url_func_1.json b/tests/test_evaluation_harness/configs/func_url_func_1.json index 7dbd8a2..993a246 100644 --- a/tests/test_evaluation_harness/configs/func_url_func_1.json +++ b/tests/test_evaluation_harness/configs/func_url_func_1.json @@ -17,7 +17,7 @@ { "url": "func:reddit_get_post_url('__last_url__')", "locator": "document.querySelector('.submission__inner').outerText", - "required_contents": {"must_include": ["​"]} + "required_contents": {"must_include": ["How will SPY close on Monday 11/28"]} } ] } From 6ab7fd2ce7287a3665acb887e872df9465c8b08a Mon Sep 17 00:00:00 2001 From: alexisxy Date: Wed, 27 Sep 2023 16:29:45 -0400 Subject: [PATCH 21/29] update fuzzy match prompt --- evaluation_harness/helper_functions.py | 37 +++++++++++++------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index 6df22e4..535dfcf 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -146,30 +146,29 @@ def gitlab_get_project_memeber_role(page: Page, account_name: str) -> str: def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: """Check whether the prediction matches the reference with GPT-3.5""" messages: list[dict[str, Any]] = [] - messages.append( - {"role": "system", "content": "You are a helpful assistant"} - ) - - messages.append( - { - "role": "user", - "content": f'Given the statement "{pred}", would it be correct to infer "{reference}"? Yes or No', - } - ) + # construct the question to ask + message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n" + message += f"question: {question}\n" + message += f"reference answer: {reference}\n" + message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n" + message += f"student answer: {pred}\n" + message += "Conclude the judgement by correct/incorrect/partially correct." + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": message}, + ] response = generate_from_openai_chat_completion( + model="gpt-4", messages=messages, - model="gpt-3.5-turbo", temperature=0, - top_p=1, - context_length=0, - max_tokens=16, - stop_token=None, - ) - if "Yes" in response: - return 1.0 - else: + max_tokens=768, + ).lower() + if "partially correct" in response or "incorrect" in response: return 0.0 + else: + assert "correct" in response + return 1.0 class PseudoPage: From 58061ee914243b07756f578e03e0dc568573a7b5 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Wed, 27 Sep 2023 16:35:44 -0400 Subject: [PATCH 22/29] reduce coordinate precision; fix template 67 annotations --- config_files/test.raw.json | 77 +++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/config_files/test.raw.json b/config_files/test.raw.json index 735919a..de29c86 100644 --- a/config_files/test.raw.json +++ b/config_files/test.raw.json @@ -1182,7 +1182,7 @@ "string_match" ], "reference_answers": { - "exact_match": "Yes" + "must_include": ["Yes"] }, "reference_url": "", "program_html": [], @@ -1212,7 +1212,7 @@ "string_match" ], "reference_answers": { - "exact_match": "Yes" + "must_include": ["Yes"] }, "reference_url": "", "program_html": [], @@ -1242,7 +1242,7 @@ "string_match" ], "reference_answers": { - "exact_match": "Yes" + "must_include": ["Yes"] }, "reference_url": "", "program_html": [], @@ -1272,7 +1272,7 @@ "string_match" ], "reference_answers": { - "exact_match": "Yes" + "must_include": ["Yes"] }, "reference_url": "", "program_html": [], @@ -1302,7 +1302,7 @@ "string_match" ], "reference_answers": { - "exact_match": "Yes" + "must_include": ["Yes"] }, "reference_url": "", "program_html": [], @@ -2859,14 +2859,13 @@ "must_include": [ "Rhode Island", "Massachusetts", - "New York", - "New Jersey" + "New York" ] }, "reference_url": "", "program_html": [], "string_note": "", - "reference_answer_raw_annotation": "Rhode Island, Massachusetts, New York, New Jersey" + "reference_answer_raw_annotation": "Rhode Island, Massachusetts, New York" }, "intent_template_id": 67 }, @@ -2894,13 +2893,15 @@ "Ohio", "Maryland", "New York", - "Virginia" + "New Jersey", + "Delaware", + "West Virginia" ] }, "reference_url": "", "program_html": [], "string_note": "", - "reference_answer_raw_annotation": "Ohio, Maryland, New York, Virginia" + "reference_answer_raw_annotation": "Ohio, Maryland, New York, New Jersey, Delaware, West Virginia" }, "intent_template_id": 67 }, @@ -5537,7 +5538,7 @@ "url_match" ], "reference_answers": { - "exact_match": "No" + "fuzzy_match": ["No, it is open"] }, "reference_url": "__GITLAB__/byteblaze/empathy-prompts/-/issues/8", "program_html": [], @@ -5568,7 +5569,7 @@ "url_match" ], "reference_answers": { - "exact_match": "No" + "fuzzy_match": ["No, it is open"] }, "reference_url": "__GITLAB__/byteblaze/a11y-webring.club/-/issues/71", "program_html": [], @@ -5598,7 +5599,7 @@ "url_match" ], "reference_answers": { - "exact_match": "No" + "fuzzy_match": ["No, it is open"] }, "reference_url": "__GITLAB__/byteblaze/empathy-prompts/-/issues/18", "program_html": [], @@ -5628,7 +5629,7 @@ "url_match" ], "reference_answers": { - "exact_match": "No" + "fuzzy_match": ["No, it is open"] }, "reference_url": "__GITLAB__/byteblaze/a11y-syntax-highlighting/-/issues/1", "program_html": [], @@ -5658,7 +5659,7 @@ "url_match" ], "reference_answers": { - "exact_match": "Yes" + "fuzzy_match": ["Yes, it is closed"] }, "reference_url": "__GITLAB__/a11yproject/a11yproject.com/-/issues/719", "program_html": [], @@ -7856,8 +7857,8 @@ ], "reference_answers": { "must_include": [ - "40.4424191", - "-79.9397388" + "40.442", + "-79.939" ] }, "reference_url": "", @@ -7888,8 +7889,8 @@ ], "reference_answers": { "must_include": [ - "40.46076", - "-79.94666" + "40.460", + "-79.946" ] }, "reference_url": "", @@ -7920,8 +7921,8 @@ ], "reference_answers": { "must_include": [ - "40.4511693", - "-79.9334241" + "40.451", + "-79.933" ] }, "reference_url": "", @@ -7952,8 +7953,8 @@ ], "reference_answers": { "must_include": [ - "40.4443", - "-79.94889" + "40.444", + "-79.948" ] }, "reference_url": "", @@ -7984,8 +7985,8 @@ ], "reference_answers": { "must_include": [ - "40.45761", - "-79.92934" + "40.457", + "-79.929" ] }, "reference_url": "", @@ -26794,11 +26795,11 @@ "geolocation": null, "intent_template": "Start a private project {{project_name}} with {{template}} template and add {{account_list}} as members", "instantiation_dict": { - "project_name": "web_agent_android", + "project_name": "web_agent_android_xl", "template": "Android", "account_list": "primer, convexegg, abishek" }, - "intent": "Start a private project web_agent_android with Android template and add primer, convexegg, abishek as members", + "intent": "Start a private project web_agent_android_xl with Android template and add primer, convexegg, abishek as members", "require_reset": false, "eval": { "eval_types": [ @@ -26808,7 +26809,7 @@ "reference_url": "", "program_html": [ { - "url": "__GITLAB__/byteblaze/web_agent_android", + "url": "__GITLAB__/byteblaze/web_agent_android_xl", "locator": "document.querySelector('.visibility-icon').getAttribute('title')", "required_contents": { "must_include": [ @@ -26817,7 +26818,7 @@ } }, { - "url": "__GITLAB__/byteblaze/web_agent_android/-/commits", + "url": "__GITLAB__/byteblaze/web_agent_android_xl/-/commits", "locator": "", "required_contents": { "must_include": [ @@ -26826,7 +26827,7 @@ } }, { - "url": "__GITLAB__/byteblaze/web_agent_android/-/project_members", + "url": "__GITLAB__/byteblaze/web_agent_android_xl/-/project_members", "locator": "", "required_contents": { "must_include": [ @@ -26931,7 +26932,7 @@ } }, { - "url": "__GITLAB__/byteblaze/agi_index/-/commits", + "url": "__GITLAB__/byteblaze/agi_index", "locator": "document.querySelector('.home-panel-description-markdown').outerText", "required_contents": { "must_include": [ @@ -26986,7 +26987,7 @@ } }, { - "url": "__GITLAB__/byteblaze/AGISite/-/commits", + "url": "__GITLAB__/byteblaze/AGISite", "locator": "document.querySelector('.home-panel-description-markdown').outerText", "required_contents": { "must_include": [ @@ -27064,10 +27065,10 @@ "geolocation": null, "intent_template": "Create a private {{template}} repository called \"{{project_name}}\" using the right template to speed up development.", "instantiation_dict": { - "project_name": "web_agent_android", + "project_name": "web_agent_android_xs", "template": "Android" }, - "intent": "Create a private Android repository called \"web_agent_android\" using the right template to speed up development.", + "intent": "Create a private Android repository called \"web_agent_android_xs\" using the right template to speed up development.", "require_reset": false, "eval": { "eval_types": [ @@ -27077,7 +27078,7 @@ "reference_url": "", "program_html": [ { - "url": "__GITLAB__/byteblaze/web_agent_android", + "url": "__GITLAB__/byteblaze/web_agent_android_xs", "locator": "document.querySelector('.visibility-icon').getAttribute('title')", "required_contents": { "must_include": [ @@ -27086,7 +27087,7 @@ } }, { - "url": "__GITLAB__/byteblaze/web_agent_android/-/commits", + "url": "__GITLAB__/byteblaze/web_agent_android_xs/-/commits", "locator": "", "required_contents": { "must_include": [ @@ -27176,7 +27177,7 @@ } }, { - "url": "__GITLAB__/byteblaze/agi_index/-/commits", + "url": "__GITLAB__/byteblaze/web_agent_index", "locator": "document.querySelector('.home-panel-description-markdown').outerText", "required_contents": { "must_include": [ @@ -27221,7 +27222,7 @@ } }, { - "url": "__GITLAB__/byteblaze/AGISite/-/commits", + "url": "__GITLAB__/byteblaze/11711_gitlab", "locator": "document.querySelector('.home-panel-description-markdown').outerText", "required_contents": { "must_include": [ From 4b86d435b9576a6a66ac380ea7407af7a24aeb5e Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 19:28:08 -0400 Subject: [PATCH 23/29] fix locator for product; add prep action; fix url for promo rules --- config_files/test.raw.json | 30 ++++++++++++++++++++---------- evaluation_harness/evaluators.py | 8 ++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/config_files/test.raw.json b/config_files/test.raw.json index de29c86..d196943 100644 --- a/config_files/test.raw.json +++ b/config_files/test.raw.json @@ -23891,7 +23891,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/catalog/product/edit/id", + "reference_url": "__SHOPPING_ADMIN__/catalog/product", "program_html": [ { "url": "last", @@ -23902,7 +23902,7 @@ }, { "url": "last", - "locator": "document.querySelector('[name=\"product[name]\"').outerText", + "locator": "document.querySelector('[name=\"product[name]\"').value", "required_contents": { "must_include": [ "Energy-Bulk Women Shirt" @@ -23978,7 +23978,7 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/catalog/product/edit/id", + "reference_url": "__SHOPPING_ADMIN__/catalog/product", "program_html": [ { "url": "last", @@ -23989,7 +23989,7 @@ }, { "url": "last", - "locator": "document.querySelector('[name=\"product[name]\"').outerText", + "locator": "document.querySelector('[name=\"product[name]\"').value", "required_contents": { "must_include": [ "Energy-Bulk Man Yoga Pant" @@ -24065,11 +24065,11 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/catalog/product/edit/id", + "reference_url": "__SHOPPING_ADMIN__/catalog/product", "program_html": [ { "url": "last", - "locator": "document.querySelector('[name=\"product[name]\"').outerText", + "locator": "document.querySelector('[name=\"product[name]\"').value", "required_contents": { "must_include": [ "FancyBoy Man Causal Jeans" @@ -24152,11 +24152,11 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/catalog/product/edit/id", + "reference_url": "__SHOPPING_ADMIN__/catalog/product", "program_html": [ { "url": "last", - "locator": "document.querySelector('[name=\"product[name]\"').outerText", + "locator": "document.querySelector('[name=\"product[name]\"').value", "required_contents": { "must_include": [ "Swaatch Smart Watch" @@ -24232,11 +24232,11 @@ "program_html" ], "reference_answers": null, - "reference_url": "__SHOPPING_ADMIN__/catalog/product/edit/id", + "reference_url": "__SHOPPING_ADMIN__/catalog/product", "program_html": [ { "url": "last", - "locator": "document.querySelector('[name=\"product[name]\"').outerText", + "locator": "document.querySelector('[name=\"product[name]\"').value", "required_contents": { "must_include": [ "Lelelumon Yoga Mat" @@ -24337,6 +24337,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"simple_action\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "by_percent" } @@ -24344,6 +24345,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"discount_amount\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "20" } @@ -24403,6 +24405,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"simple_action\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "cart_fixed" } @@ -24410,6 +24413,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"discount_amount\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "10" } @@ -24469,6 +24473,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"simple_action\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "cart_fixed" } @@ -24476,6 +24481,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"discount_amount\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "15" } @@ -24535,6 +24541,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"simple_action\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "by_percent" } @@ -24542,6 +24549,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"discount_amount\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "45" } @@ -24601,6 +24609,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"simple_action\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "cart_fixed" } @@ -24608,6 +24617,7 @@ { "url": "last", "locator": "document.querySelector('[name=\"discount_amount\"').value", + "prep_actions": ["document.querySelector('[data-index=\"actions\"]').querySelector('.admin__collapsible-title').click()"], "required_contents": { "exact_match": "40" } diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index e210fa8..ccfc3bc 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -261,6 +261,12 @@ def __call__( elif locator.startswith("document.") or locator.startswith( "[...document." ): + if "prep_actions" in target: + try: + for prep_action in target["prep_actions"]: + page.evaluate(f"() => {prep_action}") + except Exception: + pass try: selected_element = str(page.evaluate(f"() => {locator}")) if not selected_element: @@ -284,6 +290,7 @@ def __call__( ref=required_contents, pred=selected_element ) score *= float(cur_score) + # print(f"[exact match] {cur_score}, selected element: {selected_element}, required contents: {required_contents}") elif "must_include" in target["required_contents"]: required_contents = target["required_contents"]["must_include"] assert isinstance(required_contents, list) @@ -300,6 +307,7 @@ def __call__( ] ) score *= float(cur_score) + # print(f"[must include] {cur_score}, selected element: {selected_element}, required contents: {content_or}") else: raise ValueError( f"Unknown required_contents: {target['required_contents'].keys()}" From df87757d47883cdb0a048ce389bb2b7cb45a35bd Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 19:30:24 -0400 Subject: [PATCH 24/29] add options to renew cookie for selected sites --- browser_env/auto_login.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/browser_env/auto_login.py b/browser_env/auto_login.py index 7602deb..1354a21 100644 --- a/browser_env/auto_login.py +++ b/browser_env/auto_login.py @@ -2,6 +2,7 @@ import argparse import glob import os +import time from concurrent.futures import ThreadPoolExecutor from itertools import combinations from pathlib import Path @@ -40,10 +41,11 @@ def is_expired( context_manager = sync_playwright() playwright = context_manager.__enter__() - browser = playwright.chromium.launch(headless=HEADLESS, slow_mo=SLOW_MO) + browser = playwright.chromium.launch(headless=True, slow_mo=SLOW_MO) context = browser.new_context(storage_state=storage_state) page = context.new_page() page.goto(url) + time.sleep(1) d_url = page.url content = page.content() context_manager.__exit__() @@ -151,4 +153,7 @@ def main(auth_folder: str = "./.auth") -> None: if not args.site_list: main() else: - renew_comb(args.site_list, auth_folder=args.auth_folder) + if "all" in args.site_list: + main(auth_folder=args.auth_folder) + else: + renew_comb(args.site_list, auth_folder=args.auth_folder) From 3d3d837771303daa3639b009864f63db5ee7fc4e Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 19:31:54 -0400 Subject: [PATCH 25/29] print unfinished examples --- scripts/check_error_runs.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scripts/check_error_runs.py b/scripts/check_error_runs.py index 2fb4247..4b90153 100644 --- a/scripts/check_error_runs.py +++ b/scripts/check_error_runs.py @@ -28,6 +28,7 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str: cur_log and index and os.path.exists(f"{result_folder}/render_{index}.html") + and len(cur_log) >= 3 ): merged_results[index] = cur_log # update index and log @@ -36,7 +37,13 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str: else: cur_log.append(line) - if os.path.exists(f"{result_folder}/render_{index}.html"): + if ( + cur_log + and index + and os.path.exists(f"{result_folder}/render_{index}.html") + and len(cur_log) >= 3 + ): + merged_results[index] = cur_log # sort by the key @@ -68,6 +75,12 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str: for idx in unlog_examples: os.remove(f"{args.result_folder}/render_{idx}.html") + unifinished_examples = [ + i for i in range(0, 812) if str(i) not in merged_results + ] + print(f"Number of unfinished examples: {len(unifinished_examples)}") + print(unifinished_examples) + return merged_log_path From 7730a85191f949334dc5d8fb3a6e05c714c34667 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 19:32:37 -0400 Subject: [PATCH 26/29] reduce openai max retry --- llms/providers/openai_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llms/providers/openai_utils.py b/llms/providers/openai_utils.py index 05887f4..94a676e 100644 --- a/llms/providers/openai_utils.py +++ b/llms/providers/openai_utils.py @@ -19,8 +19,8 @@ def retry_with_exponential_backoff( # type: ignore initial_delay: float = 1, exponential_base: float = 2, jitter: bool = True, - max_retries: int = 10, - errors: tuple[Any] = (openai.error.RateLimitError,), + max_retries: int = 3, + errors: tuple[Any] = (openai.error.RateLimitError), ): """Retry a function with exponential backoff.""" @@ -32,9 +32,7 @@ def wrapper(*args, **kwargs): # type: ignore # Loop until a successful response or max_retries is hit or an exception is raised while True: try: - return func(*args, **kwargs) - # Retry on specified errors except errors as e: # Increment retries @@ -48,7 +46,7 @@ def wrapper(*args, **kwargs): # type: ignore # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) - + print(f"Retrying in {delay} seconds.") # Sleep for the delay time.sleep(delay) From f91eb5bbdff45701461e3f4af85ae2fcb5017a50 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 19:36:17 -0400 Subject: [PATCH 27/29] minor --- evaluation_harness/evaluate_by_trace.py | 66 ------------------------- evaluation_harness/helper_functions.py | 2 + run.py | 20 +++++--- 3 files changed, 14 insertions(+), 74 deletions(-) delete mode 100644 evaluation_harness/evaluate_by_trace.py diff --git a/evaluation_harness/evaluate_by_trace.py b/evaluation_harness/evaluate_by_trace.py deleted file mode 100644 index 3820789..0000000 --- a/evaluation_harness/evaluate_by_trace.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Evaluate by using the traces.zip files saved""" -import argparse -import json -import os -import sys -import tempfile -import zipfile - -from playwright.sync_api import Page, sync_playwright - -from evaluation_harness import evaluator_router -from evaluation_harness.helper_functions import PseudoPage - - -def eval_trace(trace_path: str, task_id: int, config_file_folder: str): - # load the config file - config_file = f"{config_file_folder}/{task_id}.json" - with open(config_file, "r") as f: - config = json.load(f) - - if "string_match" in config["eval"]["eval_types"]: - raise ValueError( - "string_match is not supported in this evaluation script" - ) - - # extract the last url from the trace file - temp_dir = tempfile.TemporaryDirectory() - with zipfile.ZipFile(trace_path, "r") as zip_ref: - zip_ref.extractall(temp_dir.name) - with open(f"{temp_dir.name}/trace.trace", "r") as f: - trace = [] - for line in f: - trace.append(json.loads(line)) - last_url = "" - for step in trace[::-1]: - if step.get("type", None) == "frame-snapshot": - last_url = step["snapshot"]["frameUrl"] - break - if not last_url: - raise ValueError("Cannot find the last url in the trace file") - - # start the playwright - context_manager = sync_playwright() - playwright = context_manager.__enter__() - browser = playwright.chromium.launch(headless=True) - context = browser.new_context() - page = context.new_page() - page.goto("https://trace.playwright.dev/") - with page.expect_file_chooser() as fc_info: - page.get_by_role("button", name="Select file(s)").click() - file_chooser = fc_info.value - file_chooser.set_files(trace_path) - with page.expect_popup() as page1_info: - page.get_by_role("button", name="").click() - page1 = page1_info.value - - pseudo_page = PseudoPage(page1, last_url) - evaluator = evaluator_router(config_file) - - score = evaluator( - trajectory=[], - config_file=config_file, - page=pseudo_page, - client=pseudo_page.context.new_cdp_session(pseudo_page), - ) - print(score) diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index 535dfcf..5baf466 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -163,6 +163,8 @@ def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: messages=messages, temperature=0, max_tokens=768, + top_p=1.0, + context_length=0, ).lower() if "partially correct" in response or "incorrect" in response: return 0.0 diff --git a/run.py b/run.py index 010bc54..cee3c98 100644 --- a/run.py +++ b/run.py @@ -423,13 +423,17 @@ def dump_config(args: argparse.Namespace) -> None: test_file_list.append(f"config_files/{i}.json") if "debug" not in args.result_dir: test_file_list = get_unfinished(test_file_list, args.result_dir) - print(f"Total {len(test_file_list)} tasks left") - args.render = False - args.render_screenshot = True - args.save_trace_enabled = True - args.current_viewport_only = True - dump_config(args) + if len(test_file_list) == 0: + logger.info("No task left to run") + else: + print(f"Total {len(test_file_list)} tasks left") + args.render = False + args.render_screenshot = True + args.save_trace_enabled = True + + args.current_viewport_only = True + dump_config(args) - agent = construct_agent(args) - test(args, agent, test_file_list) + agent = construct_agent(args) + test(args, agent, test_file_list) From 4cec5acab851bdcc2f5a56dab99e83433d8d9f1e Mon Sep 17 00:00:00 2001 From: alexisxy Date: Fri, 20 Oct 2023 21:07:12 -0400 Subject: [PATCH 28/29] minor --- .gitignore | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 1cc64e2..54703d6 100644 --- a/.gitignore +++ b/.gitignore @@ -146,14 +146,14 @@ cache/* # TMP IGNORE agent/prompts/jsons/* log_files/ -config_files/*0.json -config_files/*1.json -config_files/*2.json -config_files/*3.json -config_files/*4.json -config_files/*5.json -config_files/*6.json -config_files/*7.json -config_files/*8.json -config_files/*9.json -config_files/test.json +config_files*/*0.json +config_files*/*1.json +config_files*/*2.json +config_files*/*3.json +config_files*/*4.json +config_files*/*5.json +config_files*/*6.json +config_files*/*7.json +config_files*/*8.json +config_files*/*9.json +config_files*/test.json From 9f0900f506ad6e49c6931efc21fb52f89df804b9 Mon Sep 17 00:00:00 2001 From: alexisxy Date: Sat, 21 Oct 2023 00:20:30 -0400 Subject: [PATCH 29/29] fix type errors --- .github/workflows/tests.yml | 2 +- evaluation_harness/evaluators.py | 10 ++++------ evaluation_harness/helper_functions.py | 2 +- llms/providers/hf_utils.py | 4 ++-- llms/providers/openai_utils.py | 2 +- llms/tokenizers.py | 8 ++++---- llms/utils.py | 8 ++++++-- requirements.txt | 2 ++ scripts/check_error_runs.py | 2 +- 9 files changed, 22 insertions(+), 18 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ce3602..79be870 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,7 +34,7 @@ jobs: mypy --version # Run this mypy instance against our main package. mypy --install-types --non-interactive . - mypy --strict . + mypy --strict . --exclude scripts - name: Enviroment prepare run: | bash prepare.sh diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index cd96a0d..df20431 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -9,9 +9,7 @@ from pathlib import Path from typing import Any, Tuple, Union -import evaluate # type: ignore[import] from beartype import beartype -from beartype.door import is_bearable from nltk.tokenize import word_tokenize # type: ignore from playwright.sync_api import CDPSession, Page @@ -96,7 +94,7 @@ def exact_match(ref: str, pred: str) -> float: @staticmethod @beartype - def must_include(ref: str, pred: str, tokenize=False) -> float: + def must_include(ref: str, pred: str, tokenize: bool = False) -> float: clean_ref = StringEvaluator.clean_answer(ref) clean_pred = StringEvaluator.clean_answer(pred) # tokenize the answer if the ref is a single word @@ -180,7 +178,7 @@ def parse_url(url: str) -> tuple[str, dict[str, list[str]]]: def parse_urls( urls: list[str], - ) -> tuple[list[str], list[str], dict[str, set[str]]]: + ) -> tuple[list[str], dict[str, set[str]]]: """Parse a list of URLs.""" base_paths = [] queries = collections.defaultdict(set) @@ -324,8 +322,8 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage | None = None, - client: CDPSession | None = None, + page: Page | PseudoPage, + client: CDPSession, ) -> float: score = 1.0 diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index 5baf466..3906240 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -178,7 +178,7 @@ def __init__(self, original_page: Page, url: str): self.url = url self.original_page = original_page - def __getattr__(self, attr: str) -> any: + def __getattr__(self, attr: str) -> Any: # Delegate attribute access to the original page object if attr not in ["url"]: return getattr(self.original_page, attr) diff --git a/llms/providers/hf_utils.py b/llms/providers/hf_utils.py index c5a3f11..b5e8987 100644 --- a/llms/providers/hf_utils.py +++ b/llms/providers/hf_utils.py @@ -1,4 +1,4 @@ -from text_generation import Client +from text_generation import Client # type: ignore def generate_from_huggingface_completion( @@ -10,7 +10,7 @@ def generate_from_huggingface_completion( stop_sequences: list[str] | None = None, ) -> str: client = Client(model_endpoint, timeout=60) - generation = client.generate( + generation: str = client.generate( prompt=prompt, temperature=temperature, top_p=top_p, diff --git a/llms/providers/openai_utils.py b/llms/providers/openai_utils.py index 94a676e..4dcdad2 100644 --- a/llms/providers/openai_utils.py +++ b/llms/providers/openai_utils.py @@ -20,7 +20,7 @@ def retry_with_exponential_backoff( # type: ignore exponential_base: float = 2, jitter: bool = True, max_retries: int = 3, - errors: tuple[Any] = (openai.error.RateLimitError), + errors: tuple[Any] = (openai.error.RateLimitError,), ): """Retry a function with exponential backoff.""" diff --git a/llms/tokenizers.py b/llms/tokenizers.py index 67aa231..8e45ccf 100644 --- a/llms/tokenizers.py +++ b/llms/tokenizers.py @@ -1,7 +1,7 @@ from typing import Any import tiktoken -from transformers import LlamaTokenizer +from transformers import LlamaTokenizer # type: ignore class Tokenizer(object): @@ -11,9 +11,9 @@ def __init__(self, provider: str, model_name: str) -> None: elif provider == "huggingface": self.tokenizer = LlamaTokenizer.from_pretrained(model_name) # turn off adding special tokens automatically - self.tokenizer.add_special_tokens = False - self.tokenizer.add_bos_token = False - self.tokenizer.add_eos_token = False + self.tokenizer.add_special_tokens = False # type: ignore[attr-defined] + self.tokenizer.add_bos_token = False # type: ignore[attr-defined] + self.tokenizer.add_eos_token = False # type: ignore[attr-defined] else: raise NotImplementedError diff --git a/llms/utils.py b/llms/utils.py index 54b57e0..ea91a10 100644 --- a/llms/utils.py +++ b/llms/utils.py @@ -13,10 +13,12 @@ def call_llm( lm_config: lm_config.LMConfig, - prompt: list[Any] | str, -) -> APIInput: + prompt: APIInput, +) -> str: + response: str if lm_config.provider == "openai": if lm_config.mode == "chat": + assert isinstance(prompt, list) response = generate_from_openai_chat_completion( messages=prompt, model=lm_config.model, @@ -27,6 +29,7 @@ def call_llm( stop_token=None, ) elif lm_config.mode == "completion": + assert isinstance(prompt, str) response = generate_from_openai_completion( prompt=prompt, engine=lm_config.model, @@ -40,6 +43,7 @@ def call_llm( f"OpenAI models do not support mode {lm_config.mode}" ) elif lm_config.provider == "huggingface": + assert isinstance(prompt, str) response = generate_from_huggingface_completion( prompt=prompt, model_endpoint=lm_config.gen_config["model_endpoint"], diff --git a/requirements.txt b/requirements.txt index 2567aa5..b2f109b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ aiolimiter beartype==0.12.0 flask nltk +text-generation +transformers diff --git a/scripts/check_error_runs.py b/scripts/check_error_runs.py index 4b90153..0039b56 100644 --- a/scripts/check_error_runs.py +++ b/scripts/check_error_runs.py @@ -20,7 +20,7 @@ def merge_logs(result_folder: str, args: argparse.Namespace) -> str: with open(file.strip(), "r") as f: lines = f.readlines() - cur_log = [] + cur_log: list[str] = [] index = None for line in lines: if "[Config file]" in line: