diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index bc61862..3019926 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -17,6 +17,15 @@ mutable struct TracedRArray{T,N} <: RArray{T,N} end end +function Base.setproperty!(x::TracedRArray, f::Symbol, v) + if f === :shape && !isnothing(v) + @assert v isa MLIR.IR.Value + mlir_size = size(MLIR.IR.type(v)) + @assert mlir_size == x.shape "Expected shape $(x.shape), got $(mlir_size)" + end + return setfield!(x, f, v) +end + const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRScalar{T} = AnyTracedRArray{T,0}