Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tune Qwen3-1.7B for Chess Move Prediction\n",
"\n",
"## Introduction\n",
"\n",
"In this notebook, we showcase how to fine-tune the Qwen3-1.7B model on AWS Trainium using the Hugging Face Optimum Neuron library.\n",
"The goal of this task is chess move prediction — training the model to analyze chess positions in FEN format and select the best moves.\n",
"\n",
"We will fine-tune the model using `optimum.neuron`, save the trained checkpoint, and then deploy it for inference with Optimum-Neuron[vllm], enabling high-performance, low-latency chess move prediction.\n",
"\n",
"By the end of this notebook, you'll have a fine-tuned, Trainium-optimized Qwen3 model ready for deployment and real-time inference. This workflow demonstrates how to leverage the Optimum Neuron toolchain to efficiently train and serve large language models on AWS Neuron devices.\n",
"\n",
"For this module, you will be using the [aicrowd/ChessExplained](https://huggingface.co/datasets/aicrowd/ChessExplained) dataset which consists of thousands of chess positions with expert analysis and move selections.\n",
"\n",
"## About the Dataset\n",
"\n",
"The dataset contains chess positions in FEN (Forsyth-Edwards Notation) format along with:\n",
"- Visual board representations\n",
"- List of legal moves\n",
"- Expert reasoning (in `<think>` tags)\n",
"- Best move selection (in `<uci_move>` tags)\n",
"\n",
"**Dataset example:**\n",
"\n",
"*Position (FEN):* `rnbq1rk1/ppp1bpp1/4pn1p/3p4/2PP4/2N1PN2/PP1B1PPP/R2QKB1R b KQ - 0 7`\n",
"\n",
"*Legal moves:* `['g8h8', 'g8h7', 'f8e8', 'd8e8', 'c7c5', 'b7b6', 'a7a6', ...]`\n",
"\n",
"*Expert analysis:* \n",
"```\n",
"<think>\n",
"After Pawn moves to c5, this causes Black to attacks the pawn on d4. So c5 is the most logical. Position is drawish.\n",
"</think>\n",
"\n",
"<uci_move>c7c5</uci_move>\n",
"```\n",
"\n",
"By fine-tuning the model over several thousand of these chess examples, the model will learn to analyze positions and generate both reasoning and optimal moves.\n",
"\n",
"This chess move prediction use case was selected so you can successfully fine-tune your model in a reasonably short amount of time (~25 minutes) which is appropriate for this workshop. The same techniques can be applied to more complex reasoning tasks such as strategic game playing, multi-step planning, and expert decision-making.\n",
"\n",
"## Install requirements\n",
"This notebook uses [Hugging Face Optimum Neuron](https://github.com/huggingface/optimum-neuron) which works like an interface between the Hugging Face Transformers library and AWS Accelerators including AWS Trainium and AWS Inferentia. We will also install some other libraries like peft, trl etc.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%cd /home/ubuntu/environment/FineTuning/HuggingFaceExample/01_finetuning/assets\n",
"%pip install -r requirements.txt\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tuning\n",
"\n",
"In this section, we fine-tune the Qwen3-1.7B model on the chess move prediction task using Hugging Face Optimum Neuron. Here are the parameters we are going to pass - \n",
"\n",
"1. `--nnodes`:\tNumber of nodes (1 = single node)\n",
"2. `--nproc_per_node`: \tProcesses per node (usually equals number of devices).\n",
"3. `--model_id, --tokenizer_id`:\tModel and tokenizer identifiers (from Hugging Face or local path).\n",
"4. `--output_dir`:\tDirectory for saving checkpoints and logs.\n",
"5. `--bf16`:\tEnables bfloat16 precision for faster, memory-efficient training.\n",
"6. `--gradient_checkpointing`:\tSaves memory by recomputing activations during backprop.\n",
"7. `--gradient_accumulation_steps`:\tSteps to accumulate gradients before optimizer update.\n",
"8. `--learning_rate`:\tInitial training learning rate.\n",
"9. `--max_steps`:\tTotal number of training steps.\n",
"10. `--per_device_train_batch_size`:\tBatch size per device.\n",
"11. `--tensor_parallel_size`:\tNumber of devices for tensor parallelism.\n",
"12. `--lora_r, --lora_alpha, --lora_dropout`:\tLoRA hyperparameters — rank, scaling, and dropout rate.\n",
"13. `--dataloader_drop_last`:\tDrops last incomplete batch.\n",
"14. `--disable_tqdm`: Disables progress bar.\n",
"15. `--logging_steps`:\tLog interval (in steps).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!torchrun \\\n",
" --nnodes 1 \\\n",
" --nproc_per_node 2 \\\n",
" finetune_chess_model.py \\\n",
" --model_id Qwen/Qwen3-1.7B \\\n",
" --tokenizer_id Qwen/Qwen3-1.7B \\\n",
" --output_dir ~/environment/ml/qwen-chess \\\n",
" --bf16 True \\\n",
" --gradient_checkpointing True \\\n",
" --gradient_accumulation_steps 1 \\\n",
" --learning_rate 5e-5 \\\n",
" --max_steps 1000 \\\n",
" --per_device_train_batch_size 2 \\\n",
" --tensor_parallel_size 2 \\\n",
" --lora_r 16 \\\n",
" --lora_alpha 32 \\\n",
" --lora_dropout 0.05 \\\n",
" --dataloader_drop_last True \\\n",
" --disable_tqdm True \\\n",
" --logging_steps 10\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Compilation\n",
"\n",
"After completing the fine-tuning process, the next step is to compile the trained model for AWS Trainium inference using the Hugging Face Optimum Neuron toolchain.\n",
"Neuron compilation optimizes the model graph and converts it into a Neuron Executable File Format (NEFF), enabling efficient execution on NeuronCores.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!optimum-cli export neuron \\\n",
" --model /home/ubuntu/environment/ml/qwen-chess/merged_model \\\n",
" --task text-generation \\\n",
" --sequence_length 2048 \\\n",
" --batch_size 1 \\\n",
" /home/ubuntu/environment/ml/qwen-chess/compiled_model\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference\n",
"\n",
"We will install the Optimum Neuron vllm library. Then, run inference using the compiled model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install optimum-neuron[vllm]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from vllm import LLM, SamplingParams\n",
"\n",
"llm = LLM(\n",
" model=\"/home/ubuntu/environment/ml/qwen-chess/compiled_model\", #local compiled model\n",
" max_num_seqs=1,\n",
" max_model_len=2048,\n",
" device=\"neuron\",\n",
" tensor_parallel_size=2,\n",
" override_neuron_config={})\n",
"\n",
"example1=\"\"\"\n",
"<|im_start|>user\n",
"You are an expert chess player looking at the following position in FEN format:\n",
"\n",
"rnbq1rk1/ppp1bpp1/4pn1p/3p4/2PP4/2N1PN2/PP1B1PPP/R2QKB1R b KQ - 0 7\n",
"\n",
"Briefly, FEN describes chess pieces by single letters [PNBRKQ] for white and [pnbrkq] for black. The pieces found in each rank are specified, starting at the top of the board (a8..h8) and describing all eight ranks.\n",
"\n",
"Here is an additional visualization of the board (♔♕♖♗♘♙ = White pieces, ♚♛♜♝♞♟ = Black pieces):\n",
"\n",
"a b c d e f g h\n",
"+---------------+\n",
"8 | ♜ ♞ ♝ ♛ · ♜ ♚ · | 8\n",
"7 | ♟ ♟ ♟ · ♝ ♟ ♟ · | 7\n",
"6 | · · · · ♟ ♞ · ♟ | 6\n",
"5 | · · · ♟ · · · · | 5\n",
"4 | · · ♙ ♙ · · · · | 4\n",
"3 | · · ♘ · ♙ ♘ · · | 3\n",
"2 | ♙ ♙ · ♗ · ♙ ♙ ♙ | 2\n",
"1 | ♖ · · ♕ ♔ ♗ · ♖ | 1\n",
"+---------------+\n",
"a b c d e f g h\n",
"\n",
"The current side to move is black.\n",
"The possible legal moves for the side to move are: ['g8h8', 'g8h7', 'f8e8', 'd8e8', 'c7c5', 'b7b6', 'a7a6', 'h6h5', 'e6e5', 'g7g5', 'c7c6', 'b7b5', 'a7a5'].\n",
"\n",
"Your task is to select the best move for the side to move. Output your thinking in <think> tags and the move in <uci_move> tags.<|im_end|>\n",
"<|im_start|>assistant\n",
"\"\"\"\n",
"\n",
"example2=\"\"\"\n",
"<|im_start|>user\n",
"You are an expert chess player. Analyze this position in FEN format:\n",
"\n",
"r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3\n",
"\n",
"Here is the board visualization:\n",
"\n",
"a b c d e f g h\n",
"+---------------+\n",
"8 | ♜ · ♝ ♛ ♚ ♝ ♞ ♜ | 8\n",
"7 | ♟ ♟ ♟ ♟ · ♟ ♟ ♟ | 7\n",
"6 | · · ♞ · · · · · | 6\n",
"5 | · · · · ♟ · · · | 5\n",
"4 | · · · · ♙ · · · | 4\n",
"3 | · · · · · ♘ · · | 3\n",
"2 | ♙ ♙ ♙ ♙ · ♙ ♙ ♙ | 2\n",
"1 | ♖ ♘ ♗ ♕ ♔ ♗ · ♖ | 1\n",
"+---------------+\n",
"a b c d e f g h\n",
"\n",
"The current side to move is white.\n",
"Select the best move from: ['d2d4', 'f1c4', 'f1b5', 'b1c3', 'd2d3']\n",
"\n",
"Output your analysis in <think> tags and your move choice in <uci_move> tags.<|im_end|>\n",
"<|im_start|>assistant\n",
"\"\"\"\n",
"\n",
"example3=\"\"\"\n",
"<|im_start|>user\n",
"Analyze this chess position in FEN format:\n",
"\n",
"r2qkb1r/ppp2ppp/2n2n2/3pp1B1/1b1PP3/2N2N2/PPP2PPP/R2QKB1R w KQkq - 0 6\n",
"\n",
"Board visualization:\n",
"\n",
"a b c d e f g h\n",
"+---------------+\n",
"8 | ♜ · · ♛ ♚ ♝ · ♜ | 8\n",
"7 | ♟ ♟ ♟ · · ♟ ♟ ♟ | 7\n",
"6 | · · ♞ · · ♞ · · | 6\n",
"5 | · · · ♟ ♟ · ♗ · | 5\n",
"4 | · ♝ · ♙ ♙ · · · | 4\n",
"3 | · · ♘ · · ♘ · · | 3\n",
"2 | ♙ ♙ ♙ · · ♙ ♙ ♙ | 2\n",
"1 | ♖ · · ♕ ♔ ♗ · ♖ | 1\n",
"+---------------+\n",
"a b c d e f g h\n",
"\n",
"White to move. Legal moves: ['g5f6', 'g5e7', 'g5h6', 'g5d2', 'c3b5', 'c3d5', 'f3d4', 'f3e5', 'd1d3', 'd1d2', 'e1d2']\n",
"\n",
"Provide your reasoning and best move.<|im_end|>\n",
"<|im_start|>assistant\n",
"\"\"\"\n",
"\n",
"prompts = [\n",
" example1,\n",
" example2,\n",
" example3\n",
"]\n",
"\n",
"sampling_params = SamplingParams(max_tokens=2048, temperature=0.8)\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"\n",
"print(\"#########################################################\")\n",
"\n",
"for output in outputs:\n",
" prompt = output.prompt\n",
" generated_text = output.outputs[0].text\n",
" print(f\"Prompt: {prompt!r}, \\n\\n Generated text: {generated_text!r} \\n\")\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading