diff --git a/src/lcode.cpp b/src/lcode.cpp index ab5c6b808b..d6b2a6b009 100644 --- a/src/lcode.cpp +++ b/src/lcode.cpp @@ -1971,3 +1971,10 @@ void luaK_exp2reg (FuncState *fs, expdesc *e, int reg) { freeexp(fs, e); exp2reg(fs, e, reg); } + + +void luaK_invertcond (FuncState *fs, int list) { + expdesc e; + e.u.info = list; + negatecondition(fs, &e); +} diff --git a/src/lcode.h b/src/lcode.h index fe9a079a0a..ac41e3506a 100644 --- a/src/lcode.h +++ b/src/lcode.h @@ -97,3 +97,4 @@ LUAI_FUNC void luaK_setlist (FuncState *fs, int base, int nelems, int tostore); LUAI_FUNC void luaK_finish (FuncState *fs); [[noreturn]] LUAI_FUNC void luaK_semerror (LexState *ls, const char *msg); LUAI_FUNC void luaK_exp2reg (FuncState *fs, expdesc *e, int reg); +LUAI_FUNC void luaK_invertcond (FuncState *fs, int list); diff --git a/src/lparser.cpp b/src/lparser.cpp index 144787d656..46078cd400 100644 --- a/src/lparser.cpp +++ b/src/lparser.cpp @@ -3261,9 +3261,27 @@ inline bool testnext2 (LexState *ls, int token1, int token2) { } -static void casecond (LexState *ls, int case_line, expdesc& lcase) { - expr(ls, &lcase, nullptr, E_NO_COLON); +static std::vector casecond (LexState *ls, const expdesc& control) { + std::vector jumps{}; + FuncState *fs = ls->fs; + const auto case_line = ls->getLineNumber(); + + expdesc e, cmpval; + e = control; + luaK_infix(fs, OPR_EQ, &e); + expr(ls, &cmpval, nullptr, E_NO_COLON); + luaK_posfix(fs, OPR_EQ, &e, &cmpval, case_line); + jumps.emplace_back(e.u.info); + while (testnext(ls, ',')) { + e = control; + luaK_infix(fs, OPR_EQ, &e); + expr(ls, &cmpval, nullptr, E_NO_COLON); + luaK_posfix(fs, OPR_EQ, &e, &cmpval, case_line); + jumps.emplace_back(e.u.info); + } checknext(ls, ':'); + + return jumps; } @@ -3277,11 +3295,11 @@ static void switchstat (LexState *ls, int line) { luaX_next(ls); // Skip switch statement. testnext(ls, '('); - FuncState* fs = ls->fs; + FuncState *fs = ls->fs; BlockCnt sbl; enterblock(fs, &sbl, 1); - expdesc crtl, save, first; + expdesc crtl, save; expr(ls, &crtl); luaK_exp2nextreg(ls->fs, &crtl); init_exp(&save, VLOCAL, crtl.u.info); @@ -3290,27 +3308,19 @@ static void switchstat (LexState *ls, int line) { new_localvarliteral(ls, "(switch control value)"); // Save control value into a local. adjustlocalvars(ls, 1); + std::vector first{}; TString* const begin_switch = luaS_newliteral(ls->L, "pluto_begin_switch"); TString* const end_switch = luaS_newliteral(ls->L, "pluto_end_switch"); TString* default_case = nullptr; - int default_pc; + int first_pc, default_pc; if (gett(ls) == TK_CASE) { - int case_line = ls->getLineNumber(); - luaX_next(ls); /* Skip 'case' */ - - first = save; - - luaK_infix(fs, OPR_NE, &first); - expdesc lcase; - casecond(ls, case_line, lcase); - luaK_posfix(fs, OPR_NE, &first, &lcase, ls->getLineNumber()); - + first = casecond(ls, save); + first_pc = luaK_getlabel(fs); caselist(ls); } else { - first.k = VVOID; newgotoentry(ls, begin_switch, ls->getLineNumber(), luaK_jump(fs)); // goto begin_switch } @@ -3325,7 +3335,7 @@ static void switchstat (LexState *ls, int line) { throwerr(ls, "switch statement already has a default case", "second default case", case_line); default_case = luaS_newliteral(ls->L, "pluto_default_case"); default_pc = luaK_getlabel(fs); - createlabel(ls, default_case, ls->getLineNumber(), block_follow(ls, 0)); + createlabel(ls, default_case, ls->getLineNumber(), false); caselist(ls); } else { @@ -3340,11 +3350,15 @@ static void switchstat (LexState *ls, int line) { /* handle possible fallthrough, don't loop infinitely */ newgotoentry(ls, end_switch, ls->getLineNumber(), luaK_jump(fs)); // goto end_switch - if (first.k != VVOID) { - luaK_patchtohere(fs, first.u.info); + if (!first.empty()) { + for (int i = 0; i != first.size() - 1; ++i) { + luaK_patchlist(fs, first.at(i), first_pc); + } + luaK_invertcond(fs, first.back()); + luaK_patchtohere(fs, first.back()); } else { - createlabel(ls, begin_switch, ls->getLineNumber(), block_follow(ls, 0)); // ::begin_switch:: + createlabel(ls, begin_switch, ls->getLineNumber(), false); // ::begin_switch:: } /* prune cases that lead to default case */ @@ -3359,23 +3373,19 @@ static void switchstat (LexState *ls, int line) { } } - expdesc test; for (auto& c : cases) { - test = save; - luaK_infix(fs, OPR_EQ, &test); auto pos = luaX_getpos(ls); luaX_setpos(ls, c.tidx); - expdesc cc; - casecond(ls, ls->getLineNumber(), cc); + for (const auto& j : casecond(ls, save)) { + luaK_patchlist(fs, j, c.pc); + } luaX_setpos(ls, pos); - luaK_posfix(fs, OPR_EQ, &test, &cc, ls->getLineNumber()); - luaK_patchlist(fs, test.u.info, c.pc); } if (default_case != nullptr) lgoto(ls, default_case); - createlabel(ls, end_switch, ls->getLineNumber(), block_follow(ls, 0)); // ::end_switch:: + createlabel(ls, end_switch, ls->getLineNumber(), true); // ::end_switch:: check_match(ls, TK_END, switchToken, line); leaveblock(fs); diff --git a/tests/basic.pluto b/tests/basic.pluto index 9a0d97df4a..337af2d275 100644 --- a/tests/basic.pluto +++ b/tests/basic.pluto @@ -600,6 +600,23 @@ do case y + 1: -- x == y + 1 end end +do + local function getResponse(word) + switch word do + case "hi", "hello": + return "Greetings!" + default: + return "Unrecognised word" + case "bye", "goodbye": + return "Farewell!" + end + end + assert(getResponse("hi") == "Greetings!") + assert(getResponse("hello") == "Greetings!") + assert(getResponse("bye") == "Farewell!") + assert(getResponse("goodbye") == "Farewell!") + assert(getResponse("deez") == "Unrecognised word") +end print "Testing table freezing." do