codec.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from .core import encode, decode, alabel, ulabel, IDNAError
  2. import codecs
  3. import re
  4. from typing import Tuple, Optional
  5. _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
  6. class Codec(codecs.Codec):
  7. def encode(self, data, errors='strict'):
  8. # type: (str, str) -> Tuple[bytes, int]
  9. if errors != 'strict':
  10. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  11. if not data:
  12. return b"", 0
  13. return encode(data), len(data)
  14. def decode(self, data, errors='strict'):
  15. # type: (bytes, str) -> Tuple[str, int]
  16. if errors != 'strict':
  17. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  18. if not data:
  19. return '', 0
  20. return decode(data), len(data)
  21. class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
  22. def _buffer_encode(self, data, errors, final): # type: ignore
  23. # type: (str, str, bool) -> Tuple[str, int]
  24. if errors != 'strict':
  25. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  26. if not data:
  27. return "", 0
  28. labels = _unicode_dots_re.split(data)
  29. trailing_dot = ''
  30. if labels:
  31. if not labels[-1]:
  32. trailing_dot = '.'
  33. del labels[-1]
  34. elif not final:
  35. # Keep potentially unfinished label until the next call
  36. del labels[-1]
  37. if labels:
  38. trailing_dot = '.'
  39. result = []
  40. size = 0
  41. for label in labels:
  42. result.append(alabel(label))
  43. if size:
  44. size += 1
  45. size += len(label)
  46. # Join with U+002E
  47. result_str = '.'.join(result) + trailing_dot # type: ignore
  48. size += len(trailing_dot)
  49. return result_str, size
  50. class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
  51. def _buffer_decode(self, data, errors, final): # type: ignore
  52. # type: (str, str, bool) -> Tuple[str, int]
  53. if errors != 'strict':
  54. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  55. if not data:
  56. return ('', 0)
  57. labels = _unicode_dots_re.split(data)
  58. trailing_dot = ''
  59. if labels:
  60. if not labels[-1]:
  61. trailing_dot = '.'
  62. del labels[-1]
  63. elif not final:
  64. # Keep potentially unfinished label until the next call
  65. del labels[-1]
  66. if labels:
  67. trailing_dot = '.'
  68. result = []
  69. size = 0
  70. for label in labels:
  71. result.append(ulabel(label))
  72. if size:
  73. size += 1
  74. size += len(label)
  75. result_str = '.'.join(result) + trailing_dot
  76. size += len(trailing_dot)
  77. return (result_str, size)
  78. class StreamWriter(Codec, codecs.StreamWriter):
  79. pass
  80. class StreamReader(Codec, codecs.StreamReader):
  81. pass
  82. def getregentry():
  83. # type: () -> codecs.CodecInfo
  84. # Compatibility as a search_function for codecs.register()
  85. return codecs.CodecInfo(
  86. name='idna',
  87. encode=Codec().encode, # type: ignore
  88. decode=Codec().decode, # type: ignore
  89. incrementalencoder=IncrementalEncoder,
  90. incrementaldecoder=IncrementalDecoder,
  91. streamwriter=StreamWriter,
  92. streamreader=StreamReader,
  93. )