Skip to content

Commit

Permalink
补充最新代码
Browse files Browse the repository at this point in the history
  • Loading branch information
jodie-kang committed May 28, 2024
1 parent 310cdb2 commit 7a18d5b
Show file tree
Hide file tree
Showing 16 changed files with 1,849 additions and 101 deletions.
25 changes: 17 additions & 8 deletions ch06/01_main-chapter-code/ch06.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
Expand Down Expand Up @@ -907,7 +916,7 @@
"id": "ab8e056c-abe0-415f-b34d-df686204259e",
"metadata": {},
"source": [
"- To ensure that the model was loaded corrected, let's double-check that it generates coherent text"
"- 为了确保模型加载正确,让我们仔细检查它是否生成连贯的文本。"
]
},
{
Expand Down Expand Up @@ -951,7 +960,7 @@
"id": "69162550-6a02-4ece-8db1-06c71d61946f",
"metadata": {},
"source": [
"- Before we finetune the model as a classifier, let's see if the model can perhaps already classify spam messages via prompting"
"- 在我们将模型微调为分类器之前,让我们看看模型是否已经可以通过提示对垃圾邮件进行分类。"
]
},
{
Expand Down Expand Up @@ -991,8 +1000,8 @@
"id": "1ce39ed0-2c77-410d-8392-dd15d4b22016",
"metadata": {},
"source": [
"- As we can see, the model is not very good at following instructions\n",
"- This is expected, since it has only been pretrained and not instruction-finetuned (instruction finetuning will be covered in the next chapter)"
"- 正如我们所看到的,该模型不太擅长遵循指令\n",
"- 这是预料之中的,因为它只经过了预训练,没有进行指令微调(指令微调将在下一章中介绍)"
]
},
{
Expand Down
17 changes: 13 additions & 4 deletions ch06/01_main-chapter-code/exercise-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
"id": "ba450fb1-8a26-4894-ab7a-5d7bfefe90ce",
"metadata": {},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
Expand Down
44 changes: 33 additions & 11 deletions ch06/01_main-chapter-code/gpt-class-finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,34 @@
from previous_chapters import GPTModel, load_weights_into_gpt


def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download and extraction.")
return

# Downloading the file
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
if test_mode: # Try multiple times since CI sometimes has connectivity issues
max_retries = 5
delay = 5 # delay between retries in seconds
for attempt in range(max_retries):
try:
# Downloading the file
with urllib.request.urlopen(url, timeout=10) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())
break # if download is successful, break out of the loop
except urllib.error.URLError as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
time.sleep(delay) # wait before retrying
else:
print("Failed to download file after several attempts.")
return # exit if all retries fail

else: # Code as it appears in the chapter
# Downloading the file
with urllib.request.urlopen(url) as response:
with open(zip_path, "wb") as out_file:
out_file.write(response.read())

# Unzipping the file
with zipfile.ZipFile(zip_path, "r") as zip_ref:
Expand Down Expand Up @@ -238,6 +257,7 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los
)
parser.add_argument(
"--test_mode",
default=False,
action="store_true",
help=("This flag runs the model in test mode for internal testing purposes. "
"Otherwise, it runs the model as it is used in the chapter (recommended).")
Expand All @@ -253,7 +273,7 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode)
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
balanced_df = create_balanced_dataset(df)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
Expand Down Expand Up @@ -330,9 +350,7 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los
}
model = GPTModel(BASE_CONFIG)
model.eval()

device = "cpu"
model.to(device)

# Code as it is used in the main chapter
else:
Expand All @@ -355,15 +373,18 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los

BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

assert train_dataset.max_length <= BASE_CONFIG["context_length"], (
f"Dataset length {train_dataset.max_length} exceeds model's context "
f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "
f"`max_length={BASE_CONFIG['context_length']}`"
)

model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

########################################
# Modify and pretrained model
Expand All @@ -376,6 +397,7 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los

num_classes = 2
model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)
model.to(device)

for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
Expand Down
2 changes: 1 addition & 1 deletion ch06/01_main-chapter-code/gpt_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
last_key = variable_name_parts[-1]
target_dict[last_key] = variable_array

return params
return params
Loading

0 comments on commit 7a18d5b

Please sign in to comment.