00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 from antlr3.constants import INVALID_TOKEN_TYPE
00041 from antlr3.tokens import CommonToken
00042 from antlr3.tree import CommonTree, CommonTreeAdaptor
00043
00044
00045
00046
00047
00048
00049
00050 def computeTokenTypes(tokenNames):
00051
00052 if tokenNames is None:
00053 return {}
00054
00055 return dict((name, type) for type, name in enumerate(tokenNames))
00056
00057
00058
00059 EOF = -1
00060 BEGIN = 1
00061 END = 2
00062 ID = 3
00063 ARG = 4
00064 PERCENT = 5
00065 COLON = 6
00066 DOT = 7
00067
00068 class TreePatternLexer(object):
00069 def __init__(self, pattern):
00070
00071 self.pattern = pattern
00072
00073
00074 self.p = -1
00075
00076
00077 self.c = None
00078
00079
00080 self.n = len(pattern)
00081
00082
00083 self.sval = None
00084
00085 self.error = False
00086
00087 self.consume()
00088
00089
00090 __idStartChar = frozenset(
00091 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_'
00092 )
00093 __idChar = __idStartChar | frozenset('0123456789')
00094
00095 def nextToken(self):
00096 self.sval = ""
00097 while self.c != EOF:
00098 if self.c in (' ', '\n', '\r', '\t'):
00099 self.consume()
00100 continue
00101
00102 if self.c in self.__idStartChar:
00103 self.sval += self.c
00104 self.consume()
00105 while self.c in self.__idChar:
00106 self.sval += self.c
00107 self.consume()
00108
00109 return ID
00110
00111 if self.c == '(':
00112 self.consume()
00113 return BEGIN
00114
00115 if self.c == ')':
00116 self.consume()
00117 return END
00118
00119 if self.c == '%':
00120 self.consume()
00121 return PERCENT
00122
00123 if self.c == ':':
00124 self.consume()
00125 return COLON
00126
00127 if self.c == '.':
00128 self.consume()
00129 return DOT
00130
00131 if self.c == '[':
00132 self.consume()
00133 while self.c != ']':
00134 if self.c == '\\':
00135 self.consume()
00136 if self.c != ']':
00137 self.sval += '\\'
00138
00139 self.sval += self.c
00140
00141 else:
00142 self.sval += self.c
00143
00144 self.consume()
00145
00146 self.consume()
00147 return ARG
00148
00149 self.consume()
00150 self.error = True
00151 return EOF
00152
00153 return EOF
00154
00155
00156 def consume(self):
00157 self.p += 1
00158 if self.p >= self.n:
00159 self.c = EOF
00160
00161 else:
00162 self.c = self.pattern[self.p]
00163
00164
00165 class TreePatternParser(object):
00166 def __init__(self, tokenizer, wizard, adaptor):
00167 self.tokenizer = tokenizer
00168 self.wizard = wizard
00169 self.adaptor = adaptor
00170 self.ttype = tokenizer.nextToken()
00171
00172
00173 def pattern(self):
00174 if self.ttype == BEGIN:
00175 return self.parseTree()
00176
00177 elif self.ttype == ID:
00178 node = self.parseNode()
00179 if self.ttype == EOF:
00180 return node
00181
00182 return None
00183
00184 return None
00185
00186
00187 def parseTree(self):
00188 if self.ttype != BEGIN:
00189 return None
00190
00191 self.ttype = self.tokenizer.nextToken()
00192 root = self.parseNode()
00193 if root is None:
00194 return None
00195
00196 while self.ttype in (BEGIN, ID, PERCENT, DOT):
00197 if self.ttype == BEGIN:
00198 subtree = self.parseTree()
00199 self.adaptor.addChild(root, subtree)
00200
00201 else:
00202 child = self.parseNode()
00203 if child is None:
00204 return None
00205
00206 self.adaptor.addChild(root, child)
00207
00208 if self.ttype != END:
00209 return None
00210
00211 self.ttype = self.tokenizer.nextToken()
00212 return root
00213
00214
00215 def parseNode(self):
00216
00217 label = None
00218
00219 if self.ttype == PERCENT:
00220 self.ttype = self.tokenizer.nextToken()
00221 if self.ttype != ID:
00222 return None
00223
00224 label = self.tokenizer.sval
00225 self.ttype = self.tokenizer.nextToken()
00226 if self.ttype != COLON:
00227 return None
00228
00229 self.ttype = self.tokenizer.nextToken()
00230
00231
00232 if self.ttype == DOT:
00233 self.ttype = self.tokenizer.nextToken()
00234 wildcardPayload = CommonToken(0, ".")
00235 node = WildcardTreePattern(wildcardPayload)
00236 if label is not None:
00237 node.label = label
00238 return node
00239
00240
00241 if self.ttype != ID:
00242 return None
00243
00244 tokenName = self.tokenizer.sval
00245 self.ttype = self.tokenizer.nextToken()
00246
00247 if tokenName == "nil":
00248 return self.adaptor.nil()
00249
00250 text = tokenName
00251
00252 arg = None
00253 if self.ttype == ARG:
00254 arg = self.tokenizer.sval
00255 text = arg
00256 self.ttype = self.tokenizer.nextToken()
00257
00258
00259 treeNodeType = self.wizard.getTokenType(tokenName)
00260 if treeNodeType == INVALID_TOKEN_TYPE:
00261 return None
00262
00263 node = self.adaptor.createFromType(treeNodeType, text)
00264 if label is not None and isinstance(node, TreePattern):
00265 node.label = label
00266
00267 if arg is not None and isinstance(node, TreePattern):
00268 node.hasTextArg = True
00269
00270 return node
00271
00272
00273
00274
00275
00276
00277
00278 class TreePattern(CommonTree):
00279
00280 def __init__(self, payload):
00281 CommonTree.__init__(self, payload)
00282
00283 self.label = None
00284 self.hasTextArg = None
00285
00286
00287 def toString(self):
00288 if self.label is not None:
00289 return '%' + self.label + ':' + CommonTree.toString(self)
00290
00291 else:
00292 return CommonTree.toString(self)
00293
00294
00295 class WildcardTreePattern(TreePattern):
00296 pass
00297
00298
00299
00300
00301 class TreePatternTreeAdaptor(CommonTreeAdaptor):
00302
00303 def createWithPayload(self, payload):
00304 return TreePattern(payload)
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325 class TreeWizard(object):
00326
00327 def __init__(self, adaptor=None, tokenNames=None, typeMap=None):
00328 self.adaptor = adaptor
00329 if typeMap is None:
00330 self.tokenNameToTypeMap = computeTokenTypes(tokenNames)
00331
00332 else:
00333 if tokenNames is not None:
00334 raise ValueError("Can't have both tokenNames and typeMap")
00335
00336 self.tokenNameToTypeMap = typeMap
00337
00338
00339
00340
00341 def getTokenType(self, tokenName):
00342
00343 try:
00344 return self.tokenNameToTypeMap[tokenName]
00345 except KeyError:
00346 return INVALID_TOKEN_TYPE
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364 def create(self, pattern):
00365
00366 tokenizer = TreePatternLexer(pattern)
00367 parser = TreePatternParser(tokenizer, self, self.adaptor)
00368 return parser.pattern()
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378 def index(self, tree):
00379
00380 m = {}
00381 self._index(tree, m)
00382 return m
00383
00384
00385
00386
00387 def _index(self, t, m):
00388
00389 if t is None:
00390 return
00391
00392 ttype = self.adaptor.getType(t)
00393 elements = m.get(ttype)
00394 if elements is None:
00395 m[ttype] = elements = []
00396
00397 elements.append(t)
00398 for i in range(self.adaptor.getChildCount(t)):
00399 child = self.adaptor.getChild(t, i)
00400 self._index(child, m)
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410 def find(self, tree, what):
00411
00412 if isinstance(what, (int, long)):
00413 return self._findTokenType(tree, what)
00414
00415 elif isinstance(what, basestring):
00416 return self._findPattern(tree, what)
00417
00418 else:
00419 raise TypeError("'what' must be string or integer")
00420
00421
00422
00423
00424 def _findTokenType(self, t, ttype):
00425
00426 nodes = []
00427
00428 def visitor(tree, parent, childIndex, labels):
00429 nodes.append(tree)
00430
00431 self.visit(t, ttype, visitor)
00432
00433 return nodes
00434
00435
00436
00437
00438 def _findPattern(self, t, pattern):
00439
00440 subtrees = []
00441
00442
00443 tokenizer = TreePatternLexer(pattern)
00444 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00445 tpattern = parser.pattern()
00446
00447
00448 if (tpattern is None or tpattern.isNil()
00449 or isinstance(tpattern, WildcardTreePattern)):
00450 return None
00451
00452 rootTokenType = tpattern.getType()
00453
00454 def visitor(tree, parent, childIndex, label):
00455 if self._parse(tree, tpattern, None):
00456 subtrees.append(tree)
00457
00458 self.visit(t, rootTokenType, visitor)
00459
00460 return subtrees
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478 def visit(self, tree, what, visitor):
00479
00480 if isinstance(what, (int, long)):
00481 self._visitType(tree, None, 0, what, visitor)
00482
00483 elif isinstance(what, basestring):
00484 self._visitPattern(tree, what, visitor)
00485
00486 else:
00487 raise TypeError("'what' must be string or integer")
00488
00489
00490
00491
00492 def _visitType(self, t, parent, childIndex, ttype, visitor):
00493
00494 if t is None:
00495 return
00496
00497 if self.adaptor.getType(t) == ttype:
00498 visitor(t, parent, childIndex, None)
00499
00500 for i in range(self.adaptor.getChildCount(t)):
00501 child = self.adaptor.getChild(t, i)
00502 self._visitType(child, t, i, ttype, visitor)
00503
00504
00505
00506
00507
00508
00509 def _visitPattern(self, tree, pattern, visitor):
00510
00511
00512 tokenizer = TreePatternLexer(pattern)
00513 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00514 tpattern = parser.pattern()
00515
00516
00517 if (tpattern is None or tpattern.isNil()
00518 or isinstance(tpattern, WildcardTreePattern)):
00519 return
00520
00521 rootTokenType = tpattern.getType()
00522
00523 def rootvisitor(tree, parent, childIndex, labels):
00524 labels = {}
00525 if self._parse(tree, tpattern, labels):
00526 visitor(tree, parent, childIndex, labels)
00527
00528 self.visit(tree, rootTokenType, rootvisitor)
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542 def parse(self, t, pattern, labels=None):
00543
00544 tokenizer = TreePatternLexer(pattern)
00545 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00546 tpattern = parser.pattern()
00547
00548 return self._parse(t, tpattern, labels)
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558 def _parse(self, t1, tpattern, labels):
00559
00560
00561 if t1 is None or tpattern is None:
00562 return False
00563
00564
00565 if not isinstance(tpattern, WildcardTreePattern):
00566 if self.adaptor.getType(t1) != tpattern.getType():
00567 return False
00568
00569
00570 if (tpattern.hasTextArg
00571 and self.adaptor.getText(t1) != tpattern.getText()):
00572 return False
00573
00574 if tpattern.label is not None and labels is not None:
00575
00576 labels[tpattern.label] = t1
00577
00578
00579 n1 = self.adaptor.getChildCount(t1)
00580 n2 = tpattern.getChildCount()
00581 if n1 != n2:
00582 return False
00583
00584 for i in range(n1):
00585 child1 = self.adaptor.getChild(t1, i)
00586 child2 = tpattern.getChild(i)
00587 if not self._parse(child1, child2, labels):
00588 return False
00589
00590 return True
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600 def equals(self, t1, t2, adaptor=None):
00601
00602 if adaptor is None:
00603 adaptor = self.adaptor
00604
00605 return self._equals(t1, t2, adaptor)
00606
00607
00608 def _equals(self, t1, t2, adaptor):
00609
00610 if t1 is None or t2 is None:
00611 return False
00612
00613
00614 if adaptor.getType(t1) != adaptor.getType(t2):
00615 return False
00616
00617 if adaptor.getText(t1) != adaptor.getText(t2):
00618 return False
00619
00620
00621 n1 = adaptor.getChildCount(t1)
00622 n2 = adaptor.getChildCount(t2)
00623 if n1 != n2:
00624 return False
00625
00626 for i in range(n1):
00627 child1 = adaptor.getChild(t1, i)
00628 child2 = adaptor.getChild(t2, i)
00629 if not self._equals(child1, child2, adaptor):
00630 return False
00631
00632 return True
00633
00634