Skip to content

Commit

Permalink
A simple example of returning an array and converting it to a tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Sopel97 committed Nov 7, 2020
1 parent 4157291 commit 836308c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
__pycache__/
env/
build/
logs/
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ cmake_minimum_required(VERSION 3.0)

project(data_loader)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED 17)

add_library(data_loader SHARED data_loader.cpp)

find_package(Threads REQUIRED)

target_link_libraries(data_loader Threads::Threads)

install(TARGETS data_loader RUNTIME DESTINATION .)
17 changes: 17 additions & 0 deletions data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,21 @@ extern "C" {
std::cout<< "test successful\n";
}

struct TestDataCollection
{
int size;
int* data;
};

EXPORT TestDataCollection* CDECL create_data_collection()
{
return new TestDataCollection{ 10, new int[10]{} };
}

EXPORT void CDECL destroy_data_collection(TestDataCollection* ptr)
{
delete ptr->data;
delete ptr;
}

}
44 changes: 40 additions & 4 deletions test_dll_call.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,46 @@
from ctypes import *
from ctypes.util import find_library
import ctypes
import numpy as np

filename = 'data_loader.dll'
dll = CDLL(filename)
dll = ctypes.CDLL('./data_loader.dll')
print(dll)
print(dll.test)
dll.test()

import torch

print(dll.create_data_collection)
print(dll.destroy_data_collection)

class TestDataCollection(ctypes.Structure):
_fields_ = [
('size', ctypes.c_int),
('data', ctypes.POINTER(ctypes.c_int))
]

def __str__(self):
return str(self.size) + ' ' + str([self.data[i] for i in range(self.size)])

def get_tensor(self):
return torch.LongTensor(np.ctypeslib.as_array(self.data, shape=(self.size,)))

TestDataCollectionPtr = ctypes.POINTER(TestDataCollection)

create_data_collection = dll.create_data_collection
create_data_collection.restype = TestDataCollectionPtr

destroy_data_collection = dll.destroy_data_collection
destroy_data_collection.argtypes = [TestDataCollectionPtr]

v = dll.create_data_collection()
print(v)
vobj = v.contents
print(vobj)

tensor = vobj.get_tensor();
print(tensor)

destroy_data_collection(v)

for i in range(1000000):
v = create_data_collection()
destroy_data_collection(v)

0 comments on commit 836308c

Please sign in to comment.