Skip to content

Conversation

@JingyaHuang
Copy link
Contributor

@JingyaHuang JingyaHuang commented Jul 4, 2022

What does this PR do?

Fixes #11279 of onnxruntime

Context

Optimum users reported that the mixed-precision training on gpt2 with optimum.onnxruntime.ORTTrainer is broken since transformers>4.16.0. After investigation, the break comes from the removal of float() in gpt2 modeling from PR #14321.

Reproduction

Run optimum onnxruntime training example run_glue.py with:

python run_glue.py \
    --model_name_or_path gpt2 \
    --task_name sst2 \
    --do_train \
    --do_eval \
    --fp16 \
    --output_dir /tmp/ort-gpt2-sst2/

Error Message

RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:752 
onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, 
onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const 
pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : 
Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node 
(Where_201).

As mentioned in the error message, the forward with onnxruntime InferenceSession will fail on a node Where in the graph, which corresponds to the Where op in gpt2 modeling.

And the problem comes from the fact that after removing float(), during fp16 training, the inputs of Where have different dtype (one in fp32 and one in fp16), which violates the definition in ONNX and leads to the failure.

Who can review?

@michaelbenayoun @patrickvonplaten, @LysandreJik

Fix

Ensure attn_weights and value has the same type in exported ONNX IR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (once all the tests pass)

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 7, 2022

Hello @JingyaHuang By looking this 2 lines

mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

mask_value is already of type attn_weights.dtype. Does the issue only occur when we use ONNX? (i.e. if we run the model in Python with FP16, does it work?). This issue seems strange. Do you happen to know which argument gets fp32 and which one gets fp16?

@JingyaHuang
Copy link
Contributor Author

JingyaHuang commented Jul 7, 2022

Hi @ydshieh, yes this issue only occurs with ONNX. When exporting the ONNX IR mask_value is exported as a constant initializer(min of fp32) with dtype float32. Thus during the mixed-precision training with onnxruntime, the attn_weights will be in dtype fp16 and mask_value as a constant always fp32 -> two inputs with different dtype -> training failed.
Here is the ONNX IR which illustrates what happened with Where op:
image

[EDIT] Here I made a mistake, according to the training graph, actually mask_value was successfully cast to fp16, but attn_weights not. Check the local exported IR below.

@JingyaHuang
Copy link
Contributor Author

And if we run the model with PyTorch backend, there is no problem of the tricky tracing or op definition, it should work fine.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 7, 2022

@JingyaHuang Thank you!

@JingyaHuang
Copy link
Contributor Author

JingyaHuang commented Jul 7, 2022

Hi @ydshieh, I've just double-checked the debug exported training onnx graph. Actually the mask_value has been cast to fp16 before Where node, and it was attn_weights which was fp32, and the fix inserts another cast op to cast attention_mask from fp32 to fp16.

The IR corresponding to this line

attn_weights = attn_weights / (value.size(-1) ** 0.5)

The IR before fix:
image
The IR after fix:
image

So this is exactly what we want for fp16 training.

@JingyaHuang JingyaHuang requested review from LysandreJik and patrickvonplaten and removed request for LysandreJik and patrickvonplaten July 13, 2022 16:24
@JingyaHuang
Copy link
Contributor Author

Gently pinging @patrickvonplaten and @LysandreJik for a review.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, merging!

@LysandreJik LysandreJik merged commit 2844c5d into huggingface:main Jul 26, 2022
@TXacs
Copy link

TXacs commented Jul 28, 2022

I got the same error. But I use the torch.fx and amp to train the GPT2 model. I fix this error with the method is add attn_weights.to(attn_weights.dtype) in torch.where. I don't know why this way can fix it, but it does.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 28, 2022

Hi @TXacs , which version you tried? Could you try to install the latest version on main:

pip install git+https://github.com/huggingface/accelerate

and see if you still have the issue (without your fix). Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Type Error when training Hugging Face Transformers GPT2 with fp16 enabled

6 participants