Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchbench] speech_transformer fails to run on dynamo. #6831

Closed
ysiraichi opened this issue Mar 27, 2024 · 7 comments
Closed

[torchbench] speech_transformer fails to run on dynamo. #6831

ysiraichi opened this issue Mar 27, 2024 · 7 comments
Labels

Comments

@ysiraichi
Copy link
Collaborator

ysiraichi commented Mar 27, 2024

🐛 Bug

Running the upstreamed benchmarking scripts with the following command results in an unexpected error.

python xla/benchmarks/experiment_runner.py \
       --suite-name torchbench \
       --accelerator cuda \
       --xla PJRT \
       --dynamo openxla \
       --test train --test eval \
       --repeat 8 --iterations-per-run 1 \
       --print-subprocess \
       --no-resume -k speech_transformer
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 945, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 941, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 302, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "torch/_dynamo/eval_frame.py", line 390, in _fn
    return fn(*args, **kwargs)
  File "xla/benchmarks/benchmark_model.py", line 170, in eval
    pred = self.module(*inputs)
  File "torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "benchmark/torchbenchmark/models/speech_transformer/speech_transformer/transformer/transformer.py", line 28, in forward
    encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
  File "torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "benchmark/torchbenchmark/models/speech_transformer/speech_transformer/transformer/encoder.py", line 48, in forward
    non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
  File "benchmark/torchbenchmark/models/speech_transformer/speech_transformer/transformer/encoder.py", line 50, in torch_dynamo_resume_in_forward_at_48
    slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length)
  File "benchmark/torchbenchmark/models/speech_transformer/speech_transformer/transformer/encoder.py", line 50, in torch_dynamo_resume_in_forward_at_50
    slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length)
  File "torch/_dynamo/eval_frame.py", line 390, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 107, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 181, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "torch/_dynamo/backends/torchxla.py", line 36, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "xla/torch_xla/core/dynamo_bridge.py", line 696, in extract_compiled_graph
    extract_internal(fused_module), node.args, None)
  File "xla/torch_xla/core/dynamo_bridge.py", line 432, in extract_internal
    xm.mark_step()
  File "xla/torch_xla/core/xla_model.py", line 1056, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %batch-norm-training.210 = (f64[1,2040,512]{2,1,0}, f64[2040]{0}, f64[2040]{0}) batch-norm-training(f64[1,2040,512]{2,1,0} %reshape.209, bf16[2040]{0} %broadcast.141, bf16[2040]{0} %broadcast.136), epsilon=1e-05, feature_index=1, but mixed precision is disallowed.

Environment

  • PyTorch Commit: a52b4e22571507abc35c2d47de138497190d2e0a
  • PyTorch/XLA Commit: 84e7feb
  • PyTorch/benchmark Commit: d6015d42d9a1834bc7595c4bd6852562fb80b30b

cc @miladm @JackCaoG @vanbasten23 @zpcore @frgossen @golechwierowicz @cota

@zpcore
Copy link
Collaborator

zpcore commented Mar 27, 2024

We noticed those model floating point precision mismatch failures are related to #6669, where the / is implemented using XLA:DIV. Either we do type cast to match the precision when lowering the op or we fix the model data type. Check with @wonjoolee95 to see what's the best solution.

@wonjoolee95
Copy link
Collaborator

Thanks for reporting the issue. We should do type cast to match the precision in the op lowering itself. @bhavya01, we probably to need do an explicit floating type check and promotion in the op. Could you take a look when you get the time?

@bhavya01
Copy link
Collaborator

bhavya01 commented Mar 28, 2024

@ysiraichi @zpcore Is it possible for you to pin point where is this happening in the code. AFAIK, we promote types in the div op. A simple script like shows that this should work

import torch
import torch_xla.core.xla_model as xm


device = xm.xla_device()

x = torch.rand((10,10), dtype=torch.double, device=device)
y = torch.rand((10,10), dtype=torch.bfloat16, device=device)
print(torch.div(x,y))

@zpcore
Copy link
Collaborator

