This repository was archived by the owner on Jul 12, 2023. It is now read-only.
forked from polyglot-compiler/JLang
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathJLangSwitchExt.java
178 lines (152 loc) · 5.88 KB
/
JLangSwitchExt.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
//Copyright (C) 2018 Cornell University
package jlang.extension;
import polyglot.ast.*;
import polyglot.util.InternalCompilerError;
import polyglot.util.Position;
import polyglot.util.SerialVersionUID;
import java.lang.Override;
import java.util.*;
import jlang.ast.JLangExt;
import jlang.visit.DesugarLocally;
import jlang.visit.LLVMTranslator;
import static org.bytedeco.javacpp.LLVM.*;
public class JLangSwitchExt extends JLangExt {
private static final long serialVersionUID = SerialVersionUID.generate();
@Override
public Node desugar(DesugarLocally v) {
Switch n = (Switch) node();
if (n.expr().type().isSubtype(v.ts.String()))
return desugarStringSwitch(n, v);
return super.desugar(v);
}
// switch (strExpr) {
// case s1:
// ...
// case sn:
// ...
// default:
// ...
// }
//
// --->
//
// let s = strExpr;
// if (s == null)
// throw new NullPointerException();
//
// var i = 0;
// if (s.equals(s1))
// i = 1;
// if (s.equals(sn))
// i = n;
//
// switch (i) {
// case 1:
// ...
// case n:
// ...
// default:
// ...
// }
protected Node desugarStringSwitch(Switch n, DesugarLocally v) {
assert n.expr().type().isSubtype(v.ts.String());
Position pos = n.position();
List<Stmt> stmts = new ArrayList<>();
// Save switch expression.
LocalDecl strDecl = v.tnf.TempSSA("str", n.expr());
Local str = v.tnf.Local(pos, strDecl);
stmts.add(strDecl);
// Check for null pointer.
Expr nil = v.nf.NullLit(pos).type(v.ts.Null());
Expr check = v.nf.Binary(pos, copy(str), Binary.EQ, nil).type(v.ts.Boolean());
Stmt throwExn = v.tnf.Throw(pos, v.ts.NullPointerException(), Collections.emptyList());
stmts.add(v.tnf.If(check, throwExn));
// Declare case index.
Expr zero = v.nf.IntLit(pos, IntLit.INT, 0).type(v.ts.Int());
LocalDecl idxDecl = v.tnf.TempVar(pos, "idx", v.ts.Int(), zero);
Local idx = v.tnf.Local(pos, idxDecl);
stmts.add(idxDecl);
// Switch on case index instead of string value.
n = n.expr(copy(idx));
// Assign case index based on string value.
int counter = 0;
List<SwitchElement> elems = new ArrayList<>(n.elements());
for (int i = 0; i < elems.size(); ++i) {
SwitchElement e = elems.get(i);
if (e instanceof Case) {
Case c = (Case) e;
if (c.isDefault())
continue;
// Assign case index if strings equal.
assert c.expr() != null;
Expr equal = v.tnf.Call(
pos, copy(str), "equals", v.ts.String(), v.ts.Boolean(), c.expr());
IntLit val = (IntLit) v.nf.IntLit(pos, IntLit.INT, ++counter).type(v.ts.Int());
stmts.add(v.tnf.If(equal, v.tnf.EvalAssign(copy(idx), copy(val))));
// Update case value.
c = c.expr(copy(val)).value(val.value());
elems.set(i, c);
}
}
n = n.elements(elems);
stmts.add(n);
return v.nf.Block(pos, stmts);
}
@Override
public Node overrideTranslateLLVM(Node parent, LLVMTranslator v) {
Switch n = (Switch) node();
// Translate switch expression.
n.expr().visit(v);
LLVMValueRef exprRef = v.getTranslation(n.expr());
LLVMBasicBlockRef headBlock = LLVMGetInsertBlock(v.builder);
// Create a basic block for each switch block.
List<LLVMBasicBlockRef> blocks = new ArrayList<>();
for (SwitchElement e : n.elements())
if (e instanceof SwitchBlock)
blocks.add(v.utils.buildBlock("switch.case"));
// Build end block.
LLVMBasicBlockRef end = v.utils.buildBlock("switch.end");
blocks.add(end); // Append end block for convenience.
// Build switch blocks and case-to-block mappings.
v.pushSwitch(end); // Allows break statements to jump to end.
Map<Case, LLVMBasicBlockRef> blockMap = new LinkedHashMap<>(); // Excludes default case.
LLVMBasicBlockRef defaultBlock = end;
int nextBlockIdx = 0;
for (SwitchElement elem : n.elements()) {
if (elem instanceof Case) {
// Map case to block.
Case c = (Case) elem;
if (c.isDefault()) {
defaultBlock = blocks.get(nextBlockIdx);
} else {
blockMap.put(c, blocks.get(nextBlockIdx));
}
}
else if (elem instanceof SwitchBlock) {
// Build switch block and implement fall-through.
LLVMPositionBuilderAtEnd(v.builder, blocks.get(nextBlockIdx));
elem.visit(v);
++nextBlockIdx;
v.utils.branchUnlessTerminated(blocks.get(nextBlockIdx));
}
else {
throw new InternalCompilerError("Unhandled switch element");
}
}
v.popSwitch();
// Build switch.
LLVMPositionBuilderAtEnd(v.builder, headBlock);
LLVMValueRef switchRef = LLVMBuildSwitch(v.builder, exprRef, defaultBlock, blockMap.size());
// Add all cases.
for (Map.Entry<Case, LLVMBasicBlockRef> e : blockMap.entrySet()) {
Case c = e.getKey();
LLVMBasicBlockRef block = e.getValue();
assert !c.isDefault() : "The default case should be handled separately";
LLVMTypeRef type = v.utils.toLL(c.expr().type());
LLVMValueRef label = LLVMConstInt(type, c.value(), /*sign-extend*/ 0);
LLVMAddCase(switchRef, label, block);
}
LLVMPositionBuilderAtEnd(v.builder, end);
return n;
}
}