Skip to content

Commit

Permalink
feat: Simp.Config.implicitDefEqProofs (leanprover#4595)
Browse files Browse the repository at this point in the history
This PR implements `Simp.Config.implicitDefEqsProofs`. When `true`
(default: `true`), `simp` will **not** create a proof term for a
rewriting rule associated with an `rfl`-theorem. Rewriting rules are
provided by users by annotating theorems with the attribute `@[simp]`.
If the proof of the theorem is just `rfl` (reflexivity), and
`implicitDefEqProofs := true`, `simp` will **not** create a proof term
which is an application of the annotated theorem.

The default setting does change the existing behavior. Users can use
`simp -implicitDefEqProofs` to force `simp` to create a proof term for
`rfl`-theorems. This can positively impact proof checking time in the
kernel.

This PR also fixes an issue in the `split` tactic that has been exposed
by this feature. It was looking for `split` candidates in proofs and
implicit arguments. See new test for issue exposed by the previous
feature.

---------

Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
leodemoura and kim-em authored Nov 29, 2024
1 parent 3752241 commit 27df5e9
Show file tree
Hide file tree
Showing 15 changed files with 571 additions and 54 deletions.
3 changes: 2 additions & 1 deletion src/Init/MetaTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ structure Config where
-/
index : Bool := true
/--
This option does not have any effect (yet).
If `implicitDefEqProofs := true`, `simp` does not create proof terms when the
input and output terms are definitionally equal.
-/
implicitDefEqProofs : Bool := true
deriving Inhabited, BEq
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ structure ParamInfo where
hasFwdDeps : Bool := false
/-- `backDeps` contains the backwards dependencies. That is, the (0-indexed) position of previous parameters that this one depends on. -/
backDeps : Array Nat := #[]
/-- `isProp` is true if the parameter is always a proposition. -/
/-- `isProp` is true if the parameter type is always a proposition. -/
isProp : Bool := false
/--
`isDecInst` is true if the parameter's type is of the form `Decidable ...`.
Expand Down
8 changes: 7 additions & 1 deletion src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,19 @@ where
trace[Meta.Tactic.simp.discharge] "{← ppOrigin thmId}, failed to synthesize instance{indentExpr type}"
return false

private def useImplicitDefEqProof (thm : SimpTheorem) : SimpM Bool := do
if thm.rfl then
return (← getConfig).implicitDefEqProofs
else
return false

private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) := do
recordTriedSimpTheorem thm.origin
let rec go (e : Expr) : SimpM (Option Result) := do
if (← isDefEq lhs e) then
unless (← synthesizeArgs thm.origin bis xs) do
return none
let proof? ← if thm.rfl then
let proof? ← if (← useImplicitDefEqProof thm) then
pure none
else
let proof ← instantiateMVars (mkAppN val xs)
Expand Down
37 changes: 4 additions & 33 deletions src/Lean/Meta/Tactic/Split.lean
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def mkDiscrGenErrorMsg (e : Expr) : MessageData :=
def throwDiscrGenError (e : Expr) : MetaM α :=
throwError (mkDiscrGenErrorMsg e)

def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let some app ← matchMatcherApp? e | throwError "internal error in `split` tactic: match application expected{indentExpr e}\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
let matchEqns ← Match.getEquationsFor app.matcherName
let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
Expand All @@ -278,43 +278,14 @@ def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
return (i+1, mvarId::mvarIds)
return mvarIds.reverse

/-- Return an `if-then-else` or `match-expr` to split. -/
partial def findSplit? (env : Environment) (e : Expr) (splitIte := true) (exceptionSet : ExprSet := {}) : Option Expr :=
go e
where
go (e : Expr) : Option Expr :=
if let some target := e.find? isCandidate then
if e.isIte || e.isDIte then
let cond := target.getArg! 1 5
-- Try to find a nested `if` in `cond`
go cond |>.getD target
else
some target
else
none

