generate_pyi.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. #!/usr/bin/env python
  2. """
  3. This program will generate .pyi files for all the VTK modules
  4. in the "vtkmodules" package (or whichever package you specify).
  5. These files are used for type checking and autocompletion in
  6. some Python IDEs.
  7. The VTK modules must be in Python's path when you run this script.
  8. Options are as follows:
  9. -p PACKAGE The package to generate .pyi files for [vtkmodules]
  10. -o OUTPUT The output directory [default is the package directory]
  11. -e EXT The file suffix [.pyi]
  12. -i IMPORTER The static module importer (for static builds only)
  13. -h HELP
  14. With no arguments, the script runs with the defaults (the .pyi files
  15. are put inside the existing vtkmodules package). This is equivalent
  16. to the following:
  17. python -m vtkmodules.generate_pyi -p vtkmodules
  18. To put the pyi files somewhere else, perhaps with a different suffix:
  19. python -m vtkmodules.generate_pyi -o /path/to/vtkmodules -e .pyi
  20. To generate pyi files for just one or two modules:
  21. python -m vtkmodules.generate_pyi -p vtkmodules vtkCommonCore vtkCommonDataModel
  22. To generate pyi files for your own modules in your own package:
  23. python -m vtkmodules.generate_pyi -p mypackage mymodule [mymodule2 ...]
  24. """
  25. from vtkmodules.vtkCommonCore import vtkObject, vtkSOADataArrayTemplate
  26. import sys
  27. import os
  28. import re
  29. import ast
  30. import argparse
  31. import builtins
  32. import inspect
  33. import importlib.util
  34. # ==== For type inspection ====
  35. # list expected non-vtk type names
  36. types = set()
  37. for m,o in builtins.__dict__.items():
  38. if isinstance(o, type):
  39. types.add(m)
  40. for m in ['Any', 'Buffer', 'Callback', 'None', 'Pointer', 'Template', 'Union']:
  41. types.add(m)
  42. # basic type checking methods
  43. ismethod = inspect.isroutine
  44. isclass = inspect.isclass
  45. # VTK methods have a special type
  46. vtkmethod = type(vtkObject.IsA)
  47. template = type(vtkSOADataArrayTemplate)
  48. def isvtkmethod(m):
  49. """Check for VTK's custom method descriptor"""
  50. return (type(m) == vtkmethod)
  51. def isnamespace(m):
  52. """Check for namespaces within a module"""
  53. # until vtkmodules.vtkCommonCore.namespace is directly accessible
  54. return (str(type(m)) == "<class 'vtkmodules.vtkCommonCore.namespace'>")
  55. def isenum(m):
  56. """Check for enums (currently derived from int)"""
  57. return (isclass(m) and issubclass(m, int))
  58. def typename(o):
  59. """Generate a typename that can be used for annotation."""
  60. if o is None:
  61. return "None"
  62. elif type(o) == template:
  63. return "Template"
  64. else:
  65. return type(o).__name__
  66. def typename_forward(o):
  67. """Generate a typename, or if necessary, a forward reference."""
  68. name = typename(o)
  69. if name not in types:
  70. # do forward reference by adding quotes
  71. name = '\'' + name + '\''
  72. return name
  73. # ==== For the topological sort ====
  74. class Graph:
  75. """A graph for topological sorting."""
  76. def __init__(self):
  77. self.nodes = {}
  78. def __getitem__(self, name):
  79. return self.nodes[name]
  80. def __setitem__(self, name, node):
  81. self.nodes[name] = node
  82. class Node:
  83. """A node for the graph."""
  84. def __init__(self, o, d):
  85. self.obj = o
  86. self.deps = d
  87. def build_graph(d):
  88. """Build a graph from a module's dictionary."""
  89. graph = Graph()
  90. items = sorted(d.items())
  91. for m,o in items:
  92. if isclass(o):
  93. if m == o.__name__:
  94. # a class definition
  95. bases = [b.__name__ for b in o.__bases__]
  96. graph[m] = Node(o, bases)
  97. else:
  98. # a class alias
  99. graph[m] = Node(o, [o.__name__])
  100. elif ismethod(o):
  101. graph[m] = Node(o, [])
  102. else:
  103. graph[m] = Node(o, [typename(o)])
  104. return graph
  105. def sorted_graph_helper(graph, m, visited, items):
  106. """Helper for topological sorting."""
  107. visited.add(m)
  108. try:
  109. node = graph[m]
  110. except KeyError:
  111. return
  112. for dep in node.deps:
  113. if dep not in visited:
  114. sorted_graph_helper(graph, dep, visited, items)
  115. items.append((m, node.obj))
  116. def sorted_graph(graph):
  117. """Sort a graph and return the sorted items."""
  118. items = []
  119. visited = set()
  120. for m in graph.nodes:
  121. if m not in visited:
  122. sorted_graph_helper(graph, m, visited, items)
  123. return items
  124. def topologically_sorted_items(d):
  125. """Return the items from a module's dictionary, topologically sorted."""
  126. return sorted_graph(build_graph(d))
  127. # ==== For parsing docstrings ====
  128. # regular expressions for parsing
  129. string = re.compile(r"""("([^\\"]|\\.)*"|'([^\\']|\\.)*')""")
  130. identifier = re.compile(r"""([A-Za-z_]([A-Za-z0-9_]|[.][A-Za-z_])*)""")
  131. indent = re.compile(r"[ \t]+(?=\S)")
  132. has_self = re.compile(r"[(]self[,)]")
  133. # important characters for rapidly parsing code
  134. keychar = re.compile(r"[\'\"{}\[\]()\n]")
  135. def parse_error(message, text, begin, pos):
  136. """Print a parse error, syntax or otherwise.
  137. """
  138. end = text.find('\n', pos)
  139. if end == -1:
  140. end = len(text)
  141. sys.stderr.write("Error: " + message + ":\n")
  142. sys.stderr.write(text[begin:end] + "\n");
  143. sys.stderr.write('-' * (pos - begin) + "^\n")
  144. def annotation_text(a, text, is_return):
  145. """Return the new text to be used for an annotation.
  146. """
  147. if isinstance(a, ast.Name):
  148. name = a.id
  149. if name not in types:
  150. # quote the type, in case it isn't yet defined
  151. text = '\'' + name + '\''
  152. elif isinstance(a, (ast.Tuple, ast.List)):
  153. size = len(a.elts)
  154. e = a.elts[0]
  155. offset = a.col_offset
  156. old_name = text[e.col_offset - offset:e.end_col_offset - offset]
  157. name = annotation_text(e, old_name, is_return)
  158. if is_return:
  159. # use concrete types for return values
  160. if isinstance(a, ast.Tuple):
  161. text = 'Tuple[' + ', '.join([name]*size) + ']'
  162. else:
  163. text = 'List[' + name + ']'
  164. else:
  165. # use generic sequence types for args
  166. if isinstance(a, ast.Tuple):
  167. text = 'Sequence[' + name + ']'
  168. else:
  169. text = 'MutableSequence[' + name + ']'
  170. return text
  171. def fix_annotations(signature):
  172. """Fix the annotations in a method definition.
  173. The signature must be a single-line function def, no decorators.
  174. """
  175. # get the FunctionDef object from the parse tree
  176. definition = ast.parse(signature).body[0]
  177. annotations = [arg.annotation for arg in definition.args.args]
  178. return_i = len(annotations) # index of annotation for return
  179. annotations.append(definition.returns)
  180. # create a list of changes to apply to the annotations
  181. changes = []
  182. for i,a in enumerate(annotations):
  183. if a is not None:
  184. old_text = signature[a.col_offset:a.end_col_offset]
  185. text = annotation_text(a, old_text, (i == return_i))
  186. if text != old_text:
  187. changes.append((a.col_offset, a.end_col_offset, text))
  188. # apply changes to generate a new signature
  189. if changes:
  190. newsig = ""
  191. lastpos = 0
  192. for begin,end,text in changes:
  193. newsig += signature[lastpos:begin]
  194. newsig += text
  195. lastpos = end
  196. newsig += signature[lastpos:]
  197. signature = newsig
  198. return signature
  199. def push_signature(o, l, signature):
  200. """Process a method signature and add it to the list.
  201. """
  202. # eliminate newlines and indents
  203. signature = re.sub(r"\s+", " ", signature)
  204. # no space after opening delimiter or ':' or '='
  205. signature = re.sub(r"([({\[:=]) ", "\\1", signature)
  206. if signature.startswith('C++:'):
  207. # the C++ method signatures are unused
  208. pass
  209. elif signature.startswith(o.__name__ + "("):
  210. # make it into a python method definition
  211. signature = "def " + signature + ': ...'
  212. if sys.hexversion >= 0x3080000:
  213. # XXX(Python 3.8) uses ast features from 3.8
  214. signature = fix_annotations(signature)
  215. if signature not in l:
  216. l.append(signature)
  217. def get_signatures(o):
  218. """Return a list of method signatures found in the docstring.
  219. """
  220. doc = o.__doc__
  221. signatures = [] # output method signatures
  222. if doc is None:
  223. return signatures
  224. # variables used for parsing the docstrings
  225. begin = 0 # beginning of current signature
  226. pos = 0 # current position in docstring
  227. delim_stack = [] # keep track of bracket depth
  228. # loop through docstring using longest strides possible
  229. # (this will go line-by-line or until first ( ) { } [ ] " ')
  230. while pos < len(doc):
  231. # look for the next "character of interest" in docstring
  232. match = keychar.search(doc, pos)
  233. # did we find a match before the end of docstring?
  234. if match:
  235. # get new position
  236. pos,end = match.span()
  237. # take different action, depending on char
  238. c = match.group()
  239. if c in '\"\'':
  240. # skip over a string literal
  241. m = string.match(doc, pos)
  242. if m:
  243. pos,end = m.span()
  244. else:
  245. parse_error("Unterminated string", doc, begin, pos)
  246. break
  247. elif c in '{[(':
  248. # descend into a bracketed expression (push stack)
  249. delim_stack.append({'{':'}','[':']','(':')'}[c])
  250. elif c in '}])':
  251. # ascend out of a bracketed expression (pop stack)
  252. if not delim_stack or c != delim_stack.pop():
  253. parse_error("Unmatched bracket", doc, begin, pos)
  254. break
  255. elif c == '\n' and not (delim_stack or indent.match(doc, end)):
  256. # a newline not followed by an indent marks end of signature,
  257. # except for within brackets
  258. signature = doc[begin:pos].strip()
  259. if signature:
  260. push_signature(o, signatures, signature)
  261. begin = end
  262. else:
  263. # blank line means no more signatures in docstring
  264. break
  265. else:
  266. # reached the end of the docstring
  267. end = len(doc)
  268. if not delim_stack:
  269. signature = doc[begin:pos].strip()
  270. if signature:
  271. push_signature(o, signatures, signature)
  272. else:
  273. parse_error("Unmatched bracket", doc, begin, pos)
  274. break
  275. # advance position within docstring and return to head of loop
  276. pos = end
  277. return signatures
  278. def get_constructors(c):
  279. """Get constructors from the class documentation.
  280. """
  281. constructors = []
  282. name = c.__name__
  283. doc = c.__doc__
  284. if not doc or not doc.startswith(name + "("):
  285. return constructors
  286. signatures = get_signatures(c)
  287. for signature in signatures:
  288. if signature.startswith("def " + name + "("):
  289. signature = re.sub("-> \'?" + name + "\'?", "-> None", signature)
  290. if signature.startswith("def " + name + "()"):
  291. constructors.append(re.sub(name + r"\(", "__init__(self", signature, 1))
  292. else:
  293. constructors.append(re.sub(name + r"\(", "__init__(self, ", signature, 1))
  294. return constructors
  295. def handle_static(o, signature):
  296. """If method has no "self", add @static decorator."""
  297. if isvtkmethod(o) and not has_self.search(signature):
  298. return "@staticmethod\n" + signature
  299. else:
  300. return signature
  301. def add_indent(s, indent):
  302. """Add the given indent before every line in the string.
  303. """
  304. return indent + re.sub(r"\n(?=([^\n]))", "\n" + indent, s)
  305. def namespace_pyi(c, mod):
  306. """Fake a namespace by creating a dummy class.
  307. """
  308. base = "namespace"
  309. if mod.__name__ != 'vtkmodules.vtkCommonCore':
  310. base = 'vtkmodules.vtkCommonCore.' + base
  311. out = "class " + c.__name__ + "(" + base + "):\n"
  312. count = 0
  313. # do all nested classes (these will be enum types)
  314. items = topologically_sorted_items(c.__dict__)
  315. others = []
  316. for m,o in items:
  317. if isenum(o) and m == o.__name__:
  318. out += add_indent(class_pyi(o), " ")
  319. count += 1
  320. else:
  321. others.append((m, o))
  322. # do all constants
  323. items = others
  324. others = []
  325. for m,o in items:
  326. if not m.startswith("__") and not ismethod(o) and not isclass(o):
  327. out += " " + m + ":" + typename_forward(o) + "\n"
  328. count += 1
  329. else:
  330. others.append((m,o))
  331. if count == 0:
  332. out = out[0:-1] + " ...\n"
  333. return out
  334. def class_pyi(c):
  335. """Generate all the method stubs for a class.
  336. """
  337. bases = []
  338. for b in c.__bases__:
  339. if b.__module__ in (c.__module__, 'builtins'):
  340. bases.append(b.__name__)
  341. else:
  342. bases.append(b.__module__ + "." + b.__name__)
  343. out = "class " + c.__name__ + "(" + ", ".join(bases) + "):\n"
  344. count = 0
  345. # do all nested classes (these are usually enum types)
  346. items = topologically_sorted_items(c.__dict__)
  347. others = []
  348. for m,o in items:
  349. if isclass(o) and m == o.__name__:
  350. out += add_indent(class_pyi(o), " ")
  351. count += 1
  352. else:
  353. others.append((m, o))
  354. # do all constants
  355. items = others
  356. others = []
  357. for m,o in items:
  358. if not m.startswith("__") and not ismethod(o) and not isclass(o):
  359. out += " " + m + ":" + typename_forward(o) + "\n"
  360. count += 1
  361. else:
  362. others.append((m,o))
  363. # do the __init__ methods
  364. constructors = get_constructors(c)
  365. if len(constructors) == 0:
  366. #if hasattr(c, "__init__") and not issubclass(c, int):
  367. # out += " def __init__() -> None: ...\n"
  368. # count += 1
  369. pass
  370. else:
  371. count += 1
  372. if len(constructors) == 1:
  373. out += add_indent(constructors[0], " ") + "\n"
  374. else:
  375. for overload in constructors:
  376. out += add_indent("@overload\n" + overload, " ") + "\n"
  377. # do the methods
  378. items = others
  379. others = []
  380. for m,o in items:
  381. if ismethod(o):
  382. signatures = get_signatures(o)
  383. if len(signatures) == 0:
  384. continue
  385. count += 1
  386. if len(signatures) == 1:
  387. signature = handle_static(o, signatures[0])
  388. out += add_indent(signature, " ") + "\n"
  389. continue
  390. for overload in signatures:
  391. signature = handle_static(o, overload)
  392. out += add_indent("@overload\n" + signature, " ") + "\n"
  393. else:
  394. others.append((m, o))
  395. if count == 0:
  396. out = out[0:-1] + " ...\n"
  397. return out
  398. def module_pyi(mod, output):
  399. """Generate the contents of a .pyi file for a VTK module.
  400. """
  401. # needed stuff from typing module
  402. output.write("from typing import overload, Any, Callable, TypeVar, Union\n")
  403. output.write("from typing import Tuple, List, Sequence, MutableSequence\n")
  404. output.write("\n")
  405. output.write("Callback = Union[Callable[..., None], None]\n")
  406. output.write("Buffer = TypeVar('Buffer')\n")
  407. output.write("Pointer = TypeVar('Pointer')\n")
  408. output.write("Template = TypeVar('Template')\n")
  409. output.write("\n")
  410. if mod.__name__ == 'vtkmodules.vtkCommonCore':
  411. # dummy superclass for namespaces
  412. output.write("class namespace: pass\n")
  413. output.write("\n")
  414. # all the modules this module depends on
  415. depends = set(['vtkmodules.vtkCommonCore'])
  416. for m,o in mod.__dict__.items():
  417. if isclass(o) and m == o.__name__:
  418. for base in o.__bases__:
  419. depends.add(base.__module__)
  420. depends.discard(mod.__name__)
  421. depends.discard("builtins")
  422. for depend in sorted(depends):
  423. output.write("import " + depend + "\n")
  424. if depends:
  425. output.write("\n")
  426. # sort the dict according to dependency
  427. items = topologically_sorted_items(mod.__dict__)
  428. # do all namespaces
  429. others = []
  430. for m,o in items:
  431. if isnamespace(o) and m == o.__name__:
  432. output.write(namespace_pyi(o, mod))
  433. output.write("\n")
  434. else:
  435. others.append((m, o))
  436. # do all enum types
  437. items = others
  438. others = []
  439. for m,o in items:
  440. if isenum(o) and m == o.__name__:
  441. output.write(class_pyi(o))
  442. output.write("\n")
  443. else:
  444. others.append((m, o))
  445. # do all enum aliases
  446. items = others
  447. others = []
  448. for m,o in items:
  449. if isenum(o) and m != o.__name__:
  450. output.write(m + " = " + o.__name__ + "\n")
  451. else:
  452. others.append((m, o))
  453. # do all constants
  454. items = others
  455. others = []
  456. for m,o in items:
  457. if not m.startswith("__") and not ismethod(o) and not isclass(o):
  458. output.write(m + ":" + typename_forward(o) + "\n")
  459. else:
  460. others.append((m,o))
  461. if len(items) > len(others):
  462. output.write("\n")
  463. # do all classes
  464. items = others
  465. others = []
  466. for m,o in items:
  467. if isclass(o) and m == o.__name__:
  468. output.write(class_pyi(o))
  469. output.write("\n")
  470. else:
  471. others.append((m, o))
  472. # do all class aliases
  473. items = others
  474. others = []
  475. for m,o in items:
  476. if isclass(o) and m != o.__name__:
  477. output.write(m + " = " + o.__name__ + "\n")
  478. else:
  479. others.append((m, o))
  480. def main(argv=sys.argv):
  481. # for error messages etcetera
  482. progname = os.path.basename(argv[0])
  483. # parse the program arguments
  484. parser = argparse.ArgumentParser(
  485. prog=argv[0],
  486. usage="python " + progname + " [-p package] [-o output_dir]",
  487. description="A .pyi generator for the VTK python wrappers.")
  488. parser.add_argument('-p', '--package', type=str, default="vtkmodules",
  489. help="Package name [vtkmodules].")
  490. parser.add_argument('-i', '--importer', type=str,
  491. help="Static module importer [].")
  492. parser.add_argument('-o', '--output', type=str,
  493. help="Output directory [package directory].")
  494. parser.add_argument('-e', '--ext', type=str, default=".pyi",
  495. help="Output file suffix [.pyi].")
  496. parser.add_argument('--test', action='count', default=0,
  497. help="Test .pyi files instead of creating them.")
  498. parser.add_argument('modules', type=str, nargs='*',
  499. help="Modules to process [all].")
  500. args = parser.parse_args(argv[1:])
  501. # for convenience
  502. packagename = args.package
  503. modules = args.modules
  504. basedir = args.output
  505. ext = args.ext
  506. # if static module importer is needed, it must be handled first
  507. if args.importer:
  508. if len(modules) == 0:
  509. sys.stderr.write(progname + ": when using '-i', all modules " +
  510. "in the package must be listed on the command line.\n")
  511. return 1
  512. # check that the modules aren't already present as builtins
  513. # (we replace '.' separators with underscores for static importers)
  514. module_exemplar = (packagename + '.' + modules[0]).replace('.', '_')
  515. if module_exemplar not in sys.builtin_module_names:
  516. importlib.import_module(args.importer)
  517. # get information about the package
  518. if basedir is None or len(modules) == 0:
  519. mod = importlib.import_module(packagename)
  520. if basedir is None:
  521. filename = getattr(mod, '__file__', None)
  522. if filename is None or os.path.basename(filename) != '__init__.py':
  523. sys.stderr.write(progname + ": " + packagename + " has no __init__.py\n")
  524. return 1
  525. basedir = os.path.dirname(filename)
  526. if len(modules) == 0:
  527. for modname in mod.__all__:
  528. # only generate .pyi files for the extension modules in __all__
  529. try:
  530. spec = importlib.util.find_spec(packagename + "." + modname)
  531. except ValueError:
  532. spec = None
  533. if not errflag:
  534. errflag = True
  535. sys.stderr.write(progname + ": couldn't get loader for " + modname + "\n")
  536. if spec is None:
  537. continue
  538. if not isinstance(spec.loader, importlib.machinery.ExtensionFileLoader):
  539. continue
  540. # the module is definitely an extension module
  541. modules.append(modname)
  542. # iterate through the modules in the package
  543. errflag = False
  544. for modname in modules:
  545. pyifile = os.path.join(basedir, modname + ext)
  546. if args.test:
  547. # test the syntax of the .pyi file
  548. flags = ast.PyCF_TYPE_COMMENTS if sys.hexversion >= 0x3080000 else 0
  549. with open(pyifile, 'r') as f:
  550. compile(f.read(), pyifile, 'exec', flags)
  551. else:
  552. # generate the .pyi file for the module
  553. mod = importlib.import_module(packagename + "." + modname)
  554. with open(pyifile, "w") as f:
  555. module_pyi(mod, f)
  556. if __name__ == '__main__':
  557. result = main(sys.argv)
  558. if result is not None:
  559. sys.exit(result)