Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ None
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732)
- Introduce wrapper sublibrary (limited functionality) (#1071)

### Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
Expand Down
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
ENDIF()
ENDFOREACH()

add_subdirectory(include/ck)

add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library)

Expand Down
4 changes: 2 additions & 2 deletions client_example/25_tensor_transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_executable(client_tensor_transform tensor_transform.cpp)
target_include_directories(client_tensor_transform INTERFACE composable_kernel::wrapper)
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_include_directories(client_tensor_transform_using_wrapper INTERFACE composable_kernel::wrapper)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Current CK library are structured into 4 layers:
* "Templated Tile Operators" layer
* "Templated Kernel and Invoker" layer
* "Instantiated Kernel and Invoker" layer
* "Wrapper for tensor tranforms operations"
* "Wrapper for tensor transform operations"
* "Client API" layer

.. image:: data/ck_layer.png
Expand Down
8 changes: 5 additions & 3 deletions docs/wrapper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ Wrapper
-------------------------------------
Description
-------------------------------------
Note: Wrapper is currently under development. At the moment, its functionality
is limited.

.. note::

The wrapper is under development and its functionality is limited.


CK provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
Expand All @@ -33,7 +36,6 @@ Example:
Output::

dims:4,(2,4) strides:2,(1,8)
Print2d
0 1 8 9 16 17 24 25
2 3 10 11 18 19 26 27
4 5 12 13 20 21 28 29
Expand Down
1 change: 0 additions & 1 deletion include/ck/CMakeLists.txt

This file was deleted.

9 changes: 0 additions & 9 deletions include/ck/wrapper/CMakeLists.txt

This file was deleted.

13 changes: 6 additions & 7 deletions include/ck/wrapper/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct Layout
// Iterate over shape tuple elements:
// 1. If corresponding idx element is tuple then return (will be unrolled)
// 2. If no, pack in tuple. It will be restored during unroll.
auto unrolled_shape_via_idx = generate_tuple(
auto aligned_shape = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple,
tuple_element_t<i, Tuple<IdxDims...>>>::value)
Expand All @@ -124,7 +124,7 @@ struct Layout
Number<Tuple<IdxDims...>::Size()>{});

// Unroll and process next step
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
UnrollNestedTuple<0, 1>(idx));
}
}
Expand All @@ -144,7 +144,7 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}

// Merge nested shape dims
// Merge nested shape dims. Merge nested shape dims when idx is also nested.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
// Merge nested shape dims when corresponding index is also nested.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be fixed in next PR, thanks

// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
Expand Down Expand Up @@ -205,10 +205,9 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto unrolled_shape_via_idx = AlignShapeToIdx(shape, idx);
const auto aligned_shape = AlignShapeToIdx(shape, idx);
// Transform correct form of shape
return CreateMergedDescriptor(
unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
}
}

Expand All @@ -224,7 +223,7 @@ struct Layout
}

public:
// If stride not passed, deduce from GenerateColumnMajorPackedStrides
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Expand Down
2 changes: 1 addition & 1 deletion test/wrapper/test_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ TEST(TestLayoutHelpers, SizeAndGet)
EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0);
EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0);

// Acces via new layout (using get on layout)
// Access through new layout (using get with layout object)
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3);
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
Expand Down