Skip to content

Commit

Permalink
Exhaustively handle expressions in patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
oli-obk committed Dec 12, 2024
1 parent 637ded1 commit 37cf874
Show file tree
Hide file tree
Showing 31 changed files with 392 additions and 185 deletions.
51 changes: 28 additions & 23 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,7 @@ impl<'hir> LoweringContext<'_, 'hir> {

let kind = match &e.kind {
ExprKind::Array(exprs) => hir::ExprKind::Array(self.lower_exprs(exprs)),
ExprKind::ConstBlock(c) => {
let c = self.with_new_scopes(c.value.span, |this| {
let def_id = this.local_def_id(c.id);
hir::ConstBlock {
def_id,
hir_id: this.lower_node_id(c.id),
body: this.lower_const_body(c.value.span, Some(&c.value)),
}
});
hir::ExprKind::ConstBlock(c)
}
ExprKind::ConstBlock(c) => hir::ExprKind::ConstBlock(self.lower_const_block(c)),
ExprKind::Repeat(expr, count) => {
let expr = self.lower_expr(expr);
let count = self.lower_array_length_to_const_arg(count);
Expand Down Expand Up @@ -154,18 +144,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
let ohs = self.lower_expr(ohs);
hir::ExprKind::Unary(op, ohs)
}
ExprKind::Lit(token_lit) => {
let lit_kind = match LitKind::from_token_lit(*token_lit) {
Ok(lit_kind) => lit_kind,
Err(err) => {
let guar =
report_lit_error(&self.tcx.sess.psess, err, *token_lit, e.span);
LitKind::Err(guar)
}
};
let lit = self.arena.alloc(respan(self.lower_span(e.span), lit_kind));
hir::ExprKind::Lit(lit)
}
ExprKind::Lit(token_lit) => hir::ExprKind::Lit(self.lower_lit(token_lit, e.span)),
ExprKind::IncludedBytes(bytes) => {
let lit = self.arena.alloc(respan(
self.lower_span(e.span),
Expand Down Expand Up @@ -396,6 +375,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
})
}

pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
self.with_new_scopes(c.value.span, |this| {
let def_id = this.local_def_id(c.id);
hir::ConstBlock {
def_id,
hir_id: this.lower_node_id(c.id),
body: this.lower_const_body(c.value.span, Some(&c.value)),
}
})
}

pub(crate) fn lower_lit(
&mut self,
token_lit: &token::Lit,
span: Span,
) -> &'hir Spanned<LitKind> {
let lit_kind = match LitKind::from_token_lit(*token_lit) {
Ok(lit_kind) => lit_kind,
Err(err) => {
let guar = report_lit_error(&self.tcx.sess.psess, err, *token_lit, span);
LitKind::Err(guar)
}
};
self.arena.alloc(respan(self.lower_span(span), lit_kind))
}

fn lower_unop(&mut self, u: UnOp) -> hir::UnOp {
match u {
UnOp::Deref => hir::UnOp::Deref,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
});
}

fn visit_pat_lit(&mut self, lit: &'hir PatLit<'hir>) {
self.insert(lit.span, lit.hir_id, Node::PatLit(lit));

self.with_parent(lit.hir_id, |this| {
intravisit::walk_pat_lit(this, lit);
});
}

