Skip to content

Commit

Permalink
print type annotations of patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
javra committed Feb 25, 2025
1 parent e954250 commit f44ca0c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 24 deletions.
34 changes: 13 additions & 21 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -435,30 +435,26 @@ let rec update_ctx_pat (ctx : context) (P_aux (p, (l, annot)) as pat) =
List.fold_left update_ctx_pat ctx pats
| _ -> ctx

let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
let rec doc_pat ctx ?(in_match = false) ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
match p with
| P_wild -> underscore
| P_lit lit when in_vector -> doc_vec_lit lit
| P_lit lit -> doc_lit lit
| P_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _), p) when in_vector -> doc_pat p ^^ string ":1"
| P_typ (Typ_aux (Typ_app (Id_aux (Id id, _), [A_aux (A_nexp (Nexp_aux (Nexp_constant i, _)), _)]), _), p)
when in_vector && (id = "bits" || id = "bitvector") ->
doc_pat p ^^ string ":" ^^ doc_big_int i
| P_typ (ptyp, p) -> doc_pat p
| P_typ (ptyp, p) -> if in_match then doc_pat ctx p else flow space [doc_pat ctx p; colon; doc_typ ctx ptyp]
| P_id id -> fixup_match_id id |> doc_id_ctor
| P_tuple pats -> separate (string ", ") (List.map doc_pat pats) |> parens
| P_list pats -> separate (string ", ") (List.map doc_pat pats) |> brackets
| P_vector pats -> concat (List.map (doc_pat ~in_vector:true) pats)
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_tuple pats -> separate (string ", ") (List.map (doc_pat ctx) pats) |> parens
| P_list pats -> separate (string ", ") (List.map (doc_pat ctx) pats) |> brackets
| P_vector pats -> concat (List.map (doc_pat ctx ~in_vector:true) pats)
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ctx ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) ->
string "." ^^ doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| P_var (p, _) -> doc_pat p
| P_as (pat, id) -> doc_pat pat
string "." ^^ doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") (doc_pat ctx) pats
| P_var (p, _) -> doc_pat ctx p
| P_as (pat, id) -> doc_pat ctx pat
| P_struct (pats, _) ->
let pats = List.map (fun (id, pat) -> separate space [doc_id_ctor id; coloneq; doc_pat pat]) pats in
let pats = List.map (fun (id, pat) -> separate space [doc_id_ctor id; coloneq; doc_pat ctx pat]) pats in
braces (space ^^ separate (comma ^^ space) pats ^^ space)
| P_cons (hd_pat, tl_pat) -> parens (separate space [doc_pat hd_pat; string "::"; doc_pat tl_pat])
| P_cons (hd_pat, tl_pat) -> parens (separate space [doc_pat ctx hd_pat; string "::"; doc_pat ctx tl_pat])
| _ -> failwith ("Doc Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
Expand Down Expand Up @@ -546,7 +542,7 @@ let rec list_any (l : 'a list) (f : 'a -> bool) = match l with t :: q -> f t ||
let rec doc_match_clause (as_monadic : bool) ctx (Pat_aux (cl, l)) =
match cl with
| Pat_exp (pat, branch) ->
group (nest 2 (string "| " ^^ doc_pat pat ^^ string " =>" ^^ break 1 ^^ doc_exp as_monadic ctx branch))
group (nest 2 (string "| " ^^ doc_pat ctx ~in_match:true pat ^^ string " =>" ^^ break 1 ^^ doc_exp as_monadic ctx branch))
| Pat_when (pat, when_, branch) -> failwith "The Lean backend does not support 'when' clauses in patterns"

and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
Expand Down Expand Up @@ -708,11 +704,7 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| E_tuple es -> wrap_with_pure as_monadic (parens (separate_map (comma ^^ space) (d_of_arg ctx) es))
| E_let (LB_aux (LB_val (lpat, lexp), _), e') | E_internal_plet (lpat, lexp, e') ->
let arrow = match e with E_let _ -> string "" | _ -> string "← do" in
let id_typ =
match pat_is_plain_binder env lpat with
| Some (_, Some typ) -> doc_pat lpat ^^ space ^^ colon ^^ space ^^ doc_typ ctx typ
| _ -> doc_pat lpat
in
let id_typ = doc_pat ctx lpat in
let ctx = update_ctx_pat ctx lpat in
let pp_let_line_f l = group (nest 2 (flow (break 1) l)) in
let pp_let_line =
Expand Down
2 changes: 1 addition & 1 deletion test/lean/SailTinyArm.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,7 @@ def execute_StoreRegister (t : Nat) (n : Nat) (m : Nat) : SailM Unit := do
let base_addr ← do (rX n)
let offset ← do (rX m)
let addr := (base_addr + offset)
let _ := (wMem_Addr addr)
let _ : Unit := (wMem_Addr addr)
let data ← do (rX t)
(wMem addr data)

Expand Down
2 changes: 1 addition & 1 deletion test/lean/loop.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def foreachloopboth (m : Nat) (n : Nat) : SailM Int := do
def foreachloopmultiplevar (m : Nat) (n : Nat) : Int :=
let res : Int := 0
let mult : Int := 1
let (mult, res) :=
let (mult, res) : (Int × Int) :=
let loop_i_lower := m
let loop_i_upper := n
foreach_ loop_i_lower loop_i_upper 1 (mult, res)
Expand Down
2 changes: 1 addition & 1 deletion test/lean/match.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def const32 (_ : Unit) : ((BitVec 32) × Bool) :=

/-- Type quantifiers: k_n : Nat, k_n ≥ 0 -/
def match_width (x : (BitVec k_n)) : (BitVec (2 * k_n)) :=
let (foo, _) :=
let (foo, _) : ((BitVec k_n) × Bool) :=
match (Sail.BitVec.length x) with
| 16 => (const16 ())
| 32 => (const32 ())
Expand Down

0 comments on commit f44ca0c

Please sign in to comment.