Skip to content

Commit e922bee

Browse files
committed
[FFI][REFACTOR][ABI] Rename NDArray to Tensor
This PR Updates the NDArray => Tensor. Both tensor and ndarray are commonly used terms. Because the term Tensor is getting more common in the context of ML, we do the rename to stay more aligned with torch.Tensor and DLTensor.
1 parent e1700e1 commit e922bee

File tree

484 files changed

+3689
-3668
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

484 files changed

+3689
-3668
lines changed

apps/android_rpc/app/src/main/jni/tvm_runtime.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,15 @@
4343
#include "../ffi/src/ffi/extra/module.cc"
4444
#include "../ffi/src/ffi/extra/testing.cc"
4545
#include "../ffi/src/ffi/function.cc"
46-
#include "../ffi/src/ffi/ndarray.cc"
4746
#include "../ffi/src/ffi/object.cc"
47+
#include "../ffi/src/ffi/tensor.cc"
4848
#include "../ffi/src/ffi/traceback.cc"
4949
#include "../src/runtime/cpu_device_api.cc"
5050
#include "../src/runtime/device_api.cc"
5151
#include "../src/runtime/file_utils.cc"
5252
#include "../src/runtime/logging.cc"
5353
#include "../src/runtime/memory/memory_manager.cc"
5454
#include "../src/runtime/minrpc/minrpc_logger.cc"
55-
#include "../src/runtime/ndarray.cc"
5655
#include "../src/runtime/profiling.cc"
5756
#include "../src/runtime/registry.cc"
5857
#include "../src/runtime/rpc/rpc_channel.cc"
@@ -63,6 +62,7 @@
6362
#include "../src/runtime/rpc/rpc_server_env.cc"
6463
#include "../src/runtime/rpc/rpc_session.cc"
6564
#include "../src/runtime/rpc/rpc_socket_impl.cc"
65+
#include "../src/runtime/tensor.cc"
6666
#include "../src/runtime/thread_pool.cc"
6767
#include "../src/runtime/threading_backend.cc"
6868
#include "../src/runtime/workspace_pool.cc"

apps/android_rpc/tests/android_rpc_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def test_rpc_module():
7272
dev = remote.cl(0)
7373
remote.upload(path_dso_cl)
7474
f1 = remote.load_module("dev_lib_cl.so")
75-
a = tvm.nd.array(a_np, dev)
76-
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
75+
a = tvm.runtime.tensor(a_np, dev)
76+
b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev)
7777
time_f = f1.time_evaluator(f1.entry_name, dev, number=10)
7878
cost = time_f(a, b).mean
7979
print("%g secs/op\n" % cost)

apps/hexagon_launcher/launcher_core.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#include <tvm/ffi/function.h>
2626
#include <tvm/runtime/data_type.h>
2727
#include <tvm/runtime/module.h>
28-
#include <tvm/runtime/ndarray.h>
28+
#include <tvm/runtime/tensor.h>
2929

3030
#include <string>
3131
#include <vector>

apps/hexagon_launcher/launcher_hexagon.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu
137137
};
138138
DLManagedTensor managed{tensor, /*manager_ctx*/ nullptr, /*deleter*/ nullptr};
139139

140-
auto input = tvm::runtime::NDArray::FromDLPack(&managed);
140+
auto input = tvm::runtime::Tensor::FromDLPack(&managed);
141141

142142
tvm::ffi::Function set_input = get_module_func(TheModel->model_executor, "set_input");
143143
set_input(input_idx, input);
@@ -172,17 +172,17 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out
172172
}
173173

174174
tvm::ffi::Function get_output = get_module_func(TheModel->model_executor, "get_output");
175-
tvm::runtime::NDArray output = get_output(output_idx);
175+
tvm::runtime::Tensor output = get_output(output_idx);
176176

177177
std::vector<int64_t> shape_vec{output->shape, output->shape + output->ndim};
178178

179-
auto* container = new tvm::runtime::NDArray::Container(
180-
static_cast<void*>(output_value), shape_vec, output->dtype, Model::external());
179+
auto* container = new tvm::runtime::Tensor::Container(static_cast<void*>(output_value), shape_vec,
180+
output->dtype, Model::external());
181181
container->SetDeleter([](tvm::Object* container) {
182-
delete static_cast<tvm::runtime::NDArray::Container*>(container);
182+
delete static_cast<tvm::runtime::Tensor::Container*>(container);
183183
});
184184

