Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA][OpenCL] intelfocl #6126

Merged
merged 7 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
8 changes: 8 additions & 0 deletions cmake/modules/VTA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ elseif(PYTHON)
find_library(__cma_lib NAMES cma PATH /usr/lib)
elseif(${VTA_TARGET} STREQUAL "de10nano") # DE10-Nano rules
file(GLOB FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/de10nano/*.cc ${VTA_HW_PATH}/src/*.cc)
elseif(${VTA_TARGET} STREQUAL "intelfocl") # Intel OpenCL for FPGA rules
file(GLOB FOCL_SRC ${VTA_HW_PATH}/src/oclfpga/*.cc)
list(APPEND FPGA_RUNTIME_SRCS ${FOCL_SRC})
list(APPEND FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/vmem/virtual_memory.cc ${VTA_HW_PATH}/src/vmem/virtual_memory.h)
endif()
# Target lib: vta
add_library(vta SHARED ${FPGA_RUNTIME_SRCS})
Expand All @@ -123,6 +127,10 @@ elseif(PYTHON)
target_include_directories(vta SYSTEM PUBLIC 3rdparty)
target_include_directories(vta SYSTEM PUBLIC
"/usr/local/intelFPGA_lite/18.1/embedded/ds-5/sw/gcc/arm-linux-gnueabihf/include")
elseif(${VTA_TARGET} STREQUAL "intelfocl") # Intel OpenCL for FPGA rules
target_include_directories(vta PUBLIC 3rdparty)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
target_link_libraries(vta -lOpenCL)
endif()
endif()

Expand Down
11 changes: 7 additions & 4 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _decorate(topi_schedule):
@_register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
workload = get_workload(outs, task_name)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
tgt = Target.current()
Expand All @@ -241,18 +241,21 @@ def wrapper(outs, *args, **kwargs):
return _decorate


def get_workload(outs):
def get_workload(outs, task_name=None):
"""Retrieve the workload from outputs"""

def traverse(tensors):
"""traverse all ops to find attached workload"""
for t in tensors:
op = t.op
if "workload" in op.attrs:
return args_to_workload(op.attrs["workload"])
wkl = traverse(op.input_tensors)
if wkl:
return wkl

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind explaining the changes made to this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original code will fail if there are multiple workloads in one schedule. For example, in fused_nn_conv2d_add_add_right_shift_clip_cast_31, the conv2d and add may both have workload attrs. We have to get the correct workload by comparing the task_name.

Previously it works fine, as add is not a tunable op. But since we also want to put middle alu-only nodes (residual blocks) to VTA, such as fused_cast_cast_add_nn_relu_clip_cast_3. We create a vta schedule for add (see add.alu)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying. How do we guard against extracting add as a standalone op for other backends int his case?

if "workload" in op.attrs:
ret = args_to_workload(op.attrs["workload"])
if task_name is None or ret[0] == task_name:
return ret
return None

outs = [outs] if isinstance(outs, tensor.Tensor) else outs
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def wrapper(attrs, outs, target):
return wrapper


def wrap_topi_compute(topi_compute):
"""Wrap TOPI compute which doesn't use attrs"""

def wrapper(attrs, inputs, out_type):
return [topi_compute(*inputs)]

return wrapper


def get_conv2d_in_channels(data_shape, data_layout):
"""Get conv2d input channels"""
data_shape = get_const_tuple(data_shape)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

try:
tf_compat_v1 = tf.compat.v1
except ImportError:
except (ImportError, AttributeError):
tf_compat_v1 = tf

######################################################################
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/x86/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def bitserial_dense(
return matmul


@autotvm.register_topi_schedule("biserial_dense.x86")
@autotvm.register_topi_schedule("bitserial_dense.x86")
def schedule_bitserial_dense(cfg, outs):
"""Schedule for bitserial_dense.

Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
if (op_pattern > anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
Expand Down Expand Up @@ -309,7 +309,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
int anchor_op_pattern_{-1};
OpImplementation anchor_implementation_;
std::ostringstream readable_name_stream_;
Array<te::Operation> scalars_;
Expand Down
1 change: 0 additions & 1 deletion src/runtime/workspace_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class WorkspacePool::Pool {
}
// Release all resources
void Release(Device dev, DeviceAPI* device) {
ICHECK_EQ(allocated_.size(), 1);
for (size_t i = 1; i < free_list_.size(); ++i) {
device->FreeDataSpace(dev, free_list_[i].data);
}
Expand Down
10 changes: 0 additions & 10 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,6 @@ class BuiltinLower : public StmtExprMutator {
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
int64_t nbytes = GetVectorBytes(op->dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind explaining the reasoning behind this deletion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes special handling for kDLCPU. Otherwise, it may cause LLVM parameters match error.

Traceback (most recent call last):
  File "vta/tutorials/frontend/deploy_classification.py", line 210, in <module>
    params=params, target_host=env.target_host)
  File "/4pd/home/zhanghao/workspace/tvm-2/tvm/python/tvm/relay/build_module.py", line 251, in build
    graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
  File "/4pd/home/zhanghao/workspace/tvm-2/tvm/python/tvm/relay/build_module.py", line 120, in build
    self._build(mod, target, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 321, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 256, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 245, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(TVMFuncCall+0x4c) [0x7f385ac9bc1c]
  [bt] (7) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x316) [0x7f385ab2a566]
  [bt] (6) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)+0xe31) [0x7f385ab29c11]
  [bt] (5) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::runtime::String, tvm::IRModule, void, void> const&, tvm::Target const&)+0x3c4) [0x7f385a4322d4]
  [bt] (4) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)+0x326) [0x7f385a4318c6]
  [bt] (3) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::codegen::Build(tvm::IRModule, tvm::Target const&)+0x67a) [0x7f385a74f68a]
  [bt] (2) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(+0x1277ea1) [0x7f385ac7eea1]
  [bt] (1) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)+0x1388) [0x7f385ac82c68]
  [bt] (0) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(+0x1276a57) [0x7f385ac7da57]
  File "/4pd/home/zhanghao/workspace/tvm-2/tvm/src/target/llvm/llvm_module.cc", line 230
TVMError: LLVM module verification failed with the following errors: 
Call parameter type does not match function signature!
  %.sub = getelementptr inbounds [4 x <8 x float>], [4 x <8 x float>]* %3, i64 0, i64 0
 i8*  %34 = call i8* @VTABufferCPUPtr(i8* %17, <8 x float>* nonnull %.sub)
Call parameter type does not match function signature!
  %.sub = getelementptr inbounds [8 x float], [8 x float]* %3, i64 0, i64 0
 i8*  %31 = call i8* @VTABufferCPUPtr(i8* %14, float* nonnull %.sub)

The raise error is due to the LLVM code here (lib/IR/Verifier.cpp):

2598   // Verify that all arguments to the call match the function type.                                                                                                                                            
2599   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i)                                                                                                                                                   
2600     Assert(CS.getArgument(i)->getType() == FTy->getParamType(i),                                                                                                                                               
2601            "Call parameter type does not match function signature!",                                                                                                                                           
2602            CS.getArgument(i), FTy->getParamType(i), I); 

It will raise this error if the special handling for kDLCPU is there. I think it is because the signature for the AllocateNode is not consistent with the parameter? Any ideas about alternative fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen perhaps you'd have some input on why this code was needed in the first place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmoreau89 @tqchen @zhanghaohit

I've been searching for a bug introduced by this PR that somehow doesn't show up in CI? I've tested it locally with the docker image and still see the failure.

Anyway, if I run python/tests/onnx/test_forward.py:test_loop on main locally it fails. If I revert the change to this file, it passes.

I'm tempted to revert this PR until we can find a better way to fix this for VTA, do you guys have a better suggestion here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbrookhart I am in favor of reverting the changes applied to this file and in a separate PR we can ensure that the error encountered by @zhanghaohit is resolved while making sure that python/tests/onnx/test_forward.py:test_loop passes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would be needed to revert the entire PR, rather introduce a PR that just reverts the changes applied to this file. Given that the Intelfocl backend is not CI tested it's not going to break unit tests in TVM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the error is due to the _concatenate_shape_func, where an Any is checked against a normal shape.

The changes of this PR may introduce Any shape, thus triggering this bug.

One quick fix is to remove the assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate." as there may be Any shape in the inputs, and it should allow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwfromm took a look at the IR generated with and without this code snippet, this is what he got:

it seems like its trying to figure out how to allocate for the shape function
and without that simplifying if statement it gets into a weird recursive let construction

correct IR with fix:
allocate(v_copy_shape_func: Pointer(int64), int64, [1]) {
  attr [0] "extern_scope" = 0 {
    v_expand_dim_shape_func: Pointer(int64)[0] = 1i64
    v_expand_dim_shape_func[1] = (int64*)v_copy_shape_func[0]
  }
  attr [0] "extern_scope" = 0 {
    v_concatenate_shape_func: Pointer(int64)[0] = 0i64
    v_concatenate_shape_func[0] = ((int64*)v_concatenate_shape_func[0] + (int64*)placeholder: Pointer(int64)[0])
    v_concatenate_shape_func[0] = ((int64*)v_concatenate_shape_func[0] + (int64*)v_expand_dim_shape_func[0])
    v_concatenate_shape_func[1] = (int64*)placeholder[1]
    assert(((int64*)v_concatenate_shape_func[1] == (int64*)v_expand_dim_shape_func[1]), "Dims mismatch in the inputs of concatenate.")
    0
  }
}


naughty IR:
attr [v_expand_dim_shape_func: Pointer(int64)] "storage_alignment" = 128 {
  let v_expand_dim_shape_func = @tir.TVMBackendAllocWorkspace(1, dev_id: int32, 16u64, 0, 64, dtype=handle)
   {
    if @tir.isnullptr(v_expand_dim_shape_func, dtype=bool) {
      @tir.tvm_throw_last_error(, dtype=int32)
    }
    attr [v_copy_shape_func: Pointer(int64)] "storage_scope" = "global";
    attr [v_copy_shape_func] "storage_alignment" = 128 {
      let v_copy_shape_func = @tir.TVMBackendAllocWorkspace(1, dev_id, 8u64, 0, 64, dtype=handle)
       {
        if @tir.isnullptr(v_copy_shape_func, dtype=bool) {
          @tir.tvm_throw_last_error(, dtype=int32)
        }
         {
          attr [0] "extern_scope" = 0 {
            v_expand_dim_shape_func[0] = 1i64
            v_expand_dim_shape_func[1] = (int64*)v_copy_shape_func[0]
          }
          attr [0] "extern_scope" = 0 {
            v_concatenate_shape_func: Pointer(int64)[0] = 0i64
            v_concatenate_shape_func[0] = ((int64*)v_concatenate_shape_func[0] + (int64*)placeholder: Pointer(int64)[0])
            v_concatenate_shape_func[0] = ((int64*)v_concatenate_shape_func[0] + (int64*)v_expand_dim_shape_func[0])
            v_concatenate_shape_func[1] = (int64*)placeholder[1]
            assert(((int64*)v_concatenate_shape_func[1] == (int64*)v_expand_dim_shape_func[1]), "Dims mismatch in the inputs of concatenate.")
            0
          }
        }
      }
      if (@tir.TVMBackendFreeWorkspace(1, dev_id, v_copy_shape_func, dtype=int32) != 0) {
        @tir.tvm_throw_last_error(, dtype=int32)
      }
    }
  }
  if (@tir.TVMBackendFreeWorkspace(1, dev_id, v_expand_dim_shape_func, dtype=int32) != 0) {
    @tir.tvm_throw_last_error(, dtype=int32)
  }
}
i have a feeling this could cause some non trivial performance regressions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is fundamentally changing how shape functions get lowered, I don't think that just removing the assert is the right way to go about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the error is due to the _concatenate_shape_func, where an Any is checked against a normal shape.

The changes of this PR may introduce Any shape, thus triggering this bug.

One quick fix is to remove the assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate." as there may be Any shape in the inputs, and it should allow.

:/ It should not be possible to pass Any to a shape function, everything should be concrete in the VM before the shape_func is called, I'm not sure I buy this.

if (device_type_.defined()) {
if (const auto* dev_type = device_type_.as<IntImmNode>()) {
if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
}
}
}
}
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/autotvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def reprogram_fpga(remote, _build_result):
_build_result : tvm.autotvm.measure.measure_methods.BuildResult
Artifact from the build phase, unused here.
"""
rpc_client.program_bitstream(remote, bitstream)
rpc_client.program_fpga(remote, bitstream)
rpc_client.reconfig_runtime(remote)

return default_module_loader(reprogram_fpga)
4 changes: 3 additions & 1 deletion vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ class DevContext(object):
MEM_ID_INP = 2
MEM_ID_ACC = 3
MEM_ID_OUT = 4
MEM_ID_ACC_8BIT = 5
# VTA ALU Opcodes
ALU_OPCODE_MIN = 0
ALU_OPCODE_MAX = 1
ALU_OPCODE_ADD = 2
ALU_OPCODE_SHR = 3
ALU_OPCODE_MUL = 4
# Task queue id (pipeline stage)
QID_LOAD_INP = 1
QID_LOAD_WGT = 1
Expand Down Expand Up @@ -232,7 +234,7 @@ def target_host(self):
return "llvm -mtriple=armv7-none-linux-gnueabihf"
if self.TARGET == "ultra96":
return "llvm -mtriple=aarch64-linux-gnu"
if self.TARGET in ["sim", "tsim"]:
if self.TARGET in ["sim", "tsim", "intelfocl"]:
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)

Expand Down
14 changes: 13 additions & 1 deletion vta/python/vta/program_bitstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,26 @@ def de10nano_bitstream_program(bitstream_path):
program(bitstream_path)


def bitstream_program(target, bitstream):
def intelfocl_bitstream_program(bitstream_path, mem_size=4 * 1024 * 1024 * 1024):
# pylint: disable=import-outside-toplevel
from tvm import get_global_func

program = get_global_func("vta.oclfpga.program")
program(bitstream_path, mem_size)


def bitstream_program(target, bitstream, *args):
"""program bitstream to devices"""

if target in ["pynq", "ultra96"]:
pynq_bitstream_program(bitstream)
elif target in ["de10nano"]:
de10nano_bitstream_program(bitstream)
elif target in ["sim", "tsim"]:
# In simulation, bit stream programming is a no-op
return
elif target in ["intelfocl"]:
intelfocl_bitstream_program(bitstream, *args)
else:
raise RuntimeError("Unknown target {}".format(target))

Expand Down
14 changes: 10 additions & 4 deletions vta/python/vta/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""VTA RPC client function"""
import os

from tvm import rpc
from vta import program_bitstream
from .environment import get_env
from .bitstream import download_bitstream, get_bitstream_path

Expand Down Expand Up @@ -45,16 +47,20 @@ def program_fpga(remote, bitstream=None):
bitstream : str, optional
Path to a local bistream file. If unset, tries to download from cache server.
"""
env = get_env()

if bitstream:
assert os.path.isfile(bitstream)
else:
bitstream = get_bitstream_path()
if not os.path.isfile(bitstream):
env = get_env()
if env.TARGET == "de10nano":
return
download_bitstream()

fprogram = remote.get_function("tvm.contrib.vta.init")
remote.upload(bitstream)
fprogram(os.path.basename(bitstream))
if isinstance(remote, rpc.LocalSession):
program_bitstream.bitstream_program(env.TARGET, bitstream)
else:
fprogram = remote.get_function("tvm.contrib.vta.init")
remote.upload(bitstream)
fprogram(os.path.basename(bitstream))
8 changes: 7 additions & 1 deletion vta/python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def _load_sw():
"""Load hardware library for simulator."""

env = get_env()
lib_driver_name = "libvta_tsim" if env.TARGET == "tsim" else "libvta_fsim"
lib_driver_name = (
"libvta_tsim"
if env.TARGET == "tsim"
else "libvta"
if env.TARGET == "intelfocl"
else "libvta_fsim"
)
require_sim = env.TARGET in ("sim", "tsim")
libs = []

Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run(run_func):
"""
env = get_env()

if env.TARGET in ["sim", "tsim"]:
if env.TARGET in ["sim", "tsim", "intelfocl"]:
# Talk to local RPC if necessary to debug RPC server.
# Compile vta on your host with make at the root.
# Make sure TARGET is set to "sim" in the config.json file.
Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def visit_call(self, call):
self.start_pack and call.op == op.op.get("cast") and input_types[0].dtype == "int32"
):
cast = relay.Call(op.op.get("cast"), [args[0]], call.attrs)
return relay.Call(op.op.get("copy"), [cast])
return cast
elif call.op == self.pad:
pad_width = call.attrs.pad_width
if len(pad_width) == 6:
Expand Down
Loading