Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"expr_as_if" => expr_as_if(interner, arguments, return_type, location),
"expr_as_index" => expr_as_index(interner, arguments, return_type, location),
"expr_as_integer" => expr_as_integer(interner, arguments, return_type, location),
"expr_as_lambda" => expr_as_lambda(interner, arguments, return_type, location),
"expr_as_let" => expr_as_let(interner, arguments, return_type, location),
"expr_as_member_access" => {
expr_as_member_access(interner, arguments, return_type, location)
Expand Down Expand Up @@ -1612,6 +1613,68 @@ fn expr_as_integer(
})
}

// fn as_lambda(self) -> Option<([(Expr, Option<UnresolvedType>)], Option<UnresolvedType>, Expr)>
fn expr_as_lambda(
interner: &NodeInterner,
arguments: Vec<(Value, Location)>,
return_type: Type,
location: Location,
) -> IResult<Value> {
expr_as(interner, arguments, return_type.clone(), location, |expr| {
if let ExprValue::Expression(ExpressionKind::Lambda(lambda)) = expr {
// ([(Expr, Option<UnresolvedType>)], Option<UnresolvedType>, Expr)
let option_type = extract_option_generic_type(return_type);
let Type::Tuple(mut tuple_types) = option_type else {
panic!("Expected the return type option generic arg to be a tuple");
};
assert_eq!(tuple_types.len(), 3);

// Expr
tuple_types.pop().unwrap();

// Option<UnresolvedType>
let option_unresolved_type = tuple_types.pop().unwrap();

let parameters = lambda
.parameters
.into_iter()
.map(|(pattern, typ)| {
let pattern = Value::pattern(pattern);
let typ = if let UnresolvedTypeData::Unspecified = typ.typ {
None
} else {
Some(Value::UnresolvedType(typ.typ))
};
let typ = option(option_unresolved_type.clone(), typ).unwrap();
Value::Tuple(vec![pattern, typ])
})
.collect();
let parameters = Value::Slice(
parameters,
Type::Slice(Box::new(Type::Tuple(vec![
Type::Quoted(QuotedType::Expr),
Type::Quoted(QuotedType::UnresolvedType),
]))),
);

let return_type = lambda.return_type.typ;
let return_type = if let UnresolvedTypeData::Unspecified = return_type {
None
} else {
Some(return_type)
};
let return_type = return_type.map(Value::UnresolvedType);
let return_type = option(option_unresolved_type, return_type).ok()?;

let body = Value::expression(lambda.body.kind);

Some(Value::Tuple(vec![parameters, return_type, body]))
} else {
None
}
})
}

