Skip to content

Commit 6530815

Browse files
authored
[BACKEND] Add aipu(arm npu) backend (#9)
Merge pull request #9 from FlagTree/armnpu
2 parents ec33e82 + 160dd86 commit 6530815

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3302
-67
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: AIPU-Build-And-Test
2+
3+
on:
4+
push:
5+
branches: [ "triton_v3.3.x" ]
6+
pull_request:
7+
branches: [ "triton_v3.3.x" ]
8+
9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
11+
cancel-in-progress: true
12+
13+
jobs:
14+
aipu-build-and-test:
15+
runs-on: aipu
16+
steps:
17+
- name: Checkout code (attempt 1)
18+
id: checkout1
19+
uses: actions/checkout@v4
20+
continue-on-error: true
21+
22+
- name: Sleep before checkout2
23+
if: steps.checkout1.outcome == 'failure'
24+
run: |
25+
echo "First checkout attempt failed. Sleeping for 120 seconds before retry..."
26+
sleep 120
27+
28+
- name: Checkout code (attempt 2)
29+
id: checkout2
30+
if: steps.checkout1.outcome == 'failure'
31+
uses: actions/checkout@v4
32+
continue-on-error: true
33+
34+
- name: Sleep before final checkout
35+
if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure'
36+
run: |
37+
echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..."
38+
sleep 180
39+
40+
- name: Checkout code (final attempt)
41+
if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure'
42+
uses: actions/checkout@v4
43+
44+
- name: Verify checkout success
45+
if: success()
46+
run: echo "Checkout completed successfully"
47+
48+
- name: FlagTree Build on AIPU
49+
shell: bash
50+
run: |
51+
source ~/env.sh
52+
source ~/env_setup.sh
53+
export FLAGTREE_BACKEND=aipu
54+
cd python
55+
python3.10 -m pip install . --no-build-isolation -v
56+
57+
- name: FlagTree Test on AIPU
58+
shell: bash
59+
run: |
60+
source ~/env_setup.sh
61+
python3.10 third_party/aipu/python/test/test_01_vector_add.py
62+
python3.10 third_party/aipu/python/test/test_02_fused_softmax.py

.github/workflows/code-format-check-master.yml

Lines changed: 0 additions & 21 deletions
This file was deleted.

.github/workflows/code-format-check.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ on:
44
schedule:
55
- cron: '0 21 * * *'
66
push:
7-
branches: [ "main" ]
7+
branches: [ "main", "triton_v3.3.x" ]
88
pull_request:
9-
branches: [ "main" ]
9+
branches: [ "main", "triton_v3.3.x" ]
1010

1111
concurrency:
1212
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

.github/workflows/iluvatar-build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
export FLAGTREE_BACKEND=iluvatar
5252
source ~/env.sh
5353
cd python
54-
MAX_JOBS=20 pip3 install . --no-build-isolation
54+
MAX_JOBS=32 pip3 install . --no-build-isolation
5555
5656
- name: FlagTree Test on Iluvatar
5757
shell: bash

.github/workflows/metax-build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
source ~/env.sh
2121
export FLAGTREE_BACKEND=metax
2222
cd python
23-
MAX_JOBS=20 pip3 install . --no-build-isolation
23+
MAX_JOBS=32 pip3 install . --no-build-isolation
2424
2525
- name: FlagTree Test on Metax
2626
shell: bash

.github/workflows/mthreads-build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
source ~/env.sh
2121
export FLAGTREE_BACKEND=mthreads
2222
cd python
23-
MAX_JOBS=20 pip3 install . --no-build-isolation
23+
MAX_JOBS=32 pip3 install . --no-build-isolation
2424
2525
- name: FlagTree Test on Mthreads
2626
shell: bash

.github/workflows/nv-build-and-test.yml

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ on:
44
schedule:
55
- cron: '0 21 * * *'
66
push:
7-
branches: [ "main" ]
7+
branches: [ "main", "triton_v3.3.x" ]
88
pull_request:
9-
branches: [ "main" ]
9+
branches: [ "main", "triton_v3.3.x" ]
1010

1111
concurrency:
1212
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -19,14 +19,34 @@ jobs:
1919
- name: Checkout code
2020
uses: actions/checkout@v4
2121

22-
- name: FlagTree Build on NVIDIA-A100
22+
- name: Detect Target Branch
23+
shell: bash
24+
run: |
25+
if [ "${{ github.event_name }}" = "pull_request" ]; then
26+
TARGET_BRANCH="${{ github.base_ref }}"
27+
else
28+
TARGET_BRANCH="${{ github.ref_name }}"
29+
fi
30+
echo "TARGET_BRANCH=$TARGET_BRANCH" >> $GITHUB_ENV
31+
echo "TARGET_BRANCH=$TARGET_BRANCH"
32+
33+
- name: FlagTree Build (Main branch)
34+
if: ${{ env.TARGET_BRANCH == 'main' }}
2335
shell: bash
2436
run: |
2537
source ~/env.sh
2638
cd python
27-
MAX_JOBS=20 pip3.11 install . --no-build-isolation
39+
MAX_JOBS=32 pip3.11 install . --no-build-isolation
40+
41+
- name: FlagTree Build (triton_v3.3.x branch)
42+
if: ${{ env.TARGET_BRANCH == 'triton_v3.3.x' }}
43+
shell: bash
44+
run: |
45+
source ~/env-3.3.sh
46+
cd python
47+
MAX_JOBS=32 pip3.11 install . --no-build-isolation
2848
29-
- name: FlagTree Test on NVIDIA-A100
49+
- name: FlagTree Test
3050
shell: bash
3151
run: |
3252
pytest -s python/test/unit

CMakeLists.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
2626
set(CMAKE_C_COMPILER clang)
2727
set(CMAKE_CXX_COMPILER clang++)
2828
set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
29+
elseif(FLAGTREE_BACKEND STREQUAL "aipu")
30+
add_definitions(-D__NVIDIA__)
31+
add_definitions(-D__AMD__)
2932
endif()
3033
set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}")
3134
if(FLAGTREE_PLUGIN)
@@ -201,7 +204,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party)
201204
include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files
202205

