Skip to content

Conversation

@JingyaHuang
Copy link
Contributor

What does this PR do?

With the PR #20061, the tracing will fail during mixed-precision training, as the dtype for the inputs of a where node are not the same, which is invalid while reusing the ONNX model for inference.

The node:

attn_weights = torch.where(causal_mask, attn_weights, mask_value)

Error message:

======================================================================
ERROR: test_ort_trainer (__main__.TestORTTrainer) (model_name='gpt2', dataset_name='sst2', inference_with_ort=False)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_onnxruntime_train.py", line 131, in test_ort_trainer
    train_result = trainer.train()
  File "/workspace/optimum/onnxruntime/trainer.py", line 349, in train
    return inner_training_loop(
  File "/workspace/optimum/onnxruntime/trainer.py", line 615, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2523, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2555, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py", line 371, in _forward
    return ortmodule._torch_module.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py", line 351, in _forward
    return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 273, in forward
    self._fallback_manager.handle_exception(
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_fallback.py", line 162, in handle_exception
    raise exception
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 210, in forward
    self._initialize_graph_builder()
  File "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 478, in _initialize_graph_builder
    self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:731 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (Where_223).

@JingyaHuang
Copy link
Contributor Author

JingyaHuang commented Dec 7, 2022

A little bit more context on the issue, I previously fixed the tracing issue in #18017, but it will harm the performance due to host<->device synchronization, which has been targeted in #20061, but cause the tracing once again failed.

It seems that we can't guarantee the tracing correctness and inference performance with the same line of code while using PyTorch at the same time, that's why in the PR, I distinguish two cases to solve it:

  • Case 1: Tracing
  • Case 2: Inference with PyTorch

@JingyaHuang
Copy link
Contributor Author

Also @michaelbenayoun I saw this: #18017 (comment), does the current modeling won't have an issue while doing mixed-precision training for torch.fx?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the kind of if/else we try to avoid in the modeling code as it will become completely unreadable if we add support for all optimizations/exports like this. Let's forego the optimized path here and only do what works for ONNX/tracing.

@JingyaHuang
Copy link
Contributor Author

Feel the same, If/else removed!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 7, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's just wait for @michaelbenayoun and then we can merge!

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sgugger sgugger merged commit 521da65 into huggingface:main Dec 8, 2022
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants