sslproto.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. import collections
  2. import enum
  3. import warnings
  4. try:
  5. import ssl
  6. except ImportError: # pragma: no cover
  7. ssl = None
  8. from . import constants
  9. from . import exceptions
  10. from . import protocols
  11. from . import transports
  12. from .log import logger
  13. if ssl is not None:
  14. SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
  15. class SSLProtocolState(enum.Enum):
  16. UNWRAPPED = "UNWRAPPED"
  17. DO_HANDSHAKE = "DO_HANDSHAKE"
  18. WRAPPED = "WRAPPED"
  19. FLUSHING = "FLUSHING"
  20. SHUTDOWN = "SHUTDOWN"
  21. class AppProtocolState(enum.Enum):
  22. # This tracks the state of app protocol (https://git.io/fj59P):
  23. #
  24. # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
  25. #
  26. # * cm: connection_made()
  27. # * dr: data_received()
  28. # * er: eof_received()
  29. # * cl: connection_lost()
  30. STATE_INIT = "STATE_INIT"
  31. STATE_CON_MADE = "STATE_CON_MADE"
  32. STATE_EOF = "STATE_EOF"
  33. STATE_CON_LOST = "STATE_CON_LOST"
  34. def _create_transport_context(server_side, server_hostname):
  35. if server_side:
  36. raise ValueError('Server side SSL needs a valid SSLContext')
  37. # Client side may pass ssl=True to use a default
  38. # context; in that case the sslcontext passed is None.
  39. # The default is secure for client connections.
  40. # Python 3.4+: use up-to-date strong settings.
  41. sslcontext = ssl.create_default_context()
  42. if not server_hostname:
  43. sslcontext.check_hostname = False
  44. return sslcontext
  45. def add_flowcontrol_defaults(high, low, kb):
  46. if high is None:
  47. if low is None:
  48. hi = kb * 1024
  49. else:
  50. lo = low
  51. hi = 4 * lo
  52. else:
  53. hi = high
  54. if low is None:
  55. lo = hi // 4
  56. else:
  57. lo = low
  58. if not hi >= lo >= 0:
  59. raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
  60. (hi, lo))
  61. return hi, lo
  62. class _SSLProtocolTransport(transports._FlowControlMixin,
  63. transports.Transport):
  64. _start_tls_compatible = True
  65. _sendfile_compatible = constants._SendfileMode.FALLBACK
  66. def __init__(self, loop, ssl_protocol):
  67. self._loop = loop
  68. self._ssl_protocol = ssl_protocol
  69. self._closed = False
  70. def get_extra_info(self, name, default=None):
  71. """Get optional transport information."""
  72. return self._ssl_protocol._get_extra_info(name, default)
  73. def set_protocol(self, protocol):
  74. self._ssl_protocol._set_app_protocol(protocol)
  75. def get_protocol(self):
  76. return self._ssl_protocol._app_protocol
  77. def is_closing(self):
  78. return self._closed
  79. def close(self):
  80. """Close the transport.
  81. Buffered data will be flushed asynchronously. No more data
  82. will be received. After all buffered data is flushed, the
  83. protocol's connection_lost() method will (eventually) called
  84. with None as its argument.
  85. """
  86. if not self._closed:
  87. self._closed = True
  88. self._ssl_protocol._start_shutdown()
  89. else:
  90. self._ssl_protocol = None
  91. def __del__(self, _warnings=warnings):
  92. if not self._closed:
  93. self._closed = True
  94. _warnings.warn(
  95. "unclosed transport <asyncio._SSLProtocolTransport "
  96. "object>", ResourceWarning)
  97. def is_reading(self):
  98. return not self._ssl_protocol._app_reading_paused
  99. def pause_reading(self):
  100. """Pause the receiving end.
  101. No data will be passed to the protocol's data_received()
  102. method until resume_reading() is called.
  103. """
  104. self._ssl_protocol._pause_reading()
  105. def resume_reading(self):
  106. """Resume the receiving end.
  107. Data received will once again be passed to the protocol's
  108. data_received() method.
  109. """
  110. self._ssl_protocol._resume_reading()
  111. def set_write_buffer_limits(self, high=None, low=None):
  112. """Set the high- and low-water limits for write flow control.
  113. These two values control when to call the protocol's
  114. pause_writing() and resume_writing() methods. If specified,
  115. the low-water limit must be less than or equal to the
  116. high-water limit. Neither value can be negative.
  117. The defaults are implementation-specific. If only the
  118. high-water limit is given, the low-water limit defaults to an
  119. implementation-specific value less than or equal to the
  120. high-water limit. Setting high to zero forces low to zero as
  121. well, and causes pause_writing() to be called whenever the
  122. buffer becomes non-empty. Setting low to zero causes
  123. resume_writing() to be called only once the buffer is empty.
  124. Use of zero for either limit is generally sub-optimal as it
  125. reduces opportunities for doing I/O and computation
  126. concurrently.
  127. """
  128. self._ssl_protocol._set_write_buffer_limits(high, low)
  129. self._ssl_protocol._control_app_writing()
  130. def get_write_buffer_limits(self):
  131. return (self._ssl_protocol._outgoing_low_water,
  132. self._ssl_protocol._outgoing_high_water)
  133. def get_write_buffer_size(self):
  134. """Return the current size of the write buffers."""
  135. return self._ssl_protocol._get_write_buffer_size()
  136. def set_read_buffer_limits(self, high=None, low=None):
  137. """Set the high- and low-water limits for read flow control.
  138. These two values control when to call the upstream transport's
  139. pause_reading() and resume_reading() methods. If specified,
  140. the low-water limit must be less than or equal to the
  141. high-water limit. Neither value can be negative.
  142. The defaults are implementation-specific. If only the
  143. high-water limit is given, the low-water limit defaults to an
  144. implementation-specific value less than or equal to the
  145. high-water limit. Setting high to zero forces low to zero as
  146. well, and causes pause_reading() to be called whenever the
  147. buffer becomes non-empty. Setting low to zero causes
  148. resume_reading() to be called only once the buffer is empty.
  149. Use of zero for either limit is generally sub-optimal as it
  150. reduces opportunities for doing I/O and computation
  151. concurrently.
  152. """
  153. self._ssl_protocol._set_read_buffer_limits(high, low)
  154. self._ssl_protocol._control_ssl_reading()
  155. def get_read_buffer_limits(self):
  156. return (self._ssl_protocol._incoming_low_water,
  157. self._ssl_protocol._incoming_high_water)
  158. def get_read_buffer_size(self):
  159. """Return the current size of the read buffer."""
  160. return self._ssl_protocol._get_read_buffer_size()
  161. @property
  162. def _protocol_paused(self):
  163. # Required for sendfile fallback pause_writing/resume_writing logic
  164. return self._ssl_protocol._app_writing_paused
  165. def write(self, data):
  166. """Write some data bytes to the transport.
  167. This does not block; it buffers the data and arranges for it
  168. to be sent out asynchronously.
  169. """
  170. if not isinstance(data, (bytes, bytearray, memoryview)):
  171. raise TypeError(f"data: expecting a bytes-like instance, "
  172. f"got {type(data).__name__}")
  173. if not data:
  174. return
  175. self._ssl_protocol._write_appdata((data,))
  176. def writelines(self, list_of_data):
  177. """Write a list (or any iterable) of data bytes to the transport.
  178. The default implementation concatenates the arguments and
  179. calls write() on the result.
  180. """
  181. self._ssl_protocol._write_appdata(list_of_data)
  182. def write_eof(self):
  183. """Close the write end after flushing buffered data.
  184. This raises :exc:`NotImplementedError` right now.
  185. """
  186. raise NotImplementedError
  187. def can_write_eof(self):
  188. """Return True if this transport supports write_eof(), False if not."""
  189. return False
  190. def abort(self):
  191. """Close the transport immediately.
  192. Buffered data will be lost. No more data will be received.
  193. The protocol's connection_lost() method will (eventually) be
  194. called with None as its argument.
  195. """
  196. self._closed = True
  197. if self._ssl_protocol is not None:
  198. self._ssl_protocol._abort()
  199. def _force_close(self, exc):
  200. self._closed = True
  201. self._ssl_protocol._abort(exc)
  202. def _test__append_write_backlog(self, data):
  203. # for test only
  204. self._ssl_protocol._write_backlog.append(data)
  205. self._ssl_protocol._write_buffer_size += len(data)
  206. class SSLProtocol(protocols.BufferedProtocol):
  207. max_size = 256 * 1024 # Buffer size passed to read()
  208. _handshake_start_time = None
  209. _handshake_timeout_handle = None
  210. _shutdown_timeout_handle = None
  211. def __init__(self, loop, app_protocol, sslcontext, waiter,
  212. server_side=False, server_hostname=None,
  213. call_connection_made=True,
  214. ssl_handshake_timeout=None,
  215. ssl_shutdown_timeout=None):
  216. if ssl is None:
  217. raise RuntimeError("stdlib ssl module not available")
  218. self._ssl_buffer = bytearray(self.max_size)
  219. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  220. if ssl_handshake_timeout is None:
  221. ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
  222. elif ssl_handshake_timeout <= 0:
  223. raise ValueError(
  224. f"ssl_handshake_timeout should be a positive number, "
  225. f"got {ssl_handshake_timeout}")
  226. if ssl_shutdown_timeout is None:
  227. ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
  228. elif ssl_shutdown_timeout <= 0:
  229. raise ValueError(
  230. f"ssl_shutdown_timeout should be a positive number, "
  231. f"got {ssl_shutdown_timeout}")
  232. if not sslcontext:
  233. sslcontext = _create_transport_context(
  234. server_side, server_hostname)
  235. self._server_side = server_side
  236. if server_hostname and not server_side:
  237. self._server_hostname = server_hostname
  238. else:
  239. self._server_hostname = None
  240. self._sslcontext = sslcontext
  241. # SSL-specific extra info. More info are set when the handshake
  242. # completes.
  243. self._extra = dict(sslcontext=sslcontext)
  244. # App data write buffering
  245. self._write_backlog = collections.deque()
  246. self._write_buffer_size = 0
  247. self._waiter = waiter
  248. self._loop = loop
  249. self._set_app_protocol(app_protocol)
  250. self._app_transport = None
  251. self._app_transport_created = False
  252. # transport, ex: SelectorSocketTransport
  253. self._transport = None
  254. self._ssl_handshake_timeout = ssl_handshake_timeout
  255. self._ssl_shutdown_timeout = ssl_shutdown_timeout
  256. # SSL and state machine
  257. self._incoming = ssl.MemoryBIO()
  258. self._outgoing = ssl.MemoryBIO()
  259. self._state = SSLProtocolState.UNWRAPPED
  260. self._conn_lost = 0 # Set when connection_lost called
  261. if call_connection_made:
  262. self._app_state = AppProtocolState.STATE_INIT
  263. else:
  264. self._app_state = AppProtocolState.STATE_CON_MADE
  265. self._sslobj = self._sslcontext.wrap_bio(
  266. self._incoming, self._outgoing,
  267. server_side=self._server_side,
  268. server_hostname=self._server_hostname)
  269. # Flow Control
  270. self._ssl_writing_paused = False
  271. self._app_reading_paused = False
  272. self._ssl_reading_paused = False
  273. self._incoming_high_water = 0
  274. self._incoming_low_water = 0
  275. self._set_read_buffer_limits()
  276. self._eof_received = False
  277. self._app_writing_paused = False
  278. self._outgoing_high_water = 0
  279. self._outgoing_low_water = 0
  280. self._set_write_buffer_limits()
  281. self._get_app_transport()
  282. def _set_app_protocol(self, app_protocol):
  283. self._app_protocol = app_protocol
  284. # Make fast hasattr check first
  285. if (hasattr(app_protocol, 'get_buffer') and
  286. isinstance(app_protocol, protocols.BufferedProtocol)):
  287. self._app_protocol_get_buffer = app_protocol.get_buffer
  288. self._app_protocol_buffer_updated = app_protocol.buffer_updated
  289. self._app_protocol_is_buffer = True
  290. else:
  291. self._app_protocol_is_buffer = False
  292. def _wakeup_waiter(self, exc=None):
  293. if self._waiter is None:
  294. return
  295. if not self._waiter.cancelled():
  296. if exc is not None:
  297. self._waiter.set_exception(exc)
  298. else:
  299. self._waiter.set_result(None)
  300. self._waiter = None
  301. def _get_app_transport(self):
  302. if self._app_transport is None:
  303. if self._app_transport_created:
  304. raise RuntimeError('Creating _SSLProtocolTransport twice')
  305. self._app_transport = _SSLProtocolTransport(self._loop, self)
  306. self._app_transport_created = True
  307. return self._app_transport
  308. def connection_made(self, transport):
  309. """Called when the low-level connection is made.
  310. Start the SSL handshake.
  311. """
  312. self._transport = transport
  313. self._start_handshake()
  314. def connection_lost(self, exc):
  315. """Called when the low-level connection is lost or closed.
  316. The argument is an exception object or None (the latter
  317. meaning a regular EOF is received or the connection was
  318. aborted or closed).
  319. """
  320. self._write_backlog.clear()
  321. self._outgoing.read()
  322. self._conn_lost += 1
  323. # Just mark the app transport as closed so that its __dealloc__
  324. # doesn't complain.
  325. if self._app_transport is not None:
  326. self._app_transport._closed = True
  327. if self._state != SSLProtocolState.DO_HANDSHAKE:
  328. if (
  329. self._app_state == AppProtocolState.STATE_CON_MADE or
  330. self._app_state == AppProtocolState.STATE_EOF
  331. ):
  332. self._app_state = AppProtocolState.STATE_CON_LOST
  333. self._loop.call_soon(self._app_protocol.connection_lost, exc)
  334. self._set_state(SSLProtocolState.UNWRAPPED)
  335. self._transport = None
  336. self._app_transport = None
  337. self._app_protocol = None
  338. self._wakeup_waiter(exc)
  339. if self._shutdown_timeout_handle:
  340. self._shutdown_timeout_handle.cancel()
  341. self._shutdown_timeout_handle = None
  342. if self._handshake_timeout_handle:
  343. self._handshake_timeout_handle.cancel()
  344. self._handshake_timeout_handle = None
  345. def get_buffer(self, n):
  346. want = n
  347. if want <= 0 or want > self.max_size:
  348. want = self.max_size
  349. if len(self._ssl_buffer) < want:
  350. self._ssl_buffer = bytearray(want)
  351. self._ssl_buffer_view = memoryview(self._ssl_buffer)
  352. return self._ssl_buffer_view
  353. def buffer_updated(self, nbytes):
  354. self._incoming.write(self._ssl_buffer_view[:nbytes])
  355. if self._state == SSLProtocolState.DO_HANDSHAKE:
  356. self._do_handshake()
  357. elif self._state == SSLProtocolState.WRAPPED:
  358. self._do_read()
  359. elif self._state == SSLProtocolState.FLUSHING:
  360. self._do_flush()
  361. elif self._state == SSLProtocolState.SHUTDOWN:
  362. self._do_shutdown()
  363. def eof_received(self):
  364. """Called when the other end of the low-level stream
  365. is half-closed.
  366. If this returns a false value (including None), the transport
  367. will close itself. If it returns a true value, closing the
  368. transport is up to the protocol.
  369. """
  370. self._eof_received = True
  371. try:
  372. if self._loop.get_debug():
  373. logger.debug("%r received EOF", self)
  374. if self._state == SSLProtocolState.DO_HANDSHAKE:
  375. self._on_handshake_complete(ConnectionResetError)
  376. elif self._state == SSLProtocolState.WRAPPED:
  377. self._set_state(SSLProtocolState.FLUSHING)
  378. if self._app_reading_paused:
  379. return True
  380. else:
  381. self._do_flush()
  382. elif self._state == SSLProtocolState.FLUSHING:
  383. self._do_write()
  384. self._set_state(SSLProtocolState.SHUTDOWN)
  385. self._do_shutdown()
  386. elif self._state == SSLProtocolState.SHUTDOWN:
  387. self._do_shutdown()
  388. except Exception:
  389. self._transport.close()
  390. raise
  391. def _get_extra_info(self, name, default=None):
  392. if name in self._extra:
  393. return self._extra[name]
  394. elif self._transport is not None:
  395. return self._transport.get_extra_info(name, default)
  396. else:
  397. return default
  398. def _set_state(self, new_state):
  399. allowed = False
  400. if new_state == SSLProtocolState.UNWRAPPED:
  401. allowed = True
  402. elif (
  403. self._state == SSLProtocolState.UNWRAPPED and
  404. new_state == SSLProtocolState.DO_HANDSHAKE
  405. ):
  406. allowed = True
  407. elif (
  408. self._state == SSLProtocolState.DO_HANDSHAKE and
  409. new_state == SSLProtocolState.WRAPPED
  410. ):
  411. allowed = True
  412. elif (
  413. self._state == SSLProtocolState.WRAPPED and
  414. new_state == SSLProtocolState.FLUSHING
  415. ):
  416. allowed = True
  417. elif (
  418. self._state == SSLProtocolState.FLUSHING and
  419. new_state == SSLProtocolState.SHUTDOWN
  420. ):
  421. allowed = True
  422. if allowed:
  423. self._state = new_state
  424. else:
  425. raise RuntimeError(
  426. 'cannot switch state from {} to {}'.format(
  427. self._state, new_state))
  428. # Handshake flow
  429. def _start_handshake(self):
  430. if self._loop.get_debug():
  431. logger.debug("%r starts SSL handshake", self)
  432. self._handshake_start_time = self._loop.time()
  433. else:
  434. self._handshake_start_time = None
  435. self._set_state(SSLProtocolState.DO_HANDSHAKE)
  436. # start handshake timeout count down
  437. self._handshake_timeout_handle = \
  438. self._loop.call_later(self._ssl_handshake_timeout,
  439. lambda: self._check_handshake_timeout())
  440. self._do_handshake()
  441. def _check_handshake_timeout(self):
  442. if self._state == SSLProtocolState.DO_HANDSHAKE:
  443. msg = (
  444. f"SSL handshake is taking longer than "
  445. f"{self._ssl_handshake_timeout} seconds: "
  446. f"aborting the connection"
  447. )
  448. self._fatal_error(ConnectionAbortedError(msg))
  449. def _do_handshake(self):
  450. try:
  451. self._sslobj.do_handshake()
  452. except SSLAgainErrors:
  453. self._process_outgoing()
  454. except ssl.SSLError as exc:
  455. self._on_handshake_complete(exc)
  456. else:
  457. self._on_handshake_complete(None)
  458. def _on_handshake_complete(self, handshake_exc):
  459. if self._handshake_timeout_handle is not None:
  460. self._handshake_timeout_handle.cancel()
  461. self._handshake_timeout_handle = None
  462. sslobj = self._sslobj
  463. try:
  464. if handshake_exc is None:
  465. self._set_state(SSLProtocolState.WRAPPED)
  466. else:
  467. raise handshake_exc
  468. peercert = sslobj.getpeercert()
  469. except Exception as exc:
  470. self._set_state(SSLProtocolState.UNWRAPPED)
  471. if isinstance(exc, ssl.CertificateError):
  472. msg = 'SSL handshake failed on verifying the certificate'
  473. else:
  474. msg = 'SSL handshake failed'
  475. self._fatal_error(exc, msg)
  476. self._wakeup_waiter(exc)
  477. return
  478. if self._loop.get_debug():
  479. dt = self._loop.time() - self._handshake_start_time
  480. logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
  481. # Add extra info that becomes available after handshake.
  482. self._extra.update(peercert=peercert,
  483. cipher=sslobj.cipher(),
  484. compression=sslobj.compression(),
  485. ssl_object=sslobj)
  486. if self._app_state == AppProtocolState.STATE_INIT:
  487. self._app_state = AppProtocolState.STATE_CON_MADE
  488. self._app_protocol.connection_made(self._get_app_transport())
  489. self._wakeup_waiter()
  490. self._do_read()
  491. # Shutdown flow
  492. def _start_shutdown(self):
  493. if (
  494. self._state in (
  495. SSLProtocolState.FLUSHING,
  496. SSLProtocolState.SHUTDOWN,
  497. SSLProtocolState.UNWRAPPED
  498. )
  499. ):
  500. return
  501. if self._app_transport is not None:
  502. self._app_transport._closed = True
  503. if self._state == SSLProtocolState.DO_HANDSHAKE:
  504. self._abort()
  505. else:
  506. self._set_state(SSLProtocolState.FLUSHING)
  507. self._shutdown_timeout_handle = self._loop.call_later(
  508. self._ssl_shutdown_timeout,
  509. lambda: self._check_shutdown_timeout()
  510. )
  511. self._do_flush()
  512. def _check_shutdown_timeout(self):
  513. if (
  514. self._state in (
  515. SSLProtocolState.FLUSHING,
  516. SSLProtocolState.SHUTDOWN
  517. )
  518. ):
  519. self._transport._force_close(
  520. exceptions.TimeoutError('SSL shutdown timed out'))
  521. def _do_flush(self):
  522. self._do_read()
  523. self._set_state(SSLProtocolState.SHUTDOWN)
  524. self._do_shutdown()
  525. def _do_shutdown(self):
  526. try:
  527. if not self._eof_received:
  528. self._sslobj.unwrap()
  529. except SSLAgainErrors:
  530. self._process_outgoing()
  531. except ssl.SSLError as exc:
  532. self._on_shutdown_complete(exc)
  533. else:
  534. self._process_outgoing()
  535. self._call_eof_received()
  536. self._on_shutdown_complete(None)
  537. def _on_shutdown_complete(self, shutdown_exc):
  538. if self._shutdown_timeout_handle is not None:
  539. self._shutdown_timeout_handle.cancel()
  540. self._shutdown_timeout_handle = None
  541. if shutdown_exc:
  542. self._fatal_error(shutdown_exc)
  543. else:
  544. self._loop.call_soon(self._transport.close)
  545. def _abort(self):
  546. self._set_state(SSLProtocolState.UNWRAPPED)
  547. if self._transport is not None:
  548. self._transport.abort()
  549. # Outgoing flow
  550. def _write_appdata(self, list_of_data):
  551. if (
  552. self._state in (
  553. SSLProtocolState.FLUSHING,
  554. SSLProtocolState.SHUTDOWN,
  555. SSLProtocolState.UNWRAPPED
  556. )
  557. ):
  558. if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
  559. logger.warning('SSL connection is closed')
  560. self._conn_lost += 1
  561. return
  562. for data in list_of_data:
  563. self._write_backlog.append(data)
  564. self._write_buffer_size += len(data)
  565. try:
  566. if self._state == SSLProtocolState.WRAPPED:
  567. self._do_write()
  568. except Exception as ex:
  569. self._fatal_error(ex, 'Fatal error on SSL protocol')
  570. def _do_write(self):
  571. try:
  572. while self._write_backlog:
  573. data = self._write_backlog[0]
  574. count = self._sslobj.write(data)
  575. data_len = len(data)
  576. if count < data_len:
  577. self._write_backlog[0] = data[count:]
  578. self._write_buffer_size -= count
  579. else:
  580. del self._write_backlog[0]
  581. self._write_buffer_size -= data_len
  582. except SSLAgainErrors:
  583. pass
  584. self._process_outgoing()
  585. def _process_outgoing(self):
  586. if not self._ssl_writing_paused:
  587. data = self._outgoing.read()
  588. if len(data):
  589. self._transport.write(data)
  590. self._control_app_writing()
  591. # Incoming flow
  592. def _do_read(self):
  593. if (
  594. self._state not in (
  595. SSLProtocolState.WRAPPED,
  596. SSLProtocolState.FLUSHING,
  597. )
  598. ):
  599. return
  600. try:
  601. if not self._app_reading_paused:
  602. if self._app_protocol_is_buffer:
  603. self._do_read__buffered()
  604. else:
  605. self._do_read__copied()
  606. if self._write_backlog:
  607. self._do_write()
  608. else:
  609. self._process_outgoing()
  610. self._control_ssl_reading()
  611. except Exception as ex:
  612. self._fatal_error(ex, 'Fatal error on SSL protocol')
  613. def _do_read__buffered(self):
  614. offset = 0
  615. count = 1
  616. buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
  617. wants = len(buf)
  618. try:
  619. count = self._sslobj.read(wants, buf)
  620. if count > 0:
  621. offset = count
  622. while offset < wants:
  623. count = self._sslobj.read(wants - offset, buf[offset:])
  624. if count > 0:
  625. offset += count
  626. else:
  627. break
  628. else:
  629. self._loop.call_soon(lambda: self._do_read())
  630. except SSLAgainErrors:
  631. pass
  632. if offset > 0:
  633. self._app_protocol_buffer_updated(offset)
  634. if not count:
  635. # close_notify
  636. self._call_eof_received()
  637. self._start_shutdown()
  638. def _do_read__copied(self):
  639. chunk = b'1'
  640. zero = True
  641. one = False
  642. try:
  643. while True:
  644. chunk = self._sslobj.read(self.max_size)
  645. if not chunk:
  646. break
  647. if zero:
  648. zero = False
  649. one = True
  650. first = chunk
  651. elif one:
  652. one = False
  653. data = [first, chunk]
  654. else:
  655. data.append(chunk)
  656. except SSLAgainErrors:
  657. pass
  658. if one:
  659. self._app_protocol.data_received(first)
  660. elif not zero:
  661. self._app_protocol.data_received(b''.join(data))
  662. if not chunk:
  663. # close_notify
  664. self._call_eof_received()
  665. self._start_shutdown()
  666. def _call_eof_received(self):
  667. try:
  668. if self._app_state == AppProtocolState.STATE_CON_MADE:
  669. self._app_state = AppProtocolState.STATE_EOF
  670. keep_open = self._app_protocol.eof_received()
  671. if keep_open:
  672. logger.warning('returning true from eof_received() '
  673. 'has no effect when using ssl')
  674. except (KeyboardInterrupt, SystemExit):
  675. raise
  676. except BaseException as ex:
  677. self._fatal_error(ex, 'Error calling eof_received()')
  678. # Flow control for writes from APP socket
  679. def _control_app_writing(self):
  680. size = self._get_write_buffer_size()
  681. if size >= self._outgoing_high_water and not self._app_writing_paused:
  682. self._app_writing_paused = True
  683. try:
  684. self._app_protocol.pause_writing()
  685. except (KeyboardInterrupt, SystemExit):
  686. raise
  687. except BaseException as exc:
  688. self._loop.call_exception_handler({
  689. 'message': 'protocol.pause_writing() failed',
  690. 'exception': exc,
  691. 'transport': self._app_transport,
  692. 'protocol': self,
  693. })
  694. elif size <= self._outgoing_low_water and self._app_writing_paused:
  695. self._app_writing_paused = False
  696. try:
  697. self._app_protocol.resume_writing()
  698. except (KeyboardInterrupt, SystemExit):
  699. raise
  700. except BaseException as exc:
  701. self._loop.call_exception_handler({
  702. 'message': 'protocol.resume_writing() failed',
  703. 'exception': exc,
  704. 'transport': self._app_transport,
  705. 'protocol': self,
  706. })
  707. def _get_write_buffer_size(self):
  708. return self._outgoing.pending + self._write_buffer_size
  709. def _set_write_buffer_limits(self, high=None, low=None):
  710. high, low = add_flowcontrol_defaults(
  711. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
  712. self._outgoing_high_water = high
  713. self._outgoing_low_water = low
  714. # Flow control for reads to APP socket
  715. def _pause_reading(self):
  716. self._app_reading_paused = True
  717. def _resume_reading(self):
  718. if self._app_reading_paused:
  719. self._app_reading_paused = False
  720. def resume():
  721. if self._state == SSLProtocolState.WRAPPED:
  722. self._do_read()
  723. elif self._state == SSLProtocolState.FLUSHING:
  724. self._do_flush()
  725. elif self._state == SSLProtocolState.SHUTDOWN:
  726. self._do_shutdown()
  727. self._loop.call_soon(resume)
  728. # Flow control for reads from SSL socket
  729. def _control_ssl_reading(self):
  730. size = self._get_read_buffer_size()
  731. if size >= self._incoming_high_water and not self._ssl_reading_paused:
  732. self._ssl_reading_paused = True
  733. self._transport.pause_reading()
  734. elif size <= self._incoming_low_water and self._ssl_reading_paused:
  735. self._ssl_reading_paused = False
  736. self._transport.resume_reading()
  737. def _set_read_buffer_limits(self, high=None, low=None):
  738. high, low = add_flowcontrol_defaults(
  739. high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
  740. self._incoming_high_water = high
  741. self._incoming_low_water = low
  742. def _get_read_buffer_size(self):
  743. return self._incoming.pending
  744. # Flow control for writes to SSL socket
  745. def pause_writing(self):
  746. """Called when the low-level transport's buffer goes over
  747. the high-water mark.
  748. """
  749. assert not self._ssl_writing_paused
  750. self._ssl_writing_paused = True
  751. def resume_writing(self):
  752. """Called when the low-level transport's buffer drains below
  753. the low-water mark.
  754. """
  755. assert self._ssl_writing_paused
  756. self._ssl_writing_paused = False
  757. self._process_outgoing()
  758. def _fatal_error(self, exc, message='Fatal error on transport'):
  759. if self._transport:
  760. self._transport._force_close(exc)
  761. if isinstance(exc, OSError):
  762. if self._loop.get_debug():
  763. logger.debug("%r: %s", self, message, exc_info=True)
  764. elif not isinstance(exc, exceptions.CancelledError):
  765. self._loop.call_exception_handler({
  766. 'message': message,
  767. 'exception': exc,
  768. 'transport': self._transport,
  769. 'protocol': self,
  770. })