diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 5fa2d00b..90044b9e 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -108,7 +108,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1 export ICI_TENSOR_PARALLELISM=1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 -export PER_DEVICE_BATCH_SIZE=4 +export PER_DEVICE_BATCH_SIZE=11 ``` #### Create Llama2-7b environment variables for server flags @@ -126,7 +126,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1 export ICI_TENSOR_PARALLELISM=1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 -export PER_DEVICE_BATCH_SIZE=4 +export PER_DEVICE_BATCH_SIZE=11 ``` #### Create Llama2-13b environment variables for server flags @@ -146,7 +146,7 @@ export ICI_AUTOREGRESSIVE_PARALLELISM=-1 export ICI_TENSOR_PARALLELISM=1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 -export PER_DEVICE_BATCH_SIZE=2 +export PER_DEVICE_BATCH_SIZE=4 ``` ### Run the following command to start the JetStream MaxText server @@ -182,7 +182,7 @@ python MaxText/maxengine_server.py \ * ici\_autoregressive\_parallelism: The number of shards for autoregressive parallelism * ici\_tensor\_parallelism: The number of shards for tensor parallelism * weight\_dtype: Weight data type (e.g. bfloat16) -* scan\_layers: Scan layers boolean flag +* scan\_layers: Scan layers boolean flag (set to `false` for inference) Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml) @@ -200,7 +200,7 @@ python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokeniz The output will be similar to the following: ```bash -Sending request to: dns:///[::1]:9000 +Sending request to: 0.0.0.0:9000 Prompt: Today is a good day Response: to be a fan ``` @@ -253,7 +253,7 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r # run benchmark with the downloaded dataset and the tokenizer in maxtext # You can control the qps by setting `--request-rate`, the default value is inf. python JetStream/benchmarks/benchmark_serving.py \ ---tokenizer maxtext/assets/tokenizer.gemma \ +--tokenizer maxtext/assets/tokenizer.gemma \ --num-prompts 1000 \ --dataset sharegpt \ --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \ diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index c81ac6a1..19a62b74 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -71,17 +71,17 @@ else fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items +# We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory. +export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items # Covert MaxText compatible checkpoints to unscanned checkpoints. -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx} JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \ MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ -load_parameters_path=${CONVERTED_CHECKPOINT} \ +load_parameters_path=${SCANNED_CKPT_PATH} \ run_name=${RUN_NAME} \ model_name=${MODEL_NAME} \ force_unroll=true