203206
# link_directories(${LLVM_LIBRARY_DIR})
204-
if (FLAGTREE_BACKEND STREQUAL "cambricon")
207+
if (FLAGTREE_BACKEND MATCHES "^(cambricon|aipu)$")
205208
include_directories(${PROJECT_SOURCE_DIR}/include)
206209
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
207210
add_subdirectory(include)
@@ -263,10 +266,10 @@ if(TRITON_BUILD_PYTHON_MODULE)
263266
if (TRITON_BUILD_PROTON)
264267
add_definitions(-D__PROTON__)
265268
add_subdirectory(third_party/proton)
266-
# We always build proton dialect
267-
list(APPEND TRITON_PLUGIN_NAMES "proton")
268-
add_subdirectory(third_party/proton/dialect)
269269
endif()
270+
# We always build proton dialect
271+
list(APPEND TRITON_PLUGIN_NAMES "proton")
272+
add_subdirectory(third_party/proton/dialect)
270273

271274
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
272275
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -443,7 +446,7 @@ find_package(Threads REQUIRED)
443446

444447
add_subdirectory(third_party/f2reduce)
445448

446-
if(NOT FLAGTREE_BACKEND)
449+
if(NOT FLAGTREE_BACKEND OR FLAGTREE_BACKEND STREQUAL "aipu")
447450
add_subdirectory(bin)
448451
add_subdirectory(test)
449452
endif()

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_ATTR_DEFS
33

44
include "mlir/IR/EnumAttr.td"
5+
include "mlir/IR/AttrTypeBase.td"
56

67
// Attributes for LoadOp and StoreOp
78
def TT_CacheModifierAttr : I32EnumAttr<

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
1414
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
1515
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
1616
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
17+
include "mlir/IR/BuiltinAttributes.td"
1718

1819

1920
//
@@ -248,13 +249,33 @@ def TT_LoadOp : TT_Op<"load", [
248249
OptionalAttr<TT_PaddingOptionAttr>:$padding,
249250
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
250251
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
251-
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
252+
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
253+
// TODO: now flagtree_hints is string, default value of an empty string (""), needed redesign
254+
DefaultValuedAttr<StrAttr, "\"\"">:$flagtree_hints
252255
);
253256

254257
let results = (outs TT_Type:$result);
255258

256259
let builders = [
257260
// A tensor of pointers or a pointer to a scalar
261+
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
262+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
263+
// A tensor pointer with boundary check and padding
264+
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
265+
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
266+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
267+
// A tensor of pointers or a pointer to a scalar with mask
268+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
269+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
270+
// A tensor of pointers or a pointer to a scalar with mask and other
271+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
272+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
273+
// A utility function to build the operation with all attributes
274+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
275+
"ArrayRef<int32_t>":$boundaryCheck,
276+
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
277+
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "mlir::StringAttr":$flagtree_hints)>,
278+
// A tensor of pointers or a pointer to a scalar
258279
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
259280
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
260281
// A tensor pointer with boundary check and padding

0 commit comments

Comments
 (0)