Skip to content

Commit 5c5b4ea

Browse files
committed
safer interface for ExprLambda's formals
1 parent 5db63f3 commit 5c5b4ea

File tree

12 files changed

+101
-82
lines changed

12 files changed

+101
-82
lines changed

src/libexpr-tests/primops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ TEST_F(PrimOpTest, derivation)
771771
ASSERT_EQ(v.type(), nFunction);
772772
ASSERT_TRUE(v.isLambda());
773773
ASSERT_NE(v.lambda().fun, nullptr);
774-
ASSERT_TRUE(v.lambda().fun->hasFormals);
774+
ASSERT_TRUE(v.lambda().fun->getFormals());
775775
}
776776

777777
TEST_F(PrimOpTest, currentTime)

src/libexpr-tests/value/print.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ TEST_F(ValuePrintingTests, vLambda)
110110
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
111111
auto posIdx = state.positions.add(origin, 0);
112112
auto body = ExprInt(0);
113-
auto formals = Formals{};
114113

115-
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
114+
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
116115

117116
Value vLambda;
118117
vLambda.mkLambda(&env, &eLambda);
@@ -500,9 +499,8 @@ TEST_F(ValuePrintingTests, ansiColorsLambda)
500499
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
501500
auto posIdx = state.positions.add(origin, 0);
502501
auto body = ExprInt(0);
503-
auto formals = Formals{};
504502

505-
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
503+
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
506504

507505
Value vLambda;
508506
vLambda.mkLambda(&env, &eLambda);

src/libexpr/eval.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,15 +1496,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
14961496

14971497
ExprLambda & lambda(*vCur.lambda().fun);
14981498

1499-
auto size = (!lambda.arg ? 0 : 1) + (lambda.hasFormals ? lambda.getFormals().size() : 0);
1499+
auto size = (!lambda.arg ? 0 : 1) + (lambda.getFormals() ? lambda.getFormals()->formals.size() : 0);
15001500
Env & env2(mem.allocEnv(size));
15011501
env2.up = vCur.lambda().env;
15021502

15031503
Displacement displ = 0;
15041504

