Skip to content

[RUNTIME] Fix the function lookup problem for CUDA 11 driver#4335

Merged
ptillet merged 1 commit intomainfrom
keren/fix-driver
Jul 17, 2024
Merged

[RUNTIME] Fix the function lookup problem for CUDA 11 driver#4335
ptillet merged 1 commit intomainfrom
keren/fix-driver

Conversation

@Jokeren
Copy link
Copy Markdown
Contributor

@Jokeren Jokeren commented Jul 17, 2024

There was a function pointer lookup missing in the previous patch. f9f2960

@Jokeren Jokeren marked this pull request as ready for review July 17, 2024 01:13
@Jokeren Jokeren requested a review from ptillet as a code owner July 17, 2024 01:13
@Jokeren
Copy link
Copy Markdown
Contributor Author

Jokeren commented Jul 17, 2024

@atalman

@malfet
Copy link
Copy Markdown
Collaborator

malfet commented Jul 17, 2024

Thank you for the followup fix. We'll try to automate testing for those going forward

@atalman
Copy link
Copy Markdown
Collaborator

atalman commented Jul 17, 2024

Looks like we past the original error however still failing down the line:
Executing: https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py

Looks like could be related to this comment: pytorch/pytorch#130684 (comment)

root@08fe3942a639:/builder/test/smoke_test# python smoke_test.py --package torchonly --runtime-error-check disabled
torch: 2.4.0+cu118
ATen/Parallel:
	at::get_num_threads() : 8
	at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
	omp_get_max_threads() : 8
Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
	mkl_get_max_threads() : 8
Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
std::thread::hardware_concurrency() : 16
Environment variables:
	OMP_NUM_THREADS : [not set]
	MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

Skip version check for channel None as stable version is None
Testing smoke_test_conv2d
Testing smoke_test_linalg on cpu
Testing smoke_test_compile for cuda and torch.float16
Traceback (most recent call last):
  File "/builder/test/smoke_test/smoke_test.py", line 335, in <module>
    main()
  File "/builder/test/smoke_test/smoke_test.py", line 331, in main
    smoke_test_cuda(options.package, options.runtime_error_check, options.torch_compile_check)
  File "/builder/test/smoke_test/smoke_test.py", line 168, in smoke_test_cuda
    smoke_test_compile("cuda" if torch.cuda.is_available() else "cpu")
  File "/builder/test/smoke_test/smoke_test.py", line 252, in smoke_test_compile
    x_pt2 = torch.compile(foo)(x)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
    result = self._inner_convert(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
    super().run()
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
    while self.step():
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE
    self._return(inst)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1098, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1318, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1390, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.10/site-packages/torch/__init__.py", line 1951, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 168, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1410, in fw_compiler_base
    return inner_compile(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 527, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/usr/local/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 831, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1749, in compile_to_fn
    return self.compile_to_module().call
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1699, in compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/oq/coq63wn5k6yhq2rx3scb4s3emokvlh4jebbs3t7gdncmhpqgv56z.py", line 71, in <module>
    async_compile.wait(globals())
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/async_compile.py", line 247, in wait
    scope[key] = result.result()
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 3424, in result
    self.kernel.precompile()
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 232, in precompile
    compiled_binary, launcher = self._precompile_config(
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in _precompile_config
    binary._init_handles()
  File "/triton/python/triton/compiler/compiler.py", line 386, in _init_handles
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Nvidia-SMI

root@08fe3942a639:/builder/test/smoke_test# nvidia-smi
Wed Jul 17 02:18:53 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         Off  | 00000000:00:1E.0 Off |                    0 |
|  0%   35C    P0    56W / 300W |      0MiB / 22731MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

@ptillet ptillet merged commit 3aeb223 into main Jul 17, 2024
@ptillet ptillet deleted the keren/fix-driver branch July 17, 2024 04:34
@malfet
Copy link
Copy Markdown
Collaborator

malfet commented Jul 17, 2024

@atalman one is expected to set TRITON_PTXAS_PATH to workaround this issue, and torch should do it automatically, see https://github.com/pytorch/pytorch/blob/7c45476d38176c8d5b19fb379fc073dc21beba64/torch/_inductor/runtime/compile_tasks.py#L51 (and PTXAS supposed to be packaged into the torch + cuda-11.8 wheel see pytorch/pytorch#119750 )

In worst case, something like the following should work (though it's ugly and we should have set it up ourselves)

TRITON_PTXAS_PATH=/usr/local/lib/python3.10/site-packages/torch/bin/ptxas  python smoke_test.py --package torchonly

atalman pushed a commit to atalman/triton that referenced this pull request Jul 17, 2024
…lang#4335)

There was a function pointer lookup missing in the previous patch.
triton-lang@f9f2960
ptillet pushed a commit that referenced this pull request Jul 17, 2024
Cherry Pick of : 
#4330
#4335

to release/3.0.x branch

---------

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…lang#4335)

There was a function pointer lookup missing in the previous patch.
triton-lang@f9f2960
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants