Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [TT-Train] TTNN Training #16617

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tt-train/configs/training_ttnn_gpt2s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
training_config:
project_name: "tt_train_nano_gpt"
seed: 5489
model_save_interval: 100
model_path: "ttnn_gpt2s_bpe.msgpack"
batch_size: 16
num_epochs: 1
max_steps: 5000 #2.5B tokens used for training
learning_rate: 0.0003
weight_decay: 0.01
use_moreh_adamw: true
use_kahan_summation: false
gradient_accumulation_steps: 8
tokenizer_path: "data/train_ttnn/tokenizer.json"
data_path: "data/train_ttnn/data.txt"
scheduler_type: "warmup_linear"
#tokenizer_type: bpe
transformer_config:
runner_type: memory_efficient
num_heads: 12
embedding_dim: 768
dropout_prob: 0.2
num_blocks: 12
vocab_size: 224
max_sequence_length: 1024
experimental:
use_composite_layernorm: false
40 changes: 23 additions & 17 deletions tt-train/sources/examples/nano_gpt/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "ttnn_fixed/trivial_ttnn_ops.hpp"
#include "utils.hpp"

namespace {
constexpr auto gpt2_tokenizer_file_name = "/gpt2-tokenizer.json";
}

/* WANDB BLocks this signal.
Control+C didn't work.
*/
Expand All @@ -47,7 +51,7 @@ using DataLoader = ttml::datasets::DataLoader<
uint32_t sample(std::span<const float> log_softmax) {
auto probabilities_vector = std::vector<float>(log_softmax.size());
std::transform(log_softmax.begin(), log_softmax.end(), probabilities_vector.begin(), [](float value) {
return std::exp(value);
return std::exp(value / 0.8F);
});
auto distribution = std::discrete_distribution<uint32_t>(probabilities_vector.begin(), probabilities_vector.end());
return distribution(ttml::autograd::ctx().get_generator());
Expand Down Expand Up @@ -76,7 +80,7 @@ void generate(

auto pad_token_id = 0U;

auto vocab_size = tokenizer.get_vocab_size();
auto vocab_size = round_up_to_tile(tokenizer.get_vocab_size());

std::vector<float> mask;
mask.reserve(static_cast<size_t>(max_sequence_length * max_sequence_length * num_heads));
Expand Down Expand Up @@ -146,7 +150,7 @@ struct TrainingConfig {
std::string data_path;
std::string tokenizer_type = "char";
std::string scheduler_type = "identity";

std::string tokenizer_path = std::string(DATA_FOLDER) + gpt2_tokenizer_file_name;
ttml::models::gpt2::TransformerConfig transformer_config;
};

Expand All @@ -169,7 +173,7 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) {
config.data_path = training_config["data_path"].as<std::string>(std::string(DATA_FOLDER) + "/shakespeare.txt");
config.tokenizer_type = training_config["tokenizer_type"].as<std::string>(config.tokenizer_type);
config.scheduler_type = training_config["scheduler_type"].as<std::string>(config.scheduler_type);

config.tokenizer_path = training_config["tokenizer_path"].as<std::string>(config.tokenizer_path);
config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]);
return config;
}
Expand Down Expand Up @@ -249,19 +253,21 @@ int main(int argc, char **argv) {
fmt::print("Seed {}\n", ttml::autograd::ctx().get_seed());
auto sequence_length = config.transformer_config.max_sequence_length;

auto create_dataset_and_tokenizer = [](const auto &text, const auto sequence_length, const auto &tokenizer_type) {
if (tokenizer_type == "char") {
return ttml::datasets::create_in_memory_token_dataset<ttml::tokenizers::CharTokenizer>(
text, sequence_length);
} else if (tokenizer_type == "bpe") {
return ttml::datasets::create_in_memory_token_dataset<ttml::tokenizers::BPETokenizer>(
text, sequence_length);
} else {
throw std::runtime_error("Unknown tokenizer type: " + tokenizer_type);
}
};
auto create_dataset_and_tokenizer =
[](const auto &text, const auto sequence_length, const auto &tokenizer_path, const auto &tokenizer_type) {
if (tokenizer_type == "char") {
return ttml::datasets::create_in_memory_token_dataset<ttml::tokenizers::CharTokenizer>(
text, sequence_length);
} else if (tokenizer_type == "bpe") {
return ttml::datasets::create_in_memory_token_dataset<ttml::tokenizers::BPETokenizer>(
text, sequence_length, tokenizer_path);
} else {
throw std::runtime_error("Unknown tokenizer type: " + tokenizer_type);
}
};

auto [dataset, tokenizer] = create_dataset_and_tokenizer(text, sequence_length, config.tokenizer_type);
auto [dataset, tokenizer] =
create_dataset_and_tokenizer(text, sequence_length, config.tokenizer_path, config.tokenizer_type);
fmt::print("Dataset size: {}\n", dataset.get_size());
fmt::print("Vocab size: {}\n", tokenizer->get_vocab_size());
fmt::print("Tokenizer type: {}\n", config.tokenizer_type);
Expand Down Expand Up @@ -393,7 +399,7 @@ int main(int argc, char **argv) {
fmt::print("Step: {}, Loss: {}\n", global_step, gradient_accumulator_helper.average_loss());
loss_meter.update(gradient_accumulator_helper.average_loss());

if (enable_wandb && global_step % 10 == 0) {
if (enable_wandb && global_step % 1 == 0) {
wandbcpp::log(
{{"Step", (int)global_step},
{"Samples", (int)get_samples_count(global_step)},
Expand Down
18 changes: 8 additions & 10 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,17 @@ tt::tt_metal::Tensor from_vector<float, DataType::BFLOAT16>(
// remove possible paddings from the shape (it conflicts with ROW MAJOR)
auto output = tt::tt_metal::Tensor(OwnedStorage{owned_buffer}, logical_shape, data_type, Layout::ROW_MAJOR);

const size_t MAX_TILE_DIMENSION = 16384;
const size_t MAX_TILE_DIMENSION = 16384 * 4;
// Temporary workaround for the issue with tilize for large size
// https://github.com/tenstorrent/tt-metal/issues/15950
if (logical_shape[-1] >= MAX_TILE_DIMENSION && layout == Layout::TILE) {
output = ttnn::to_layout(output, Layout::TILE, std::nullopt, output_mem_config, device);
output = ttnn::to_device(output, device, output_mem_config);
} else {
output = ttnn::to_device(output, device, output_mem_config);
if (layout == Layout::TILE) {
output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true);
}
}
auto pad_tile = [](uint32_t val) { return (val + 32 - 1) / 32 * 32; };
auto padded_shape = pad_tile(logical_shape[-1]) * pad_tile(logical_shape[-2]);
bool multicore = padded_shape <= MAX_TILE_DIMENSION;
output = ttnn::to_device(output, device, output_mem_config);

if (layout == Layout::TILE) {
output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ multicore);
}
return output;
}

Expand Down
10 changes: 4 additions & 6 deletions tt-train/sources/ttml/datasets/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
#include "tokenizers/char_tokenizer_trainer.hpp"
#include "tokenizers/tokenizer_base.hpp"

namespace {
constexpr auto gpt2_tokenizer_file_name = "/gpt2-tokenizer.json";
}
namespace ttml::datasets {

template <>
std::tuple<InMemoryTokenDataset, std::unique_ptr<tokenizers::TokenizerBase>>
create_in_memory_token_dataset<tokenizers::CharTokenizer>(const std::string &text, uint32_t seq_length) {
create_in_memory_token_dataset<tokenizers::CharTokenizer>(
const std::string &text, uint32_t seq_length, [[maybe_unused]] const std::string &json_file_path) {
std::unique_ptr<tokenizers::TokenizerBase> tokenizer = tokenizers::CharTokenizerTrainer::train(text);

std::vector<uint32_t> tokenized_text = tokenizer->encode(text);
Expand All @@ -26,8 +24,8 @@ create_in_memory_token_dataset<tokenizers::CharTokenizer>(const std::string &tex

template <>
std::tuple<InMemoryTokenDataset, std::unique_ptr<tokenizers::TokenizerBase>>
create_in_memory_token_dataset<tokenizers::BPETokenizer>(const std::string &text, uint32_t seq_length) {
auto json_file_path = std::string(TOKENIZERS_DATA_PATH) + gpt2_tokenizer_file_name;
create_in_memory_token_dataset<tokenizers::BPETokenizer>(
const std::string &text, uint32_t seq_length, const std::string &json_file_path) {
std::unique_ptr<tokenizers::TokenizerBase> tokenizer = std::make_unique<tokenizers::BPETokenizer>(json_file_path);

const std::vector<uint32_t> tokenized_text = tokenizer->encode(text);
Expand Down
3 changes: 2 additions & 1 deletion tt-train/sources/ttml/datasets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <numeric>
#include <random>
#include <span>
#include <string>

#include "autograd/auto_context.hpp"
#include "dataset_subset.hpp"
Expand All @@ -16,7 +17,7 @@ namespace ttml::datasets {

template <typename Tokenizer>
std::tuple<InMemoryTokenDataset, std::unique_ptr<tokenizers::TokenizerBase>> create_in_memory_token_dataset(
const std::string& text, uint32_t seq_length);
const std::string& text, uint32_t seq_length, const std::string& json_file_path = "");

template <typename DatasetType>
std::vector<DatasetSubset<DatasetType>> random_split(
Expand Down
8 changes: 6 additions & 2 deletions tt-train/tests/3rd_party/tokenizers_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ void test_tokenizer(std::unique_ptr<Tokenizer> tok, bool check_id_back = true) {
} // namespace

TEST(HuggingFaceTokenizer, ExampleUsage) {
auto blob = load_bytes_from_file(get_test_data_dir() + "/tokenizer.json");
auto blob = load_bytes_from_file(get_test_data_dir() + "/train_ttnn/tokenizer.json");
auto tok = Tokenizer::FromBlobJSON(blob);
test_tokenizer(std::move(tok), true);
// test_tokenizer(std::move(tok), true);
std::string prompt = "What is the capital of Canada?";
std::vector<int> ids = tok->Encode(prompt);
std::string decoded_prompt = tok->Decode(ids);
EXPECT_EQ(decoded_prompt, prompt);
}
103 changes: 103 additions & 0 deletions tt-train/utils/tokenize_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3

import os
import argparse
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
from tokenizers.normalizers import Sequence, NFD, Lowercase
from tokenizers.decoders import ByteLevel as ByteLevelDecoder


def gather_source_files(folder):
"""
Recursively walk `folder`, yielding paths to files with
extensions in ('.hpp', '.cpp', '.h', '.c').
"""
valid_exts = {".hpp", ".cpp", ".h", ".c"}
for root, _, files in os.walk(folder):
for fname in files:
_, ext = os.path.splitext(fname)
if ext.lower() in valid_exts:
yield os.path.join(root, fname)


def merge_into_one_file(file_paths, merged_file_path="merged.txt"):
"""
Merges the content of all files in `file_paths` into
a single text file `merged_file_path`, removing trailing
whitespace from each line.
"""
with open(merged_file_path, "w", encoding="utf-8") as writer:
for path in file_paths:
try:
with open(path, "r", encoding="utf-8") as reader:
for line in reader:
# Remove trailing whitespace from each line
line = " ".join(line.split())
writer.write(line + "\n")
except Exception as e:
print(f"Warning: Could not read file {path}: {e}")
return merged_file_path


def train_bpe_tokenizer(text_file, output_tokenizer="tokenizer.json", vocab_size=32000):
"""
Trains a BPE tokenizer on the text file, saves it to `output_tokenizer`.
"""
tokenizer = Tokenizer(BPE(unk_token=None))
# 2. No normalizer; GPT-2 works byte-level, so we skip lowercasing or accent-stripping
# tokenizer.normalizer = None

# 3. Byte-level pre-tokenizer + Byte-level decoder
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()
# tokenizer.pre_tokenizer = Whitespace()

# 3. Setup BPE trainer with desired vocab size + special tokens
trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])

# 4. Train on the merged text file
tokenizer.train([text_file], trainer)

# 5. Save the tokenizer as a single JSON
tokenizer.save(output_tokenizer)
print(f"Saved tokenizer to {output_tokenizer}")


def main():
parser = argparse.ArgumentParser(
description="Recursively gather .hpp/.cpp/.h/.c files, merge them, then train a BPE tokenizer."
)
parser.add_argument("--folder", type=str, required=True, help="Root folder to recursively find source files.")
parser.add_argument(
"--merged_txt",
type=str,
default="merged.txt",
help="Path to the merged output text file (default: merged.txt).",
)
parser.add_argument(
"--tokenizer_output",
type=str,
default="tokenizer.json",
help="Path to save the trained tokenizer JSON (default: tokenizer.json).",
)
parser.add_argument("--vocab_size", type=int, default=32000, help="Desired vocabulary size (default: 32000).")
args = parser.parse_args()

# 1. Gather source files
file_paths = list(gather_source_files(args.folder))
if not file_paths:
print("No .hpp, .cpp, .h, or .c files found. Exiting.")
return

# 2. Merge into one file
merged_path = merge_into_one_file(file_paths, args.merged_txt)

# 3. Train tokenizer
train_bpe_tokenizer(merged_path, args.tokenizer_output, args.vocab_size)


if __name__ == "__main__":
main()
Loading