1505-
if (!lambda.hasFormals)
1506-
env2.values[displ++] = args[0];
1507-
else {
1505+
if (auto formals = lambda.getFormals()) {
15081506
try {
15091507
forceAttrs(*args[0], lambda.pos, "while evaluating the value passed for the lambda argument");
15101508
} catch (Error & e) {
@@ -1520,7 +1518,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15201518
there is no matching actual argument but the formal
15211519
argument has a default, use the default. */
15221520
size_t attrsUsed = 0;
1523-
for (auto & i : lambda.getFormals()) {
1521+
for (auto & i : formals->formals) {
15241522
auto j = args[0]->attrs()->get(i.name);
15251523
if (!j) {
15261524
if (!i.def) {
@@ -1542,13 +1540,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15421540

15431541
/* Check that each actual argument is listed as a formal
15441542
argument (unless the attribute match specifies a `...'). */
1545-
if (!lambda.ellipsis && attrsUsed != args[0]->attrs()->size()) {
1543+
if (!formals->ellipsis && attrsUsed != args[0]->attrs()->size()) {
15461544
/* Nope, so show the first unexpected argument to the
15471545
user. */
15481546
for (auto & i : *args[0]->attrs())
1549-
if (!lambda.hasFormal(i.name)) {
1547+
if (!formals->has(i.name)) {
15501548
StringSet formalNames;
1551-
for (auto & formal : lambda.getFormals())
1549+
for (auto & formal : formals->formals)
15521550
formalNames.insert(std::string(symbols[formal.name]));
15531551
auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]);
15541552
error<TypeError>(
@@ -1564,6 +1562,9 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
15641562
unreachable();
15651563
}
15661564
}
1565+
else {
1566+
env2.values[displ++] = args[0];
1567+
}
15671568

15681569
nrFunctionCalls++;
15691570
if (countCalls)
@@ -1747,22 +1748,23 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
17471748
}
17481749
}
17491750

1750-
if (!fun.isLambda() || !fun.lambda().fun->hasFormals) {
1751+
if (!fun.isLambda() || !fun.lambda().fun->getFormals()) {
17511752
res = fun;
17521753
return;
17531754
}
1755+
auto formals = fun.lambda().fun->getFormals();
17541756

1755-
auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->nFormals), args.size()));
1757+
auto attrs = buildBindings(std::max(static_cast<uint32_t>(formals->formals.size()), args.size()));
17561758

1757-
if (fun.lambda().fun->ellipsis) {
1759+
if (formals->ellipsis) {
17581760
// If the formals have an ellipsis (eg the function accepts extra args) pass
17591761
// all available automatic arguments (which includes arguments specified on
17601762
// the command line via --arg/--argstr)
17611763
for (auto & v : args)
17621764
attrs.insert(v);
17631765
} else {
17641766
// Otherwise, only pass the arguments that the function accepts
1765-
for (auto & i : fun.lambda().fun->getFormals()) {
1767+
for (auto & i : formals->formals) {
17661768
auto j = args.get(i.name);
17671769
if (j) {
17681770
attrs.insert(*j);
@@ -1782,6 +1784,7 @@ values, or passed explicitly with '--arg' or '--argstr'. See
17821784
}
17831785

17841786
callFunction(fun, allocValue()->mkAttrs(attrs), res, pos);
1787+
17851788
}
17861789

17871790
void ExprWith::eval(EvalState & state, Env & env, Value & v)

src/libexpr/include/nix/expr/nixexpr.hh

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ struct Formal
460460
Expr * def;
461461
};
462462

463-
struct Formals
463+
struct FormalsBuilder
464464
{
465465
typedef std::vector<Formal> Formals_;
466466
/**
@@ -477,26 +477,63 @@ struct Formals
477477
}
478478
};
479479

480+
struct Formals {
481+
std::span<Formal> formals;
482+
bool ellipsis;
483+
484+
Formals(std::span<Formal> formals, bool ellipsis)
485+
: formals(formals)
486+
, ellipsis(ellipsis) {};
487+
488+
bool has(Symbol arg) const
489+
{
490+
auto it = std::lower_bound(
491+
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
492+
return it != formals.end() && it->name == arg;
493+
}
494+
495+
std::vector<Formal> lexicographicOrder(const SymbolTable & symbols) const
496+
{
497+
std::vector<Formal> result(formals.begin(), formals.end());
498+
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
499+
std::string_view sa = symbols[a.name], sb = symbols[b.name];
500+
return sa < sb;
501+
});
502+
return result;
503+
}
504+
};
505+
506+
480507
struct ExprLambda : Expr
481508
{
482509
PosIdx pos;
483510
Symbol name;
484511
Symbol arg;
485512

486-
bool ellipsis;
513+
private:
487514
bool hasFormals;
515+
bool ellipsis;
488516
uint16_t nFormals;
489517
Formal * formalsStart;
518+
public:
519+
520+
std::optional<Formals> getFormals() const
521+
{
522+
if (hasFormals)
523+
return Formals{{formalsStart, nFormals}, ellipsis};
524+
else
525+
return std::nullopt;
526+
}
490527

491528
Expr * body;
492529
DocComment docComment;
493530

494531
ExprLambda(
495-
std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, const Formals & formals, Expr * body)
532+
std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, const FormalsBuilder & formals, Expr * body)
496533
: pos(pos)
497534
, arg(arg)
498-
, ellipsis(formals.ellipsis)
499535
, hasFormals(true)
536+
, ellipsis(formals.ellipsis)
500537
, nFormals(formals.formals.size())
501538
, formalsStart(alloc.allocate_object<Formal>(nFormals))
502539
, body(body)
@@ -508,44 +545,22 @@ struct ExprLambda : Expr
508545
: pos(pos)
509546
, arg(arg)
510547
, hasFormals(false)
548+
, ellipsis(false)
549+
, nFormals(0)
511550
, formalsStart(nullptr)
512551
, body(body) {};
513552

514-
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Formals formals, Expr * body)
553+
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, FormalsBuilder formals, Expr * body)
515554
: ExprLambda(alloc, pos, Symbol(), formals, body) {};
516555

517-
bool hasFormal(Symbol arg) const
518-
{
519-
auto formals = getFormals();
520-
auto it = std::lower_bound(
521-
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
522-
return it != formals.end() && it->name == arg;
523-
}
524-
525556
void setName(Symbol name) override;
526557
std::string showNamePos(const EvalState & state) const;
527558

528-
std::vector<Formal> getFormalsLexicographic(const SymbolTable & symbols) const
529-
{
530-
std::vector<Formal> result(getFormals().begin(), getFormals().end());
531-
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
532-
std::string_view sa = symbols[a.name], sb = symbols[b.name];
533-
return sa < sb;
534-
});
535-
return result;
536-
}
537-
538559
PosIdx getPos() const override
539560
{
540561
return pos;
541562
}
542563

543-
std::span<Formal> getFormals() const
544-
{
545-
assert(hasFormals);
546-
return {formalsStart, nFormals};
547-
}
548-
549564
virtual void setDocComment(DocComment docComment) override;
550565
COMMON_METHODS
551566
};

src/libexpr/include/nix/expr/parser-state.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct ParserState
9393
void addAttr(
9494
ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc);
9595
void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def);
96-
void validateFormals(Formals & formals, PosIdx pos = noPos, Symbol arg = {});
96+
void validateFormals(FormalsBuilder & formals, PosIdx pos = noPos, Symbol arg = {});
9797
Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es);
9898
PosIdx at(const ParserLocation & loc);
9999
};
@@ -213,7 +213,7 @@ ParserState::addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symb
213213
}
214214
}
215215

216-
inline void ParserState::validateFormals(Formals & formals, PosIdx pos, Symbol arg)
216+
inline void ParserState::validateFormals(FormalsBuilder & formals, PosIdx pos, Symbol arg)
217217
{
218218
std::sort(formals.formals.begin(), formals.formals.end(), [](const auto & a, const auto & b) {
219219
return std::tie(a.name, a.pos) < std::tie(b.name, b.pos);

src/libexpr/nixexpr.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ void ExprList::show(const SymbolTable & symbols, std::ostream & str) const
154154
void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
155155
{
156156
str << "(";
157-
if (hasFormals) {
157+
if (auto formals = getFormals()) {
158158
str << "{ ";
159159
bool first = true;
160160
// the natural Symbol ordering is by creation time, which can lead to the
161161
// same expression being printed in two different ways depending on its
162162
// context. always use lexicographic ordering to avoid this.
163-
for (auto & i : getFormalsLexicographic(symbols)) {
163+
for (auto & i : formals->lexicographicOrder(symbols)) {
164164
if (first)
165165
first = false;
166166
else
@@ -451,20 +451,20 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
451451
if (es.debugRepl)
452452
es.exprEnvs.insert(std::make_pair(this, env));
453453

454-
auto newEnv = std::make_shared<StaticEnv>(nullptr, env, (hasFormals ? getFormals().size() : 0) + (!arg ? 0 : 1));
454+
auto newEnv = std::make_shared<StaticEnv>(nullptr, env, (getFormals() ? getFormals()->formals.size() : 0) + (!arg ? 0 : 1));
455455

456456
Displacement displ = 0;
457457

458458
if (arg)
459459
newEnv->vars.emplace_back(arg, displ++);
460460

461-
if (hasFormals) {
462-
for (auto & i : getFormals())
461+
if (auto formals = getFormals()) {
462+
for (auto & i : formals->formals)
463463
newEnv->vars.emplace_back(i.name, displ++);
464464

465465
newEnv->sort();
466466

467-
for (auto & i : getFormals())
467+
for (auto & i : formals->formals)
468468
if (i.def)
469469
i.def->bindVars(es, newEnv);
470470
}

src/libexpr/parser.y

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ static Expr * makeCall(PosIdx pos, Expr * fn, Expr * arg) {
131131
%type <nix::Expr *> expr_pipe_from expr_pipe_into
132132
%type <nix::ExprList *> expr_list
133133
%type <nix::ExprAttrs *> binds binds1
134-
%type <nix::Formals> formals formal_set
134+
%type <nix::FormalsBuilder> formals formal_set
135135
%type <nix::Formal> formal
136136
%type <std::vector<nix::AttrName>> attrpath
137137
%type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs

src/libexpr/primops.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3363,21 +3363,20 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
33633363
if (!args[0]->isLambda())
33643364
state.error<TypeError>("'functionArgs' requires a function").atPos(pos).debugThrow();
33653365

3366-
if (!args[0]->lambda().fun->hasFormals) {
3366+
if (const auto & formals = args[0]->lambda().fun->getFormals()) {
3367+
auto attrs = state.buildBindings(formals->formals.size());
3368+
for (auto & i : formals->formals)
3369+
attrs.insert(i.name, state.getBool(i.def), i.pos);
3370+
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
3371+
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
3372+
always holds:
3373+
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
3374+
.*/
3375+
v.mkAttrs(attrs.alreadySorted());
3376+
} else {
33673377
v.mkAttrs(&Bindings::emptyBindings);
33683378
return;
33693379
}
3370-
3371-
const auto & formals = args[0]->lambda().fun->getFormals();
3372-
auto attrs = state.buildBindings(formals.size());
3373-
for (auto & i : formals)
3374-
attrs.insert(i.name, state.getBool(i.def), i.pos);
3375-
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
3376-
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
3377-
always holds:
3378-
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
3379-
.*/
3380-
v.mkAttrs(attrs.alreadySorted());
33813380
}
33823381

33833382
static RegisterPrimOp primop_functionArgs({

src/libexpr/value-to-xml.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ static void printValueAsXML(
145145
posToXML(state, xmlAttrs, state.positions[v.lambda().fun->pos]);
146146
XMLOpenElement _(doc, "function", xmlAttrs);
147147

148-
if (v.lambda().fun->hasFormals) {
148+
if (auto formals = v.lambda().fun->getFormals()) {
149149
XMLAttrs attrs;
150150
if (v.lambda().fun->arg)
151151
attrs["name"] = state.symbols[v.lambda().fun->arg];
152-
if (v.lambda().fun->ellipsis)
152+
if (formals->ellipsis)
153153
attrs["ellipsis"] = "1";
154154
XMLOpenElement _(doc, "attrspat", attrs);
155-
for (auto & i : v.lambda().fun->getFormalsLexicographic(state.symbols))
155+
for (auto & i : formals->lexicographicOrder(state.symbols))
156156
doc.writeEmptyElement("attr", singletonAttrs("name", state.symbols[i.name]));
157157
} else
158158
doc.writeEmptyElement("varpat", singletonAttrs("name", state.symbols[v.lambda().fun->arg]));

src/libflake/flake.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,14 @@ static Flake readFlake(
281281
if (auto outputs = vInfo.attrs()->get(sOutputs)) {
282282
expectType(state, nFunction, *outputs->value, outputs->pos);
283283

284-
if (outputs->value->isLambda() && outputs->value->lambda().fun->hasFormals) {
285-
for (auto & formal : outputs->value->lambda().fun->getFormals()) {
286-
if (formal.name != state.s.self)
287-
flake.inputs.emplace(
288-
state.symbols[formal.name],
289-
FlakeInput{.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
284+
if (outputs->value->isLambda()) {
285+
if (auto formals = outputs->value->lambda().fun->getFormals()) {
286+
for (auto & formal : formals->formals) {
287+
if (formal.name != state.s.self)
288+
flake.inputs.emplace(
289+
state.symbols[formal.name],
290+
FlakeInput{.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
291+
}
290292
}
291293
}
292294

0 commit comments

Comments
 (0)