Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,10 @@ cdef void python_cb_wrapper_temp_create_context(void** memory_context,
cdef PyObject * ret_memory_context = NULL
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
python_fn = wrapped_global_context.temp_create_context_fn
python_global_context = wrapped_global_context.temp_global_context
args = PyTuple_New(1)
Py_INCREF(<object> python_global_context)
PyTuple_SetItem(args, 0, <object> python_global_context)
py_memory_context = PyObject_CallObject(<object> python_fn, <object> args)
fn = <object> wrapped_global_context.temp_create_context_fn
ctx = <object> wrapped_global_context.temp_global_context
py_memory_context = fn(ctx)
ret_memory_context = <PyObject *> py_memory_context
Py_DECREF(args)
Py_INCREF(ret_memory_context)
(<PyObject **> memory_context)[0] = ret_memory_context
return
Expand All @@ -442,15 +438,10 @@ cdef void python_cb_wrapper_temp_destroy_context(void * memory_context,
void * global_context) nogil:
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
python_fn = wrapped_global_context.temp_destroy_context_fn
python_global_context = wrapped_global_context.temp_global_context
args = PyTuple_New(2)
Py_INCREF(<object> <PyObject *> memory_context)
PyTuple_SetItem(args, 0, <object> <PyObject *> memory_context)
Py_INCREF(<object> python_global_context)
PyTuple_SetItem(args, 1, <object> python_global_context)
PyObject_CallObject(<object> python_fn, <object> args)
Py_DECREF(args)
fn = <object> wrapped_global_context.temp_destroy_context_fn
ctx = <object> wrapped_global_context.temp_global_context
mem_ctx = <object> <PyObject *> memory_context
fn(mem_ctx, ctx)
Py_DECREF(<PyObject *> memory_context)
return

Expand All @@ -461,38 +452,23 @@ cdef void * python_cb_wrapper_temp_malloc(wholememory_tensor_description_t * ten
cdef int64_t res_ptr = 0
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
py_tensor_desc = PyWholeMemoryTensorDescription()
py_tensor_desc.set_by_tensor_desc(tensor_desc)
py_malloc_type = PyMemoryAllocType()
py_malloc_type.set_type(malloc_type)
python_fn = wrapped_global_context.temp_malloc_fn
python_global_context = wrapped_global_context.temp_global_context
args = PyTuple_New(4)
Py_INCREF(py_tensor_desc)
PyTuple_SetItem(args, 0, <object> py_tensor_desc)
Py_INCREF(py_malloc_type)
PyTuple_SetItem(args, 1, <object> py_malloc_type)
Py_INCREF(<object> <PyObject *> memory_context)
PyTuple_SetItem(args, 2, <object> <PyObject *> memory_context)
Py_INCREF(<object> <PyObject *> python_global_context)
PyTuple_SetItem(args, 3, <object> <PyObject *> python_global_context)
res_ptr = PyLong_AsLongLong(PyObject_CallObject(<object> python_fn, <object> args))
Py_DECREF(args)
py_shape = tuple([tensor_desc.sizes[i] for i in range(tensor_desc.dim)])
py_dtype = int(tensor_desc.dtype)
py_malloc_type_int = int(malloc_type)
fn = <object> wrapped_global_context.temp_malloc_fn
ctx = <object> wrapped_global_context.temp_global_context
mem_ctx = <object> <PyObject *> memory_context
res_ptr = fn(py_shape, py_dtype, py_malloc_type_int, mem_ctx, ctx)
return <void *> res_ptr

cdef void python_cb_wrapper_temp_free(void * memory_context,
void * global_context) nogil:
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
python_fn = wrapped_global_context.temp_free_fn
python_global_context = wrapped_global_context.temp_global_context
args = PyTuple_New(2)
Py_INCREF(<object> <PyObject *> memory_context)
PyTuple_SetItem(args, 0, <object> <PyObject *> memory_context)
Py_INCREF(<object> python_global_context)
PyTuple_SetItem(args, 1, <object> python_global_context)
PyObject_CallObject(<object> python_fn, <object> args)
Py_DECREF(args)
fn = <object> wrapped_global_context.temp_free_fn
ctx = <object> wrapped_global_context.temp_global_context
mem_ctx = <object> <PyObject *> memory_context
fn(mem_ctx, ctx)
return

cdef void * python_cb_wrapper_output_malloc(wholememory_tensor_description_t * tensor_desc,
Expand All @@ -502,38 +478,23 @@ cdef void * python_cb_wrapper_output_malloc(wholememory_tensor_description_t * t
cdef int64_t res_ptr = 0
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
py_tensor_desc = PyWholeMemoryTensorDescription()
py_tensor_desc.set_by_tensor_desc(tensor_desc)
py_malloc_type = PyMemoryAllocType()
py_malloc_type.set_type(malloc_type)
python_fn = wrapped_global_context.output_malloc_fn
python_global_context = wrapped_global_context.output_global_context
args = PyTuple_New(4)
Py_INCREF(py_tensor_desc)
PyTuple_SetItem(args, 0, <object> <PyObject *> py_tensor_desc)
Py_INCREF(py_malloc_type)
PyTuple_SetItem(args, 1, <object> <PyObject *> py_malloc_type)
Py_INCREF(<object> <PyObject *> memory_context)
PyTuple_SetItem(args, 2, <object> <PyObject *> memory_context)
Py_INCREF(<object> <PyObject *> python_global_context)
PyTuple_SetItem(args, 3, <object> <PyObject *> python_global_context)
res_ptr = PyLong_AsLongLong(PyObject_CallObject(<object> python_fn, <object> args))
Py_DECREF(args)
py_shape = tuple([tensor_desc.sizes[i] for i in range(tensor_desc.dim)])
py_dtype = int(tensor_desc.dtype)
py_malloc_type_int = int(malloc_type)
fn = <object> wrapped_global_context.output_malloc_fn
ctx = <object> wrapped_global_context.output_global_context
mem_ctx = <object> <PyObject *> memory_context
res_ptr = fn(py_shape, py_dtype, py_malloc_type_int, mem_ctx, ctx)
return <void *> res_ptr

cdef void python_cb_wrapper_output_free(void * memory_context,
void * global_context) nogil:
with gil:
wrapped_global_context = <GlobalContextWrapper> <PyObject *> global_context
python_fn = wrapped_global_context.output_free_fn
python_global_context = wrapped_global_context.output_global_context
args = PyTuple_New(2)
Py_INCREF(<object> <PyObject *> memory_context)
PyTuple_SetItem(args, 0, <object> <PyObject *> memory_context)
Py_INCREF(<object> python_global_context)
PyTuple_SetItem(args, 1, <object> python_global_context)
PyObject_CallObject(<object> python_fn, <object> args)
Py_DECREF(args)
fn = <object> wrapped_global_context.output_free_fn
ctx = <object> wrapped_global_context.output_global_context
mem_ctx = <object> <PyObject *> memory_context
fn(mem_ctx, ctx)
return


Expand Down
14 changes: 7 additions & 7 deletions python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,23 @@ def torch_destroy_memory_context_env_fn(


def torch_malloc_env_fn(
tensor_desc: wmb.PyWholeMemoryTensorDescription,
malloc_type: wmb.PyMemoryAllocType,
shape: tuple,
dtype_int: int,
malloc_type_int: int,
memory_context: TorchMemoryContext,
global_context: TorchEmptyGlobalContext,
) -> int:
pinned = False
device = None
if malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatDevice:
if malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatDevice):
device = torch.device("cuda")
elif malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatHost:
elif malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatHost):
device = torch.device("cpu")
else:
assert malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatPinned
assert malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatPinned)
device = torch.device("cpu")
pinned = True
shape = tensor_desc.shape
dtype = wholememory_dtype_to_torch_dtype(tensor_desc.dtype)
dtype = wholememory_dtype_to_torch_dtype(wmb.WholeMemoryDataType(dtype_int))
t = torch.empty(shape, dtype=dtype, device=device, pin_memory=pinned)
memory_context.set_tensor(t)
return t.data_ptr()
Expand Down
Loading