Skip to content

Commit

Permalink
update folder name of DIOPI
Browse files Browse the repository at this point in the history
  • Loading branch information
xintian-514 committed Jul 6, 2023
2 parents e726f6d + 62ca03f commit 3dbbc1c
Show file tree
Hide file tree
Showing 12 changed files with 874 additions and 17 deletions.
14 changes: 1 addition & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ DIPU 的这一部分主要就是对 PyTorch 的``c10`` 和``c10d``相关接口

DIPU CPP层适配的 ATen 算子对应的是分派过程中最底层(*backend*层) 的算子或者 *composite* 层里等效为 *backend* 的算子。

这里面有一定的灵活性,以``Linear``算子为例,在 PyTorch 的 ``cpu/cuda`` 设备上,它被实现为一个 ``composite`` 算子,实际的 *backend* 层算子是组合算子内部调用的 ``addmm`` 或者更底层的 ``mm``。 而在 DIPU (``privateuse1``) 设备中,目前是注册了 一个 ``Linear`` 算子 ( DIOPI 有这个算子 ) 来替代组合算子,所以分派会直接走到新的 *backend* 层算子 ``Linear`` ,而不会在调用原来的 ``addmm/mm``。但是如果对应设备的 DIOPI/impl 算子库 没有实现 ``diopiLinear`` 而是实现了 ``mm`` 算子,也是可以正常走通 ``Linear`` 的调用流程的。
这里面有一定的灵活性,以``Linear``算子为例,在 PyTorch 的 ``cpu/cuda`` 设备上,它被实现为一个 ``composite`` 算子,实际的 *backend* 层算子是组合算子内部调用的 ``addmm`` 或者更底层的 ``mm``。 而在 DIPU (``privateuse1``) 设备中,目前是注册了 一个 ``Linear`` 算子 ( DIOPI 有这个算子 ) 来替代组合算子,所以分派会直接走到新的 *backend* 层算子 ``Linear`` ,而不会在调用原来的 ``addmm/mm``。但是如果对应设备的 DIOPI 的 IMPL 算子库 没有实现 ``diopiLinear`` 而是实现了 ``mm`` 算子,也是可以正常走通 ``Linear`` 的调用流程的。

### 无侵入式的 PyTorch 扩展包:
DIPU 没有直接修改 PyTorch 的代码,而是使用 out-of-tree 的方式接入新设备,详见[参考文档](https://pytorch.org/tutorials/advanced/extend_dispatcher.html)
Expand All @@ -93,18 +93,6 @@ DIPU 的这一部分主要就是对 PyTorch 的``c10`` 和``c10d``相关接口

更多信息请参考:[dipu/tests](https://github.com/DeepLink-org/dipu/tree/main/tests)


## 应用行业
#### 1. 对硬件行业:

实现软硬件解耦,根本性的破除生态壁垒。仅适配算子即可使用 PyTorch 最新版本的多种能力。

#### 2. 对应用行业:

实现主流框架与芯片高效适配,极大降低算力使用门槛,激活算力使用需求。
融入 Torch 软件大生态,海量的优秀算法框架,让模型训练更简单。比如基于 DIPU,无需修改模型即可支持在不同的国产化设备上运行 ``mmcv`` 系列模型。理论上也可以支持其他模型在最少量修改后运行在多产国产硬件上。


## Learn More

* [使用/设备接入教学](https://github.com/DeepLink-org/dipu/blob/main/SOP.md)
Expand Down
6 changes: 4 additions & 2 deletions scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,8 @@
size[0] = std::floor(self.size(-2) * scales_h.value_or(1.0));
size[1] = std::floor(self.size(-1) * scales_w.value_or(1.0));
}
interface: diopiUpsampleLinear(ctx, out, self, size, align_corners, "bilinear");
const char* mode = "bilinear";
interface: diopiUpsampleLinear(ctx, out, self, size, align_corners, mode);

- schema: "upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)"
size_attr: [size]
Expand All @@ -1265,7 +1266,8 @@
size[0] = std::floor((*(input_sizeVector.rbegin() + 1)) * scales_h.value_or(1.0));
size[1] = std::floor((*(input_sizeVector.rbegin())) * scales_w.value_or(1.0));
}
interface: diopiUpsampleLinearBackward(ctx, grad_input, grad_output, size, input_size, align_corners, "bilinear")
const char* mode = "bilinear";
interface: diopiUpsampleLinearBackward(ctx, grad_input, grad_output, size, input_size, align_corners, mode)

- schema: "sin(Tensor self) -> Tensor"
custom_code_at_the_beginning: |
Expand Down
5 changes: 4 additions & 1 deletion scripts/autogen_diopi_wrapper/diopi_wrapper_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "csrc_dipu/aten/DIPUATenFunctions.h"
#include "csrc_dipu/aten/RegisterDIPU.hpp"
#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/profiler/profiler.h"
#include "CustomFallbackFunctions.hpp"
$header_include_code
Expand Down Expand Up @@ -221,7 +222,9 @@
$custom_code_before_call_diopi
dipu::profile::RecordBlockCreator dipuRecorder("$diopi_fun_call_code");
::diopiError_t ret = $diopi_fun_call_code
dipuRecorder.end();
if (checkDiopiReturnValue()) {
TORCH_CHECK(ret == ::diopiSuccess, __FILE__, ":", __LINE__, R"($diopi_fun_call_code)", " error, error code is ", ret, "error message is ", diopiGetLastErrorString());
}
Expand Down Expand Up @@ -290,4 +293,4 @@ class $autograd_function_name : public torch::autograd::Function<$autograd_funct
$result_compare_code
}
"""
"""
8 changes: 8 additions & 0 deletions torch_dipu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .dipu.device import _get_device_index
from .dipu.distributed import apply_dist_patch
from .dipu.tensor import apply_tensor_type_patch
from .profiler.profiler import dipu_profiler, dipu_kineto_available

