diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index c361e347949..2c5bc242a43 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -359,7 +359,14 @@ def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> flat_stride = cute.flatten_to_tuple(x.stride) assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) - return x.iterator + offset + # HACK: we assume that applying the offset does not change the pointer alignment + byte_offset = offset * x.element_type.width // 8 + return cute.make_ptr( + x.element_type, + x.iterator.toint() + byte_offset, + x.memspace, + assumed_align=x.iterator.alignment, + ) @cute.jit