185-
tvm::runtime::NDArray host_output(tvm::runtime::GetObjectPtr<tvm::runtime::Object>(container));
185+
tvm::runtime::Tensor host_output(tvm::runtime::GetObjectPtr<tvm::runtime::Object>(container));
186186

187187
if (meta_size != 0) {
188188
auto* meta = reinterpret_cast<tensor_meta*>(output_meta);

apps/ios_rpc/tests/ios_rpc_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def test_rpc_module(host, port, key, mode):
7272
dev = remote.metal(0)
7373
f1 = remote.load_module("dev_lib.dylib")
7474
a_np = np.random.uniform(size=1024).astype(A.dtype)
75-
a = tvm.nd.array(a_np, dev)
76-
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev)
75+
a = tvm.runtime.tensor(a_np, dev)
76+
b = tvm.runtime.tensor(np.zeros(1024, dtype=A.dtype), dev)
7777
time_f = f1.time_evaluator(f1.entry_name, dev, number=10)
7878
cost = time_f(a, b).mean
7979
print("Metal: %g secs/op" % cost)

docs/arch/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
133133
import tvm
134134
# Example runtime execution program in python, with type annotated
135135
mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
136-
arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0))
136+
arr: tvm.runtime.Tensor = tvm.runtime.tensor([1, 2, 3], device=tvm.cuda(0))
137137
fun: tvm.runtime.PackedFunc = mod["addone"]
138138
fun(arr)
139139
print(arr.numpy())
@@ -142,7 +142,7 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
142142
:py:class:`tvm.runtime.Module` encapsulates the result of compilation. A runtime.Module contains a GetFunction method to obtain PackedFuncs by name.
143143

144144
:py:class:`tvm.runtime.PackedFunc` is a type-erased function interface for both the generated functions. A runtime.PackedFunc can take arguments and return values with the
145-
following types: POD types(int, float), string, runtime.PackedFunc, runtime.Module, runtime.NDArray, and other sub-classes of runtime.Object.
145+
following types: POD types(int, float), string, runtime.PackedFunc, runtime.Module, runtime.Tensor, and other sub-classes of runtime.Object.
146146

147147
:py:class:`tvm.runtime.Module` and :py:class:`tvm.runtime.PackedFunc` are powerful mechanisms to modularize the runtime. For example, to get the above `addone` function on CUDA, we can use LLVM to generate the host-side code to compute the launching parameters(e.g. size of the thread groups) and then call into another PackedFunc from a CUDAModule that is backed by the CUDA driver API. The same mechanism can be used for OpenCL kernels.
148148

@@ -155,7 +155,7 @@ The above example only deals with a simple `addone` function. The code snippet b
155155
factory: tvm.runtime.Module = tvm.runtime.load_module("resnet18.so")
156156
# Create a stateful graph execution module for resnet18 on cuda(0)
157157
gmod: tvm.runtime.Module = factory["resnet18"](tvm.cuda(0))
158-
data: tvm.runtime.NDArray = get_input_data()
158+
data: tvm.runtime.Tensor = get_input_data()
159159
# set input
160160
gmod["set_input"](0, data)
161161
# execute the model

docs/deep_dive/tensor_ir/tutorials/tir_creation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def mm_relu(a: T.handle, b: T.handle, c: T.handle):
204204

205205

206206
def evaluate_dynamic_shape(lib: tvm.runtime.Module, m: int, n: int, k: int):
207-
A = tvm.nd.array(np.random.uniform(size=(m, k)).astype("float32"))
208-
B = tvm.nd.array(np.random.uniform(size=(k, n)).astype("float32"))
209-
C = tvm.nd.array(np.zeros((m, n), dtype="float32"))
207+
A = tvm.runtime.tensor(np.random.uniform(size=(m, k)).astype("float32"))
208+
B = tvm.runtime.tensor(np.random.uniform(size=(k, n)).astype("float32"))
209+
C = tvm.runtime.tensor(np.zeros((m, n), dtype="float32"))
210210
lib(A, B, C)
211211
return C.numpy()
212212