def validate_dipu_device(location):
device = _get_device_index(location, True)
Expand Down Expand Up @@ -186,12 +187,19 @@ def _settitem_wrapper(self, indices: Union[None, _int, slice, Tensor, List, Tupl
torch.Tensor. __setitem__ = get_itemop_wrapper(torch.Tensor.__setitem__)


def apply_profiler_patch():
setattr(torch.profiler, 'kineto_available', dipu_kineto_available)
setattr(torch.autograd.profiler, 'kineto_available', dipu_kineto_available)
torch.profiler.profile = dipu_profiler


def apply_patches():
apply_tensor_method_patch()
apply_torch_function_patch()
apply_temp_patch()
apply_dist_patch()
apply_tensor_type_patch()
apply_profiler_patch()


apply_patches()
2 changes: 2 additions & 0 deletions torch_dipu/csrc_dipu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ file(GLOB OP_SRC_FILES aten/RegisterDIPU.cpp
)

file(GLOB DIOPI_RT_FILES diopirt/*.cpp)
file(GLOB PROFILER_FILES profiler/*.cpp)

# vendor src
add_subdirectory(vendor/${UsedVendor})
Expand All @@ -44,6 +45,7 @@ set(SOURCE_FILES
${OP_SRC_FILES}
${DIOPI_RT_FILES}
${VENDOR_FILES}
${PROFILER_FILES}
)

add_library(${DIPU_LIB} SHARED ${SOURCE_FILES})
Expand Down
18 changes: 18 additions & 0 deletions torch_dipu/csrc_dipu/aten/RegisterDIPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <c10/util/Exception.h>

#include <csrc_dipu/common.h>
#include <csrc_dipu/profiler/profiler.h>

static std::string force_fallback_operators_list = []()-> std::string {
std::ifstream stream(".dipu_force_fallback_op_list.config", std::ios_base::in | std::ios::binary);
Expand Down Expand Up @@ -82,58 +83,69 @@ namespace {
c10::optional<at::Layout> layout_opt,
c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt,
c10::optional<at::MemoryFormat> memory_format_opt) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const DeviceGuard device_guard(device_or_default(device_opt));
return dnative::empty(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}

at::Tensor wrapper_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const DeviceGuard device_guard(device_or_default(device_opt));
return dnative::empty_strided(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}

at::Tensor& wrapper_copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return dnative::copy_(self, src, non_blocking);
}

at::Tensor wrapper_DIPU___reshape_alias(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::_reshape_alias(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride));
}

// only used by cpu_fallback.
at::Tensor wrapper_DIPU___copy_from_and_resize(const at::Tensor & self, const at::Tensor& dst) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
dst.resize_as_(self).copy_(self);
return dst;
}

const at::Tensor& wrapper_resize_(const at::Tensor& self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) {
// add guard for device switch.
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return dnative::resize_(self, size, memory_format);
}

at::Tensor wrapper_DIPU__as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
// No device check
// DeviceGuard omitted
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::as_strided_tensorimpl(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride), storage_offset.has_value() ? c10::make_optional(storage_offset->expect_int()) : c10::nullopt);
}

at::Tensor wrapper_DIPU__view(const at::Tensor & self, c10::SymIntArrayRef size) {
// No device check
// DeviceGuard omitted
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size));
}

at::Tensor wrapper_DIPU__view_as_real(const at::Tensor & self) {
// DeviceGuard omitted
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::view_as_real(self);
}

at::Tensor wrapper_DIPU__view_as_complex(const at::Tensor & self) {
// DeviceGuard omitted
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::view_as_complex(self);
}

at::Tensor & wrapper_DIPU__zero_(at::Tensor & self) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const OptionalDeviceGuard device_guard(device_of(self));
return at::native::zero_(self);
}
Expand All @@ -143,15 +155,18 @@ namespace {
at::Tensor wrapper_DIPU__unfold(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) {
// No device check
// DeviceGuard omitted
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
return at::native::unfold(self, dimension, size, step);
}

at::Scalar wrapper_DIPU___local_scalar_dense(const at::Tensor & self) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const OptionalDeviceGuard device_guard(device_of(self));
return dnative::_local_scalar_dense_dipu(self);
}

bool wrapper_BackendSelect_is_pinned(const at::Tensor& self, c10::optional<at::Device> device) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
// Only CPU tensors can be pinned
if (!self.is_cpu()) {
return false;
Expand All @@ -162,17 +177,20 @@ namespace {
}

at::Tensor wrapper_BackendSelect__pin_memory(const at::Tensor& self, c10::optional<at::Device> device) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
TORCH_CHECK(self.device().is_cpu(), "cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
c10::DispatchKeySet dk = c10::DispatchKeySet(c10::computeDispatchKey(c10::nullopt, self.layout(), device.value_or(dipu::DIPU_DEVICE_TYPE)));
return at::_ops::_pin_memory::redispatch(dk, self, device);
}

bool wrapper_DIPU_is_pinned(const at::Tensor& self, c10::optional<at::Device> device) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const OptionalDeviceGuard device_guard(device_of(self));
return dnative::is_pinned(self, device);
}

at::Tensor wrapper_DIPU__pin_memory(const at::Tensor& self, c10::optional<at::Device> device) {
dipu::profile::RecordBlockCreator dipu_recorder(__FUNCTION__);
const OptionalDeviceGuard device_guard(device_of(self));
return dnative::_pin_memory(self, device);
}
Expand Down
28 changes: 28 additions & 0 deletions torch_dipu/csrc_dipu/binding/ExportProfiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

// Copyright (c) 2023, DeepLink.
#include <torch/csrc/utils/pybind.h>
#include <pybind11/chrono.h>

#include "exportapi.h"
#include <csrc_dipu/profiler/profiler.h>

namespace py = pybind11;

namespace dipu {

void exportProfiler(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

m.def("profile_start", &dipu::profile::startProfile);
m.def("profile_end", &dipu::profile::endProfile);
m.def("profiler_flush", &dipu::profile::FlushAllRecords);
py::class_<dipu::profile::Record>(m, "_DIPUProfilerRecord")
.def_readonly("name", &dipu::profile::Record::name)
.def_readonly("opid", &dipu::profile::Record::opId)
.def_readonly("begin", &dipu::profile::Record::begin)
.def_readonly("end", &dipu::profile::Record::end)
.def_readonly("thread_idx", &dipu::profile::Record::threadIdx);
m.def("get_record", &dipu::profile::getRecordList);
}

} // namespace dipu
3 changes: 2 additions & 1 deletion torch_dipu/csrc_dipu/binding/exportapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
namespace dipu {
DIPU_API PyMethodDef* exportTensorFunctions();
DIPU_API void exportDIPURuntime(PyObject* module);
}
DIPU_API void exportProfiler(PyObject* module);
} // namespace dipu
Loading

0 comments on commit 3dbbc1c

Please sign in to comment.