diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index ccbd0552..3fbaf13a 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -426,14 +426,10 @@ cdef void python_cb_wrapper_temp_create_context(void** memory_context, cdef PyObject * ret_memory_context = NULL with gil: wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_create_context_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(1) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 0, python_global_context) - py_memory_context = PyObject_CallObject( python_fn, args) + fn = wrapped_global_context.temp_create_context_fn + ctx = wrapped_global_context.temp_global_context + py_memory_context = fn(ctx) ret_memory_context = py_memory_context - Py_DECREF(args) Py_INCREF(ret_memory_context) ( memory_context)[0] = ret_memory_context return @@ -442,15 +438,10 @@ cdef void python_cb_wrapper_temp_destroy_context(void * memory_context, void * global_context) nogil: with gil: wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_destroy_context_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) + fn = wrapped_global_context.temp_destroy_context_fn + ctx = wrapped_global_context.temp_global_context + mem_ctx = memory_context + fn(mem_ctx, ctx) Py_DECREF( memory_context) return @@ -461,38 +452,23 @@ cdef void * python_cb_wrapper_temp_malloc(wholememory_tensor_description_t * ten cdef int64_t res_ptr = 0 with gil: wrapped_global_context = global_context - py_tensor_desc = PyWholeMemoryTensorDescription() - py_tensor_desc.set_by_tensor_desc(tensor_desc) - py_malloc_type = PyMemoryAllocType() - py_malloc_type.set_type(malloc_type) - python_fn = wrapped_global_context.temp_malloc_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(4) - Py_INCREF(py_tensor_desc) - PyTuple_SetItem(args, 0, py_tensor_desc) - Py_INCREF(py_malloc_type) - PyTuple_SetItem(args, 1, py_malloc_type) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 2, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 3, python_global_context) - res_ptr = PyLong_AsLongLong(PyObject_CallObject( python_fn, args)) - Py_DECREF(args) + py_shape = tuple([tensor_desc.sizes[i] for i in range(tensor_desc.dim)]) + py_dtype = int(tensor_desc.dtype) + py_malloc_type_int = int(malloc_type) + fn = wrapped_global_context.temp_malloc_fn + ctx = wrapped_global_context.temp_global_context + mem_ctx = memory_context + res_ptr = fn(py_shape, py_dtype, py_malloc_type_int, mem_ctx, ctx) return res_ptr cdef void python_cb_wrapper_temp_free(void * memory_context, void * global_context) nogil: with gil: wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_free_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) + fn = wrapped_global_context.temp_free_fn + ctx = wrapped_global_context.temp_global_context + mem_ctx = memory_context + fn(mem_ctx, ctx) return cdef void * python_cb_wrapper_output_malloc(wholememory_tensor_description_t * tensor_desc, @@ -502,38 +478,23 @@ cdef void * python_cb_wrapper_output_malloc(wholememory_tensor_description_t * t cdef int64_t res_ptr = 0 with gil: wrapped_global_context = global_context - py_tensor_desc = PyWholeMemoryTensorDescription() - py_tensor_desc.set_by_tensor_desc(tensor_desc) - py_malloc_type = PyMemoryAllocType() - py_malloc_type.set_type(malloc_type) - python_fn = wrapped_global_context.output_malloc_fn - python_global_context = wrapped_global_context.output_global_context - args = PyTuple_New(4) - Py_INCREF(py_tensor_desc) - PyTuple_SetItem(args, 0, py_tensor_desc) - Py_INCREF(py_malloc_type) - PyTuple_SetItem(args, 1, py_malloc_type) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 2, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 3, python_global_context) - res_ptr = PyLong_AsLongLong(PyObject_CallObject( python_fn, args)) - Py_DECREF(args) + py_shape = tuple([tensor_desc.sizes[i] for i in range(tensor_desc.dim)]) + py_dtype = int(tensor_desc.dtype) + py_malloc_type_int = int(malloc_type) + fn = wrapped_global_context.output_malloc_fn + ctx = wrapped_global_context.output_global_context + mem_ctx = memory_context + res_ptr = fn(py_shape, py_dtype, py_malloc_type_int, mem_ctx, ctx) return res_ptr cdef void python_cb_wrapper_output_free(void * memory_context, void * global_context) nogil: with gil: wrapped_global_context = global_context - python_fn = wrapped_global_context.output_free_fn - python_global_context = wrapped_global_context.output_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) + fn = wrapped_global_context.output_free_fn + ctx = wrapped_global_context.output_global_context + mem_ctx = memory_context + fn(mem_ctx, ctx) return diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index d9c90a5e..0a9e1fae 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -88,23 +88,23 @@ def torch_destroy_memory_context_env_fn( def torch_malloc_env_fn( - tensor_desc: wmb.PyWholeMemoryTensorDescription, - malloc_type: wmb.PyMemoryAllocType, + shape: tuple, + dtype_int: int, + malloc_type_int: int, memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext, ) -> int: pinned = False device = None - if malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatDevice: + if malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatDevice): device = torch.device("cuda") - elif malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatHost: + elif malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatHost): device = torch.device("cpu") else: - assert malloc_type.get_type() == wmb.WholeMemoryMemoryAllocType.MatPinned + assert malloc_type_int == int(wmb.WholeMemoryMemoryAllocType.MatPinned) device = torch.device("cpu") pinned = True - shape = tensor_desc.shape - dtype = wholememory_dtype_to_torch_dtype(tensor_desc.dtype) + dtype = wholememory_dtype_to_torch_dtype(wmb.WholeMemoryDataType(dtype_int)) t = torch.empty(shape, dtype=dtype, device=device, pin_memory=pinned) memory_context.set_tensor(t) return t.data_ptr()