Skip to content

Commit ad0f7ba

Browse files
committed
Update llm_lib to use include untuned examples
- Allow untuned examples to be used as part of the chat prompt - Updates format goldens to generate the untuned examples subset - Minor adjustments to prompt
1 parent 361914c commit ad0f7ba

File tree

4 files changed

+141
-48
lines changed

4 files changed

+141
-48
lines changed

ai/src/ai/common/llm_lib.py

+93-32
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,15 @@
1+
import json
12
import re
23
from os import getenv
34
from typing import NamedTuple
45

6+
from dotenv import load_dotenv
57
from openai import OpenAI
68
from openai.types.chat import (
79
ChatCompletionMessageParam,
810
)
911

10-
SYSTEM_INSTRUCTION_PART_1_PATH = "src/ai/prompts/mesop_overview.txt"
11-
# SYSTEM_INSTRUCTION_PART_2_PATH = "src/ai/prompts/mini_docs.txt"
12-
13-
with open(SYSTEM_INSTRUCTION_PART_1_PATH) as f:
14-
SYSTEM_INSTRUCTION_PART_1 = f.read()
15-
16-
# with open(SYSTEM_INSTRUCTION_PART_2_PATH) as f:
17-
# SYSTEM_INSTRUCTION_PART_2 = f.read()
18-
19-
# Intentionally skip the more extensive system instruction with docs for now.
20-
SYSTEM_INSTRUCTION = SYSTEM_INSTRUCTION_PART_1 # + SYSTEM_INSTRUCTION_PART_2
21-
PROMPT_PATH = "src/ai/prompts/revise_prompt.txt"
22-
23-
with open(PROMPT_PATH) as f:
24-
REVISE_APP_BASE_PROMPT = f.read().strip()
12+
load_dotenv()
2513

2614
EDIT_HERE_MARKER = " # <--- EDIT HERE"
2715

@@ -31,6 +19,11 @@ class ApplyPatchResult(NamedTuple):
3119
result: str
3220

3321

22+
def read_file(filepath: str) -> str:
23+
with open(filepath) as f:
24+
return f.read().strip()
25+
26+
3427
def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:
3528
# Extract the diff content
3629
diff_pattern = r"<<<<<<< ORIGINAL(.*?)=======\n(.*?)>>>>>>> UPDATED"
@@ -64,24 +57,87 @@ def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:
6457
)
6558

6659

