_scaled_dot_product_flash_attention_for_cpu can't be rewrite in ReduceOpVariantsPass.
Because of string not pattern.
|
"torch.aten._scaled_dot_product_flash_attention_for_cpu", |
I found out the reason is:
|
return createMlirOperationAtEnd( |
Modify the code and it will work.
return createMlirOperationAtEnd(
appendToBlock, "torch.operator", loc, resultTypes, operands,
toMlirNamedAttribute(
"name", mlirStringAttrGet(context, toMlirStringRef(opName))));