diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp index 7ca1f240f490..90f728912761 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp @@ -148,16 +148,13 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() { // This type converter converts tensor types to memref types when no exact // memref type can be inferred from the context. - options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, + options.unknownTypeConverterFn = [](TensorType tensorType, + Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = llvm::cast(value.getType()); - - // Special rule for ConstantOps: These always lower to some memref with a - // static identity layout. - if (value.getDefiningOp()) + if (tensorType.hasStaticShape()) { return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); - + } // Default case: Fully dynamic layout map for best compatibility. return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index cf66ac2b5134..95261048cd84 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -886,16 +886,13 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() { // This type converter converts tensor types to memref types when no exact // memref type can be inferred from the context. - options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, + options.unknownTypeConverterFn = [](TensorType tensorType, + Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = llvm::cast(value.getType()); - - // Special rule for ConstantOps: These always lower to some memref with a - // static identity layout. - if (value.getDefiningOp()) + if (tensorType.hasStaticShape()) { return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); - + } // Default case: Fully dynamic layout map for best compatibility. return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); diff --git a/third_party/llvm-project b/third_party/llvm-project index 6325c62aa92b..20a3487ee988 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 6325c62aa92b95cb9c3df1c04de71b610012b792 +Subproject commit 20a3487ee98805e675e032ecc026609c53c88830