-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrim.go
253 lines (222 loc) · 5.75 KB
/
trim.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
package discover
import "go/ast"
// Trim trims the AST rooted at node based on the coverage profile,
// removing irrelevant and unreached parts of the program.
// If the node is an *ast.File, comments are updated as well using
// an ast.CommentMap.
func (p *Profile) Trim(node ast.Node) {
if f, ok := node.(*ast.File); ok {
cmap := ast.NewCommentMap(p.Fset, f, f.Comments)
ast.Walk(&trimVisitor{p}, f)
f.Comments = cmap.Filter(f).Comments()
} else {
ast.Walk(&trimVisitor{p}, node)
}
}
// trimVisitor is an ast.Visitor that trims nodes as it walks the tree.
type trimVisitor struct {
p *Profile
}
func (v *trimVisitor) Visit(node ast.Node) ast.Visitor {
var list *[]ast.Stmt
switch node := node.(type) {
case *ast.File:
var replaced []ast.Decl
for _, decl := range node.Decls {
// Remove non-func declarations and funcs that were not covered
if f, ok := decl.(*ast.FuncDecl); ok && v.p.Funcs[f] {
replaced = append(replaced, decl)
}
}
node.Decls = replaced
// Node types containing lists of statements
case *ast.BlockStmt:
list = &node.List
case *ast.CommClause:
list = &node.Body
case *ast.CaseClause:
list = &node.Body
}
if list != nil {
var replaced []ast.Stmt
for _, stmt := range *list {
replaced = append(replaced, v.replaceStmt(stmt)...)
}
*list = replaced
}
return v
}
// replaceStmt returns the (possibly many) statements that should replace
// stmt. Generally a stmt is untouched or removed, but in some cases a
// single stmt can result in multiple statements. This is usually only the case
// when removing a block that was not taken, but pulling out function calls
// that were part of the initialization of the block.
func (v *trimVisitor) replaceStmt(stmt ast.Stmt) []ast.Stmt {
switch stmt := stmt.(type) {
case nil:
return nil
default:
// Keep original
return []ast.Stmt{stmt}
case *ast.RangeStmt:
if v.visited(stmt.Body) {
return []ast.Stmt{stmt}
}
call := v.findCall(stmt.X)
if call != nil {
return []ast.Stmt{&ast.ExprStmt{call}}
}
return nil
case *ast.ForStmt:
if v.visited(stmt.Body) {
return []ast.Stmt{stmt}
}
nodes := []*ast.CallExpr{
v.findCall(stmt.Init),
v.findCall(stmt.Cond),
v.findCall(stmt.Post),
}
var result []ast.Stmt
for _, call := range nodes {
if call != nil {
result = append(result, &ast.ExprStmt{call})
}
}
return result
case *ast.IfStmt:
vIf := v.visited(stmt.Body)
vElse := v.visited(stmt.Else)
if !vIf {
var result []ast.Stmt
// If we didn't reach the body, pull out any calls from
// init and cond.
nodes := []*ast.CallExpr{
v.findCall(stmt.Init),
v.findCall(stmt.Cond),
}
for _, call := range nodes {
if call != nil {
result = append(result, &ast.ExprStmt{call})
}
}
if vElse {
// We reached the else; add it
if block, ok := stmt.Else.(*ast.BlockStmt); ok {
// For a block statement, add the statements individually
// so we don't end up with an unnecessary block
for _, stmt := range block.List {
result = append(result, v.replaceStmt(stmt)...)
}
} else {
result = append(result, v.replaceStmt(stmt.Else)...)
}
}
return result
} else {
// We did take the if body
if !vElse {
// But not the else: remove it
stmt.Else = nil
}
return []ast.Stmt{stmt}
}
case *ast.SelectStmt:
var list []ast.Stmt
for _, stmt := range stmt.Body.List {
if v.visited(stmt) {
list = append(list, stmt)
}
}
stmt.Body.List = list
return []ast.Stmt{stmt}
case *ast.SwitchStmt:
var list []ast.Stmt
for _, stmt := range stmt.Body.List {
if v.visitedAndMatters(stmt) {
list = append(list, stmt)
}
}
// If we didn't visit any case clauses, don't add the select at all.
if len(list) == 0 {
return nil
} else {
stmt.Body.List = list
return []ast.Stmt{stmt}
}
case *ast.TypeSwitchStmt:
var list []ast.Stmt
for _, stmt := range stmt.Body.List {
if v.visitedAndMatters(stmt) {
list = append(list, stmt)
}
}
// If we didn't visit any case clauses, don't add the select at all.
if len(list) == 0 {
return nil
} else {
stmt.Body.List = list
return []ast.Stmt{stmt}
}
}
}
// visited is a helper function to return whether or not a statement
// was visited. If stmt is nil, visited returns false.
func (v *trimVisitor) visited(stmt ast.Stmt) bool {
if stmt == nil { // for convenience with e.g. IfStmt.Else
return false
}
return v.p.Stmts[stmt]
}
// visitedAndMatters is like visited, but also checks that the statement
// has any effect. For example, an empty block has no effect and thus
// is considered to not matter, even though it may have been visited.
func (v *trimVisitor) visitedAndMatters(stmt ast.Stmt) bool {
if !v.visited(stmt) {
return false
}
switch stmt := stmt.(type) {
default:
// By default, statements matter
return true
case *ast.EmptyStmt:
// Empty statements do not matter
return false
case *ast.BlockStmt:
// Blocks matter if and only if any of the containing statements
// matter.
for _, stmt := range stmt.List {
if v.visitedAndMatters(stmt) {
return true
}
}
return false
case *ast.CaseClause:
for _, stmt := range stmt.Body {
if v.visitedAndMatters(stmt) {
return true
}
}
return false
}
}
// findCall returns the first *ast.CallExpr encountered within the tree
// rooted at node, or nil if no CallExpr was found. This is useful for
// "pulling out" calls out of a statement or expression.
func (v *trimVisitor) findCall(node ast.Node) *ast.CallExpr {
if node == nil { // for convenience
return nil
}
var call *ast.CallExpr
ast.Inspect(node, func(n ast.Node) bool {
if call != nil {
return false
}
c, ok := n.(*ast.CallExpr)
if ok {
call = c
return false
}
return true
})
return call
}