isCandidate (e : Expr) : Bool := Id.run do
if exceptionSet.contains e then
false
else if splitIte && (e.isIte || e.isDIte) then
!(e.getArg! 1 5).hasLooseBVars
else if let some info := isMatcherAppCore? env e then
let args := e.getAppArgs
for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do
if args[i]!.hasLooseBVars then
return false
return true
else
false

end Split

open Split

partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
if let some e := findSplit? (← getEnv) target splitIte badCases then
if let some e findSplit? target (if splitIte then .both else .match) badCases then
if e.isIte || e.isDIte then
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
else
Expand All @@ -333,7 +304,7 @@ partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (L

def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
mvarId.withContext do
if let some e := findSplit? (← getEnv) (← instantiateMVars (← inferType (mkFVar fvarId))) then
if let some e findSplit? (← instantiateMVars (← inferType (mkFVar fvarId))) then
if e.isIte || e.isDIte then
return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂]
else
Expand Down
134 changes: 122 additions & 12 deletions src/Lean/Meta/Tactic/SplitIf.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,124 @@ import Lean.Meta.Tactic.Cases
import Lean.Meta.Tactic.Simp.Main

namespace Lean.Meta

inductive SplitKind where
| ite | match | both

def SplitKind.considerIte : SplitKind → Bool
| .ite | .both => true
| _ => false

def SplitKind.considerMatch : SplitKind → Bool
| .match | .both => true
| _ => false

namespace FindSplitImpl

structure Context where
exceptionSet : ExprSet := {}
kind : SplitKind := .both

unsafe abbrev FindM := ReaderT Context $ StateT (PtrSet Expr) MetaM

/--
Checks whether `e` is a candidate for `split`.
Returns `some e'` if a prefix is a candidate.
Example: suppose `e` is `(if b then f else g) x`, then
the result is `some e'` where `e'` is the subterm `(if b then f else g)`
-/
private def isCandidate? (env : Environment) (ctx : Context) (e : Expr) : Option Expr := Id.run do
let ret (e : Expr) : Option Expr :=
if ctx.exceptionSet.contains e then none else some e
if ctx.kind.considerIte then
if e.isAppOf ``ite || e.isAppOf ``dite then
let numArgs := e.getAppNumArgs
if numArgs >= 5 && !(e.getArg! 1 5).hasLooseBVars then
return ret (e.getBoundedAppFn (numArgs - 5))
if ctx.kind.considerMatch then
if let some info := isMatcherAppCore? env e then
let args := e.getAppArgs
for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do
if args[i]!.hasLooseBVars then
return none
return ret (e.getBoundedAppFn (args.size - info.arity))
return none

@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
if (← get).contains e then
failure
modify fun s => s.insert e

unsafe def visit (e : Expr) : OptionT FindM Expr := do
checkVisited e
if let some e := isCandidate? (← getEnv) (← read) e then
return e
else
-- We do not look for split candidates in proofs.
unless e.hasLooseBVars do
if (← isProof e) then
failure
match e with
| .lam _ _ b _ | .proj _ _ b -- We do not look for split candidates in the binder of lambdas.
| .mdata _ b => visit b
| .forallE _ d b _ => visit d <|> visit b -- We want to look for candidates at `A → B`
| .letE _ _ v b _ => visit v <|> visit b
| .app .. => visitApp? e
| _ => failure
where
visitApp? (e : Expr) : FindM (Option Expr) :=
e.withApp fun f args => do
-- See comment at `Canonicalizer.lean` regarding the case where
-- `f` has loose bound variables.
let info ← if f.hasLooseBVars then
pure {}
else
getFunInfo f
for u : i in [0:args.size] do
let arg := args[i]
if h : i < info.paramInfo.size then
let info := info.paramInfo[i]
unless info.isProp do
if info.isExplicit then
let some found ← visit arg | pure ()
return found
else
let some found ← visit arg | pure ()
return found
visit f

end FindSplitImpl

/-- Return an `if-then-else` or `match-expr` to split. -/
partial def findSplit? (e : Expr) (kind : SplitKind := .both) (exceptionSet : ExprSet := {}) : MetaM (Option Expr) := do
go (← instantiateMVars e)
where
go (e : Expr) : MetaM (Option Expr) := do
if let some target ← find? e then
if target.isIte || target.isDIte then
let cond := target.getArg! 1 5
-- Try to find a nested `if` in `cond`
return (← go cond).getD target
else
return some target
else
return none

find? (e : Expr) : MetaM (Option Expr) := do
let some candidate ← unsafe FindSplitImpl.visit e { kind, exceptionSet } |>.run' mkPtrSet
| return none
trace[split.debug] "candidate:{indentExpr candidate}"
return some candidate

/-- Return the condition and decidable instance of an `if` expression to case split. -/
private partial def findIfToSplit? (e : Expr) : MetaM (Option (Expr × Expr)) := do
if let some iteApp ← findSplit? e .ite then
let cond := iteApp.getArg! 1 5
let dec := iteApp.getArg! 2 5
return (cond, dec)
else
return none

namespace SplitIf

/--
Expand Down Expand Up @@ -62,19 +180,9 @@ private def discharge? (numIndices : Nat) (useDecide : Bool) : Simp.Discharge :=
def mkDischarge? (useDecide := false) : MetaM Simp.Discharge :=
return discharge? (← getLCtx).numIndices useDecide

/-- Return the condition and decidable instance of an `if` expression to case split. -/
private partial def findIfToSplit? (e : Expr) : Option (Expr × Expr) :=
if let some iteApp := e.find? fun e => (e.isIte || e.isDIte) && !(e.getArg! 1 5).hasLooseBVars then
let cond := iteApp.getArg! 1 5
let dec := iteApp.getArg! 2 5
-- Try to find a nested `if` in `cond`
findIfToSplit? cond |>.getD (cond, dec)
else
none

def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do
def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := mvarId.withContext do
let e ← instantiateMVars e
if let some (cond, decInst) := findIfToSplit? e then
if let some (cond, decInst) findIfToSplit? e then
let hName ← match hName? with
| none => mkFreshUserName `h
| some hName => pure hName
Expand Down Expand Up @@ -106,6 +214,7 @@ def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Opt
let mvarId₁ ← simpIfTarget s₁.mvarId
let mvarId₂ ← simpIfTarget s₂.mvarId
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}"
return none
else
return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ })
Expand All @@ -118,6 +227,7 @@ def splitIfLocalDecl? (mvarId : MVarId) (fvarId : FVarId) (hName? : Option Name
let mvarId₁ ← simpIfLocalDecl s₁.mvarId fvarId
let mvarId₂ ← simpIfLocalDecl s₂.mvarId fvarId
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}"
return none
else
return some (mvarId₁, mvarId₂)
Expand Down
1 change: 0 additions & 1 deletion tests/compiler/uset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ structure Point where

def main : IO Unit :=
IO.println (Point.right ⟨0, 0⟩).x

2 changes: 1 addition & 1 deletion tests/lean/1113.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def foo: {n: Nat} → Fin n → Nat

theorem t3 {f: Fin (n+1)}:
foo f = 0 := by
simp only [←Nat.succ_eq_add_one n] at f
dsimp only [←Nat.succ_eq_add_one n] at f -- use `dsimp` to ensure we don't copy `f`
trace_state
simp only [←Nat.succ_eq_add_one n, foo]

Expand Down
2 changes: 1 addition & 1 deletion tests/lean/rfl_simp_thm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ def inc (x : Nat) := x + 1
@[simp] theorem inc_eq : inc x = x + 1 := rfl

theorem ex (a b : Fin (inc n)) (h : a = b) : b = a := by
simp only [inc_eq] at a
simp +implicitDefEqProofs only [inc_eq] at a
trace_state
exact h.symm
Loading

0 comments on commit 27df5e9

Please sign in to comment.