Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): Parse match expressions #7243

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub enum ExpressionKind {
Cast(Box<CastExpression>),
Infix(Box<InfixExpression>),
If(Box<IfExpression>),
Match(Box<MatchExpression>),
Variable(Path),
Tuple(Vec<Expression>),
Lambda(Box<Lambda>),
Expand Down Expand Up @@ -465,6 +466,12 @@ pub struct IfExpression {
pub alternative: Option<Expression>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct MatchExpression {
pub expression: Expression,
pub rules: Vec<(/*pattern*/ Expression, /*branch*/ Expression)>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Lambda {
pub parameters: Vec<(Pattern, UnresolvedType)>,
Expand Down Expand Up @@ -612,6 +619,7 @@ impl Display for ExpressionKind {
Cast(cast) => cast.fmt(f),
Infix(infix) => infix.fmt(f),
If(if_expr) => if_expr.fmt(f),
Match(match_expr) => match_expr.fmt(f),
Variable(path) => path.fmt(f),
Constructor(constructor) => constructor.fmt(f),
MemberAccess(access) => access.fmt(f),
Expand Down Expand Up @@ -790,6 +798,16 @@ impl Display for IfExpression {
}
}

impl Display for MatchExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "match {} {{", self.expression)?;
for (pattern, branch) in &self.rules {
writeln!(f, " {pattern} -> {branch},")?;
}
write!(f, "}}")
}
}

impl Display for Lambda {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let parameters = vecmap(&self.parameters, |(name, r#type)| format!("{name}: {type}"));
Expand Down
25 changes: 24 additions & 1 deletion compiler/noirc_frontend/src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{

use super::{
ForBounds, FunctionReturnType, GenericTypeArgs, IntegerBitSize, ItemVisibility,
NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath,
MatchExpression, NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath,
UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression,
};
Expand Down Expand Up @@ -222,6 +222,10 @@ pub trait Visitor {
true
}

fn visit_match_expression(&mut self, _: &MatchExpression, _: Span) -> bool {
true
}

fn visit_tuple(&mut self, _: &[Expression], _: Span) -> bool {
true
}
Expand Down Expand Up @@ -866,6 +870,9 @@ impl Expression {
ExpressionKind::If(if_expression) => {
if_expression.accept(self.span, visitor);
}
ExpressionKind::Match(match_expression) => {
match_expression.accept(self.span, visitor);
}
ExpressionKind::Tuple(expressions) => {
if visitor.visit_tuple(expressions, self.span) {
visit_expressions(expressions, visitor);
Expand Down Expand Up @@ -1073,6 +1080,22 @@ impl IfExpression {
}
}

impl MatchExpression {
pub fn accept(&self, span: Span, visitor: &mut impl Visitor) {
if visitor.visit_match_expression(self, span) {
self.accept_children(visitor);
}
}

pub fn accept_children(&self, visitor: &mut impl Visitor) {
self.expression.accept(visitor);
for (pattern, branch) in &self.rules {
pattern.accept(visitor);
branch.accept(visitor);
}
}
}

impl Lambda {
pub fn accept(&self, span: Span, visitor: &mut impl Visitor) {
if visitor.visit_lambda(self, span) {
Expand Down
11 changes: 8 additions & 3 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::{
ast::{
ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression,
Expression, ExpressionKind, Ident, IfExpression, IndexExpression, InfixExpression,
ItemVisibility, Lambda, Literal, MemberAccessExpression, MethodCallExpression, Path,
PathSegment, PrefixExpression, StatementKind, UnaryOp, UnresolvedTypeData,
UnresolvedTypeExpression,
ItemVisibility, Lambda, Literal, MatchExpression, MemberAccessExpression,
MethodCallExpression, Path, PathSegment, PrefixExpression, StatementKind, UnaryOp,
UnresolvedTypeData, UnresolvedTypeExpression,
},
hir::{
comptime::{self, InterpreterError},
Expand Down Expand Up @@ -51,6 +51,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span),
ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span),
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::Match(match_) => self.elaborate_match(*match_),
ExpressionKind::Variable(variable) => return self.elaborate_variable(variable),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None),
Expand Down Expand Up @@ -926,6 +927,10 @@ impl<'context> Elaborator<'context> {
(HirExpression::If(if_expr), ret_type)
}

fn elaborate_match(&mut self, _match_expr: MatchExpression) -> (HirExpression, Type) {
(HirExpression::Error, Type::Error)
}

fn elaborate_tuple(&mut self, tuple: Vec<Expression>) -> (HirExpression, Type) {
let mut element_ids = Vec::with_capacity(tuple.len());
let mut element_types = Vec::with_capacity(tuple.len());
Expand Down
15 changes: 12 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::{
ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression,
CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind,
ForBounds, ForLoopStatement, ForRange, GenericTypeArgs, IfExpression, IndexExpression,
InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression,
MethodCallExpression, Pattern, PrefixExpression, Statement, StatementKind, UnresolvedType,
UnresolvedTypeData,
InfixExpression, LValue, Lambda, LetStatement, Literal, MatchExpression,
MemberAccessExpression, MethodCallExpression, Pattern, PrefixExpression, Statement,
StatementKind, UnresolvedType, UnresolvedTypeData,
},
hir_def::traits::TraitConstraint,
node_interner::{InternedStatementKind, NodeInterner},
Expand Down Expand Up @@ -241,6 +241,7 @@ impl<'interner> TokenPrettyPrinter<'interner> {
| Token::GreaterEqual
| Token::Equal
| Token::NotEqual
| Token::FatArrow
| Token::Arrow => write!(f, " {token} "),
Token::Assign => {
if last_was_op {
Expand Down Expand Up @@ -602,6 +603,14 @@ fn remove_interned_in_expression_kind(
.alternative
.map(|alternative| remove_interned_in_expression(interner, alternative)),
})),
ExpressionKind::Match(match_expr) => ExpressionKind::Match(Box::new(MatchExpression {
expression: remove_interned_in_expression(interner, match_expr.expression),
rules: vecmap(match_expr.rules, |(pattern, branch)| {
let pattern = remove_interned_in_expression(interner, pattern);
let branch = remove_interned_in_expression(interner, branch);
(pattern, branch)
}),
})),
ExpressionKind::Variable(_) => expr,
ExpressionKind::Tuple(expressions) => ExpressionKind::Tuple(vecmap(expressions, |expr| {
remove_interned_in_expression(interner, expr)
Expand Down
13 changes: 12 additions & 1 deletion compiler/noirc_frontend/src/lexer/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,19 @@ impl<'a> Lexer<'a> {
Ok(prev_token.into_single_span(start))
}
}
Token::Assign => {
let start = self.position;
if self.peek_char_is('=') {
self.next_char();
Ok(Token::Equal.into_span(start, start + 1))
} else if self.peek_char_is('>') {
self.next_char();
Ok(Token::FatArrow.into_span(start, start + 1))
} else {
Ok(prev_token.into_single_span(start))
}
}
Token::Bang => self.single_double_peek_token('=', prev_token, Token::NotEqual),
Token::Assign => self.single_double_peek_token('=', prev_token, Token::Equal),
Token::Minus => self.single_double_peek_token('>', prev_token, Token::Arrow),
Token::Colon => self.single_double_peek_token(':', prev_token, Token::DoubleColon),
Token::Slash => {
Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ pub enum BorrowedToken<'input> {
RightBracket,
/// ->
Arrow,
/// =>
FatArrow,
/// |
Pipe,
/// #
Expand Down Expand Up @@ -212,6 +214,8 @@ pub enum Token {
RightBracket,
/// ->
Arrow,
/// =>
FatArrow,
/// |
Pipe,
/// #
Expand Down Expand Up @@ -296,6 +300,7 @@ pub fn token_to_borrowed_token(token: &Token) -> BorrowedToken<'_> {
Token::LeftBracket => BorrowedToken::LeftBracket,
Token::RightBracket => BorrowedToken::RightBracket,
Token::Arrow => BorrowedToken::Arrow,
Token::FatArrow => BorrowedToken::FatArrow,
Token::Pipe => BorrowedToken::Pipe,
Token::Pound => BorrowedToken::Pound,
Token::Comma => BorrowedToken::Comma,
Expand Down Expand Up @@ -473,6 +478,7 @@ impl fmt::Display for Token {
Token::LeftBracket => write!(f, "["),
Token::RightBracket => write!(f, "]"),
Token::Arrow => write!(f, "->"),
Token::FatArrow => write!(f, "=>"),
Token::Pipe => write!(f, "|"),
Token::Pound => write!(f, "#"),
Token::Comma => write!(f, ","),
Expand Down
52 changes: 49 additions & 3 deletions compiler/noirc_frontend/src/parser/parser/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use noirc_errors::Span;
use crate::{
ast::{
ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression,
Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal,
Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, MatchExpression,
MemberAccessExpression, MethodCallExpression, Statement, TypePath, UnaryOp, UnresolvedType,
},
parser::{labels::ParsingRuleLabel, parser::parse_many::separated_by_comma, ParserErrorReason},
Expand Down Expand Up @@ -91,8 +91,7 @@ impl<'a> Parser<'a> {
}

/// AtomOrUnaryRightExpression
/// = Atom
/// | UnaryRightExpression
/// = Atom UnaryRightExpression*
fn parse_atom_or_unary_right(&mut self, allow_constructors: bool) -> Option<Expression> {
let start_span = self.current_token_span;
let mut atom = self.parse_atom(allow_constructors)?;
Expand Down Expand Up @@ -311,6 +310,10 @@ impl<'a> Parser<'a> {
return Some(kind);
}

if let Some(kind) = self.parse_match_expr() {
return Some(kind);
}

if let Some(kind) = self.parse_lambda() {
return Some(kind);
}
Expand Down Expand Up @@ -518,6 +521,49 @@ impl<'a> Parser<'a> {
Some(ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative })))
}

/// MatchExpression = 'match' ExpressionExceptConstructor '{' MatchRule* '}'
pub(super) fn parse_match_expr(&mut self) -> Option<ExpressionKind> {
let start_span = self.current_token_span;
if !self.eat_keyword(Keyword::Match) {
return None;
}

let expression = self.parse_expression_except_constructor_or_error();

self.eat_left_brace();

let rules = self.parse_many(
"match cases",
without_separator().until(Token::RightBrace),
Self::parse_match_rule,
);

self.push_error(ParserErrorReason::ExperimentalFeature("Match expressions"), start_span);
Some(ExpressionKind::Match(Box::new(MatchExpression { expression, rules })))
}

/// MatchRule = Expression '->' (Block ','?) | (Expression ',')
fn parse_match_rule(&mut self) -> Option<(Expression, Expression)> {
let pattern = self.parse_expression()?;
self.eat_or_error(Token::FatArrow);

let start_span = self.current_token_span;
let branch = match self.parse_block() {
Some(block) => {
let span = self.span_since(start_span);
let block = Expression::new(ExpressionKind::Block(block), span);
self.eat_comma(); // comma is optional if we have a block
block
}
None => {
let branch = self.parse_expression_or_error();
self.eat_or_error(Token::Comma);
branch
}
};
Some((pattern, branch))
}

/// ComptimeExpression = 'comptime' Block
fn parse_comptime_expr(&mut self) -> Option<ExpressionKind> {
if !self.eat_keyword(Keyword::Comptime) {
Expand Down
11 changes: 7 additions & 4 deletions compiler/noirc_frontend/src/parser/parser/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,13 @@ impl<'a> Parser<'a> {
}

if let Some(kind) = self.parse_if_expr() {
return Some(StatementKind::Expression(Expression {
kind,
span: self.span_since(start_span),
}));
let span = self.span_since(start_span);
return Some(StatementKind::Expression(Expression { kind, span }));
}

if let Some(kind) = self.parse_match_expr() {
let span = self.span_since(start_span);
return Some(StatementKind::Expression(Expression { kind, span }));
}

if let Some(block) = self.parse_block() {
Expand Down
1 change: 1 addition & 0 deletions tooling/lsp/src/requests/inlay_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ fn get_expression_name(expression: &Expression) -> Option<String> {
| ExpressionKind::InternedStatement(..)
| ExpressionKind::Literal(..)
| ExpressionKind::Unsafe(..)
| ExpressionKind::Match(_)
| ExpressionKind::Error => None,
}
}
Expand Down
Loading
Loading