// fn as_let(self) -> Option<(Expr, Option<UnresolvedType>, Expr)>
fn expr_as_let(
interner: &NodeInterner,
Expand Down
6 changes: 6 additions & 0 deletions docs/docs/noir/standard_library/meta/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ array and the index.
If this expression is an integer literal, return the integer as a field
as well as whether the integer is negative (true) or not (false).

### as_lambda

#include_code as_lambda noir_stdlib/src/meta/expr.nr rust

If this expression is a lambda, returns the parameters, return type and body.

### as_let

#include_code as_let noir_stdlib/src/meta/expr.nr rust
Expand Down
43 changes: 43 additions & 0 deletions noir_stdlib/src/meta/expr.nr
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ impl Expr {
comptime fn as_integer(self) -> Option<(Field, bool)> {}
// docs:end:as_integer

/// If this expression is a lambda, returns the parameters, return type and body.
#[builtin(expr_as_lambda)]
// docs:start:as_lambda
comptime fn as_lambda(self) -> Option<([(Expr, Option<UnresolvedType>)], Option<UnresolvedType>, Expr)> {}
// docs:end:as_lambda

/// If this expression is a let statement, returns the let pattern as an `Expr`,
/// the optional type annotation, and the assigned expression.
#[builtin(expr_as_let)]
Expand Down Expand Up @@ -234,6 +240,7 @@ impl Expr {
let result = result.or_else(|| modify_index(self, f));
let result = result.or_else(|| modify_for(self, f));
let result = result.or_else(|| modify_for_range(self, f));
let result = result.or_else(|| modify_lambda(self, f));
let result = result.or_else(|| modify_let(self, f));
let result = result.or_else(|| modify_function_call(self, f));
let result = result.or_else(|| modify_member_access(self, f));
Expand Down Expand Up @@ -427,6 +434,17 @@ comptime fn modify_for_range<Env>(expr: Expr, f: fn[Env](Expr) -> Option<Expr>)
)
}

comptime fn modify_lambda<Env>(expr: Expr, f: fn[Env](Expr) -> Option<Expr>) -> Option<Expr> {
expr.as_lambda().map(
|expr: ([(Expr, Option<UnresolvedType>)], Option<UnresolvedType>, Expr)| {
let (params, return_type, body) = expr;
let params = params.map(|param: (Expr, Option<UnresolvedType>)| (param.0.modify(f), param.1));
let body = body.modify(f);
new_lambda(params, return_type, body)
}
)
}

comptime fn modify_let<Env>(expr: Expr, f: fn[Env](Expr) -> Option<Expr>) -> Option<Expr> {
expr.as_let().map(
|expr: (Expr, Option<UnresolvedType>, Expr)| {
Expand Down Expand Up @@ -599,6 +617,31 @@ comptime fn new_index(object: Expr, index: Expr) -> Expr {
quote { $object[$index] }.as_expr().unwrap()
}

comptime fn new_lambda(
params: [(Expr, Option<UnresolvedType>)],
return_type: Option<UnresolvedType>,
body: Expr
) -> Expr {
let params = params.map(
|param: (Expr, Option<UnresolvedType>)| {
let (name, typ) = param;
if typ.is_some() {
let typ = typ.unwrap();
quote { $name: $typ }
} else {
quote { $name }
}
}
).join(quote { , });

if return_type.is_some() {
let return_type = return_type.unwrap();
quote { |$params| -> $return_type { $body } }.as_expr().unwrap()
} else {
quote { |$params| { $body } }.as_expr().unwrap()
}
}

comptime fn new_let(pattern: Expr, typ: Option<UnresolvedType>, expr: Expr) -> Expr {
if typ.is_some() {
let typ = typ.unwrap();
Expand Down
42 changes: 42 additions & 0 deletions test_programs/noir_test_success/comptime_expr/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,48 @@ mod tests {
}
}

#[test]
fn test_expr_as_lambda() {
comptime
{
let expr = quote { |x: Field| -> Field { 1 } }.as_expr().unwrap();
let (params, return_type, body) = expr.as_lambda().unwrap();
assert_eq(params.len(), 1);
assert(params[0].1.unwrap().is_field());
assert(return_type.unwrap().is_field());
assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (1, false));

let expr = quote { |x| { 1 } }.as_expr().unwrap();
let (params, return_type, body) = expr.as_lambda().unwrap();
assert_eq(params.len(), 1);
assert(params[0].1.is_none());
assert(return_type.is_none());
assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (1, false));
}
}

#[test]
fn test_expr_modify_lambda() {
comptime
{
let expr = quote { |x: Field| -> Field { 1 } }.as_expr().unwrap();
let expr = expr.modify(times_two);
let (params, return_type, body) = expr.as_lambda().unwrap();
assert_eq(params.len(), 1);
assert(params[0].1.unwrap().is_field());
assert(return_type.unwrap().is_field());
assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (2, false));

let expr = quote { |x| { 1 } }.as_expr().unwrap();
let expr = expr.modify(times_two);
let (params, return_type, body) = expr.as_lambda().unwrap();
assert_eq(params.len(), 1);
assert(params[0].1.is_none());
assert(return_type.is_none());
assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (2, false));
}
}

#[test]
fn test_expr_as_let() {
comptime
Expand Down