Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@
tokenizer_format_call,
)
from mypyc.primitives.bytearray_ops import isinstance_bytearray
from mypyc.primitives.bytes_ops import isinstance_bytes
from mypyc.primitives.bytes_ops import (
bytes_adjust_index_op,
bytes_get_item_unsafe_op,
bytes_range_check_op,
isinstance_bytes,
)
from mypyc.primitives.dict_ops import (
dict_items_op,
dict_keys_op,
Expand Down Expand Up @@ -1207,30 +1212,50 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr
return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line)


@specialize_dunder("__getitem__", bytes_writer_rprimitive)
def translate_bytes_writer_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
def translate_getitem_with_bounds_check(
builder: IRBuilder,
base_expr: Expression,
args: list[Expression],
ctx_expr: Expression,
adjust_index_op: PrimitiveDescription,
range_check_op: PrimitiveDescription,
get_item_unsafe_op: PrimitiveDescription,
) -> Value | None:
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
"""Shared helper for optimized __getitem__ with bounds checking.

This implements the common pattern of:
1. Adjusting negative indices
2. Checking if index is in valid range
3. Raising IndexError if out of range
4. Getting the item if in range

Args:
builder: The IR builder
base_expr: The base object expression
args: The arguments to __getitem__ (should be length 1)
ctx_expr: The context expression for line numbers
adjust_index_op: Primitive op to adjust negative indices
range_check_op: Primitive op to check if index is in valid range
get_item_unsafe_op: Primitive op to get item (no bounds checking)

Returns:
The result value, or None if optimization doesn't apply
"""
# Check that we have exactly one argument
if len(args) != 1:
return None

# Get the BytesWriter object
# Get the object
obj = builder.accept(base_expr)

# Get the index argument
index = builder.accept(args[0])

# Adjust the index (handle negative indices)
adjusted_index = builder.primitive_op(
bytes_writer_adjust_index_op, [obj, index], ctx_expr.line
)
adjusted_index = builder.primitive_op(adjust_index_op, [obj, index], ctx_expr.line)

# Check if the adjusted index is in valid range
range_check = builder.primitive_op(
bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line
)
range_check = builder.primitive_op(range_check_op, [obj, adjusted_index], ctx_expr.line)

# Create blocks for branching
valid_block = BasicBlock()
Expand All @@ -1247,13 +1272,27 @@ def translate_bytes_writer_get_item(

# Handle valid index - get the item
builder.activate_block(valid_block)
result = builder.primitive_op(
bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line
)
result = builder.primitive_op(get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line)

return result


@specialize_dunder("__getitem__", bytes_writer_rprimitive)
def translate_bytes_writer_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
) -> Value | None:
"""Optimized BytesWriter.__getitem__ implementation with bounds checking."""
return translate_getitem_with_bounds_check(
builder,
base_expr,
args,
ctx_expr,
bytes_writer_adjust_index_op,
bytes_writer_range_check_op,
bytes_writer_get_item_unsafe_op,
)


@specialize_dunder("__setitem__", bytes_writer_rprimitive)
def translate_bytes_writer_set_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
Expand Down Expand Up @@ -1300,3 +1339,19 @@ def translate_bytes_writer_set_item(
)

return builder.none()


@specialize_dunder("__getitem__", bytes_rprimitive)
def translate_bytes_get_item(
builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression
) -> Value | None:
"""Optimized bytes.__getitem__ implementation with bounds checking."""
return translate_getitem_with_bounds_check(
builder,
base_expr,
args,
ctx_expr,
bytes_adjust_index_op,
bytes_range_check_op,
bytes_get_item_unsafe_op,
)
21 changes: 21 additions & 0 deletions mypyc/lib-rt/bytes_extra_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,30 @@
#define MYPYC_BYTES_EXTRA_OPS_H

#include <Python.h>
#include <stdint.h>
#include "CPy.h"

// Optimized bytes translate operation
PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table);

// Optimized bytes.__getitem__ operations

// If index is negative, convert to non-negative index (no range checking)
static inline int64_t CPyBytes_AdjustIndex(PyObject *obj, int64_t index) {
if (index < 0) {
return index + Py_SIZE(obj);
}
return index;
}

// Check if index is in valid range [0, len)
static inline bool CPyBytes_RangeCheck(PyObject *obj, int64_t index) {
return index >= 0 && index < Py_SIZE(obj);
}

// Get byte at index (no bounds checking) - returns as CPyTagged
static inline CPyTagged CPyBytes_GetItemUnsafe(PyObject *obj, int64_t index) {
return ((CPyTagged)(uint8_t)(PyBytes_AS_STRING(obj))[index]) << 1;
}

#endif
37 changes: 37 additions & 0 deletions mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
c_int_rprimitive,
c_pyssize_t_rprimitive,
dict_rprimitive,
int64_rprimitive,
int_rprimitive,
list_rprimitive,
object_rprimitive,
Expand All @@ -21,6 +22,7 @@
ERR_NEG_INT,
binary_op,
custom_op,
custom_primitive_op,
function_op,
load_address_op,
method_op,
Expand Down Expand Up @@ -148,3 +150,38 @@
c_function_name="CPyBytes_Ord",
error_kind=ERR_MAGIC,
)

# Optimized bytes.__getitem__ operations

# bytes index adjustment - convert negative index to positive
bytes_adjust_index_op = custom_primitive_op(
name="bytes_adjust_index",
arg_types=[bytes_rprimitive, int64_rprimitive],
return_type=int64_rprimitive,
c_function_name="CPyBytes_AdjustIndex",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[BYTES_EXTRA_OPS],
)

# bytes range check - check if index is in valid range
bytes_range_check_op = custom_primitive_op(
name="bytes_range_check",
arg_types=[bytes_rprimitive, int64_rprimitive],
return_type=bool_rprimitive,
c_function_name="CPyBytes_RangeCheck",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[BYTES_EXTRA_OPS],
)

# bytes.__getitem__() - get byte at index (no bounds checking)
bytes_get_item_unsafe_op = custom_primitive_op(
name="bytes_get_item_unsafe",
arg_types=[bytes_rprimitive, int64_rprimitive],
return_type=int_rprimitive,
c_function_name="CPyBytes_GetItemUnsafe",
error_kind=ERR_NEVER,
experimental=True,
dependencies=[BYTES_EXTRA_OPS],
)
19 changes: 15 additions & 4 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,26 @@ L0:
return r0

[case testBytesIndex]
def f(a: bytes, i: int) -> int:
from mypy_extensions import i64

def f(a: bytes, i: i64) -> int:
return a[i]
[out]
def f(a, i):
a :: bytes
i, r0 :: int
i, r0 :: i64
r1, r2 :: bool
r3 :: int
L0:
r0 = CPyBytes_GetItem(a, i)
return r0
r0 = CPyBytes_AdjustIndex(a, i)
r1 = CPyBytes_RangeCheck(a, r0)
if r1 goto L2 else goto L1 :: bool
L1:
r2 = raise IndexError('index out of range')
unreachable
L2:
r3 = CPyBytes_GetItemUnsafe(a, r0)
return r3

[case testBytesConcat]
def f(a: bytes, b: bytes) -> bytes:
Expand Down