Skip to content

Commit 7ca0d43

Browse files
authored
added creative writing category to category.py and config (#3584)
1 parent dd90e21 commit 7ca0d43

File tree

4 files changed

+177
-0
lines changed

4 files changed

+177
-0
lines changed
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
## Download dataset
2+
We have pre-generated several category classifier benchmarks and ground truths. You can download them (with [`git-lfs`](https://git-lfs.com) installed) to the directory `classify/` by running
3+
```console
4+
> git clone https://huggingface.co/datasets/lmarena-ai/categories-benchmark-eval
5+
// cd into classify/ and then copy the label_bench directory to the current directory
6+
> cp -r categories-benchmark-eval/label_bench .
7+
```
8+
Your label_bench directory should follow the structure:
9+
```markdown
10+
├── label_bench/
11+
│ ├── creative_writing_bench/
12+
│ │ ├── data/
13+
│ │ │ └── llama-v3p1-70b-instruct.json
14+
│ │ └── test.json
15+
│ ├── ...
16+
│ ├── your_bench_name/
17+
│ │ ├── data/
18+
│ │ │ ├── your_classifier_data_1.json
19+
│ │ │ ├── your_classifier_data_2.json
20+
│ │ │ └── ...
21+
│ │ └── test.json (your ground truth)
22+
└── ...
23+
```
24+
25+
## How to evaluate your category classifier?
26+
27+
To test your new classifier for a new category, you would have to make sure you created the category child class in `category.py`. Then, to generate classification labels, make the necessary edits in `config.yaml` and run
28+
```console
29+
python label.py --config config.yaml --testing
30+
```
31+
32+
Then, add your new category bench to `tag_names` in `display_score.py`. After making sure that you also have a correctly formatted ground truth json file, you can report the performance of your classifier by running
33+
```console
34+
python display_score.py --bench <your_bench>
35+
```
36+
37+
If you want to check out conflicts between your classifier and ground truth, use
38+
```console
39+
python display_score.py --bench <your_bench> --display-conflict
40+
```
41+
42+
Example output:
43+
```console
44+
> python display_score.py --bench if_bench --display-conflict
45+
Model: gpt-4o-mini-2024-07-18
46+
Accuracy: 0.967
47+
Precision: 0.684
48+
Recall: 0.918
49+
50+
###### CONFLICT ######
51+
52+
Ground Truth = True; Pred = False
53+
\####################
54+
...
55+
56+
Ground Truth = False; Pred = True
57+
\####################
58+
...
59+
```
60+

fastchat/serve/monitor/classify/category.py

+40
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def create_category(name):
2424
return CategoryIF()
2525
elif name == "math_v0.1":
2626
return CategoryMath()
27+
elif name == "creative_writing_v0.1":
28+
return CategoryCreativeWriting()
2729

2830
raise Exception(f"Category name is incorrect: {name}")
2931

@@ -134,3 +136,41 @@ def pre_process(self, prompt):
134136
def post_process(self, judgment):
135137
score = self.get_score(judgment=judgment)
136138
return {"math": bool(score == "yes") if score else False}
139+
140+
141+
class CategoryCreativeWriting(Category):
142+
def __init__(self):
143+
super().__init__()
144+
self.name_tag = "creative_writing_v0.1"
145+
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
146+
self.system_prompt = 'You are tasked with determining whether a given user prompt is asking for creative writing. Creative writing is defined as any form of writing that goes beyond standard professional, journalistic, academic, or technical literature. It typically involves imagination, originality, and expression of thoughts and emotions. Creative writing can include, but is not limited to, the following formats:\n- Fiction (e.g., short stories, novels)\n- Poetry (e.g., sonnets, free verse)\n- Dramatic writing (e.g., screenplays, monologues, scripts)\n- Personal essays (focusing on subjective experiences or narrative storytelling)\n- Songs and lyrics\n\nCarefully analyze the user prompt and consider whether it primarily requires creative writing. Think about the following aspects:\n1. Does the prompt ask for fictional content, speculative scenarios, or the use of imagination to construct narratives?\n2. Does it encourage the expression of thoughts, emotions, or personal experiences beyond mere factual reporting or analysis?\n3. Is it asking for writing in a specific creative format (e.g., story, poem, script, etc)?\n4. Is the primary purpose of the prompt to foster creative expression or originality rather than information delivery, technical documentation, or analytical reasoning?\n5. Does the prompt request stylistic or rhetorical elements often associated with creative writing, such as metaphor, imagery, dialogue, etc?\n6. Does the prompt expect a response in natural language (e.g., sentences, paragraphs) rather than visual, mathematical, or non-linguistic output?\n\nOutput your verdict as either "yes" or "no"in the following format:\n<decision>\n[yes/no]\n</decision>. Do NOT explain.'
147+
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"
148+
149+
def get_score(self, judgment):
150+
matches = self.pattern.findall(
151+
judgment.replace("\n", "")
152+
.replace("[", "")
153+
.replace("]", "")
154+
.replace(" ", "")
155+
.lower()
156+
)
157+
matches = [m for m in matches if m != ""]
158+
if len(set(matches)) == 0:
159+
return None
160+
elif len(set(matches)) == 1:
161+
return matches[0]
162+
else:
163+
return None
164+
165+
def pre_process(self, prompt):
166+
args = {"PROMPT": prompt}
167+
conv = [
168+
{"role": "system", "content": self.system_prompt},
169+
{"role": "user", "content": self.prompt_template.format(**args)},
170+
]
171+
return conv
172+
173+
def post_process(self, judgment):
174+
score = self.get_score(judgment=judgment)
175+
bool_score = bool(score == "yes") if score else False
176+
return {"creative_writing": bool_score, "score": score}

fastchat/serve/monitor/classify/config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ task_name:
1010
- criteria_v0.1
1111
- if_v0.1
1212
- math_v0.1
13+
- creative_writing_v0.1
1314

1415
model_name: null
1516
name: llama-3-70b-instruct
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pandas as pd
2+
import argparse
3+
import os
4+
from glob import glob
5+
from sklearn.metrics import recall_score, precision_score
6+
7+
tag_names = {
8+
"if_bench": ("if_v0.1", "if"),
9+
"math_bench": ("math_v0.1", "math"),
10+
"hard_bench": ("criteria_v0.1", "hard"),
11+
"creative_writing_bench": ("creative_writing_v0.1", "creative_writing"),
12+
}
13+
14+
15+
if __name__ == "__main__":
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument("--bench", type=str, default="if_bench")
18+
parser.add_argument("--display-conflict", action="store_true")
19+
args = parser.parse_args()
20+
assert args.bench in tag_names, "Not valid bench argument, add bench if needed."
21+
22+
test = pd.read_json(os.path.join("label_bench", args.bench, "test.json"))
23+
24+
for file in glob(os.path.join("label_bench", args.bench, "data", "*.json")):
25+
output = pd.read_json(file)
26+
27+
tag_map = (
28+
output[["question_id", "category_tag"]]
29+
.set_index("question_id")
30+
.to_dict("index")
31+
)
32+
33+
tag_1, tag_2 = tag_names[args.bench]
34+
test["pred"] = test.question_id.map(
35+
lambda id: tag_map[id]["category_tag"][tag_1][tag_2]
36+
)
37+
38+
accuracy = (test.label == test.pred).mean()
39+
recall = recall_score(y_pred=test.pred, y_true=test.label)
40+
precision = precision_score(y_pred=test.pred, y_true=test.label)
41+
42+
print(f"Model: {output.model[0]}")
43+
print(f"Accuracy: {round(accuracy, 3)}")
44+
print(f"Precision: {round(precision, 3)}")
45+
print(f"Recall: {round(recall, 3)}")
46+
47+
if args.display_conflict:
48+
print()
49+
print("###### CONFLICT ######")
50+
print()
51+
conflict = test[test.label & ~test.pred]
52+
print("Ground Truth = True; Pred = False")
53+
prompts = (
54+
conflict.conversation_a.map(lambda x: x[0]["content"])
55+
.sample(n=5)
56+
.tolist()
57+
)
58+
for prompt in prompts:
59+
print("####################")
60+
print(prompt)
61+
print()
62+
print()
63+
64+
conflict = test[~test.label & test.pred]
65+
print("Ground Truth = False; Pred = True")
66+
prompts = (
67+
conflict.conversation_a.map(lambda x: x[0]["content"])
68+
.sample(n=5)
69+
.tolist()
70+
)
71+
for prompt in prompts:
72+
print("####################")
73+
print(prompt)
74+
print()
75+
print()
76+
print()

0 commit comments

Comments
 (0)