60+
class MessageFormatter:
61+
def __init__(self, system_instruction: str, revise_app_prompt: str):
62+
self.system_instruction = system_instruction
63+
self.revise_app_prompt = revise_app_prompt
64+
65+
def format_messages(
66+
self, code: str, user_input: str, line_number: int | None
67+
) -> list[ChatCompletionMessageParam]:
68+
# Add sentinel token based on line_number (1-indexed)
69+
if line_number is not None:
70+
code_lines = code.splitlines()
71+
if 1 <= line_number <= len(code_lines):
72+
code_lines[line_number - 1] += EDIT_HERE_MARKER
73+
code = "\n".join(code_lines)
74+
75+
formatted_prompt = self.revise_app_prompt.replace(
76+
"<APP_CODE>", code
77+
).replace("<APP_CHANGES>", user_input)
78+
79+
return [
80+
{"role": "system", "content": self.system_instruction},
81+
{"role": "user", "content": formatted_prompt},
82+
]
83+
84+
85+
def MakeDefaultMessageFormatter():
86+
system_instructions = read_file("src/ai/prompts/mesop_overview.txt")
87+
base_prompt = read_file("src/ai/prompts/revise_prompt_base.txt")
88+
prompt = read_file("src/ai/prompts/revise_prompt_shorter.txt")
89+
return MessageFormatter(system_instructions, base_prompt + "\n\n" + prompt)
90+
91+
92+
def MakeMessageFormatterShorterUserMsg():
93+
"""Formats user messages with a shorter prompt.
94+
95+
We use a shorter prompt since we will be including goldens that
96+
have not been fine-tuned yet. This allows us to test new training
97+
data without having to fine tune all the time.
98+
99+
Instead the main user instruction prompt will be bundled with user
100+
instructions instead.
101+
"""
102+
system_instructions = read_file("src/ai/prompts/mesop_overview.txt")
103+
revise_instructions = read_file("src/ai/prompts/revise_prompt_base.txt")
104+
prompt = read_file("src/ai/prompts/revise_prompt_shorter.txt")
105+
return MessageFormatter(
106+
system_instructions + "\n\n" + revise_instructions, prompt
107+
)
108+
109+
110+
def load_unused_goldens():
111+
goldens_path = "ft/gen/formatted_dataset_for_prompting.jsonl"
112+
new_goldens = []
113+
num_rows = 0
114+
try:
115+
with open(goldens_path) as f:
116+
for row in f:
117+
num_rows += 1
118+
messages = json.loads(row)["messages"]
119+
new_goldens.append(messages[1])
120+
new_goldens.append(messages[2])
121+
new_goldens.pop(0) # Remove the redundant system instruction
122+
print(f"Adding {num_rows} additional examples to prompt.")
123+
except FileNotFoundError as e:
124+
print(e)
125+
126+
return new_goldens
127+
128+
129+
if getenv("MESOP_AI_INCLUDE_NEW_GOLDENS"):
130+
message_formatter = MakeMessageFormatterShorterUserMsg()
131+
goldens_path = load_unused_goldens()
132+
else:
133+
message_formatter = MakeDefaultMessageFormatter()
134+
goldens_path = []
135+
136+
67137
def format_messages(
68138
code: str, user_input: str, line_number: int | None
69139
) -> list[ChatCompletionMessageParam]:
70-
# Add sentinel token based on line_number (1-indexed)
71-
if line_number is not None:
72-
code_lines = code.splitlines()
73-
if 1 <= line_number <= len(code_lines):
74-
code_lines[line_number - 1] += EDIT_HERE_MARKER
75-
code = "\n".join(code_lines)
76-
77-
formatted_prompt = REVISE_APP_BASE_PROMPT.replace("<APP_CODE>", code).replace(
78-
"<APP_CHANGES>", user_input
79-
)
80-
81-
return [
82-
{"role": "system", "content": SYSTEM_INSTRUCTION},
83-
{"role": "user", "content": formatted_prompt},
84-
]
140+
return message_formatter.format_messages(code, user_input, line_number)
85141

86142

