Skip to content

Commit cd55a0b

Browse files
committed
fix macro
1 parent 9621f28 commit cd55a0b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Diff for: csrc/gpu/fused_rotary_position_encoding.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void FusedRotaryPositionEncoding(
103103

104104
dim3 grid(num_tokens);
105105
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
106-
PD_DISPATCH_FLOATING_TYPES(
106+
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
107107
query.dtype(), "apply_rotary_embedding_kernel", [&] {
108108
if (is_neox) {
109109
apply_rotary_embedding_kernel<data_t, true>

0 commit comments

Comments
 (0)