fn visit_pat_field(&mut self, field: &'hir PatField<'hir>) {
self.insert(field.span, field.hir_id, Node::PatField(field));
self.with_parent(field.hir_id, |this| {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#![doc(rust_logo)]
#![feature(assert_matches)]
#![feature(box_patterns)]
#![feature(if_let_guard)]
#![feature(let_chains)]
#![feature(rustdoc_internals)]
#![warn(unreachable_pub)]
Expand Down
51 changes: 38 additions & 13 deletions compiler/rustc_ast_lowering/src/pat.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::sync::Arc;

use rustc_ast::ptr::P;
use rustc_ast::*;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_middle::span_bug;
use rustc_span::Span;
use rustc_span::source_map::Spanned;
use rustc_span::source_map::{Spanned, respan};
use rustc_span::symbol::Ident;

use super::errors::{
Expand Down Expand Up @@ -360,24 +363,46 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// }
// m!(S);
// ```
fn lower_expr_within_pat(&mut self, expr: &Expr, allow_paths: bool) -> &'hir hir::Expr<'hir> {
match &expr.kind {
ExprKind::Lit(..)
| ExprKind::ConstBlock(..)
| ExprKind::IncludedBytes(..)
| ExprKind::Err(_)
| ExprKind::Dummy => {}
ExprKind::Path(..) if allow_paths => {}
ExprKind::Unary(UnOp::Neg, inner) if matches!(inner.kind, ExprKind::Lit(_)) => {}
fn lower_expr_within_pat(&mut self, expr: &Expr, allow_paths: bool) -> &'hir hir::PatLit<'hir> {
let err = |guar| hir::PatLitKind::Lit {
lit: self.arena.alloc(respan(self.lower_span(expr.span), LitKind::Err(guar))),
negated: false,
};
let kind = match &expr.kind {
ExprKind::Lit(lit) => {
hir::PatLitKind::Lit { lit: self.lower_lit(lit, expr.span), negated: false }
}
ExprKind::ConstBlock(c) => hir::PatLitKind::ConstBlock(self.lower_const_block(c)),
ExprKind::IncludedBytes(bytes) => hir::PatLitKind::Lit {
lit: self.arena.alloc(respan(
self.lower_span(expr.span),
LitKind::ByteStr(Arc::clone(bytes), StrStyle::Cooked),
)),
negated: false,
},
ExprKind::Err(guar) => err(*guar),
ExprKind::Dummy => span_bug!(expr.span, "lowered ExprKind::Dummy"),
ExprKind::Path(qself, path) if allow_paths => hir::PatLitKind::Path(self.lower_qpath(
expr.id,
qself,
path,
ParamMode::Optional,
AllowReturnTypeNotation::No,
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
)),
ExprKind::Unary(UnOp::Neg, inner) if let ExprKind::Lit(lit) = &inner.kind => {
hir::PatLitKind::Lit { lit: self.lower_lit(lit, expr.span), negated: true }
}
_ => {
let pattern_from_macro = expr.is_approximately_pattern();
let guar = self.dcx().emit_err(ArbitraryExpressionInPattern {
span: expr.span,
pattern_from_macro_note: pattern_from_macro,
});
return self.arena.alloc(self.expr_err(expr.span, guar));
err(guar)
}
}
self.lower_expr(expr)
};
self.arena.alloc(hir::PatLit { hir_id: self.lower_node_id(expr.id), span: expr.span, kind })
}
}
29 changes: 27 additions & 2 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,26 @@ impl fmt::Debug for DotDotPos {
}
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub struct PatLit<'hir> {
pub hir_id: HirId,
pub span: Span,
pub kind: PatLitKind<'hir>,
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub enum PatLitKind<'hir> {
Lit {
lit: &'hir Lit,
// FIXME: move this into `Lit` and handle negated literal expressions
// once instead of matching on unop neg expressions everywhere.
negated: bool,
},
ConstBlock(ConstBlock),
/// A path pattern for a unit struct/variant or a (maybe-associated) constant.
Path(QPath<'hir>),
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub enum PatKind<'hir> {
/// Represents a wildcard pattern (i.e., `_`).
Expand Down Expand Up @@ -1315,10 +1335,10 @@ pub enum PatKind<'hir> {
Ref(&'hir Pat<'hir>, Mutability),

/// A literal.
Lit(&'hir Expr<'hir>),
Lit(&'hir PatLit<'hir>),

/// A range pattern (e.g., `1..=2` or `1..2`).
Range(Option<&'hir Expr<'hir>>, Option<&'hir Expr<'hir>>, RangeEnd),
Range(Option<&'hir PatLit<'hir>>, Option<&'hir PatLit<'hir>>, RangeEnd),

/// A slice pattern, `[before_0, ..., before_n, (slice, after_0, ..., after_n)?]`.
///
Expand Down Expand Up @@ -3833,6 +3853,10 @@ pub enum Node<'hir> {
OpaqueTy(&'hir OpaqueTy<'hir>),
Pat(&'hir Pat<'hir>),
PatField(&'hir PatField<'hir>),
/// Needed as its own node with its own HirId for tracking
/// the unadjusted type of literals within patterns
/// (e.g. byte str literals not being of slice type).
PatLit(&'hir PatLit<'hir>),
Arm(&'hir Arm<'hir>),
Block(&'hir Block<'hir>),
LetStmt(&'hir LetStmt<'hir>),
Expand Down Expand Up @@ -3889,6 +3913,7 @@ impl<'hir> Node<'hir> {
| Node::Block(..)
| Node::Ctor(..)
| Node::Pat(..)
| Node::PatLit(..)
| Node::Arm(..)
| Node::LetStmt(..)
| Node::Crate(..)
Expand Down
18 changes: 15 additions & 3 deletions compiler/rustc_hir/src/intravisit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ pub trait Visitor<'v>: Sized {
fn visit_pat_field(&mut self, f: &'v PatField<'v>) -> Self::Result {
walk_pat_field(self, f)
}
fn visit_pat_lit(&mut self, lit: &'v PatLit<'v>) -> Self::Result {
walk_pat_lit(self, lit)
}
fn visit_anon_const(&mut self, c: &'v AnonConst) -> Self::Result {
walk_anon_const(self, c)
}
Expand Down Expand Up @@ -686,10 +689,10 @@ pub fn walk_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v Pat<'v>) -> V:
try_visit!(visitor.visit_ident(ident));
visit_opt!(visitor, visit_pat, optional_subpattern);
}
PatKind::Lit(ref expression) => try_visit!(visitor.visit_expr(expression)),
PatKind::Lit(lit) => try_visit!(visitor.visit_pat_lit(lit)),
PatKind::Range(ref lower_bound, ref upper_bound, _) => {
visit_opt!(visitor, visit_expr, lower_bound);
visit_opt!(visitor, visit_expr, upper_bound);
visit_opt!(visitor, visit_pat_lit, lower_bound);
visit_opt!(visitor, visit_pat_lit, upper_bound);
}
PatKind::Never | PatKind::Wild | PatKind::Err(_) => (),
PatKind::Slice(prepatterns, ref slice_pattern, postpatterns) => {
Expand All @@ -707,6 +710,15 @@ pub fn walk_pat_field<'v, V: Visitor<'v>>(visitor: &mut V, field: &'v PatField<'
visitor.visit_pat(field.pat)
}

pub fn walk_pat_lit<'v, V: Visitor<'v>>(visitor: &mut V, lit: &'v PatLit<'v>) -> V::Result {
try_visit!(visitor.visit_id(lit.hir_id));
match &lit.kind {
PatLitKind::Lit { .. } => V::Result::output(),
PatLitKind::ConstBlock(c) => visitor.visit_inline_const(c),
PatLitKind::Path(qpath) => visitor.visit_qpath(qpath, lit.hir_id, lit.span),
}
}

pub fn walk_anon_const<'v, V: Visitor<'v>>(visitor: &mut V, constant: &'v AnonConst) -> V::Result {
try_visit!(visitor.visit_id(constant.hir_id));
visitor.visit_nested_body(constant.body)
Expand Down
42 changes: 20 additions & 22 deletions compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2430,17 +2430,11 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
Ty::new_error(tcx, err)
}
hir::PatKind::Range(start, end, include_end) => {
let expr_to_const = |expr: &'tcx hir::Expr<'tcx>| -> ty::Const<'tcx> {
let (expr, neg) = match expr.kind {
hir::ExprKind::Unary(hir::UnOp::Neg, negated) => {
(negated, Some((expr.hir_id, expr.span)))
}
_ => (expr, None),
};
let (c, c_ty) = match &expr.kind {
hir::ExprKind::Lit(lit) => {
let expr_to_const = |expr: &'tcx hir::PatLit<'tcx>| -> ty::Const<'tcx> {
let (c, c_ty) = match expr.kind {
hir::PatLitKind::Lit { lit, negated } => {
let lit_input =
LitToConstInput { lit: &lit.node, ty, neg: neg.is_some() };
LitToConstInput { lit: &lit.node, ty, neg: negated };
let ct = match tcx.lit_to_const(lit_input) {
Ok(c) => c,
Err(LitToConstError::Reported(err)) => {
Expand All @@ -2451,23 +2445,30 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
(ct, ty)
}

hir::ExprKind::Path(hir::QPath::Resolved(
hir::PatLitKind::Path(hir::QPath::Resolved(
_,
path @ &hir::Path {
res: Res::Def(DefKind::ConstParam, def_id),
..
},
)) => {
let _ = self.prohibit_generic_args(
match self.prohibit_generic_args(
path.segments.iter(),
GenericsArgsErrExtend::Param(def_id),
);
let ty = tcx
.type_of(def_id)
.no_bound_vars()
.expect("const parameter types cannot be generic");
let ct = self.lower_const_param(def_id, expr.hir_id);
(ct, ty)
) {
Ok(()) => {
let ty = tcx
.type_of(def_id)
.no_bound_vars()
.expect("const parameter types cannot be generic");
let ct = self.lower_const_param(def_id, expr.hir_id);
(ct, ty)
}
Err(guar) => (
ty::Const::new_error(tcx, guar),
Ty::new_error(tcx, guar),
),
}
}

_ => {
Expand All @@ -2478,9 +2479,6 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
}
};
self.record_ty(expr.hir_id, c_ty, expr.span);
if let Some((id, span)) = neg {
self.record_ty(id, c_ty, span);
}
c
};

Expand Down
20 changes: 17 additions & 3 deletions compiler/rustc_hir_pretty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl<'a> State<'a> {
Node::OpaqueTy(o) => self.print_opaque_ty(o),
Node::Pat(a) => self.print_pat(a),
Node::PatField(a) => self.print_patfield(a),
Node::PatLit(a) => self.print_pat_lit(a),
Node::Arm(a) => self.print_arm(a),
Node::Infer(_) => self.word("_"),
Node::PreciseCapturingNonLifetimeArg(param) => self.print_ident(param.ident),
Expand Down Expand Up @@ -1715,6 +1716,19 @@ impl<'a> State<'a> {
}
}

fn print_pat_lit(&mut self, lit: &hir::PatLit<'_>) {
match &lit.kind {
hir::PatLitKind::Lit { lit, negated } => {
if *negated {
self.word("-");
}
self.print_literal(lit);
}
hir::PatLitKind::ConstBlock(c) => self.print_inline_const(c),
hir::PatLitKind::Path(qpath) => self.print_qpath(qpath, true),
}
}

fn print_pat(&mut self, pat: &hir::Pat<'_>) {
self.maybe_print_comment(pat.span.lo());
self.ann.pre(self, AnnNode::Pat(pat));
Expand Down Expand Up @@ -1832,17 +1846,17 @@ impl<'a> State<'a> {
self.pclose();
}
}
PatKind::Lit(e) => self.print_expr(e),
PatKind::Lit(e) => self.print_pat_lit(e),
PatKind::Range(begin, end, end_kind) => {
if let Some(expr) = begin {
self.print_expr(expr);
self.print_pat_lit(expr);
}
match end_kind {
RangeEnd::Included => self.word("..."),
RangeEnd::Excluded => self.word(".."),
}
if let Some(expr) = end {
self.print_expr(expr);
self.print_pat_lit(expr);
}
}
PatKind::Slice(before, slice, after) => {
Expand Down
Loading

0 comments on commit 37cf874

Please sign in to comment.