87143
def adjust_mesop_app_stream(
@@ -97,9 +153,12 @@ def adjust_mesop_app_stream(
97153
"""
98154
messages = format_messages(code, user_input, line_number)
99155

156+
if goldens_path:
157+
messages = [messages[0], *goldens_path, messages[1]]
158+
100159
return client.chat.completions.create(
101160
model=model,
102-
max_tokens=10_000,
161+
max_tokens=16_384,
103162
messages=messages,
104163
stream=True,
105164
)
@@ -117,10 +176,12 @@ def adjust_mesop_app_blocking(
117176
Returns the code diff.
118177
"""
119178
messages = format_messages(code, user_input, line_number)
179+
if goldens_path:
180+
messages = [messages[0], *goldens_path, messages[1]]
120181

121182
response = client.chat.completions.create(
122183
model=model,
123-
max_tokens=10_000,
184+
max_tokens=16_384,
124185
messages=messages,
125186
stream=False,
126187
)

ai/src/ai/prompts/revise_prompt.txt renamed to ai/src/ai/prompts/revise_prompt_base.txt

+1-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Your task is to modify a Mesop app given the code and a description.
33
Make sure to remember these rules when making modifications:
44
1. For the @me.page decorator, keep it the same as the original *unless* you need to modify on_load.
55
2. Event handler functions cannot use lambdas. You must use functions.
6-
3. Event handle functions only pass in the event type. They do not accept extra parameters.
6+
3. Event handler functions only pass in the event type. They do not accept extra parameters.
77
4. For padding, make sure to use the the `me.Padding` object rather than a string or int.
88
5. For margin, make sure to use the the `me.Margin` object rather than a string or int.
99
6. For border, make sure to use the the `me.Border` and `me.BorderSide` objects rather than a string.
@@ -143,13 +143,3 @@ def page():
143143
>>>>>>> UPDATED
144144

145145
OK, now that I've shown you an example, let's do this for real.
146-
147-
Existing app code:
148-
```
149-
<APP_CODE>
150-
```
151-
152-
User instructions:
153-
<APP_CHANGES>
154-
155-
Diff output:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Existing app code:
2+
```
3+
<APP_CODE>
4+
```
5+
6+
User instructions:
7+
<APP_CHANGES>
8+
9+
Diff output:

ai/src/format_goldens.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,22 @@
22
Formats the golden dataset for the fine-tuning process.
33
"""
44

5+
import argparse
56
import json
67
import os
8+
from datetime import datetime
79
from typing import Any
810

9-
from ai.common.llm_lib import format_messages
11+
from ai.common.llm_lib import (
12+
MakeDefaultMessageFormatter,
13+
MakeMessageFormatterShorterUserMsg,
14+
)
1015

16+
# TODO: Allow this to be configurable
17+
FINE_TUNING_CUTOFF = datetime(2024, 8, 2, 0, 0)
1118

12-
def process_goldens():
19+
20+
def process_goldens(skip_fine_tuned_goldens: bool = False):
1321
dataset: list[dict[str, Any]] = []
1422
outputs_dir = "ft/goldens"
1523

@@ -20,6 +28,13 @@ def process_goldens():
2028
diff_path = os.path.join(dir_path, "diff.txt")
2129
line_number: int | None = None
2230
meta_path = os.path.join(dir_path, "metadata.json")
31+
32+
_, timestamp = os.path.basename(dir_path).rsplit("_", 1)
33+
creation_date = datetime.strptime(timestamp, "%Y%m%d%H%M")
34+
print(creation_date)
35+
if skip_fine_tuned_goldens and creation_date < FINE_TUNING_CUTOFF:
36+
continue
37+
2338
if os.path.exists(meta_path):
2439
with open(meta_path) as meta_file:
2540
meta = json.load(meta_file)
@@ -39,10 +54,15 @@ def process_goldens():
3954
else:
4055
code = ""
4156

57+
if skip_fine_tuned_goldens:
58+
formatter = MakeMessageFormatterShorterUserMsg()
59+
else:
60+
formatter = MakeDefaultMessageFormatter()
61+
4262
dataset.append(
4363
{
4464
"messages": [
45-
*format_messages(code, prompt, line_number),
65+
*formatter.format_messages(code, prompt, line_number),
4666
{
4767
"role": "assistant",
4868
"content": diff,
@@ -55,11 +75,24 @@ def process_goldens():
5575

5676

5777
if __name__ == "__main__":
58-
formatted_dataset = process_goldens()
78+
parser = argparse.ArgumentParser()
79+
parser.add_argument(
80+
"--skip_fine_tuned_goldens",
81+
action="store_true",
82+
help="Generates a formatted dataset with goldens that have not been fine tuned.",
83+
)
84+
args = parser.parse_args()
85+
86+
formatted_dataset = process_goldens(args.skip_fine_tuned_goldens)
5987
print(f"Processed {len(formatted_dataset)} samples.")
6088
# create gen dir if it doesn't exist
6189
os.makedirs("ft/gen", exist_ok=True)
62-
full_path = os.path.join("ft/gen/formatted_dataset.jsonl")
90+
91+
if args.skip_fine_tuned_goldens:
92+
full_path = os.path.join("ft/gen/formatted_dataset_for_prompting.jsonl")
93+
else:
94+
full_path = os.path.join("ft/gen/formatted_dataset.jsonl")
95+
6396
# Append each sample as a JSON object on a separate line to a file
6497
with open(full_path, "w") as f:
6598
for sample in formatted_dataset:

0 commit comments

Comments
 (0)