etree.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. """Shim module exporting the same ElementTree API for lxml and
  2. xml.etree backends.
  3. When lxml is installed, it is automatically preferred over the built-in
  4. xml.etree module.
  5. On Python 2.7, the cElementTree module is preferred over the pure-python
  6. ElementTree module.
  7. Besides exporting a unified interface, this also defines extra functions
  8. or subclasses built-in ElementTree classes to add features that are
  9. only availble in lxml, like OrderedDict for attributes, pretty_print and
  10. iterwalk.
  11. """
  12. from fontTools.misc.textTools import tostr
  13. XML_DECLARATION = """<?xml version='1.0' encoding='%s'?>"""
  14. __all__ = [
  15. # public symbols
  16. "Comment",
  17. "dump",
  18. "Element",
  19. "ElementTree",
  20. "fromstring",
  21. "fromstringlist",
  22. "iselement",
  23. "iterparse",
  24. "parse",
  25. "ParseError",
  26. "PI",
  27. "ProcessingInstruction",
  28. "QName",
  29. "SubElement",
  30. "tostring",
  31. "tostringlist",
  32. "TreeBuilder",
  33. "XML",
  34. "XMLParser",
  35. "register_namespace",
  36. ]
  37. try:
  38. from lxml.etree import *
  39. _have_lxml = True
  40. except ImportError:
  41. try:
  42. from xml.etree.cElementTree import *
  43. # the cElementTree version of XML function doesn't support
  44. # the optional 'parser' keyword argument
  45. from xml.etree.ElementTree import XML
  46. except ImportError: # pragma: no cover
  47. from xml.etree.ElementTree import *
  48. _have_lxml = False
  49. import sys
  50. # dict is always ordered in python >= 3.6 and on pypy
  51. PY36 = sys.version_info >= (3, 6)
  52. try:
  53. import __pypy__
  54. except ImportError:
  55. __pypy__ = None
  56. _dict_is_ordered = bool(PY36 or __pypy__)
  57. del PY36, __pypy__
  58. if _dict_is_ordered:
  59. _Attrib = dict
  60. else:
  61. from collections import OrderedDict as _Attrib
  62. if isinstance(Element, type):
  63. _Element = Element
  64. else:
  65. # in py27, cElementTree.Element cannot be subclassed, so
  66. # we need to import the pure-python class
  67. from xml.etree.ElementTree import Element as _Element
  68. class Element(_Element):
  69. """Element subclass that keeps the order of attributes."""
  70. def __init__(self, tag, attrib=_Attrib(), **extra):
  71. super(Element, self).__init__(tag)
  72. self.attrib = _Attrib()
  73. if attrib:
  74. self.attrib.update(attrib)
  75. if extra:
  76. self.attrib.update(extra)
  77. def SubElement(parent, tag, attrib=_Attrib(), **extra):
  78. """Must override SubElement as well otherwise _elementtree.SubElement
  79. fails if 'parent' is a subclass of Element object.
  80. """
  81. element = parent.__class__(tag, attrib, **extra)
  82. parent.append(element)
  83. return element
  84. def _iterwalk(element, events, tag):
  85. include = tag is None or element.tag == tag
  86. if include and "start" in events:
  87. yield ("start", element)
  88. for e in element:
  89. for item in _iterwalk(e, events, tag):
  90. yield item
  91. if include:
  92. yield ("end", element)
  93. def iterwalk(element_or_tree, events=("end",), tag=None):
  94. """A tree walker that generates events from an existing tree as
  95. if it was parsing XML data with iterparse().
  96. Drop-in replacement for lxml.etree.iterwalk.
  97. """
  98. if iselement(element_or_tree):
  99. element = element_or_tree
  100. else:
  101. element = element_or_tree.getroot()
  102. if tag == "*":
  103. tag = None
  104. for item in _iterwalk(element, events, tag):
  105. yield item
  106. _ElementTree = ElementTree
  107. class ElementTree(_ElementTree):
  108. """ElementTree subclass that adds 'pretty_print' and 'doctype'
  109. arguments to the 'write' method.
  110. Currently these are only supported for the default XML serialization
  111. 'method', and not also for "html" or "text", for these are delegated
  112. to the base class.
  113. """
  114. def write(
  115. self,
  116. file_or_filename,
  117. encoding=None,
  118. xml_declaration=False,
  119. method=None,
  120. doctype=None,
  121. pretty_print=False,
  122. ):
  123. if method and method != "xml":
  124. # delegate to super-class
  125. super(ElementTree, self).write(
  126. file_or_filename,
  127. encoding=encoding,
  128. xml_declaration=xml_declaration,
  129. method=method,
  130. )
  131. return
  132. if encoding is not None and encoding.lower() == "unicode":
  133. if xml_declaration:
  134. raise ValueError(
  135. "Serialisation to unicode must not request an XML declaration"
  136. )
  137. write_declaration = False
  138. encoding = "unicode"
  139. elif xml_declaration is None:
  140. # by default, write an XML declaration only for non-standard encodings
  141. write_declaration = encoding is not None and encoding.upper() not in (
  142. "ASCII",
  143. "UTF-8",
  144. "UTF8",
  145. "US-ASCII",
  146. )
  147. else:
  148. write_declaration = xml_declaration
  149. if encoding is None:
  150. encoding = "ASCII"
  151. if pretty_print:
  152. # NOTE this will modify the tree in-place
  153. _indent(self._root)
  154. with _get_writer(file_or_filename, encoding) as write:
  155. if write_declaration:
  156. write(XML_DECLARATION % encoding.upper())
  157. if pretty_print:
  158. write("\n")
  159. if doctype:
  160. write(_tounicode(doctype))
  161. if pretty_print:
  162. write("\n")
  163. qnames, namespaces = _namespaces(self._root)
  164. _serialize_xml(write, self._root, qnames, namespaces)
  165. import io
  166. def tostring(
  167. element,
  168. encoding=None,
  169. xml_declaration=None,
  170. method=None,
  171. doctype=None,
  172. pretty_print=False,
  173. ):
  174. """Custom 'tostring' function that uses our ElementTree subclass, with
  175. pretty_print support.
  176. """
  177. stream = io.StringIO() if encoding == "unicode" else io.BytesIO()
  178. ElementTree(element).write(
  179. stream,
  180. encoding=encoding,
  181. xml_declaration=xml_declaration,
  182. method=method,
  183. doctype=doctype,
  184. pretty_print=pretty_print,
  185. )
  186. return stream.getvalue()
  187. # serialization support
  188. import re
  189. # Valid XML strings can include any Unicode character, excluding control
  190. # characters, the surrogate blocks, FFFE, and FFFF:
  191. # Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
  192. # Here we reversed the pattern to match only the invalid characters.
  193. # For the 'narrow' python builds supporting only UCS-2, which represent
  194. # characters beyond BMP as UTF-16 surrogate pairs, we need to pass through
  195. # the surrogate block. I haven't found a more elegant solution...
  196. UCS2 = sys.maxunicode < 0x10FFFF
  197. if UCS2:
  198. _invalid_xml_string = re.compile(
  199. "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uFFFE-\uFFFF]"
  200. )
  201. else:
  202. _invalid_xml_string = re.compile(
  203. "[\u0000-\u0008\u000B-\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE-\uFFFF]"
  204. )
  205. def _tounicode(s):
  206. """Test if a string is valid user input and decode it to unicode string
  207. using ASCII encoding if it's a bytes string.
  208. Reject all bytes/unicode input that contains non-XML characters.
  209. Reject all bytes input that contains non-ASCII characters.
  210. """
  211. try:
  212. s = tostr(s, encoding="ascii", errors="strict")
  213. except UnicodeDecodeError:
  214. raise ValueError(
  215. "Bytes strings can only contain ASCII characters. "
  216. "Use unicode strings for non-ASCII characters."
  217. )
  218. except AttributeError:
  219. _raise_serialization_error(s)
  220. if s and _invalid_xml_string.search(s):
  221. raise ValueError(
  222. "All strings must be XML compatible: Unicode or ASCII, "
  223. "no NULL bytes or control characters"
  224. )
  225. return s
  226. import contextlib
  227. @contextlib.contextmanager
  228. def _get_writer(file_or_filename, encoding):
  229. # returns text write method and release all resources after using
  230. try:
  231. write = file_or_filename.write
  232. except AttributeError:
  233. # file_or_filename is a file name
  234. f = open(
  235. file_or_filename,
  236. "w",
  237. encoding="utf-8" if encoding == "unicode" else encoding,
  238. errors="xmlcharrefreplace",
  239. )
  240. with f:
  241. yield f.write
  242. else:
  243. # file_or_filename is a file-like object
  244. # encoding determines if it is a text or binary writer
  245. if encoding == "unicode":
  246. # use a text writer as is
  247. yield write
  248. else:
  249. # wrap a binary writer with TextIOWrapper
  250. detach_buffer = False
  251. if isinstance(file_or_filename, io.BufferedIOBase):
  252. buf = file_or_filename
  253. elif isinstance(file_or_filename, io.RawIOBase):
  254. buf = io.BufferedWriter(file_or_filename)
  255. detach_buffer = True
  256. else:
  257. # This is to handle passed objects that aren't in the
  258. # IOBase hierarchy, but just have a write method
  259. buf = io.BufferedIOBase()
  260. buf.writable = lambda: True
  261. buf.write = write
  262. try:
  263. # TextIOWrapper uses this methods to determine
  264. # if BOM (for UTF-16, etc) should be added
  265. buf.seekable = file_or_filename.seekable
  266. buf.tell = file_or_filename.tell
  267. except AttributeError:
  268. pass
  269. wrapper = io.TextIOWrapper(
  270. buf,
  271. encoding=encoding,
  272. errors="xmlcharrefreplace",
  273. newline="\n",
  274. )
  275. try:
  276. yield wrapper.write
  277. finally:
  278. # Keep the original file open when the TextIOWrapper and
  279. # the BufferedWriter are destroyed
  280. wrapper.detach()
  281. if detach_buffer:
  282. buf.detach()
  283. from xml.etree.ElementTree import _namespace_map
  284. def _namespaces(elem):
  285. # identify namespaces used in this tree
  286. # maps qnames to *encoded* prefix:local names
  287. qnames = {None: None}
  288. # maps uri:s to prefixes
  289. namespaces = {}
  290. def add_qname(qname):
  291. # calculate serialized qname representation
  292. try:
  293. qname = _tounicode(qname)
  294. if qname[:1] == "{":
  295. uri, tag = qname[1:].rsplit("}", 1)
  296. prefix = namespaces.get(uri)
  297. if prefix is None:
  298. prefix = _namespace_map.get(uri)
  299. if prefix is None:
  300. prefix = "ns%d" % len(namespaces)
  301. else:
  302. prefix = _tounicode(prefix)
  303. if prefix != "xml":
  304. namespaces[uri] = prefix
  305. if prefix:
  306. qnames[qname] = "%s:%s" % (prefix, tag)
  307. else:
  308. qnames[qname] = tag # default element
  309. else:
  310. qnames[qname] = qname
  311. except TypeError:
  312. _raise_serialization_error(qname)
  313. # populate qname and namespaces table
  314. for elem in elem.iter():
  315. tag = elem.tag
  316. if isinstance(tag, QName):
  317. if tag.text not in qnames:
  318. add_qname(tag.text)
  319. elif isinstance(tag, str):
  320. if tag not in qnames:
  321. add_qname(tag)
  322. elif tag is not None and tag is not Comment and tag is not PI:
  323. _raise_serialization_error(tag)
  324. for key, value in elem.items():
  325. if isinstance(key, QName):
  326. key = key.text
  327. if key not in qnames:
  328. add_qname(key)
  329. if isinstance(value, QName) and value.text not in qnames:
  330. add_qname(value.text)
  331. text = elem.text
  332. if isinstance(text, QName) and text.text not in qnames:
  333. add_qname(text.text)
  334. return qnames, namespaces
  335. def _serialize_xml(write, elem, qnames, namespaces, **kwargs):
  336. tag = elem.tag
  337. text = elem.text
  338. if tag is Comment:
  339. write("<!--%s-->" % _tounicode(text))
  340. elif tag is ProcessingInstruction:
  341. write("<?%s?>" % _tounicode(text))
  342. else:
  343. tag = qnames[_tounicode(tag) if tag is not None else None]
  344. if tag is None:
  345. if text:
  346. write(_escape_cdata(text))
  347. for e in elem:
  348. _serialize_xml(write, e, qnames, None)
  349. else:
  350. write("<" + tag)
  351. if namespaces:
  352. for uri, prefix in sorted(
  353. namespaces.items(), key=lambda x: x[1]
  354. ): # sort on prefix
  355. if prefix:
  356. prefix = ":" + prefix
  357. write(' xmlns%s="%s"' % (prefix, _escape_attrib(uri)))
  358. attrs = elem.attrib
  359. if attrs:
  360. # try to keep existing attrib order
  361. if len(attrs) <= 1 or type(attrs) is _Attrib:
  362. items = attrs.items()
  363. else:
  364. # if plain dict, use lexical order
  365. items = sorted(attrs.items())
  366. for k, v in items:
  367. if isinstance(k, QName):
  368. k = _tounicode(k.text)
  369. else:
  370. k = _tounicode(k)
  371. if isinstance(v, QName):
  372. v = qnames[_tounicode(v.text)]
  373. else:
  374. v = _escape_attrib(v)
  375. write(' %s="%s"' % (qnames[k], v))
  376. if text is not None or len(elem):
  377. write(">")
  378. if text:
  379. write(_escape_cdata(text))
  380. for e in elem:
  381. _serialize_xml(write, e, qnames, None)
  382. write("</" + tag + ">")
  383. else:
  384. write("/>")
  385. if elem.tail:
  386. write(_escape_cdata(elem.tail))
  387. def _raise_serialization_error(text):
  388. raise TypeError("cannot serialize %r (type %s)" % (text, type(text).__name__))
  389. def _escape_cdata(text):
  390. # escape character data
  391. try:
  392. text = _tounicode(text)
  393. # it's worth avoiding do-nothing calls for short strings
  394. if "&" in text:
  395. text = text.replace("&", "&amp;")
  396. if "<" in text:
  397. text = text.replace("<", "&lt;")
  398. if ">" in text:
  399. text = text.replace(">", "&gt;")
  400. return text
  401. except (TypeError, AttributeError):
  402. _raise_serialization_error(text)
  403. def _escape_attrib(text):
  404. # escape attribute value
  405. try:
  406. text = _tounicode(text)
  407. if "&" in text:
  408. text = text.replace("&", "&amp;")
  409. if "<" in text:
  410. text = text.replace("<", "&lt;")
  411. if ">" in text:
  412. text = text.replace(">", "&gt;")
  413. if '"' in text:
  414. text = text.replace('"', "&quot;")
  415. if "\n" in text:
  416. text = text.replace("\n", "&#10;")
  417. return text
  418. except (TypeError, AttributeError):
  419. _raise_serialization_error(text)
  420. def _indent(elem, level=0):
  421. # From http://effbot.org/zone/element-lib.htm#prettyprint
  422. i = "\n" + level * " "
  423. if len(elem):
  424. if not elem.text or not elem.text.strip():
  425. elem.text = i + " "
  426. if not elem.tail or not elem.tail.strip():
  427. elem.tail = i
  428. for elem in elem:
  429. _indent(elem, level + 1)
  430. if not elem.tail or not elem.tail.strip():
  431. elem.tail = i
  432. else:
  433. if level and (not elem.tail or not elem.tail.strip()):
  434. elem.tail = i