diff --git a/README.md b/README.md index cf532cd..35d7ea0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,72 @@ -# LLM-argumentation +# Exploring the Potential of Large Language Models in Computational Argumentation -To be updated soon. +This repo contains the data and codes for our paper ["Exploring the Potential of Large Language Models in Computational Argumentation"](https://aclanthology.org/2024.acl-long.126/) in ACL 2024. + +### Abstract + +Computational argumentation has become an essential tool in various domains, including law, public policy, and artificial intelligence. It is an emerging research field in natural language processing that attracts increasing attention. Research on computational argumentation mainly involves two types of tasks: argument mining and argument generation. As large language models (LLMs) have demonstrated impressive capabilities in understanding context and generating natural language, it is worthwhile to evaluate the performance of LLMs on diverse computational argumentation tasks. This work aims to embark on an assessment of LLMs, such as ChatGPT, Flan models, and LLaMA2 models, in both zero-shot and few-shot settings. We organize existing tasks into six main categories and standardize the format of fourteen openly available datasets. In addition, we present a new benchmark dataset on counter speech generation that aims to holistically evaluate the end-to-end performance of LLMs on argument mining and argument generation. Extensive experiments show that LLMs exhibit commendable performance across most of the datasets, demonstrating their capabilities in the field of argumentation. Our analysis offers valuable suggestions for evaluating computational argumentation and its integration with LLMs in future research endeavors. + +### Setup + +To install dependencies: + +``` +conda create -n llm-am python=3.9 -y +conda activate llm-am +pip install -r requirements.txt +``` + +To run OpenAI models, insert your [OpenAI key](https://platform.openai.com/account/api-keys) and model version in [openai_info.json](openai_info.json): + +``` +{ + "engine": "gpt-3.5-turbo-0301", + "key": "YOUR API KEY" +} +``` + +### Example Usage + +To run Flan-T5-XL on the ibm_claims dataset using 5-shot demonstrations: + +``` +python main.py \ +--model_name flan_t5_xl \ +--path_model google/flan-t5-xl \ +--task claim_detection \ +--data_name ibm_claims \ +--num_train 5 +``` + +The results will be printed as + +``` +{'accuracy': 0.74, 'f1': 0.7909496513561132} +``` + +(Note that some variance is possible) + + +Run using your own prompts by modifying the prompts in [prompting.py](prompting.py). + +### Citation +``` +@inproceedings{chen-etal-2024-exploring-potential, + title = "Exploring the Potential of Large Language Models in Computational Argumentation", + author = "Chen, Guizhen and + Cheng, Liying and + Luu, Anh Tuan and + Bing, Lidong", + editor = "Ku, Lun-Wei and + Martins, Andre and + Srikumar, Vivek", + booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + month = aug, + year = "2024", + address = "Bangkok, Thailand", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.acl-long.126", + pages = "2309--2330", + abstract = "Computational argumentation has become an essential tool in various domains, including law, public policy, and artificial intelligence. It is an emerging research field in natural language processing that attracts increasing attention. Research on computational argumentation mainly involves two types of tasks: argument mining and argument generation. As large language models (LLMs) have demonstrated impressive capabilities in understanding context and generating natural language, it is worthwhile to evaluate the performance of LLMs on diverse computational argumentation tasks. This work aims to embark on an assessment of LLMs, such as ChatGPT, Flan models, and LLaMA2 models, in both zero-shot and few-shot settings. We organize existing tasks into six main categories and standardize the format of fourteen openly available datasets. In addition, we present a new benchmark dataset on counter speech generation that aims to holistically evaluate the end-to-end performance of LLMs on argument mining and argument generation. Extensive experiments show that LLMs exhibit commendable performance across most of the datasets, demonstrating their capabilities in the field of argumentation. Our analysis offers valuable suggestions for evaluating computational argumentation and its integration with LLMs in future research endeavors.", +} +``` \ No newline at end of file diff --git a/data_loading.py b/data_loading.py new file mode 100644 index 0000000..be13cdb --- /dev/null +++ b/data_loading.py @@ -0,0 +1,81 @@ +import json +import os + +from typing import List +from fire import Fire + +from pydantic import BaseModel + + +class ArgumentSample(BaseModel): + src: str = "" + tgt: str = "" + prompt: str = "" + raw: str = "" + pred: str = "" + + @classmethod + def default_format(cls, src: str, tgt: str): + src = src.strip() + tgt = tgt.strip() + src = src.replace('\n', ' ') + tgt = tgt.replace('\n', ' ') + + return cls(src=src, tgt=tgt) + + +class ArgumentData(BaseModel): + samples: List[ArgumentSample] + + @classmethod + def load_from_paths(cls, src_path: str, tgt_path: str): + with open(src_path) as f: + raw_src = [line for line in f] + with open(tgt_path) as f: + raw_tgt = [line for line in f] + + assert len(raw_src) == len(raw_tgt) + + return cls(samples=[ArgumentSample.default_format(line_src, line_tgt) for line_src, line_tgt in zip(raw_src, raw_tgt)]) + + @classmethod + def load_train(cls, task: str, data_name: str, num_train: int, seed: int): + train_folder = f"sampled_data/{task}/{data_name}/train/{num_train}_shot/seed_{seed}" + if not os.path.isdir(train_folder): + return cls(samples=[]) + else: + train_src_path = f"{train_folder}/source.txt" + train_tgt_path = f"{train_folder}/target.txt" + return cls.load_from_paths(train_src_path, train_tgt_path) + + @classmethod + def load_test(cls, task: str, data_name: str): + test_folder = f"sampled_data/{task}/{data_name}/test" + test_src_path = f"{test_folder}/source.txt" + test_tgt_path = f"{test_folder}/target.txt" + return cls.load_from_paths(test_src_path, test_tgt_path) + + @classmethod + def load(cls, task: str, data_name: str, num_train: int, seed: int): + data_train = cls.load_train(task, data_name, num_train, seed) + data_test = cls.load_test(task, data_name) + return data_train, data_test + + @classmethod + def load_outputs(cls, output_path: str): + samples = [] + with open(output_path) as f: + for line in f: + samples.append(ArgumentSample(**json.loads(line.strip()))) + return cls(samples=samples) + + +def test_data(task: str, data_name: str, num_train: int, seed: int): + data_train, data_test = ArgumentData.load(task, data_name, num_train, seed) + print(data_test.samples[0]) + print("num train: ", len(data_train.samples)) + print("num test: ", len(data_test.samples)) + + +if __name__ == "__main__": + Fire() diff --git a/main.py b/main.py new file mode 100644 index 0000000..6d1f55a --- /dev/null +++ b/main.py @@ -0,0 +1,90 @@ +import os +import json +import time +import numpy as np +import pandas as pd +from pathlib import Path +from typing import TextIO + +from fire import Fire +from tqdm import tqdm + +from data_loading import ArgumentSample, ArgumentData +from modeling import select_model, EvalModel +from prompting import select_prompter, Prompter +from scoring import select_scorer + + +def inference( + model: EvalModel, + data_train: ArgumentData, + data_test: ArgumentData, + prompter: Prompter, + file: TextIO, +): + + progress = tqdm(data_test.samples) + sample: ArgumentSample + + targets = [] + predictions = [] + for _, sample in enumerate(progress): + k = int(len(data_train.samples)) + prompt = prompter.run(data_train, sample) + # handle prompt length + while not model.check_valid_length(prompt) and k > 0: + k -= 1 + data_train.samples = data_train.samples[:k] + prompt = prompter.run(data_train, sample) + + if not model.check_valid_length(prompt): + prompt = model.truncate_input(prompt) + + # predict + sample.prompt = prompt + sample.raw = model.run(prompt) + sample.pred = prompter.get_answer(sample.raw) + print(sample.model_dump_json(), file=file) + + targets.append(sample.tgt) + predictions.append(sample.pred) + + return predictions, targets + +def main( + task: str = "conclugen", + data_name: str = "base", + num_train: int = 5, + seed: int = 0, + **kwargs +): + # load model + model = select_model(**kwargs) + print(locals()) + + # select prompter + prompter = select_prompter(task, data_name) + + # load data + data_train, data_test = ArgumentData.load(task, data_name, num_train, seed) + + # set path + output_folder = f"output/{task}/{data_name}/{num_train}_shot/seed_{seed}" + if not os.path.isdir(output_folder): + os.makedirs(output_folder) + model_name = Path(model.path_model).stem + output_path = f"{output_folder}/{model_name}.json" + + # infer + Path(output_path).parent.mkdir(exist_ok=True, parents=True) + with open(output_path, "w") as file: + targets, predictions = inference(model, data_train, data_test, prompter, file) + + # score + scorer = select_scorer(task) + scores = scorer.run(predictions, targets) + print(scores) + + +if __name__ == "__main__": + Fire(main) diff --git a/modeling.py b/modeling.py new file mode 100644 index 0000000..2b7b60f --- /dev/null +++ b/modeling.py @@ -0,0 +1,215 @@ +import time +import json +from typing import Optional + +import torch +import openai +import tiktoken +from fire import Fire +from pydantic import BaseModel +from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + AutoModelForSeq2SeqLM, + AutoTokenizer, +) + + +class DummyImport: + LLM = None + SamplingParams = None + + +try: + import vllm + from vllm.lora.request import LoRARequest +except ImportError: + print("vLLM not installed") + vllm = DummyImport() + LoRARequest = lambda *args: args + + +class EvalModel(BaseModel, arbitrary_types_allowed=True): + path_model: str + max_input_length: int = 512 + max_output_length: int = 512 + + def run(self, prompt: str) -> str: + raise NotImplementedError + + +class VLLMModel(EvalModel): + path_model: str + model: vllm.LLM = None + quantization: Optional[str] = None + tokenizer: Optional[PreTrainedTokenizer] = None + tensor_parallel_size: int = 1 + + def load(self): + if self.model is None: + self.model = vllm.LLM( + model=self.path_model, + trust_remote_code=True, + quantization=self.quantization, + tensor_parallel_size=self.tensor_parallel_size, + ) + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) + + def format_prompt(self, prompt: str) -> str: + self.load() + prompt = prompt.rstrip(" ") + return prompt + + def make_kwargs(self, do_sample: bool, **kwargs) -> dict: + params = vllm.SamplingParams( + temperature=0.5 if do_sample else 0.0, + max_tokens=self.max_output_length, + **kwargs + ) + outputs = dict(sampling_params=params, use_tqdm=False) + return outputs + + def run(self, prompt: str) -> str: + prompt = self.format_prompt(prompt) + outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False)) + pred = outputs[0].outputs[0].text + pred = pred.split("<|endoftext|>")[0] + return pred + + def check_valid_length(self, text: str) -> bool: + self.load() + inputs = self.tokenizer(text) + return len(inputs.input_ids) <= self.max_input_length + + def truncate_input(self, input) -> str: + return self.tokenizer.decode(self.tokenizer(input).input_ids[:self.max_input_length]) + + +class SeqToSeqModel(EvalModel): + path_model: str + model: Optional[PreTrainedModel] = None + tokenizer: Optional[PreTrainedTokenizer] = None + device: str = "cuda" + load_8bit: bool = False + fp16: bool = False + + def load(self): + if "flan-ul2" in self.path_model.lower(): + self.max_input_length = 2048 + if self.model is None: + args = {} + if self.load_8bit: + args.update(device_map="auto", load_in_8bit=True) + elif self.fp16: + args.update(device_map="auto", torch_dtype=torch.float16) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.path_model, **args) + if self.fp16 or self.load_8bit: + self.model.eval() + else: + self.model.to(self.device) + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) + + def run(self, prompt: str) -> str: + self.load() + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) + outputs = self.model.generate(**inputs, max_new_tokens=self.max_output_length) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + def check_valid_length(self, text: str) -> bool: + self.load() + inputs = self.tokenizer(text) + return len(inputs.input_ids) <= self.max_input_length + + def truncate_input(self, input) -> str: + return self.tokenizer.decode(self.tokenizer(input).input_ids[:self.max_input_length]) + + +class OpenAIModel(EvalModel): + path_model: str + tokenizer: Optional[tiktoken.Encoding] + temperature: float = 0.0 + max_input_length: int = 3996 # to allow 100 tokens for response + + def load(self): + if self.tokenizer is None: + self.tokenizer = tiktoken.get_encoding("cl100k_base") # chatgpt/gpt-4 + + with open(self.path_model) as f: + info = json.load(f) + openai.api_key = info["key"] + self.model = info["model"] + + def run(self, prompt: str) -> str: + self.load() + while True: + try: + messages = [{"role": "user", "content": prompt}] + response = openai.ChatCompletion.create( + model=self.model, + messages=messages, + temperature=self.temperature, + ) + output = response.choices[0].message["content"] + break + except Exception as e: + print(e) + time.sleep(5) + continue + return output + + def check_valid_length(self, prompt: str) -> bool: + self.load() + tokens_per_message = 4 + tokens_per_name = -1 + messages = [{"role": "user", "content": prompt}] + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(self.tokenizer.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + + return num_tokens <= self.max_input_length + + def truncate_input(self, input) -> str: + return self.tokenizer.decode(self.tokenizer.encode(input)[:self.max_input_length-8]) + + +def select_model(model_name: str, **kwargs) -> EvalModel: + model_map = dict( + flan_t5_xl=SeqToSeqModel, + flan_t5_xxl = SeqToSeqModel, + flan_ul2 = SeqToSeqModel, + openai=OpenAIModel, + llama2_7b=VLLMModel, + llama2_13b=VLLMModel, + ) + model_class = model_map.get(model_name) + if model_class is None: + raise ValueError(f"{model_name}. Choose from {list(model_map.keys())}") + return model_class(**kwargs) + + +def test_model( + prompt: str = "Identify the stance of the given sentence. Choose from 'support', 'attack', or 'neutral'.\nSentence: Menace II Society is a motion picture.\nLabel: ", + model_name: str = "flan_t5_xl", + path_model: str = "google/flan-t5-xl", + **kwargs, +): + + model = select_model(model_name, path_model=path_model, **kwargs) + print(locals()) + print(model.check_valid_length(prompt)) + if not model.check_valid_length(prompt): + prompt = model.truncate_input(prompt) + print(f"Truncated prompt: {prompt}\n Length:{model.max_input_length}") + + print(model.run(prompt)) + + +if __name__ == "__main__": + Fire() diff --git a/openai_info.json b/openai_info.json new file mode 100644 index 0000000..c5feb2d --- /dev/null +++ b/openai_info.json @@ -0,0 +1 @@ +{"engine": "gpt-3.5-turbo-0301", "key": "YOUR_KEY_HERE"} diff --git a/prompting.py b/prompting.py new file mode 100644 index 0000000..f443c66 --- /dev/null +++ b/prompting.py @@ -0,0 +1,423 @@ +from typing import Optional, List + +from fire import Fire +from pydantic import BaseModel + +from data_loading import ArgumentSample, ArgumentData + + +class Prompter(BaseModel): + + def run(self, data_train: ArgumentSample, sample_test: ArgumentSample) -> str: + raise NotImplementedError + + def get_answer(self, text: str) -> str: + raise NotImplementedError + + +class ConclugenBasePrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Based on the evidence presented, what is the most logical and justifiable stance to take on the issue at hand?\n" + for sample in data_train.samples: + prompt += f"Argument: {sample.src}\n" + prompt += f"Conclusion: {sample.tgt}\n\n" + + prompt += f"Argument: {sample_test.src}\n" + prompt += "Conclusion: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class ConclugenTopicPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Based on the evidence presented, what is the most logical and justifiable stance to take on the issue at hand?\n" + for sample in data_train.samples: + topic = sample.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample.src.split('<|ARGUMENT|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += f"Conclusion: {sample.tgt}\n\n" + + topic = sample_test.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample_test.src.split('<|ARGUMENT|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += "Conclusion: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class ConclugenAspectsPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Based on the evidence presented, what is the most logical and justifiable stance to take on the issue at hand?\n" + for sample in data_train.samples: + topic = sample.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample.src.split('<|ARGUMENT|>')[1].split('<|ASPECTS|>')[0] + aspects = sample.src.split('<|ASPECTS|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += f"Aspects: {aspects}\n" + prompt += f"Conclusion: {sample.tgt}\n\n" + + topic = sample_test.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample_test.src.split('<|ARGUMENT|>')[1].split('<|ASPECTS|>')[0] + aspects = sample_test.src.split('<|ASPECTS|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += f"Aspects: {aspects}\n" + prompt += "Conclusion: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class ConclugenTargetsPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Based on the evidence presented, what is the most logical and justifiable stance to take on the issue at hand?\n" + for sample in data_train.samples: + topic = sample.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample.src.split('<|ARGUMENT|>')[1].split('<|TARGETS|>')[0] + targets = sample.src.split('<|TARGETS|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += f"Targets: {targets}\n" + prompt += f"Conclusion: {sample.tgt}\n\n" + + topic = sample_test.src.split('<|TOPIC|>')[1].split('<|ARGUMENT|>')[0] + argument = sample_test.src.split('<|ARGUMENT|>')[1].split('<|TARGETS|>')[0] + targets = sample_test.src.split('<|TARGETS|>')[1].split('<|CONCLUSION|>')[0] + prompt += f"Topic: {topic}\n" + prompt += f"Argument: {argument}\n" + prompt += f"Targets: {targets}\n" + prompt += "Conclusion: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class DebatesumAbstractivePrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "What is the main idea or argument presented in the document?\n" + for sample in data_train.samples: + prompt += f"Document: {sample.src}\n" + prompt += f"Abstractive Summary: {sample.tgt}\n\n" + + prompt += f"Document: {sample_test.src}\n" + prompt += "Abstractive Summary: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class DebatesumExtractivePrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the main points and supporting evidence in the document that support the argument being made.\n" + for sample in data_train.samples: + prompt += f"Document: {sample.src}\n" + prompt += f"Extractive Summary: {sample.tgt}\n\n" + + prompt += f"Document: {sample_test.src}\n" + prompt += "Extractive Summary: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class CounterargumentPremisesPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify a premise for the claim and come up with a counter-argument that challenges the validity of that premise.\n" + for sample in data_train.samples: + claim, premises = sample.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Premises: {premises}\n" + prompt += f"Counter Argument: {sample.tgt}\n\n" + + claim, premises = sample_test.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Premises: {premises}\n" + prompt += "Counter Argument: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class CounterargumentWeakPremisesPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify a weak premise for the claim and formulate a counter-argument to challenge it.\n" + for sample in data_train.samples: + claim, premises = sample.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Weak Premises: {premises}\n" + prompt += f"Counter Argument: {sample.tgt}\n\n" + + claim, premises = sample_test.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Weak Premises: {premises}\n" + prompt += "Counter Argument: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n\n')[0] + return text.strip() + + +class ClaimDetectionPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify whether the given sentence is a claim towards the given topic. Choose from 'claim' or 'non claim'.\n" + for sample in data_train.samples: + topic, sentence = sample.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += f"Label: {sample.tgt}\n\n" + + topic, sentence = sample_test.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + text = text.replace('-', ' ') + return text.strip().lower() + + +class ArgumentDetectionPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify whether the given sentence is an argument towards the given topic. Choose from 'argument' or 'non argument'.\n" + for sample in data_train.samples: + topic, sentence = sample.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += f"Label: {sample.tgt}\n\n" + + topic, sentence = sample_test.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + text = text.replace('-', ' ') + return text.strip().lower() + + +class EvidenceDetectionIAMPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify whether the given sentence is a piece of evidence towards the given claim. Choose from 'evidence' or 'non evidence'.\n" + for sample in data_train.samples: + claim, sentence = sample.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Sentence: {sentence}\n" + prompt += f"Label: {sample.tgt}\n\n" + + claim, sentence = sample_test.src.split('\t') + prompt += f"Claim: {claim}\n" + prompt += f"Sentence: {sentence}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + text = text.replace('-', ' ') + return text.strip().lower() + + +class EvidenceDetectionIBMPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify whether the given sentence is a piece of evidence towards the given topic. Choose from 'evidence' or 'non evidence'.\n" + for sample in data_train.samples: + topic, sentence = sample.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += f"Label: {sample.tgt}\n\n" + + topic, sentence = sample_test.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + text = text.replace('-', ' ') + return text.strip().lower() + + +class StanceDetectionPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the stance of the given sentence towards the given topic. Choose from 'support' or 'attack'.\n" + for sample in data_train.samples: + topic, sentence = sample.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += f"Label: {sample.tgt}\n\n" + + topic, sentence = sample_test.src.split('\t') + prompt += f"Topic: {topic}\n" + prompt += f"Sentence: {sentence}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + return text.strip().lower() + + +class StanceDetectionFeverPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the stance of the given sentence. Choose from 'support', 'attack', or 'neutral'.\n" + for sample in data_train.samples: + prompt += f"Sentence: {sample.src}\n" + prompt += f"Label: {sample.tgt}\n\n" + + prompt += f"Sentence: {sample_test.src}\n" + prompt += "Label: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + return text.strip().lower() + + +class StanceDetectionMTSDPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the stance of the given sentence towards each given target. Choose from 'support', 'attack', or 'neutral' for each target in the target pair. Format the output as a label pair: label1, label2.\n" + for sample in data_train.samples: + sentence, target1, target2 = sample.src.split('\t') + label1, label2 = sample.tgt.split('\t') + prompt += f"Sentence: {sentence}\n" + prompt += f"Target Pair: {target1}, {target2}\n" + prompt += f"Label Pair: {label1}, {label2}\n\n" + + sentence, target1, target2 = sample_test.src.split('\t') + prompt += f"Sentence: {sentence}\n" + prompt += f"Target Pair: {target1}, {target2}\n" + prompt += "Label Pair: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + text = text.replace(', ', '\t') + return text.strip().lower() + + +class EvidenceClassificationIBMPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the evidence type of the given sentence. Choose from 'study', 'anecdotal' or 'expert'.\n" + for sample in data_train.samples: + prompt += f"Sentence: {sample.src}\n" + prompt += f"Evidence Type: {sample.tgt}\n\n" + + prompt += f"Sentence: {sample_test.src}\n" + prompt += "Evidence Type: " + return prompt + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + return text.strip().lower() + + +class EvidenceClassificationAQEPrompter(Prompter): + def run(self, data_train: ArgumentData, sample_test: ArgumentSample) -> str: + prompt = "Identify the evidence type of the given sentence. Choose from 'research', 'case', 'expert', 'explanation' or 'others'.\n" + for sample in data_train.samples: + prompt += f"Sentence: {sample.src}\n" + prompt += f"Evidence Type: {sample.tgt}\n\n" + + prompt += f"Sentence: {sample_test.src}\n" + prompt += "Evidence Type: " + return prompt + + + def get_answer(self, text: str) -> str: + text = text.split('\n')[0] + return text.strip().lower() + + +def select_prompter(task: str, data_name: str) -> Prompter: + if task == "conclugen": + if data_name == "base": + return ConclugenBasePrompter() + elif data_name == "aspects": + return ConclugenAspectsPrompter() + elif data_name == "targets": + return ConclugenTargetsPrompter() + elif data_name == "topic": + return ConclugenTopicPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "debatesum": + if data_name == "abstract": + return DebatesumAbstractivePrompter() + elif data_name == "extract": + return DebatesumExtractivePrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "counter_arg_gen": + if data_name == "weak_premises": + return CounterargumentWeakPremisesPrompter() + elif data_name == "premises": + return CounterargumentPremisesPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "claim_detection": + if data_name == "ibm_argument": + return ArgumentDetectionPrompter() + elif data_name in ["iam_claims", "ibm_claims"]: + return ClaimDetectionPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "evidence_detection": + if data_name == "iam_evidence": + return EvidenceDetectionIAMPrompter() + elif data_name == "ibm_evidence": + return EvidenceDetectionIBMPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "stance_detection": + if data_name == "fever": + return StanceDetectionFeverPrompter() + elif data_name == "mtsd": + return StanceDetectionMTSDPrompter() + elif data_name in ["ibm_stance", "iam_stance"]: + return StanceDetectionPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + elif task == "evidence_classification": + if data_name == "aqe_type": + return EvidenceClassificationAQEPrompter() + elif data_name == "ibm_type": + return EvidenceClassificationIBMPrompter() + else: + raise ValueError(f"Invalid data name: {data_name}") + else: + raise ValueError(f"Invalid task: {task}") + + +def test_prompt(task: str, data_name: str, num_train: int, seed: int): + data_train, data_test = ArgumentData.load(task, data_name, num_train, seed) + prompter = select_prompter(task, data_name) + prompt = prompter.run(data_train, data_test.samples[0]) + print(prompt) + + +if __name__ == "__main__": + Fire() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5007e5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +numpy==1.26.4 +fire==0.5.0 +pydantic==2.0 +torch==2.1.2 +tqdm==4.65.0 +sentencepiece==0.1.97 +datasets==2.10.1 +bitsandbytes==0.41.1 +accelerate==0.27.2 +editdistance==0.6.2 +tiktoken==0.6.0 +openai==0.27.7 +transformers==4.39.1 +nltk==3.8.1 +scikit-learn==1.2.2 +evaluate==0.4.0 +bert_score==0.3.13 +rouge_score==0.1.2 +protobuf==4.24.2 +einops==0.6.1 +transformers_stream_generator==0.0.4 +ninja==1.11.1 +vllm==0.4.0.post1 \ No newline at end of file diff --git a/scoring.py b/scoring.py new file mode 100644 index 0000000..6714f58 --- /dev/null +++ b/scoring.py @@ -0,0 +1,59 @@ +from fire import Fire +from typing import List +from pydantic import BaseModel +import evaluate +from sklearn.metrics import accuracy_score, f1_score + +from data_loading import ArgumentData + +class AMScorer(BaseModel): + @staticmethod + def run(predictions: List[str], targets: List[str]) -> float: + accuracy = accuracy_score(targets, predictions) + f1 = f1_score(targets, predictions, average="weighted") + return dict(accuracy=accuracy, f1=f1) + +class AGScorer(BaseModel): + @staticmethod + def run(predictions: List[str], references: List[str]) -> dict: + # bert score + bertscore = evaluate.load("bertscore") + results = bertscore.compute(predictions=predictions, references=references, lang="en") + bertscore = sum(results["f1"])/len(results["f1"]) + + # rouge score + rouge = evaluate.load('rouge') + results = rouge.compute(predictions=predictions,references=references) + rouge1 = results["rouge1"] + rouge2 = results["rouge2"] + rougeL = results["rougeL"] + + # meteor + meteor = evaluate.load('meteor') + results = meteor.compute(predictions=predictions, references=references) + meteor = results["meteor"] + + return dict(bertscore=bertscore, rouge1=rouge1, rouge2=rouge2, rougeL=rougeL, meteor=meteor) + + +def select_scorer(task: str): + if task in ["claim_detection", "evidence_detection", "stance_detection", "evidence_classification"]: + return AMScorer() + elif task in ["counter_arg_gen", "conclugen", "debatesum"]: + return AGScorer() + else: + raise NotImplementedError + + +def test_scorer(task: str, output_path: str): + data = ArgumentData.load_outputs(output_path) + predictions = [sample.pred for sample in data.samples] + targets = [sample.tgt for sample in data.samples] + + scorer = select_scorer(task) + scores = scorer.run(predictions, targets) + print(scores) + + +if __name__ == "__main__": + Fire()