_quoting_py.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import codecs
  2. import re
  3. from string import ascii_letters, ascii_lowercase, digits
  4. from typing import Optional, cast
  5. BASCII_LOWERCASE = ascii_lowercase.encode("ascii")
  6. BPCT_ALLOWED = {"%{:02X}".format(i).encode("ascii") for i in range(256)}
  7. GEN_DELIMS = ":/?#[]@"
  8. SUB_DELIMS_WITHOUT_QS = "!$'()*,"
  9. SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + "+&=;"
  10. RESERVED = GEN_DELIMS + SUB_DELIMS
  11. UNRESERVED = ascii_letters + digits + "-._~"
  12. ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS
  13. _IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]")
  14. _IS_HEX_STR = re.compile("[A-Fa-f0-9][A-Fa-f0-9]")
  15. utf8_decoder = codecs.getincrementaldecoder("utf-8")
  16. class _Quoter:
  17. def __init__(
  18. self,
  19. *,
  20. safe: str = "",
  21. protected: str = "",
  22. qs: bool = False,
  23. requote: bool = True
  24. ) -> None:
  25. self._safe = safe
  26. self._protected = protected
  27. self._qs = qs
  28. self._requote = requote
  29. def __call__(self, val: Optional[str]) -> Optional[str]:
  30. if val is None:
  31. return None
  32. if not isinstance(val, str):
  33. raise TypeError("Argument should be str")
  34. if not val:
  35. return ""
  36. bval = cast(str, val).encode("utf8", errors="ignore")
  37. ret = bytearray()
  38. pct = bytearray()
  39. safe = self._safe
  40. safe += ALLOWED
  41. if not self._qs:
  42. safe += "+&=;"
  43. safe += self._protected
  44. bsafe = safe.encode("ascii")
  45. idx = 0
  46. while idx < len(bval):
  47. ch = bval[idx]
  48. idx += 1
  49. if pct:
  50. if ch in BASCII_LOWERCASE:
  51. ch = ch - 32 # convert to uppercase
  52. pct.append(ch)
  53. if len(pct) == 3: # pragma: no branch # peephole optimizer
  54. buf = pct[1:]
  55. if not _IS_HEX.match(buf):
  56. ret.extend(b"%25")
  57. pct.clear()
  58. idx -= 2
  59. continue
  60. try:
  61. unquoted = chr(int(pct[1:].decode("ascii"), base=16))
  62. except ValueError:
  63. ret.extend(b"%25")
  64. pct.clear()
  65. idx -= 2
  66. continue
  67. if unquoted in self._protected:
  68. ret.extend(pct)
  69. elif unquoted in safe:
  70. ret.append(ord(unquoted))
  71. else:
  72. ret.extend(pct)
  73. pct.clear()
  74. # special case, if we have only one char after "%"
  75. elif len(pct) == 2 and idx == len(bval):
  76. ret.extend(b"%25")
  77. pct.clear()
  78. idx -= 1
  79. continue
  80. elif ch == ord("%") and self._requote:
  81. pct.clear()
  82. pct.append(ch)
  83. # special case if "%" is last char
  84. if idx == len(bval):
  85. ret.extend(b"%25")
  86. continue
  87. if self._qs:
  88. if ch == ord(" "):
  89. ret.append(ord("+"))
  90. continue
  91. if ch in bsafe:
  92. ret.append(ch)
  93. continue
  94. ret.extend(("%{:02X}".format(ch)).encode("ascii"))
  95. ret2 = ret.decode("ascii")
  96. if ret2 == val:
  97. return val
  98. return ret2
  99. class _Unquoter:
  100. def __init__(self, *, unsafe: str = "", qs: bool = False) -> None:
  101. self._unsafe = unsafe
  102. self._qs = qs
  103. self._quoter = _Quoter()
  104. self._qs_quoter = _Quoter(qs=True)
  105. def __call__(self, val: Optional[str]) -> Optional[str]:
  106. if val is None:
  107. return None
  108. if not isinstance(val, str):
  109. raise TypeError("Argument should be str")
  110. if not val:
  111. return ""
  112. decoder = cast(codecs.BufferedIncrementalDecoder, utf8_decoder())
  113. ret = []
  114. idx = 0
  115. while idx < len(val):
  116. ch = val[idx]
  117. idx += 1
  118. if ch == "%" and idx <= len(val) - 2:
  119. pct = val[idx : idx + 2]
  120. if _IS_HEX_STR.fullmatch(pct):
  121. b = bytes([int(pct, base=16)])
  122. idx += 2
  123. try:
  124. unquoted = decoder.decode(b)
  125. except UnicodeDecodeError:
  126. start_pct = idx - 3 - len(decoder.buffer) * 3
  127. ret.append(val[start_pct : idx - 3])
  128. decoder.reset()
  129. try:
  130. unquoted = decoder.decode(b)
  131. except UnicodeDecodeError:
  132. ret.append(val[idx - 3 : idx])
  133. continue
  134. if not unquoted:
  135. continue
  136. if self._qs and unquoted in "+=&;":
  137. to_add = self._qs_quoter(unquoted)
  138. if to_add is None: # pragma: no cover
  139. raise RuntimeError("Cannot quote None")
  140. ret.append(to_add)
  141. elif unquoted in self._unsafe:
  142. to_add = self._quoter(unquoted)
  143. if to_add is None: # pragma: no cover
  144. raise RuntimeError("Cannot quote None")
  145. ret.append(to_add)
  146. else:
  147. ret.append(unquoted)
  148. continue
  149. if decoder.buffer:
  150. start_pct = idx - 1 - len(decoder.buffer) * 3
  151. ret.append(val[start_pct : idx - 1])
  152. decoder.reset()
  153. if ch == "+":
  154. if not self._qs or ch in self._unsafe:
  155. ret.append("+")
  156. else:
  157. ret.append(" ")
  158. continue
  159. if ch in self._unsafe:
  160. ret.append("%")
  161. h = hex(ord(ch)).upper()[2:]
  162. for ch in h:
  163. ret.append(ch)
  164. continue
  165. ret.append(ch)
  166. if decoder.buffer:
  167. ret.append(val[-len(decoder.buffer) * 3 :])
  168. ret2 = "".join(ret)
  169. if ret2 == val:
  170. return val
  171. return ret2