docs/deep_dive/tensor_ir/tutorials/tir_transformation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def main(
7272
b_np = np.random.uniform(size=(128, 128)).astype("float32")
7373
c_np = a_np @ b_np
7474

75-
a_nd = tvm.nd.array(a_np)
76-
b_nd = tvm.nd.array(b_np)
77-
c_nd = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
75+
a_nd = tvm.runtime.tensor(a_np)
76+
b_nd = tvm.runtime.tensor(b_np)
77+
c_nd = tvm.runtime.tensor(np.zeros((128, 128), dtype="float32"))
7878

7979

8080
def evaluate(mod: tvm.IRModule):

docs/get_started/tutorials/ir_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def main(
237237
vm = relax.VirtualMachine(exec, dev)
238238

239239
raw_data = np.random.rand(1, 784).astype("float32")
240-
data = tvm.nd.array(raw_data, dev)
240+
data = tvm.runtime.tensor(raw_data, dev)
241241
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
242242
print(cpu_out)
243243

@@ -267,8 +267,8 @@ def main(
267267
dev = tvm.device("cuda", 0)
268268
vm = relax.VirtualMachine(exec, dev)
269269
# Need to allocate data and params on GPU device
270-
data = tvm.nd.array(raw_data, dev)
271-
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
270+
data = tvm.runtime.tensor(raw_data, dev)
271+
gpu_params = [tvm.runtime.tensor(p, dev) for p in params_from_torch["main"]]
272272
gpu_out = vm["main"](data, *gpu_params).numpy()
273273
print(gpu_out)
274274

docs/get_started/tutorials/quick_start.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def forward(self, x):
141141
device = tvm.cpu()
142142
vm = relax.VirtualMachine(ex, device)
143143
data = np.random.rand(1, 784).astype("float32")
144-
tvm_data = tvm.nd.array(data, device=device)
144+
tvm_data = tvm.runtime.tensor(data, device=device)
145145
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
146-
params = [tvm.nd.array(param, device=device) for param in params]
146+
params = [tvm.runtime.tensor(param, device=device) for param in params]
147147
print(vm["forward"](tvm_data, *params).numpy())
148148

149149
################################################################################
@@ -158,14 +158,14 @@ def forward(self, x):
158158
# prefill_logits = vm["prefill"](inputs, weight, kv_cache)
159159
# decoded_logits = vm["decode"](inputs, weight, kv_cache)
160160
#
161-
# - TVM runtime comes with native data structures, such as NDArray, can also have zero
161+
# - TVM runtime comes with native data structures, such as Tensor, can also have zero
162162
# copy exchange with existing ecosystem (DLPack exchange with PyTorch)
163163
#
164164
# .. code-block:: Python
165165
#
166-
# # Convert PyTorch tensor to TVM NDArray
167-
# x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack())
168-
# # Convert TVM NDArray to PyTorch tensor
166+
# # Convert PyTorch tensor to TVM Tensor
167+
# x_tvm = tvm.runtime.from_dlpack(x_torch.to_dlpack())
168+
# # Convert TVM Tensor to PyTorch tensor
169169
# x_torch = torch.from_dlpack(x_tvm.to_dlpack())
170170
#
171171
# - TVM runtime works in non-python environments, so it works on settings such as mobile
@@ -175,14 +175,14 @@ def forward(self, x):
175175
# // C++ snippet
176176
# runtime::Module vm = ex.GetFunction("load_executable")();
177177
# vm.GetFunction("init")(...);
178-
# NDArray out = vm.GetFunction("prefill")(data, weight, kv_cache);
178+
# Tensor out = vm.GetFunction("prefill")(data, weight, kv_cache);
179179
#
180180
# .. code-block:: Java
181181
#
182182
# // Java snippet
183183
# Module vm = ex.getFunction("load_executable").invoke();
184184
# vm.getFunction("init").pushArg(...).invoke;
185-
# NDArray out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke();
185+
# Tensor out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke();
186186
#
187187

188188
################################################################################

0 commit comments

Comments
 (0)