Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Is it necessary to wait CUDA stream when calling WaitToRead or WaitToWrite? #12823

Closed
wkcn opened this issue Oct 14, 2018 · 8 comments
Closed

Comments

@wkcn
Copy link
Member

wkcn commented Oct 14, 2018

Description

Hi! there.
I found a problem about the asynchronous execution.
In the two functions NDArray::WaitToRead and NDArray::WaitToWrite, there is no any statement to wait the CUDA stream to finish.
It means that the task pushed before calling the two functions may start to execute after calling the two functions. But the task before calling the two functions should have executed before the end of calling the two functions.

In the PR [MXNET-779]Add DLPack Transformation API I submitted,
[Code]python/mxnet/ndarray/ndarray.py#L3980

def to_dlpack_for_write(data):
    """Returns a reference view of NDArray that represents as DLManagedTensor until
       all previous read/write operations on the current array are finished.
    Parameters
    ----------
    data: NDArray
        input data.
    Returns
    -------
    PyCapsule (the pointer of DLManagedTensor)
        a reference view of NDArray that represents as DLManagedTensor.
    """
    check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
    dlpack = DLPackHandle()
    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)

After calling MXNDArrayWaitToWrite, there may be some task on the CUDA stream because WaitToWrite and WaitToRead don't wait the CUDA stream to finish. So the data in the DLPack may be wrong.

Environment info (Required)

What to do:
1. Download the diagnosis script from https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py
2. Run the script using `python diagnose.py` and paste its output here.

Package used (Python/R/Scala/Julia):
Python

For Scala user, please provide:

  1. Java version: (java -version)
  2. Maven version: (mvn -version)
  3. Scala runtime if applicable: (scala -version)

For R user, please provide R sessionInfo():

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio):

MXNet commit hash:
efa7d3a

Build config:
(Paste the content of config.mk, or the build command.)

Error Message:

(Paste the complete error message, including stack trace.)

Minimum reproducible example

(If you are using your own code, please provide a short script that reproduces the error. Otherwise, please provide link to the existing example.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

@piyushghai
Copy link
Contributor

Thank you for posting your issue. We will look into this.

@mxnet-label-bot [ Bug, Cuda]

@wkcn
Copy link
Member Author

wkcn commented Oct 14, 2018

@piyushghai Thank you!

@eric-haibin-lin
Copy link
Member

@wkcn
Copy link
Member Author

wkcn commented Oct 16, 2018

@eric-haibin-lin
Thank you!

I call MXNDArrayWaitToWrite firstly, then call MXNDArrayGetData to get the data pointer of a NDArray.
Passing the data pointer into a CUDA kernel.
In the CUDA kernel, there are some assignments through the data pointer.
The CUDA kernel will run with NULL stream (The CUDA code is out of MXNet, so I couldn't obtain the CUDA stream from RunContext in MXNet).
It works sometime, however it triggers the CUDA error illegal memory access randomly.
I couldn't find the position which triggers the error.

I will check it.

@eric-haibin-lin
Copy link
Member

Did you make sure the reference to the original NDArray is kept and the memory is not freed?

@wkcn
Copy link
Member Author

wkcn commented Oct 17, 2018

@eric-haibin-lin
Yes. I found that the CUDA kernel I wrote may runs failed sometime.
I'm looking for the reason.

Thank you!

@wkcn
Copy link
Member Author

wkcn commented Oct 22, 2018

Sorry, there is a bug in my project.
I use Python Code to call cudaSetDevice and kernel function, like that:

class CFuncDef:
    [...]
    def __call__(self, arg_datas, arg_types, dev_id):
        if dev_id is None:
            ctx = 'cpu'
        else:
            set_device(dev_id)
            ctx = gpu_ctx_name
        # function loader
        func = self.loader(self, arg_types, ctx, **self.loader_kwargs)
        return func(*arg_datas)

func is a ctypes reference of a CUDA kernel function.
The function CFuncDef.__call__ will be called in MXNet Custom Operator.
However, the device ID may be changed between calling set_device and func(*arg_datas) because of asynchronous execution. It causes the problem access illegal memory.

Calling WaitToRead or WaitToWrite is enough.

Solved it.
Thank you!

@wkcn wkcn closed this as completed Oct 22, 2018
@eric-haibin-lin
Copy link
Member

Good to know it's resolved.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants