match.go raw
1 package pattern
2
3 import (
4 "fmt"
5 "go/ast"
6 "go/token"
7 "go/types"
8 "reflect"
9
10 "golang.org/x/tools/go/ast/astutil"
11 )
12
13 var tokensByString = map[string]Token{
14 "INT": Token(token.INT),
15 "FLOAT": Token(token.FLOAT),
16 "IMAG": Token(token.IMAG),
17 "CHAR": Token(token.CHAR),
18 "STRING": Token(token.STRING),
19 "+": Token(token.ADD),
20 "-": Token(token.SUB),
21 "*": Token(token.MUL),
22 "/": Token(token.QUO),
23 "%": Token(token.REM),
24 "&": Token(token.AND),
25 "|": Token(token.OR),
26 "^": Token(token.XOR),
27 "<<": Token(token.SHL),
28 ">>": Token(token.SHR),
29 "&^": Token(token.AND_NOT),
30 "+=": Token(token.ADD_ASSIGN),
31 "-=": Token(token.SUB_ASSIGN),
32 "*=": Token(token.MUL_ASSIGN),
33 "/=": Token(token.QUO_ASSIGN),
34 "%=": Token(token.REM_ASSIGN),
35 "&=": Token(token.AND_ASSIGN),
36 "|=": Token(token.OR_ASSIGN),
37 "^=": Token(token.XOR_ASSIGN),
38 "<<=": Token(token.SHL_ASSIGN),
39 ">>=": Token(token.SHR_ASSIGN),
40 "&^=": Token(token.AND_NOT_ASSIGN),
41 "&&": Token(token.LAND),
42 "||": Token(token.LOR),
43 "<-": Token(token.ARROW),
44 "++": Token(token.INC),
45 "--": Token(token.DEC),
46 "==": Token(token.EQL),
47 "<": Token(token.LSS),
48 ">": Token(token.GTR),
49 "=": Token(token.ASSIGN),
50 "!": Token(token.NOT),
51 "!=": Token(token.NEQ),
52 "<=": Token(token.LEQ),
53 ">=": Token(token.GEQ),
54 ":=": Token(token.DEFINE),
55 "...": Token(token.ELLIPSIS),
56 "IMPORT": Token(token.IMPORT),
57 "VAR": Token(token.VAR),
58 "TYPE": Token(token.TYPE),
59 "CONST": Token(token.CONST),
60 "BREAK": Token(token.BREAK),
61 "CONTINUE": Token(token.CONTINUE),
62 "GOTO": Token(token.GOTO),
63 "FALLTHROUGH": Token(token.FALLTHROUGH),
64 }
65
66 func maybeToken(node Node) (Node, bool) {
67 if node, ok := node.(String); ok {
68 if tok, ok := tokensByString[string(node)]; ok {
69 return tok, true
70 }
71 return node, false
72 }
73 return node, false
74 }
75
76 func isNil(v interface{}) bool {
77 if v == nil {
78 return true
79 }
80 if _, ok := v.(Nil); ok {
81 return true
82 }
83 return false
84 }
85
86 type matcher interface {
87 Match(*Matcher, interface{}) (interface{}, bool)
88 }
89
90 type State = map[string]any
91
92 type Matcher struct {
93 TypesInfo *types.Info
94 State State
95
96 bindingsMapping []string
97
98 setBindings []uint64
99 }
100
101 func (m *Matcher) set(b Binding, value interface{}) {
102 m.State[b.Name] = value
103 m.setBindings[len(m.setBindings)-1] |= 1 << b.idx
104 }
105
106 func (m *Matcher) push() {
107 m.setBindings = append(m.setBindings, 0)
108 }
109
110 func (m *Matcher) pop() {
111 set := m.setBindings[len(m.setBindings)-1]
112 if set != 0 {
113 for i := 0; i < len(m.bindingsMapping); i++ {
114 if (set & (1 << i)) != 0 {
115 key := m.bindingsMapping[i]
116 delete(m.State, key)
117 }
118 }
119 }
120 m.setBindings = m.setBindings[:len(m.setBindings)-1]
121 }
122
123 func (m *Matcher) merge() {
124 m.setBindings = m.setBindings[:len(m.setBindings)-1]
125 }
126
127 func (m *Matcher) Match(a Pattern, b ast.Node) bool {
128 m.bindingsMapping = a.Bindings
129 m.State = State{}
130 m.push()
131 _, ok := match(m, a.Root, b)
132 m.merge()
133 if len(m.setBindings) != 0 {
134 panic(fmt.Sprintf("%d entries left on the stack, expected none", len(m.setBindings)))
135 }
136 return ok
137 }
138
139 func Match(a Pattern, b ast.Node) (*Matcher, bool) {
140 m := &Matcher{}
141 ret := m.Match(a, b)
142 return m, ret
143 }
144
145 // Match two items, which may be (Node, AST) or (AST, AST)
146 func match(m *Matcher, l, r interface{}) (interface{}, bool) {
147 if _, ok := r.(Node); ok {
148 panic("Node mustn't be on right side of match")
149 }
150
151 switch l := l.(type) {
152 case *ast.ParenExpr:
153 return match(m, l.X, r)
154 case *ast.ExprStmt:
155 return match(m, l.X, r)
156 case *ast.DeclStmt:
157 return match(m, l.Decl, r)
158 case *ast.LabeledStmt:
159 return match(m, l.Stmt, r)
160 case *ast.BlockStmt:
161 return match(m, l.List, r)
162 case *ast.FieldList:
163 if l == nil {
164 return match(m, nil, r)
165 } else {
166 return match(m, l.List, r)
167 }
168 }
169
170 switch r := r.(type) {
171 case *ast.ParenExpr:
172 return match(m, l, r.X)
173 case *ast.ExprStmt:
174 return match(m, l, r.X)
175 case *ast.DeclStmt:
176 return match(m, l, r.Decl)
177 case *ast.LabeledStmt:
178 return match(m, l, r.Stmt)
179 case *ast.BlockStmt:
180 if r == nil {
181 return match(m, l, nil)
182 }
183 return match(m, l, r.List)
184 case *ast.FieldList:
185 if r == nil {
186 return match(m, l, nil)
187 }
188 return match(m, l, r.List)
189 case *ast.BasicLit:
190 if r == nil {
191 return match(m, l, nil)
192 }
193 }
194
195 if l, ok := l.(matcher); ok {
196 return l.Match(m, r)
197 }
198
199 if l, ok := l.(Node); ok {
200 // Matching of pattern with concrete value
201 return matchNodeAST(m, l, r)
202 }
203
204 if l == nil || r == nil {
205 return nil, l == r
206 }
207
208 {
209 ln, ok1 := l.(ast.Node)
210 rn, ok2 := r.(ast.Node)
211 if ok1 && ok2 {
212 return matchAST(m, ln, rn)
213 }
214 }
215
216 {
217 obj, ok := l.(types.Object)
218 if ok {
219 switch r := r.(type) {
220 case *ast.Ident:
221 return obj, obj == m.TypesInfo.ObjectOf(r)
222 case *ast.SelectorExpr:
223 return obj, obj == m.TypesInfo.ObjectOf(r.Sel)
224 default:
225 return obj, false
226 }
227 }
228 }
229
230 // TODO(dh): the three blocks handling slices can be combined into a single block if we use reflection
231
232 {
233 ln, ok1 := l.([]ast.Expr)
234 rn, ok2 := r.([]ast.Expr)
235 if ok1 || ok2 {
236 if ok1 && !ok2 {
237 cast, ok := r.(ast.Expr)
238 if !ok {
239 return nil, false
240 }
241 rn = []ast.Expr{cast}
242 } else if !ok1 && ok2 {
243 cast, ok := l.(ast.Expr)
244 if !ok {
245 return nil, false
246 }
247 ln = []ast.Expr{cast}
248 }
249
250 if len(ln) != len(rn) {
251 return nil, false
252 }
253 for i, ll := range ln {
254 if _, ok := match(m, ll, rn[i]); !ok {
255 return nil, false
256 }
257 }
258 return r, true
259 }
260 }
261
262 {
263 ln, ok1 := l.([]ast.Stmt)
264 rn, ok2 := r.([]ast.Stmt)
265 if ok1 || ok2 {
266 if ok1 && !ok2 {
267 cast, ok := r.(ast.Stmt)
268 if !ok {
269 return nil, false
270 }
271 rn = []ast.Stmt{cast}
272 } else if !ok1 && ok2 {
273 cast, ok := l.(ast.Stmt)
274 if !ok {
275 return nil, false
276 }
277 ln = []ast.Stmt{cast}
278 }
279
280 if len(ln) != len(rn) {
281 return nil, false
282 }
283 for i, ll := range ln {
284 if _, ok := match(m, ll, rn[i]); !ok {
285 return nil, false
286 }
287 }
288 return r, true
289 }
290 }
291
292 {
293 ln, ok1 := l.([]*ast.Field)
294 rn, ok2 := r.([]*ast.Field)
295 if ok1 || ok2 {
296 if ok1 && !ok2 {
297 cast, ok := r.(*ast.Field)
298 if !ok {
299 return nil, false
300 }
301 rn = []*ast.Field{cast}
302 } else if !ok1 && ok2 {
303 cast, ok := l.(*ast.Field)
304 if !ok {
305 return nil, false
306 }
307 ln = []*ast.Field{cast}
308 }
309
310 if len(ln) != len(rn) {
311 return nil, false
312 }
313 for i, ll := range ln {
314 if _, ok := match(m, ll, rn[i]); !ok {
315 return nil, false
316 }
317 }
318 return r, true
319 }
320 }
321
322 return nil, false
323 }
324
325 // Match a Node with an AST node
326 func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) {
327 switch b := b.(type) {
328 case []ast.Stmt:
329 // 'a' is not a List or we'd be using its Match
330 // implementation.
331
332 if len(b) != 1 {
333 return nil, false
334 }
335 return match(m, a, b[0])
336 case []ast.Expr:
337 // 'a' is not a List or we'd be using its Match
338 // implementation.
339
340 if len(b) != 1 {
341 return nil, false
342 }
343 return match(m, a, b[0])
344 case []*ast.Field:
345 // 'a' is not a List or we'd be using its Match
346 // implementation
347 if len(b) != 1 {
348 return nil, false
349 }
350 return match(m, a, b[0])
351 case ast.Node:
352 ra := reflect.ValueOf(a)
353 rb := reflect.ValueOf(b).Elem()
354
355 if ra.Type().Name() != rb.Type().Name() {
356 return nil, false
357 }
358
359 for i := 0; i < ra.NumField(); i++ {
360 af := ra.Field(i)
361 fieldName := ra.Type().Field(i).Name
362 bf := rb.FieldByName(fieldName)
363 if (bf == reflect.Value{}) {
364 panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a))
365 }
366 ai := af.Interface()
367 bi := bf.Interface()
368 if ai == nil {
369 return b, bi == nil
370 }
371 if _, ok := match(m, ai.(Node), bi); !ok {
372 return b, false
373 }
374 }
375 return b, true
376 case nil:
377 return nil, a == Nil{}
378 case string, token.Token:
379 // 'a' can't be a String, Token, or Binding or we'd be using their Match implementations.
380 return nil, false
381 default:
382 panic(fmt.Sprintf("unhandled type %T", b))
383 }
384 }
385
386 // Match two AST nodes
387 func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) {
388 ra := reflect.ValueOf(a)
389 rb := reflect.ValueOf(b)
390
391 if ra.Type() != rb.Type() {
392 return nil, false
393 }
394 if ra.IsNil() || rb.IsNil() {
395 return rb, ra.IsNil() == rb.IsNil()
396 }
397
398 ra = ra.Elem()
399 rb = rb.Elem()
400 for i := 0; i < ra.NumField(); i++ {
401 af := ra.Field(i)
402 bf := rb.Field(i)
403 if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup {
404 continue
405 }
406
407 switch af.Kind() {
408 case reflect.Slice:
409 if af.Len() != bf.Len() {
410 return nil, false
411 }
412 for j := 0; j < af.Len(); j++ {
413 if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok {
414 return nil, false
415 }
416 }
417 case reflect.String:
418 if af.String() != bf.String() {
419 return nil, false
420 }
421 case reflect.Int:
422 if af.Int() != bf.Int() {
423 return nil, false
424 }
425 case reflect.Bool:
426 if af.Bool() != bf.Bool() {
427 return nil, false
428 }
429 case reflect.Ptr, reflect.Interface:
430 if _, ok := match(m, af.Interface(), bf.Interface()); !ok {
431 return nil, false
432 }
433 default:
434 panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface()))
435 }
436 }
437 return b, true
438 }
439
440 func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) {
441 if isNil(b.Node) {
442 v, ok := m.State[b.Name]
443 if ok {
444 // Recall value
445 return match(m, v, node)
446 }
447 // Matching anything
448 b.Node = Any{}
449 }
450
451 // Store value
452 if _, ok := m.State[b.Name]; ok {
453 panic(fmt.Sprintf("binding already created: %s", b.Name))
454 }
455 new, ret := match(m, b.Node, node)
456 if ret {
457 m.set(b, new)
458 }
459 return new, ret
460 }
461
462 func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) {
463 return node, true
464 }
465
466 func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) {
467 v := reflect.ValueOf(node)
468 if v.Kind() == reflect.Slice {
469 if isNil(l.Head) {
470 return node, v.Len() == 0
471 }
472 if v.Len() == 0 {
473 return nil, false
474 }
475 // OPT(dh): don't check the entire tail if head didn't match
476 _, ok1 := match(m, l.Head, v.Index(0).Interface())
477 _, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface())
478 return node, ok1 && ok2
479 }
480 // Our empty list does not equal an untyped Go nil. This way, we can
481 // tell apart an if with no else and an if with an empty else.
482 return nil, false
483 }
484
485 func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) {
486 switch o := node.(type) {
487 case token.Token:
488 if tok, ok := maybeToken(s); ok {
489 return match(m, tok, node)
490 }
491 return nil, false
492 case string:
493 return o, string(s) == o
494 case types.TypeAndValue:
495 return o, o.Value != nil && o.Value.String() == string(s)
496 default:
497 return nil, false
498 }
499 }
500
501 func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) {
502 o, ok := node.(token.Token)
503 if !ok {
504 return nil, false
505 }
506 return o, token.Token(tok) == o
507 }
508
509 func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) {
510 if isNil(node) {
511 return nil, true
512 }
513 v := reflect.ValueOf(node)
514 switch v.Kind() {
515 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
516 return nil, v.IsNil()
517 default:
518 return nil, false
519 }
520 }
521
522 func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) {
523 r, ok := match(m, Ident(builtin), node)
524 if !ok {
525 return nil, false
526 }
527 ident := r.(*ast.Ident)
528 obj := m.TypesInfo.ObjectOf(ident)
529 if obj != types.Universe.Lookup(ident.Name) {
530 return nil, false
531 }
532 return ident, true
533 }
534
535 func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) {
536 r, ok := match(m, Ident(obj), node)
537 if !ok {
538 return nil, false
539 }
540 ident := r.(*ast.Ident)
541
542 id := m.TypesInfo.ObjectOf(ident)
543 _, ok = match(m, obj.Name, ident.Name)
544 return id, ok
545 }
546
547 func (fn Symbol) Match(m *Matcher, node interface{}) (interface{}, bool) {
548 var name string
549 var obj types.Object
550
551 base := []Node{
552 Ident{Any{}},
553 SelectorExpr{Any{}, Any{}},
554 }
555 p := Or{
556 Nodes: append(base,
557 IndexExpr{Or{Nodes: base}, Any{}},
558 IndexListExpr{Or{Nodes: base}, Any{}})}
559
560 r, ok := match(m, p, node)
561 if !ok {
562 return nil, false
563 }
564
565 fun := r.(ast.Expr)
566 switch idx := fun.(type) {
567 case *ast.IndexExpr:
568 fun = idx.X
569 case *ast.IndexListExpr:
570 fun = idx.X
571 }
572 fun = astutil.Unparen(fun)
573
574 switch fun := fun.(type) {
575 case *ast.Ident:
576 obj = m.TypesInfo.ObjectOf(fun)
577 case *ast.SelectorExpr:
578 obj = m.TypesInfo.ObjectOf(fun.Sel)
579 default:
580 panic("unreachable")
581 }
582 switch obj := obj.(type) {
583 case *types.Func:
584 // OPT(dh): optimize this similar to code.FuncName
585 name = obj.FullName()
586 case *types.Builtin:
587 name = obj.Name()
588 case *types.TypeName:
589 origObj := obj
590 for {
591 if obj.Parent() != obj.Pkg().Scope() {
592 return nil, false
593 }
594 name = types.TypeString(obj.Type(), nil)
595 _, ok = match(m, fn.Name, name)
596 if ok || !obj.IsAlias() {
597 return origObj, ok
598 } else {
599 // FIXME(dh): we should peel away one layer of alias at a time; this is blocked on
600 // github.com/golang/go/issues/66559
601 switch typ := types.Unalias(obj.Type()).(type) {
602 case interface{ Obj() *types.TypeName }:
603 obj = typ.Obj()
604 case *types.Basic:
605 return match(m, fn.Name, typ.Name())
606 default:
607 return nil, false
608 }
609 }
610 }
611 case *types.Const, *types.Var:
612 if obj.Pkg() == nil {
613 return nil, false
614 }
615 if obj.Parent() != obj.Pkg().Scope() {
616 return nil, false
617 }
618 name = fmt.Sprintf("%s.%s", obj.Pkg().Path(), obj.Name())
619 default:
620 return nil, false
621 }
622
623 _, ok = match(m, fn.Name, name)
624 return obj, ok
625 }
626
627 func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) {
628 for _, opt := range or.Nodes {
629 m.push()
630 if ret, ok := match(m, opt, node); ok {
631 m.merge()
632 return ret, true
633 } else {
634 m.pop()
635 }
636 }
637 return nil, false
638 }
639
640 func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) {
641 _, ok := match(m, not.Node, node)
642 if ok {
643 return nil, false
644 }
645 return node, true
646 }
647
648 var integerLiteralQ = MustParse(`(Or (BasicLit "INT" _) (UnaryExpr (Or "+" "-") (IntegerLiteral _)))`)
649
650 func (lit IntegerLiteral) Match(m *Matcher, node interface{}) (interface{}, bool) {
651 matched, ok := match(m, integerLiteralQ.Root, node)
652 if !ok {
653 return nil, false
654 }
655 tv, ok := m.TypesInfo.Types[matched.(ast.Expr)]
656 if !ok {
657 return nil, false
658 }
659 if tv.Value == nil {
660 return nil, false
661 }
662 _, ok = match(m, lit.Value, tv)
663 return matched, ok
664 }
665
666 func (texpr TrulyConstantExpression) Match(m *Matcher, node interface{}) (interface{}, bool) {
667 expr, ok := node.(ast.Expr)
668 if !ok {
669 return nil, false
670 }
671 tv, ok := m.TypesInfo.Types[expr]
672 if !ok {
673 return nil, false
674 }
675 if tv.Value == nil {
676 return nil, false
677 }
678 truly := true
679 ast.Inspect(expr, func(node ast.Node) bool {
680 if _, ok := node.(*ast.Ident); ok {
681 truly = false
682 return false
683 }
684 return true
685 })
686 if !truly {
687 return nil, false
688 }
689 _, ok = match(m, texpr.Value, tv)
690 return expr, ok
691 }
692
693 var (
694 // Types of fields in go/ast structs that we want to skip
695 rtTokPos = reflect.TypeOf(token.Pos(0))
696 //lint:ignore SA1019 It's deprecated, but we still want to skip the field.
697 rtObject = reflect.TypeOf((*ast.Object)(nil))
698 rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil))
699 )
700
701 var (
702 _ matcher = Binding{}
703 _ matcher = Any{}
704 _ matcher = List{}
705 _ matcher = String("")
706 _ matcher = Token(0)
707 _ matcher = Nil{}
708 _ matcher = Builtin{}
709 _ matcher = Object{}
710 _ matcher = Symbol{}
711 _ matcher = Or{}
712 _ matcher = Not{}
713 _ matcher = IntegerLiteral{}
714 _ matcher = TrulyConstantExpression{}
715 )
716