web_ws.py 16 KB


  1. import asyncio
  2. import base64
  3. import binascii
  4. import hashlib
  5. import json
  6. from typing import Any, Iterable, Optional, Tuple
  7. import async_timeout
  8. import attr
  9. from multidict import CIMultiDict
  10. from . import hdrs
  11. from .abc import AbstractStreamWriter
  12. from .helpers import call_later, set_result
  13. from .http import (
  14. WS_CLOSED_MESSAGE,
  15. WS_CLOSING_MESSAGE,
  16. WS_KEY,
  17. WebSocketError,
  18. WebSocketReader,
  19. WebSocketWriter,
  20. WSMessage,
  21. WSMsgType as WSMsgType,
  22. ws_ext_gen,
  23. ws_ext_parse,
  24. )
  25. from .log import ws_logger
  26. from .streams import EofStream, FlowControlDataQueue
  27. from .typedefs import JSONDecoder, JSONEncoder
  28. from .web_exceptions import HTTPBadRequest, HTTPException
  29. from .web_request import BaseRequest
  30. from .web_response import StreamResponse
  31. __all__ = (
  32. "WebSocketResponse",
  33. "WebSocketReady",
  34. "WSMsgType",
  35. )
  36. THRESHOLD_CONNLOST_ACCESS = 5
  37. @attr.s(auto_attribs=True, frozen=True, slots=True)
  38. class WebSocketReady:
  39. ok: bool
  40. protocol: Optional[str]
  41. def __bool__(self) -> bool:
  42. return self.ok
  43. class WebSocketResponse(StreamResponse):
  44. _length_check = False
  45. def __init__(
  46. self,
  47. *,
  48. timeout: float = 10.0,
  49. receive_timeout: Optional[float] = None,
  50. autoclose: bool = True,
  51. autoping: bool = True,
  52. heartbeat: Optional[float] = None,
  53. protocols: Iterable[str] = (),
  54. compress: bool = True,
  55. max_msg_size: int = 4 * 1024 * 1024,
  56. ) -> None:
  57. super().__init__(status=101)
  58. self._protocols = protocols
  59. self._ws_protocol = None # type: Optional[str]
  60. self._writer = None # type: Optional[WebSocketWriter]
  61. self._reader = None # type: Optional[FlowControlDataQueue[WSMessage]]
  62. self._closed = False
  63. self._closing = False
  64. self._conn_lost = 0
  65. self._close_code = None # type: Optional[int]
  66. self._loop = None # type: Optional[asyncio.AbstractEventLoop]
  67. self._waiting = None # type: Optional[asyncio.Future[bool]]
  68. self._exception = None # type: Optional[BaseException]
  69. self._timeout = timeout
  70. self._receive_timeout = receive_timeout
  71. self._autoclose = autoclose
  72. self._autoping = autoping
  73. self._heartbeat = heartbeat
  74. self._heartbeat_cb = None
  75. if heartbeat is not None:
  76. self._pong_heartbeat = heartbeat / 2.0
  77. self._pong_response_cb = None
  78. self._compress = compress
  79. self._max_msg_size = max_msg_size
  80. def _cancel_heartbeat(self) -> None:
  81. if self._pong_response_cb is not None:
  82. self._pong_response_cb.cancel()
  83. self._pong_response_cb = None
  84. if self._heartbeat_cb is not None:
  85. self._heartbeat_cb.cancel()
  86. self._heartbeat_cb = None
  87. def _reset_heartbeat(self) -> None:
  88. self._cancel_heartbeat()
  89. if self._heartbeat is not None:
  90. self._heartbeat_cb = call_later(
  91. self._send_heartbeat, self._heartbeat, self._loop
  92. )
  93. def _send_heartbeat(self) -> None:
  94. if self._heartbeat is not None and not self._closed:
  95. # fire-and-forget a task is not perfect but maybe ok for
  96. # sending ping. Otherwise we need a long-living heartbeat
  97. # task in the class.
  98. self._loop.create_task(self._writer.ping()) # type: ignore
  99. if self._pong_response_cb is not None:
  100. self._pong_response_cb.cancel()
  101. self._pong_response_cb = call_later(
  102. self._pong_not_received, self._pong_heartbeat, self._loop
  103. )
  104. def _pong_not_received(self) -> None:
  105. if self._req is not None and self._req.transport is not None:
  106. self._closed = True
  107. self._close_code = 1006
  108. self._exception = asyncio.TimeoutError()
  109. self._req.transport.close()
  110. async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
  111. # make pre-check to don't hide it by do_handshake() exceptions
  112. if self._payload_writer is not None:
  113. return self._payload_writer
  114. protocol, writer = self._pre_start(request)
  115. payload_writer = await super().prepare(request)
  116. assert payload_writer is not None
  117. self._post_start(request, protocol, writer)
  118. await payload_writer.drain()
  119. return payload_writer
  120. def _handshake(
  121. self, request: BaseRequest
  122. ) -> Tuple["CIMultiDict[str]", str, bool, bool]:
  123. headers = request.headers
  124. if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
  125. raise HTTPBadRequest(
  126. text=(
  127. "No WebSocket UPGRADE hdr: {}\n Can "
  128. '"Upgrade" only to "WebSocket".'
  129. ).format(headers.get(hdrs.UPGRADE))
  130. )
  131. if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
  132. raise HTTPBadRequest(
  133. text="No CONNECTION upgrade hdr: {}".format(
  134. headers.get(hdrs.CONNECTION)
  135. )
  136. )
  137. # find common sub-protocol between client and server
  138. protocol = None
  139. if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
  140. req_protocols = [
  141. str(proto.strip())
  142. for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
  143. ]
  144. for proto in req_protocols:
  145. if proto in self._protocols:
  146. protocol = proto
  147. break
  148. else:
  149. # No overlap found: Return no protocol as per spec
  150. ws_logger.warning(
  151. "Client protocols %r don’t overlap server-known ones %r",
  152. req_protocols,
  153. self._protocols,
  154. )
  155. # check supported version
  156. version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
  157. if version not in ("13", "8", "7"):
  158. raise HTTPBadRequest(text=f"Unsupported version: {version}")
  159. # check client handshake for validity
  160. key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
  161. try:
  162. if not key or len(base64.b64decode(key)) != 16:
  163. raise HTTPBadRequest(text=f"Handshake error: {key!r}")
  164. except binascii.Error:
  165. raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
  166. accept_val = base64.b64encode(
  167. hashlib.sha1(key.encode() + WS_KEY).digest()
  168. ).decode()
  169. response_headers = CIMultiDict( # type: ignore
  170. {
  171. hdrs.UPGRADE: "websocket", # type: ignore
  172. hdrs.CONNECTION: "upgrade",
  173. hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
  174. }
  175. )
  176. notakeover = False
  177. compress = 0
  178. if self._compress:
  179. extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
  180. # Server side always get return with no exception.
  181. # If something happened, just drop compress extension
  182. compress, notakeover = ws_ext_parse(extensions, isserver=True)
  183. if compress:
  184. enabledext = ws_ext_gen(
  185. compress=compress, isserver=True, server_notakeover=notakeover
  186. )
  187. response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
  188. if protocol:
  189. response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
  190. return (response_headers, protocol, compress, notakeover) # type: ignore
  191. def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
  192. self._loop = request._loop
  193. headers, protocol, compress, notakeover = self._handshake(request)
  194. self.set_status(101)
  195. self.headers.update(headers)
  196. self.force_close()
  197. self._compress = compress
  198. transport = request._protocol.transport
  199. assert transport is not None
  200. writer = WebSocketWriter(
  201. request._protocol, transport, compress=compress, notakeover=notakeover
  202. )
  203. return protocol, writer
  204. def _post_start(
  205. self, request: BaseRequest, protocol: str, writer: WebSocketWriter
  206. ) -> None:
  207. self._ws_protocol = protocol
  208. self._writer = writer
  209. self._reset_heartbeat()
  210. loop = self._loop
  211. assert loop is not None
  212. self._reader = FlowControlDataQueue(request._protocol, 2 ** 16, loop=loop)
  213. request.protocol.set_parser(
  214. WebSocketReader(self._reader, self._max_msg_size, compress=self._compress)
  215. )
  216. # disable HTTP keepalive for WebSocket
  217. request.protocol.keep_alive(False)
  218. def can_prepare(self, request: BaseRequest) -> WebSocketReady:
  219. if self._writer is not None:
  220. raise RuntimeError("Already started")
  221. try:
  222. _, protocol, _, _ = self._handshake(request)
  223. except HTTPException:
  224. return WebSocketReady(False, None)
  225. else:
  226. return WebSocketReady(True, protocol)
  227. @property
  228. def closed(self) -> bool:
  229. return self._closed
  230. @property
  231. def close_code(self) -> Optional[int]:
  232. return self._close_code
  233. @property
  234. def ws_protocol(self) -> Optional[str]:
  235. return self._ws_protocol
  236. @property
  237. def compress(self) -> bool:
  238. return self._compress
  239. def exception(self) -> Optional[BaseException]:
  240. return self._exception
  241. async def ping(self, message: bytes = b"") -> None:
  242. if self._writer is None:
  243. raise RuntimeError("Call .prepare() first")
  244. await self._writer.ping(message)
  245. async def pong(self, message: bytes = b"") -> None:
  246. # unsolicited pong
  247. if self._writer is None:
  248. raise RuntimeError("Call .prepare() first")
  249. await self._writer.pong(message)
  250. async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
  251. if self._writer is None:
  252. raise RuntimeError("Call .prepare() first")
  253. if not isinstance(data, str):
  254. raise TypeError("data argument must be str (%r)" % type(data))
  255. await self._writer.send(data, binary=False, compress=compress)
  256. async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
  257. if self._writer is None:
  258. raise RuntimeError("Call .prepare() first")
  259. if not isinstance(data, (bytes, bytearray, memoryview)):
  260. raise TypeError("data argument must be byte-ish (%r)" % type(data))
  261. await self._writer.send(data, binary=True, compress=compress)
  262. async def send_json(
  263. self,
  264. data: Any,
  265. compress: Optional[bool] = None,
  266. *,
  267. dumps: JSONEncoder = json.dumps,
  268. ) -> None:
  269. await self.send_str(dumps(data), compress=compress)
  270. async def write_eof(self) -> None: # type: ignore
  271. if self._eof_sent:
  272. return
  273. if self._payload_writer is None:
  274. raise RuntimeError("Response has not been started")
  275. await self.close()
  276. self._eof_sent = True
  277. async def close(self, *, code: int = 1000, message: bytes = b"") -> bool:
  278. if self._writer is None:
  279. raise RuntimeError("Call .prepare() first")
  280. self._cancel_heartbeat()
  281. reader = self._reader
  282. assert reader is not None
  283. # we need to break `receive()` cycle first,
  284. # `close()` may be called from different task
  285. if self._waiting is not None and not self._closed:
  286. reader.feed_data(WS_CLOSING_MESSAGE, 0)
  287. await self._waiting
  288. if not self._closed:
  289. self._closed = True
  290. try:
  291. await self._writer.close(code, message)
  292. writer = self._payload_writer
  293. assert writer is not None
  294. await writer.drain()
  295. except (asyncio.CancelledError, asyncio.TimeoutError):
  296. self._close_code = 1006
  297. raise
  298. except Exception as exc:
  299. self._close_code = 1006
  300. self._exception = exc
  301. return True
  302. if self._closing:
  303. return True
  304. reader = self._reader
  305. assert reader is not None
  306. try:
  307. with async_timeout.timeout(self._timeout, loop=self._loop):
  308. msg = await reader.read()
  309. except asyncio.CancelledError:
  310. self._close_code = 1006
  311. raise
  312. except Exception as exc:
  313. self._close_code = 1006
  314. self._exception = exc
  315. return True
  316. if msg.type == WSMsgType.CLOSE:
  317. self._close_code = msg.data
  318. return True
  319. self._close_code = 1006
  320. self._exception = asyncio.TimeoutError()
  321. return True
  322. else:
  323. return False
  324. async def receive(self, timeout: Optional[float] = None) -> WSMessage:
  325. if self._reader is None:
  326. raise RuntimeError("Call .prepare() first")
  327. loop = self._loop
  328. assert loop is not None
  329. while True:
  330. if self._waiting is not None:
  331. raise RuntimeError("Concurrent call to receive() is not allowed")
  332. if self._closed:
  333. self._conn_lost += 1
  334. if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
  335. raise RuntimeError("WebSocket connection is closed.")
  336. return WS_CLOSED_MESSAGE
  337. elif self._closing:
  338. return WS_CLOSING_MESSAGE
  339. try:
  340. self._waiting = loop.create_future()
  341. try:
  342. with async_timeout.timeout(
  343. timeout or self._receive_timeout, loop=self._loop
  344. ):
  345. msg = await self._reader.read()
  346. self._reset_heartbeat()
  347. finally:
  348. waiter = self._waiting
  349. set_result(waiter, True)
  350. self._waiting = None
  351. except (asyncio.CancelledError, asyncio.TimeoutError):
  352. self._close_code = 1006
  353. raise
  354. except EofStream:
  355. self._close_code = 1000
  356. await self.close()
  357. return WSMessage(WSMsgType.CLOSED, None, None)
  358. except WebSocketError as exc:
  359. self._close_code = exc.code
  360. await self.close(code=exc.code)
  361. return WSMessage(WSMsgType.ERROR, exc, None)
  362. except Exception as exc:
  363. self._exception = exc
  364. self._closing = True
  365. self._close_code = 1006
  366. await self.close()
  367. return WSMessage(WSMsgType.ERROR, exc, None)
  368. if msg.type == WSMsgType.CLOSE:
  369. self._closing = True
  370. self._close_code = msg.data
  371. if not self._closed and self._autoclose:
  372. await self.close()
  373. elif msg.type == WSMsgType.CLOSING:
  374. self._closing = True
  375. elif msg.type == WSMsgType.PING and self._autoping:
  376. await self.pong(msg.data)
  377. continue
  378. elif msg.type == WSMsgType.PONG and self._autoping:
  379. continue
  380. return msg
  381. async def receive_str(self, *, timeout: Optional[float] = None) -> str:
  382. msg = await self.receive(timeout)
  383. if msg.type != WSMsgType.TEXT:
  384. raise TypeError(
  385. "Received message {}:{!r} is not WSMsgType.TEXT".format(
  386. msg.type, msg.data
  387. )
  388. )
  389. return msg.data
  390. async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
  391. msg = await self.receive(timeout)
  392. if msg.type != WSMsgType.BINARY:
  393. raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
  394. return msg.data
  395. async def receive_json(
  396. self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
  397. ) -> Any:
  398. data = await self.receive_str(timeout=timeout)
  399. return loads(data)
  400. async def write(self, data: bytes) -> None:
  401. raise RuntimeError("Cannot call .write() for websocket")
  402. def __aiter__(self) -> "WebSocketResponse":
  403. return self
  404. async def __anext__(self) -> WSMessage:
  405. msg = await self.receive()
  406. if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
  407. raise StopAsyncIteration
  408. return msg
  409. def _cancel(self, exc: BaseException) -> None:
  410. if self._reader is not None:
  411. self._reader.set_exception(exc)