Skip to content

Commit 87c929f

Browse files
yajiedesigntqchen
authored andcommitted
add msvc in cc (#531)
1 parent 85c545c commit 87c929f

File tree

7 files changed

+92
-12
lines changed

7 files changed

+92
-12
lines changed

python/tvm/contrib/cc.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import sys
55
import subprocess
66

7+
import os
8+
from .util import tempdir
9+
10+
711
def create_shared(output,
812
objects,
913
options=None,
@@ -24,26 +28,85 @@ def create_shared(output,
2428
cc : str, optional
2529
The compile string.
2630
"""
31+
if sys.platform == "darwin" or sys.platform.startswith('linux'):
32+
_linux_shared(output, objects, options, cc)
33+
elif sys.platform == "win32":
34+
_windows_shared(output, objects, options)
35+
else:
36+
raise ValueError("Unsupported platform")
37+
38+
39+
def _linux_shared(output, objects, options, cc="g++"):
2740
cmd = [cc]
2841
cmd += ["-shared", "-fPIC"]
29-
3042
if sys.platform == "darwin":
3143
cmd += ["-undefined", "dynamic_lookup"]
3244
cmd += ["-o", output]
33-
3445
if isinstance(objects, str):
3546
cmd += [objects]
3647
else:
3748
cmd += objects
38-
3949
if options:
4050
cmd += options
41-
4251
proc = subprocess.Popen(
4352
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
4453
(out, _) = proc.communicate()
54+
if proc.returncode != 0:
55+
msg = "Compilation error:\n"
56+
msg += str(out)
57+
raise RuntimeError(msg)
58+
59+
60+
def _windows_shared(output, objects, options):
61+
cl_cmd = ["cl"]
62+
cl_cmd += ["-c"]
63+
if isinstance(objects, str):
64+
objects = [objects]
65+
cl_cmd += objects
66+
if options:
67+
cl_cmd += options
68+
69+
temp = tempdir()
70+
dllmain_path = temp.relpath("dllmain.cc")
71+
with open(dllmain_path, "w") as dllmain_obj:
72+
dllmain_obj.write('#include <windows.h>\
73+
BOOL APIENTRY DllMain( HMODULE hModule,\
74+
DWORD ul_reason_for_call,\
75+
LPVOID lpReserved)\
76+
{return TRUE;}')
77+
78+
cl_cmd += [dllmain_path]
79+
80+
temp_path = dllmain_path.replace("dllmain.cc", "")
81+
cl_cmd += ["-Fo:" + temp_path]
82+
83+
proc = subprocess.Popen(
84+
cl_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
85+
(out, _) = proc.communicate()
86+
if proc.returncode != 0:
87+
msg = "Compilation error:\n"
88+
msg += str(out)
89+
raise RuntimeError(msg)
90+
link_cmd = ["link"]
91+
link_cmd += ["-dll", "-FORCE:MULTIPLE"]
92+
93+
for obj in objects:
94+
if obj.endswith(".cc"):
95+
(_, temp_file_name) = os.path.split(obj)
96+
(shot_name, _) = os.path.splitext(temp_file_name)
97+
link_cmd += [os.path.join(temp_path, shot_name + ".obj")]
98+
if obj.endswith(".o"):
99+
link_cmd += [obj]
100+
101+
link_cmd += ["-EXPORT:__tvm_main__"]
102+
link_cmd += [temp_path + "dllmain.obj"]
103+
link_cmd += ["-out:" + output]
45104

105+
proc = subprocess.Popen(
106+
link_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
107+
(out, _) = proc.communicate()
46108
if proc.returncode != 0:
47109
msg = "Compilation error:\n"
48-
msg += out
110+
msg += str(out)
111+
49112
raise RuntimeError(msg)

python/tvm/contrib/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616
def remove(self):
1717
"""Remote the tmp dir"""
1818
if self.temp_dir:
19-
self._rmtree(self.temp_dir)
19+
self._rmtree(self.temp_dir, ignore_errors=True)
2020
self.temp_dir = None
2121

2222
def __del__(self):

python/tvm/module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from __future__ import absolute_import as _abs
33

44
from collections import namedtuple
5+
56
from ._ffi.function import ModuleBase, _set_class_module
67
from ._ffi.function import _init_api
78
from .contrib import cc as _cc, tar as _tar, util as _util
89

910
ProfileResult = namedtuple("ProfileResult", ["mean"])
1011

12+
1113
class Module(ModuleBase):
1214
"""Module container of all TVM generated functions"""
15+
1316
def __repr__(self):
1417
return "Module(%s, %x)" % (self.type_key, self.handle.value)
1518

@@ -135,11 +138,13 @@ def time_evaluator(self, func_name, ctx, number):
135138
try:
136139
feval = _RPCTimeEvaluator(
137140
self, func_name, ctx.device_type, ctx.device_id, number)
141+
138142
def evaluator(*args):
139143
"""Internal wrapped evaluator."""
140144
# Wrap feval so we can add more stats in future.
141145
mean = feval(*args)
142146
return ProfileResult(mean=mean)
147+
143148
return evaluator
144149
except NameError:
145150
raise NameError("time_evaluate is only supported when RPC is enabled")

src/codegen/llvm/codegen_cpu.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(
226226
name);
227227
gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
228228
gv->setInitializer(llvm::Constant::getNullValue(p_type));
229+
gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
229230
return gv;
230231
}
231232

src/codegen/llvm/codegen_llvm.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
117117
ftype, llvm::Function::ExternalLinkage,
118118
f->name, module_.get());
119119
function_->setCallingConv(llvm::CallingConv::C);
120+
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
120121
// set var map and align information
121122
auto arg_it = function_->arg_begin();
122123
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {

src/codegen/llvm/llvm_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <llvm/Target/TargetMachine.h>
4242
#include <llvm/Target/TargetOptions.h>
4343
#include <llvm/IRReader/IRReader.h>
44+
#include <llvm/CodeGen/TargetLoweringObjectFileImpl.h>
4445

4546
#include <utility>
4647
#include <string>

tests/python/unittest/test_module_load.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tvm.contrib import cc, util
33
import ctypes
44
import os
5+
import sys
56
import numpy as np
67
import subprocess
78

@@ -88,18 +89,25 @@ def check_device(device):
8889
return
8990
temp = util.tempdir()
9091
name = "myadd_%s" % device
91-
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
92+
if sys.platform == "darwin" or sys.platform.startswith('linux'):
93+
f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name)
94+
elif sys.platform == "win32":
95+
f = tvm.build(s, [A, B], device, "llvm", name=name)
96+
else:
97+
raise ValueError("Unsupported platform")
98+
9299
path_dso = temp.relpath("dev_lib.so")
93100
f.export_library(path_dso)
94101

95102
f1 = tvm.module.load(path_dso)
96103
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
97104
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
98105
f1(a, b)
99-
f2 = tvm.module.system_lib()
100-
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
101-
f2[name](a, b)
102106
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
107+
if sys.platform != "win32":
108+
f2 = tvm.module.system_lib()
109+
f2[name](a, b)
110+
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
103111

104112
check_device("cuda")
105113
check_device("opencl")
@@ -164,8 +172,9 @@ def check_system_lib():
164172
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
165173
mm['myadd2'](a, b)
166174
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
167-
168-
check_system_lib()
175+
176+
if sys.platform != "win32":
177+
check_system_lib()
169178
check_llvm()
170179

171180

0 commit comments

Comments
 (0)