zpcore commented Mar 28, 2024

Thanks @bhavya01 for the update. I retested with the following command:

XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"  XLA_IR_DEBUG=1 XLA_FLAGS=--xla_dump_to=/tmp/piz_hlo python xla/benchmarks/experiment_runner.py  \
--suite-name torchbench    \
--accelerator cuda  \
--xla PJRT   \
--dynamo openxla   \
--test eval  \
--repeat 2  \
--iterations-per-run 1   \
--print-subprocess \
--no-resume 
-k speech_transformer \
--dump-hlo

, and compared the HLO difference between 03/12 (last version no issue) and 03/28. Some foundings here:

HLO for 03/12 (Pass):

%batch-norm-training.207 = (bf16[1,2040,512]{2,1,0}, bf16[2040]{0}, bf16[2040]{0}) 
batch-norm-training(
bf16[1,2040,512]{2,1,0} %reshape.206, 
bf16[2040]{0} %broadcast.142, bf16[2040]{0} %broadcast.137), 
epsilon=1e-05, 
feature_index=1, metadata=...

HLO for 03/28 (ERROR):

batch-norm-training.210 = (f64[1,2040,512]{2,1,0}, f64[2040]{0}, f64[2040]{0}) 
batch-norm-training(
f64[1,2040,512]{2,1,0} %reshape.209, 
bf16[2040]{0} %broadcast.141, 
bf16[2040]{0} %broadcast.136), 
epsilon=1e-05, 
feature_index=1, metadata=...

We can see that input tensor, scale and bias parameters used in batch normalization changed from bf16 into f64 and f64 is a mismatch with the real input type.

Check further, I notice that we are missing the op xla__cast after the aten__div op call, while in HLO for 03/12 (Pass) run, we see:

%divide.95 = f64[80,204,204]{2,1,0} 
divide(
f64[80,204,204]{2,1,0} %convert.93, 
f64[80,204,204]{2,1,0} %broadcast.94), 
metadata={op_type="aten__div" op_name="aten__div" ...}

convert.96 = bf16[80,204,204]{2,1,0} 
convert(
f64[80,204,204]{2,1,0} %divide.95), 
metadata={op_type="xla__cast" op_name="xla__cast" , ...

%broadcast.171 = bf16[80,204,204]{2,1,0} 
broadcast(bf16[80,204,204]{2,1,0} %convert.96), 
dimensions={0,1,2}, 
metadata={op_type="aten__masked_fill" op_name="aten__masked_fill"...}

However, in 03/28 (ERROR), we only have:

%divide.95 = f64[80,204,204]{2,1,0} 
divide(
f64[80,204,204]{2,1,0} %convert.93, 
f64[80,204,204]{2,1,0} %broadcast.94), 
metadata={op_type="aten__div" op_name="aten__div" ...

%broadcast.166 = f64[80,204,204]{2,1,0} 
broadcast(f64[80,204,204]{2,1,0} %divide.95),
dimensions={0,1,2}, 
metadata={op_type="aten__masked_fill" op_name="aten__masked_fill"...}

Basically every time we call divide with data type f64, xla__cast will automatically appended xla__cast after the aten__div operation and cast the data into bf16.

I don't know why we also cast into BF16 originally with native operator /. I think @bhavya01 is doing it correctly. Do we need to update the output shape of the DIV operator into BF16?

@bhavya01
Copy link
Collaborator

bhavya01 commented Apr 2, 2024

I think that #6873 should fix the issue. This PR fixes the shapes in XLA node for the div op.

@zpcore
Copy link
Collaborator

zpcore commented Apr 2, 2024

Thanks @bhavya01 for the head up, I notice that the PR will call promoteType, it will force the two operator to return the same type. That probably explains where the xla__cast is coming from. Thanks @lsy323 for the fix. Let's see if this will be fixed after the merge.

@zpcore
Copy link
Collaborator

zpcore commented Apr 4, 2024

Manually tested with #6873 and the test passed. Will close the issue for now. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants