Skip to content

Commit

Permalink
fix(sem): non-boolean when conditions being ignored (#1447)
Browse files Browse the repository at this point in the history
## Summary

Fix `when` branches where the condition has an incorrect type being
silently ignored without an error.

## Details

* a type mismatch error was already created, but not properly
  propagated; now it is
* fix error propagation for `when nimvm` statements
* simplify the control-flow in `semWhen` by handling `when nimvm`
  separately
* refactor `semWhen` to not modify the input AST
* pass the analysis flags along to `semWhen`, and integrate the
  `semCheck` parameter into the flags (now indicated by `efNoSemCheck`)
  • Loading branch information
zerbina authored Sep 4, 2024
1 parent 73b9136 commit 64aa3b5
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 52 deletions.
2 changes: 1 addition & 1 deletion compiler/sem/sem.nim
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ proc paramsTypeCheck(c: PContext, typ: PType) {.inline.} =
allowedFlags: {})))

proc semDirectOp(c: PContext, n: PNode, flags: TExprFlags): PNode
proc semWhen(c: PContext, n: PNode, semCheck: bool = true): PNode
proc semWhen(c: PContext, n: PNode, flags: TExprFlags): PNode
proc semTemplateExpr(c: PContext, n: PNode, s: PSym,
flags: TExprFlags = {}): PNode
proc semMacroExpr(c: PContext, n: PNode, sym: PSym,
Expand Down
105 changes: 55 additions & 50 deletions compiler/sem/semexprs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2894,68 +2894,81 @@ proc semMagic(c: PContext, n: PNode, s: PSym, flags: TExprFlags): PNode =
else:
result = semDirectOp(c, n, flags)

proc semWhen(c: PContext, n: PNode, semCheck = true): PNode =
# If semCheck is set to false, `when` will return the verbatim AST of
# the correct branch. Otherwise the AST will be passed through semStmt.
addInNimDebugUtils(c.config, "semWhen", n, result)

result = nil
proc semWhen(c: PContext, n: PNode, flags: TExprFlags): PNode =
## Types and checks an ``nkWhenStmt`` AST. If ``efNoSemCheck`` is part of
## `flags`, the verbatim AST of the correct branch is returned. Otherwise the
## AST will be passed through ``semExpr``.
addInNimDebugUtils(c.config, "semWhen", n, result, flags)

template setResult(e: untyped) =
if semCheck: result = semExpr(c, e) # do not open a new scope!
else: result = e
if efNoSemCheck in flags:
result = e
else:
result = semExpr(c, e, flags) # do not open a new scope!

# Check if the node is "when nimvm"
# when nimvm:
# ...
# else:
# ...
var whenNimvm = false
var typ = commonTypeBegin
if n.len == 2 and n[0].kind == nkElifBranch and
n[1].kind == nkElse:
let exprNode = n[0][0]
if exprNode.kind == nkIdent:
whenNimvm = lookUp(c, exprNode).magic == mNimvm
elif exprNode.kind == nkSym:
whenNimvm = exprNode.sym.magic == mNimvm
if whenNimvm: n.flags.incl nfLL

for i in 0..<n.len:
var it = n[i]
case it.kind
of nkElifBranch, nkElifExpr:
checkSonsLen(it, 2, c.config)
if whenNimvm:
if semCheck:
it[1] = semExpr(c, it[1])
typ = commonType(c, typ, it[1].typ)
result = n # when nimvm is not elimited until codegen
else:
if whenNimvm:
result = shallowCopy(n)
result.flags.incl nfLL # disable lambda-lifting

result[0] = copyNodeWithKids(n[0])
result[1] = copyNodeWithKids(n[1])
checkSonsLen(n[0], 2, c.config)
checkSonsLen(n[1], 1, c.config)

if efNoSemCheck notin flags:
# there are always only two branches
result[0][1] = semExpr(c, n[0][1], flags)
result[1][0] = semExpr(c, n[1][0], flags)

# assign the common type to the ``when`` expression/statement
# XXX: fitting the branches is missing
result.typ = commonType(c, result[0][1].typ, result[1][0].typ)
if nkError in {result[0][1].kind, result[1][0].kind}:
result = c.config.wrapError(result)

else:
# pick the first branch where the condition evaluates to true
for i in 0..<n.len:
let it = n[i]
case it.kind
of nkElifBranch, nkElifExpr:
checkSonsLen(it, 2, c.config)
let e = forceBool(c, semConstExpr(c, it[0]))
if e.kind != nkIntLit:
# can happen for cascading errors, assume false
# InternalError(n.info, "semWhen")
discard
elif e.intVal != 0 and result == nil:
if e.kind == nkError:
# error in the condition expression; wrap and return
result = copyNodeWithKids(n)
result[i] = copyNodeWithKids(it)
result[i][0] = e
result = c.config.wrapError(result)
break
elif e.intVal != 0:
setResult(it[1])
return # we're not in nimvm and we already have a result
of nkElse, nkElseExpr:
checkSonsLen(it, 1, c.config)
if result == nil or whenNimvm:
if semCheck:
it[0] = semExpr(c, it[0])
typ = commonType(c, typ, it[0].typ)
if result == nil:
result = it[0]
else:
semReportIllformedAst(c.config, n, {
nkElse, nkElseExpr, nkElifBranch, nkElifExpr})
break
of nkElse, nkElseExpr:
checkSonsLen(it, 1, c.config)
# no earlier branch was picked -> the else branch is the correct one
setResult(it[0])
else:
semReportIllformedAst(c.config, it, {
nkElse, nkElseExpr, nkElifBranch, nkElifExpr})

if result == nil:
if result.isNil:
# no branch was picked
result = newNodeI(nkEmpty, n.info)
if whenNimvm:
result.typ = typ

proc semSetConstr(c: PContext, n: PNode): PNode =
## Analyses and types a set construction expression (``nkCurly``). Produces
Expand Down Expand Up @@ -3737,15 +3750,7 @@ proc semExpr(c: PContext, n: PNode, flags: TExprFlags = {}): PNode =
hoistParamsUsedInDefault(c, result, hoistedParams, result[i])
result = newTreeIT(nkStmtListExpr, result.info, result.typ, hoistedParams, result)
of nkWhen:
if efWantStmt in flags:
result = semWhen(c, n, true)
else:
result = semWhen(c, n, false)
if result == n:
# This is a "when nimvm" stmt.
result = semWhen(c, n, true)
else:
result = semExpr(c, result, flags)
result = semWhen(c, n, flags)
of nkBracketExpr:
checkMinSonsLen(n, 1, c.config)
result = semArrayAccess(c, n, flags)
Expand Down
2 changes: 1 addition & 1 deletion compiler/sem/semtypes.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,7 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
else:
result = semTypeExpr(c, n, prev)
of nkWhenStmt:
var whenResult = semWhen(c, n, false)
var whenResult = semWhen(c, n, {efNoSemCheck})
if whenResult.kind == nkStmtList:
whenResult.transitionSonsKind(nkStmtListExpr)
result = semTypeNode(c, whenResult, prev)
Expand Down
11 changes: 11 additions & 0 deletions tests/lang_stmts/whenstmt/twhen_condition_must_be_bool.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
discard """
description: '''
Ensure that an error is reported for a non-boolean expression used as the
condition of a `when`
'''
errormsg: "type mismatch: got <int literal(2)> but expected 'bool'"
line: 10
"""

when 1 + 1:
discard

0 comments on commit 64aa3b5

Please sign in to comment.