client.py 43 KB


  1. """HTTP Client for asyncio."""
  2. import asyncio
  3. import base64
  4. import hashlib
  5. import json
  6. import os
  7. import sys
  8. import traceback
  9. import warnings
  10. from types import SimpleNamespace, TracebackType
  11. from typing import (
  12. Any,
  13. Awaitable,
  14. Callable,
  15. Coroutine,
  16. FrozenSet,
  17. Generator,
  18. Generic,
  19. Iterable,
  20. List,
  21. Mapping,
  22. Optional,
  23. Set,
  24. Tuple,
  25. Type,
  26. TypeVar,
  27. Union,
  28. )
  29. import attr
  30. from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
  31. from yarl import URL
  32. from . import hdrs, http, payload
  33. from .abc import AbstractCookieJar
  34. from .client_exceptions import (
  35. ClientConnectionError as ClientConnectionError,
  36. ClientConnectorCertificateError as ClientConnectorCertificateError,
  37. ClientConnectorError as ClientConnectorError,
  38. ClientConnectorSSLError as ClientConnectorSSLError,
  39. ClientError as ClientError,
  40. ClientHttpProxyError as ClientHttpProxyError,
  41. ClientOSError as ClientOSError,
  42. ClientPayloadError as ClientPayloadError,
  43. ClientProxyConnectionError as ClientProxyConnectionError,
  44. ClientResponseError as ClientResponseError,
  45. ClientSSLError as ClientSSLError,
  46. ContentTypeError as ContentTypeError,
  47. InvalidURL as InvalidURL,
  48. ServerConnectionError as ServerConnectionError,
  49. ServerDisconnectedError as ServerDisconnectedError,
  50. ServerFingerprintMismatch as ServerFingerprintMismatch,
  51. ServerTimeoutError as ServerTimeoutError,
  52. TooManyRedirects as TooManyRedirects,
  53. WSServerHandshakeError as WSServerHandshakeError,
  54. )
  55. from .client_reqrep import (
  56. ClientRequest as ClientRequest,
  57. ClientResponse as ClientResponse,
  58. Fingerprint as Fingerprint,
  59. RequestInfo as RequestInfo,
  60. _merge_ssl_params,
  61. )
  62. from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse
  63. from .connector import (
  64. BaseConnector as BaseConnector,
  65. NamedPipeConnector as NamedPipeConnector,
  66. TCPConnector as TCPConnector,
  67. UnixConnector as UnixConnector,
  68. )
  69. from .cookiejar import CookieJar
  70. from .helpers import (
  71. DEBUG,
  72. PY_36,
  73. BasicAuth,
  74. CeilTimeout,
  75. TimeoutHandle,
  76. get_running_loop,
  77. proxies_from_env,
  78. sentinel,
  79. strip_auth_from_url,
  80. )
  81. from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
  82. from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
  83. from .streams import FlowControlDataQueue
  84. from .tracing import Trace, TraceConfig
  85. from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL
  86. __all__ = (
  87. # client_exceptions
  88. "ClientConnectionError",
  89. "ClientConnectorCertificateError",
  90. "ClientConnectorError",
  91. "ClientConnectorSSLError",
  92. "ClientError",
  93. "ClientHttpProxyError",
  94. "ClientOSError",
  95. "ClientPayloadError",
  96. "ClientProxyConnectionError",
  97. "ClientResponseError",
  98. "ClientSSLError",
  99. "ContentTypeError",
  100. "InvalidURL",
  101. "ServerConnectionError",
  102. "ServerDisconnectedError",
  103. "ServerFingerprintMismatch",
  104. "ServerTimeoutError",
  105. "TooManyRedirects",
  106. "WSServerHandshakeError",
  107. # client_reqrep
  108. "ClientRequest",
  109. "ClientResponse",
  110. "Fingerprint",
  111. "RequestInfo",
  112. # connector
  113. "BaseConnector",
  114. "TCPConnector",
  115. "UnixConnector",
  116. "NamedPipeConnector",
  117. # client_ws
  118. "ClientWebSocketResponse",
  119. # client
  120. "ClientSession",
  121. "ClientTimeout",
  122. "request",
  123. )
  124. try:
  125. from ssl import SSLContext
  126. except ImportError: # pragma: no cover
  127. SSLContext = object # type: ignore
  128. @attr.s(auto_attribs=True, frozen=True, slots=True)
  129. class ClientTimeout:
  130. total: Optional[float] = None
  131. connect: Optional[float] = None
  132. sock_read: Optional[float] = None
  133. sock_connect: Optional[float] = None
  134. # pool_queue_timeout: Optional[float] = None
  135. # dns_resolution_timeout: Optional[float] = None
  136. # socket_connect_timeout: Optional[float] = None
  137. # connection_acquiring_timeout: Optional[float] = None
  138. # new_connection_timeout: Optional[float] = None
  139. # http_header_timeout: Optional[float] = None
  140. # response_body_timeout: Optional[float] = None
  141. # to create a timeout specific for a single request, either
  142. # - create a completely new one to overwrite the default
  143. # - or use http://www.attrs.org/en/stable/api.html#attr.evolve
  144. # to overwrite the defaults
  145. # 5 Minute default read timeout
  146. DEFAULT_TIMEOUT = ClientTimeout(total=5 * 60)
  147. _RetType = TypeVar("_RetType")
  148. class ClientSession:
  149. """First-class interface for making HTTP requests."""
  150. ATTRS = frozenset(
  151. [
  152. "_source_traceback",
  153. "_connector",
  154. "requote_redirect_url",
  155. "_loop",
  156. "_cookie_jar",
  157. "_connector_owner",
  158. "_default_auth",
  159. "_version",
  160. "_json_serialize",
  161. "_requote_redirect_url",
  162. "_timeout",
  163. "_raise_for_status",
  164. "_auto_decompress",
  165. "_trust_env",
  166. "_default_headers",
  167. "_skip_auto_headers",
  168. "_request_class",
  169. "_response_class",
  170. "_ws_response_class",
  171. "_trace_configs",
  172. "_read_bufsize",
  173. ]
  174. )
  175. _source_traceback = None
  176. def __init__(
  177. self,
  178. *,
  179. connector: Optional[BaseConnector] = None,
  180. loop: Optional[asyncio.AbstractEventLoop] = None,
  181. cookies: Optional[LooseCookies] = None,
  182. headers: Optional[LooseHeaders] = None,
  183. skip_auto_headers: Optional[Iterable[str]] = None,
  184. auth: Optional[BasicAuth] = None,
  185. json_serialize: JSONEncoder = json.dumps,
  186. request_class: Type[ClientRequest] = ClientRequest,
  187. response_class: Type[ClientResponse] = ClientResponse,
  188. ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse,
  189. version: HttpVersion = http.HttpVersion11,
  190. cookie_jar: Optional[AbstractCookieJar] = None,
  191. connector_owner: bool = True,
  192. raise_for_status: bool = False,
  193. read_timeout: Union[float, object] = sentinel,
  194. conn_timeout: Optional[float] = None,
  195. timeout: Union[object, ClientTimeout] = sentinel,
  196. auto_decompress: bool = True,
  197. trust_env: bool = False,
  198. requote_redirect_url: bool = True,
  199. trace_configs: Optional[List[TraceConfig]] = None,
  200. read_bufsize: int = 2 ** 16,
  201. ) -> None:
  202. if loop is None:
  203. if connector is not None:
  204. loop = connector._loop
  205. loop = get_running_loop(loop)
  206. if connector is None:
  207. connector = TCPConnector(loop=loop)
  208. if connector._loop is not loop:
  209. raise RuntimeError("Session and connector has to use same event loop")
  210. self._loop = loop
  211. if loop.get_debug():
  212. self._source_traceback = traceback.extract_stack(sys._getframe(1))
  213. if cookie_jar is None:
  214. cookie_jar = CookieJar(loop=loop)
  215. self._cookie_jar = cookie_jar
  216. if cookies is not None:
  217. self._cookie_jar.update_cookies(cookies)
  218. self._connector = connector # type: Optional[BaseConnector]
  219. self._connector_owner = connector_owner
  220. self._default_auth = auth
  221. self._version = version
  222. self._json_serialize = json_serialize
  223. if timeout is sentinel:
  224. self._timeout = DEFAULT_TIMEOUT
  225. if read_timeout is not sentinel:
  226. warnings.warn(
  227. "read_timeout is deprecated, " "use timeout argument instead",
  228. DeprecationWarning,
  229. stacklevel=2,
  230. )
  231. self._timeout = attr.evolve(self._timeout, total=read_timeout)
  232. if conn_timeout is not None:
  233. self._timeout = attr.evolve(self._timeout, connect=conn_timeout)
  234. warnings.warn(
  235. "conn_timeout is deprecated, " "use timeout argument instead",
  236. DeprecationWarning,
  237. stacklevel=2,
  238. )
  239. else:
  240. self._timeout = timeout # type: ignore
  241. if read_timeout is not sentinel:
  242. raise ValueError(
  243. "read_timeout and timeout parameters "
  244. "conflict, please setup "
  245. "timeout.read"
  246. )
  247. if conn_timeout is not None:
  248. raise ValueError(
  249. "conn_timeout and timeout parameters "
  250. "conflict, please setup "
  251. "timeout.connect"
  252. )
  253. self._raise_for_status = raise_for_status
  254. self._auto_decompress = auto_decompress
  255. self._trust_env = trust_env
  256. self._requote_redirect_url = requote_redirect_url
  257. self._read_bufsize = read_bufsize
  258. # Convert to list of tuples
  259. if headers:
  260. real_headers = CIMultiDict(headers) # type: CIMultiDict[str]
  261. else:
  262. real_headers = CIMultiDict()
  263. self._default_headers = real_headers # type: CIMultiDict[str]
  264. if skip_auto_headers is not None:
  265. self._skip_auto_headers = frozenset([istr(i) for i in skip_auto_headers])
  266. else:
  267. self._skip_auto_headers = frozenset()
  268. self._request_class = request_class
  269. self._response_class = response_class
  270. self._ws_response_class = ws_response_class
  271. self._trace_configs = trace_configs or []
  272. for trace_config in self._trace_configs:
  273. trace_config.freeze()
  274. def __init_subclass__(cls: Type["ClientSession"]) -> None:
  275. warnings.warn(
  276. "Inheritance class {} from ClientSession "
  277. "is discouraged".format(cls.__name__),
  278. DeprecationWarning,
  279. stacklevel=2,
  280. )
  281. if DEBUG:
  282. def __setattr__(self, name: str, val: Any) -> None:
  283. if name not in self.ATTRS:
  284. warnings.warn(
  285. "Setting custom ClientSession.{} attribute "
  286. "is discouraged".format(name),
  287. DeprecationWarning,
  288. stacklevel=2,
  289. )
  290. super().__setattr__(name, val)
  291. def __del__(self, _warnings: Any = warnings) -> None:
  292. if not self.closed:
  293. if PY_36:
  294. kwargs = {"source": self}
  295. else:
  296. kwargs = {}
  297. _warnings.warn(
  298. f"Unclosed client session {self!r}", ResourceWarning, **kwargs
  299. )
  300. context = {"client_session": self, "message": "Unclosed client session"}
  301. if self._source_traceback is not None:
  302. context["source_traceback"] = self._source_traceback
  303. self._loop.call_exception_handler(context)
  304. def request(
  305. self, method: str, url: StrOrURL, **kwargs: Any
  306. ) -> "_RequestContextManager":
  307. """Perform HTTP request."""
  308. return _RequestContextManager(self._request(method, url, **kwargs))
  309. async def _request(
  310. self,
  311. method: str,
  312. str_or_url: StrOrURL,
  313. *,
  314. params: Optional[Mapping[str, str]] = None,
  315. data: Any = None,
  316. json: Any = None,
  317. cookies: Optional[LooseCookies] = None,
  318. headers: Optional[LooseHeaders] = None,
  319. skip_auto_headers: Optional[Iterable[str]] = None,
  320. auth: Optional[BasicAuth] = None,
  321. allow_redirects: bool = True,
  322. max_redirects: int = 10,
  323. compress: Optional[str] = None,
  324. chunked: Optional[bool] = None,
  325. expect100: bool = False,
  326. raise_for_status: Optional[bool] = None,
  327. read_until_eof: bool = True,
  328. proxy: Optional[StrOrURL] = None,
  329. proxy_auth: Optional[BasicAuth] = None,
  330. timeout: Union[ClientTimeout, object] = sentinel,
  331. verify_ssl: Optional[bool] = None,
  332. fingerprint: Optional[bytes] = None,
  333. ssl_context: Optional[SSLContext] = None,
  334. ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
  335. proxy_headers: Optional[LooseHeaders] = None,
  336. trace_request_ctx: Optional[SimpleNamespace] = None,
  337. read_bufsize: Optional[int] = None,
  338. ) -> ClientResponse:
  339. # NOTE: timeout clamps existing connect and read timeouts. We cannot
  340. # set the default to None because we need to detect if the user wants
  341. # to use the existing timeouts by setting timeout to None.
  342. if self.closed:
  343. raise RuntimeError("Session is closed")
  344. ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
  345. if data is not None and json is not None:
  346. raise ValueError(
  347. "data and json parameters can not be used at the same time"
  348. )
  349. elif json is not None:
  350. data = payload.JsonPayload(json, dumps=self._json_serialize)
  351. if not isinstance(chunked, bool) and chunked is not None:
  352. warnings.warn("Chunk size is deprecated #1615", DeprecationWarning)
  353. redirects = 0
  354. history = []
  355. version = self._version
  356. # Merge with default headers and transform to CIMultiDict
  357. headers = self._prepare_headers(headers)
  358. proxy_headers = self._prepare_headers(proxy_headers)
  359. try:
  360. url = URL(str_or_url)
  361. except ValueError as e:
  362. raise InvalidURL(str_or_url) from e
  363. skip_headers = set(self._skip_auto_headers)
  364. if skip_auto_headers is not None:
  365. for i in skip_auto_headers:
  366. skip_headers.add(istr(i))
  367. if proxy is not None:
  368. try:
  369. proxy = URL(proxy)
  370. except ValueError as e:
  371. raise InvalidURL(proxy) from e
  372. if timeout is sentinel:
  373. real_timeout = self._timeout # type: ClientTimeout
  374. else:
  375. if not isinstance(timeout, ClientTimeout):
  376. real_timeout = ClientTimeout(total=timeout) # type: ignore
  377. else:
  378. real_timeout = timeout
  379. # timeout is cumulative for all request operations
  380. # (request, redirects, responses, data consuming)
  381. tm = TimeoutHandle(self._loop, real_timeout.total)
  382. handle = tm.start()
  383. if read_bufsize is None:
  384. read_bufsize = self._read_bufsize
  385. traces = [
  386. Trace(
  387. self,
  388. trace_config,
  389. trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx),
  390. )
  391. for trace_config in self._trace_configs
  392. ]
  393. for trace in traces:
  394. await trace.send_request_start(method, url, headers)
  395. timer = tm.timer()
  396. try:
  397. with timer:
  398. while True:
  399. url, auth_from_url = strip_auth_from_url(url)
  400. if auth and auth_from_url:
  401. raise ValueError(
  402. "Cannot combine AUTH argument with "
  403. "credentials encoded in URL"
  404. )
  405. if auth is None:
  406. auth = auth_from_url
  407. if auth is None:
  408. auth = self._default_auth
  409. # It would be confusing if we support explicit
  410. # Authorization header with auth argument
  411. if (
  412. headers is not None
  413. and auth is not None
  414. and hdrs.AUTHORIZATION in headers
  415. ):
  416. raise ValueError(
  417. "Cannot combine AUTHORIZATION header "
  418. "with AUTH argument or credentials "
  419. "encoded in URL"
  420. )
  421. all_cookies = self._cookie_jar.filter_cookies(url)
  422. if cookies is not None:
  423. tmp_cookie_jar = CookieJar()
  424. tmp_cookie_jar.update_cookies(cookies)
  425. req_cookies = tmp_cookie_jar.filter_cookies(url)
  426. if req_cookies:
  427. all_cookies.load(req_cookies)
  428. if proxy is not None:
  429. proxy = URL(proxy)
  430. elif self._trust_env:
  431. for scheme, proxy_info in proxies_from_env().items():
  432. if scheme == url.scheme:
  433. proxy = proxy_info.proxy
  434. proxy_auth = proxy_info.proxy_auth
  435. break
  436. req = self._request_class(
  437. method,
  438. url,
  439. params=params,
  440. headers=headers,
  441. skip_auto_headers=skip_headers,
  442. data=data,
  443. cookies=all_cookies,
  444. auth=auth,
  445. version=version,
  446. compress=compress,
  447. chunked=chunked,
  448. expect100=expect100,
  449. loop=self._loop,
  450. response_class=self._response_class,
  451. proxy=proxy,
  452. proxy_auth=proxy_auth,
  453. timer=timer,
  454. session=self,
  455. ssl=ssl,
  456. proxy_headers=proxy_headers,
  457. traces=traces,
  458. )
  459. # connection timeout
  460. try:
  461. with CeilTimeout(real_timeout.connect, loop=self._loop):
  462. assert self._connector is not None
  463. conn = await self._connector.connect(
  464. req, traces=traces, timeout=real_timeout
  465. )
  466. except asyncio.TimeoutError as exc:
  467. raise ServerTimeoutError(
  468. "Connection timeout " "to host {}".format(url)
  469. ) from exc
  470. assert conn.transport is not None
  471. assert conn.protocol is not None
  472. conn.protocol.set_response_params(
  473. timer=timer,
  474. skip_payload=method.upper() == "HEAD",
  475. read_until_eof=read_until_eof,
  476. auto_decompress=self._auto_decompress,
  477. read_timeout=real_timeout.sock_read,
  478. read_bufsize=read_bufsize,
  479. )
  480. try:
  481. try:
  482. resp = await req.send(conn)
  483. try:
  484. await resp.start(conn)
  485. except BaseException:
  486. resp.close()
  487. raise
  488. except BaseException:
  489. conn.close()
  490. raise
  491. except ClientError:
  492. raise
  493. except OSError as exc:
  494. raise ClientOSError(*exc.args) from exc
  495. self._cookie_jar.update_cookies(resp.cookies, resp.url)
  496. # redirects
  497. if resp.status in (301, 302, 303, 307, 308) and allow_redirects:
  498. for trace in traces:
  499. await trace.send_request_redirect(
  500. method, url, headers, resp
  501. )
  502. redirects += 1
  503. history.append(resp)
  504. if max_redirects and redirects >= max_redirects:
  505. resp.close()
  506. raise TooManyRedirects(
  507. history[0].request_info, tuple(history)
  508. )
  509. # For 301 and 302, mimic IE, now changed in RFC
  510. # https://github.com/kennethreitz/requests/pull/269
  511. if (resp.status == 303 and resp.method != hdrs.METH_HEAD) or (
  512. resp.status in (301, 302) and resp.method == hdrs.METH_POST
  513. ):
  514. method = hdrs.METH_GET
  515. data = None
  516. if headers.get(hdrs.CONTENT_LENGTH):
  517. headers.pop(hdrs.CONTENT_LENGTH)
  518. r_url = resp.headers.get(hdrs.LOCATION) or resp.headers.get(
  519. hdrs.URI
  520. )
  521. if r_url is None:
  522. # see github.com/aio-libs/aiohttp/issues/2022
  523. break
  524. else:
  525. # reading from correct redirection
  526. # response is forbidden
  527. resp.release()
  528. try:
  529. parsed_url = URL(
  530. r_url, encoded=not self._requote_redirect_url
  531. )
  532. except ValueError as e:
  533. raise InvalidURL(r_url) from e
  534. scheme = parsed_url.scheme
  535. if scheme not in ("http", "https", ""):
  536. resp.close()
  537. raise ValueError("Can redirect only to http or https")
  538. elif not scheme:
  539. parsed_url = url.join(parsed_url)
  540. if url.origin() != parsed_url.origin():
  541. auth = None
  542. headers.pop(hdrs.AUTHORIZATION, None)
  543. url = parsed_url
  544. params = None
  545. resp.release()
  546. continue
  547. break
  548. # check response status
  549. if raise_for_status is None:
  550. raise_for_status = self._raise_for_status
  551. if raise_for_status:
  552. resp.raise_for_status()
  553. # register connection
  554. if handle is not None:
  555. if resp.connection is not None:
  556. resp.connection.add_callback(handle.cancel)
  557. else:
  558. handle.cancel()
  559. resp._history = tuple(history)
  560. for trace in traces:
  561. await trace.send_request_end(method, url, headers, resp)
  562. return resp
  563. except BaseException as e:
  564. # cleanup timer
  565. tm.close()
  566. if handle:
  567. handle.cancel()
  568. handle = None
  569. for trace in traces:
  570. await trace.send_request_exception(method, url, headers, e)
  571. raise
  572. def ws_connect(
  573. self,
  574. url: StrOrURL,
  575. *,
  576. method: str = hdrs.METH_GET,
  577. protocols: Iterable[str] = (),
  578. timeout: float = 10.0,
  579. receive_timeout: Optional[float] = None,
  580. autoclose: bool = True,
  581. autoping: bool = True,
  582. heartbeat: Optional[float] = None,
  583. auth: Optional[BasicAuth] = None,
  584. origin: Optional[str] = None,
  585. headers: Optional[LooseHeaders] = None,
  586. proxy: Optional[StrOrURL] = None,
  587. proxy_auth: Optional[BasicAuth] = None,
  588. ssl: Union[SSLContext, bool, None, Fingerprint] = None,
  589. verify_ssl: Optional[bool] = None,
  590. fingerprint: Optional[bytes] = None,
  591. ssl_context: Optional[SSLContext] = None,
  592. proxy_headers: Optional[LooseHeaders] = None,
  593. compress: int = 0,
  594. max_msg_size: int = 4 * 1024 * 1024,
  595. ) -> "_WSRequestContextManager":
  596. """Initiate websocket connection."""
  597. return _WSRequestContextManager(
  598. self._ws_connect(
  599. url,
  600. method=method,
  601. protocols=protocols,
  602. timeout=timeout,
  603. receive_timeout=receive_timeout,
  604. autoclose=autoclose,
  605. autoping=autoping,
  606. heartbeat=heartbeat,
  607. auth=auth,
  608. origin=origin,
  609. headers=headers,
  610. proxy=proxy,
  611. proxy_auth=proxy_auth,
  612. ssl=ssl,
  613. verify_ssl=verify_ssl,
  614. fingerprint=fingerprint,
  615. ssl_context=ssl_context,
  616. proxy_headers=proxy_headers,
  617. compress=compress,
  618. max_msg_size=max_msg_size,
  619. )
  620. )
  621. async def _ws_connect(
  622. self,
  623. url: StrOrURL,
  624. *,
  625. method: str = hdrs.METH_GET,
  626. protocols: Iterable[str] = (),
  627. timeout: float = 10.0,
  628. receive_timeout: Optional[float] = None,
  629. autoclose: bool = True,
  630. autoping: bool = True,
  631. heartbeat: Optional[float] = None,
  632. auth: Optional[BasicAuth] = None,
  633. origin: Optional[str] = None,
  634. headers: Optional[LooseHeaders] = None,
  635. proxy: Optional[StrOrURL] = None,
  636. proxy_auth: Optional[BasicAuth] = None,
  637. ssl: Union[SSLContext, bool, None, Fingerprint] = None,
  638. verify_ssl: Optional[bool] = None,
  639. fingerprint: Optional[bytes] = None,
  640. ssl_context: Optional[SSLContext] = None,
  641. proxy_headers: Optional[LooseHeaders] = None,
  642. compress: int = 0,
  643. max_msg_size: int = 4 * 1024 * 1024,
  644. ) -> ClientWebSocketResponse:
  645. if headers is None:
  646. real_headers = CIMultiDict() # type: CIMultiDict[str]
  647. else:
  648. real_headers = CIMultiDict(headers)
  649. default_headers = {
  650. hdrs.UPGRADE: "websocket",
  651. hdrs.CONNECTION: "upgrade",
  652. hdrs.SEC_WEBSOCKET_VERSION: "13",
  653. }
  654. for key, value in default_headers.items():
  655. real_headers.setdefault(key, value)
  656. sec_key = base64.b64encode(os.urandom(16))
  657. real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode()
  658. if protocols:
  659. real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ",".join(protocols)
  660. if origin is not None:
  661. real_headers[hdrs.ORIGIN] = origin
  662. if compress:
  663. extstr = ws_ext_gen(compress=compress)
  664. real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr
  665. ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
  666. # send request
  667. resp = await self.request(
  668. method,
  669. url,
  670. headers=real_headers,
  671. read_until_eof=False,
  672. auth=auth,
  673. proxy=proxy,
  674. proxy_auth=proxy_auth,
  675. ssl=ssl,
  676. proxy_headers=proxy_headers,
  677. )
  678. try:
  679. # check handshake
  680. if resp.status != 101:
  681. raise WSServerHandshakeError(
  682. resp.request_info,
  683. resp.history,
  684. message="Invalid response status",
  685. status=resp.status,
  686. headers=resp.headers,
  687. )
  688. if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket":
  689. raise WSServerHandshakeError(
  690. resp.request_info,
  691. resp.history,
  692. message="Invalid upgrade header",
  693. status=resp.status,
  694. headers=resp.headers,
  695. )
  696. if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade":
  697. raise WSServerHandshakeError(
  698. resp.request_info,
  699. resp.history,
  700. message="Invalid connection header",
  701. status=resp.status,
  702. headers=resp.headers,
  703. )
  704. # key calculation
  705. r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "")
  706. match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
  707. if r_key != match:
  708. raise WSServerHandshakeError(
  709. resp.request_info,
  710. resp.history,
  711. message="Invalid challenge response",
  712. status=resp.status,
  713. headers=resp.headers,
  714. )
  715. # websocket protocol
  716. protocol = None
  717. if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
  718. resp_protocols = [
  719. proto.strip()
  720. for proto in resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
  721. ]
  722. for proto in resp_protocols:
  723. if proto in protocols:
  724. protocol = proto
  725. break
  726. # websocket compress
  727. notakeover = False
  728. if compress:
  729. compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
  730. if compress_hdrs:
  731. try:
  732. compress, notakeover = ws_ext_parse(compress_hdrs)
  733. except WSHandshakeError as exc:
  734. raise WSServerHandshakeError(
  735. resp.request_info,
  736. resp.history,
  737. message=exc.args[0],
  738. status=resp.status,
  739. headers=resp.headers,
  740. ) from exc
  741. else:
  742. compress = 0
  743. notakeover = False
  744. conn = resp.connection
  745. assert conn is not None
  746. conn_proto = conn.protocol
  747. assert conn_proto is not None
  748. transport = conn.transport
  749. assert transport is not None
  750. reader = FlowControlDataQueue(
  751. conn_proto, 2 ** 16, loop=self._loop
  752. ) # type: FlowControlDataQueue[WSMessage]
  753. conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
  754. writer = WebSocketWriter(
  755. conn_proto,
  756. transport,
  757. use_mask=True,
  758. compress=compress,
  759. notakeover=notakeover,
  760. )
  761. except BaseException:
  762. resp.close()
  763. raise
  764. else:
  765. return self._ws_response_class(
  766. reader,
  767. writer,
  768. protocol,
  769. resp,
  770. timeout,
  771. autoclose,
  772. autoping,
  773. self._loop,
  774. receive_timeout=receive_timeout,
  775. heartbeat=heartbeat,
  776. compress=compress,
  777. client_notakeover=notakeover,
  778. )
  779. def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]":
  780. """Add default headers and transform it to CIMultiDict"""
  781. # Convert headers to MultiDict
  782. result = CIMultiDict(self._default_headers)
  783. if headers:
  784. if not isinstance(headers, (MultiDictProxy, MultiDict)):
  785. headers = CIMultiDict(headers)
  786. added_names = set() # type: Set[str]
  787. for key, value in headers.items():
  788. if key in added_names:
  789. result.add(key, value)
  790. else:
  791. result[key] = value
  792. added_names.add(key)
  793. return result
  794. def get(
  795. self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
  796. ) -> "_RequestContextManager":
  797. """Perform HTTP GET request."""
  798. return _RequestContextManager(
  799. self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs)
  800. )
  801. def options(
  802. self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
  803. ) -> "_RequestContextManager":
  804. """Perform HTTP OPTIONS request."""
  805. return _RequestContextManager(
  806. self._request(
  807. hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs
  808. )
  809. )
  810. def head(
  811. self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any
  812. ) -> "_RequestContextManager":
  813. """Perform HTTP HEAD request."""
  814. return _RequestContextManager(
  815. self._request(
  816. hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs
  817. )
  818. )
  819. def post(
  820. self, url: StrOrURL, *, data: Any = None, **kwargs: Any
  821. ) -> "_RequestContextManager":
  822. """Perform HTTP POST request."""
  823. return _RequestContextManager(
  824. self._request(hdrs.METH_POST, url, data=data, **kwargs)
  825. )
  826. def put(
  827. self, url: StrOrURL, *, data: Any = None, **kwargs: Any
  828. ) -> "_RequestContextManager":
  829. """Perform HTTP PUT request."""
  830. return _RequestContextManager(
  831. self._request(hdrs.METH_PUT, url, data=data, **kwargs)
  832. )
  833. def patch(
  834. self, url: StrOrURL, *, data: Any = None, **kwargs: Any
  835. ) -> "_RequestContextManager":
  836. """Perform HTTP PATCH request."""
  837. return _RequestContextManager(
  838. self._request(hdrs.METH_PATCH, url, data=data, **kwargs)
  839. )
  840. def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager":
  841. """Perform HTTP DELETE request."""
  842. return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs))
  843. async def close(self) -> None:
  844. """Close underlying connector.
  845. Release all acquired resources.
  846. """
  847. if not self.closed:
  848. if self._connector is not None and self._connector_owner:
  849. await self._connector.close()
  850. self._connector = None
  851. @property
  852. def closed(self) -> bool:
  853. """Is client session closed.
  854. A readonly property.
  855. """
  856. return self._connector is None or self._connector.closed
  857. @property
  858. def connector(self) -> Optional[BaseConnector]:
  859. """Connector instance used for the session."""
  860. return self._connector
  861. @property
  862. def cookie_jar(self) -> AbstractCookieJar:
  863. """The session cookies."""
  864. return self._cookie_jar
  865. @property
  866. def version(self) -> Tuple[int, int]:
  867. """The session HTTP protocol version."""
  868. return self._version
  869. @property
  870. def requote_redirect_url(self) -> bool:
  871. """Do URL requoting on redirection handling."""
  872. return self._requote_redirect_url
  873. @requote_redirect_url.setter
  874. def requote_redirect_url(self, val: bool) -> None:
  875. """Do URL requoting on redirection handling."""
  876. warnings.warn(
  877. "session.requote_redirect_url modification " "is deprecated #2778",
  878. DeprecationWarning,
  879. stacklevel=2,
  880. )
  881. self._requote_redirect_url = val
  882. @property
  883. def loop(self) -> asyncio.AbstractEventLoop:
  884. """Session's loop."""
  885. warnings.warn(
  886. "client.loop property is deprecated", DeprecationWarning, stacklevel=2
  887. )
  888. return self._loop
  889. @property
  890. def timeout(self) -> Union[object, ClientTimeout]:
  891. """Timeout for the session."""
  892. return self._timeout
  893. @property
  894. def headers(self) -> "CIMultiDict[str]":
  895. """The default headers of the client session."""
  896. return self._default_headers
  897. @property
  898. def skip_auto_headers(self) -> FrozenSet[istr]:
  899. """Headers for which autogeneration should be skipped"""
  900. return self._skip_auto_headers
  901. @property
  902. def auth(self) -> Optional[BasicAuth]:
  903. """An object that represents HTTP Basic Authorization"""
  904. return self._default_auth
  905. @property
  906. def json_serialize(self) -> JSONEncoder:
  907. """Json serializer callable"""
  908. return self._json_serialize
  909. @property
  910. def connector_owner(self) -> bool:
  911. """Should connector be closed on session closing"""
  912. return self._connector_owner
  913. @property
  914. def raise_for_status(
  915. self,
  916. ) -> Union[bool, Callable[[ClientResponse], Awaitable[None]]]:
  917. """
  918. Should `ClientResponse.raise_for_status()`
  919. be called for each response
  920. """
  921. return self._raise_for_status
  922. @property
  923. def auto_decompress(self) -> bool:
  924. """Should the body response be automatically decompressed"""
  925. return self._auto_decompress
  926. @property
  927. def trust_env(self) -> bool:
  928. """
  929. Should get proxies information
  930. from HTTP_PROXY / HTTPS_PROXY environment variables
  931. or ~/.netrc file if present
  932. """
  933. return self._trust_env
  934. @property
  935. def trace_configs(self) -> List[TraceConfig]:
  936. """A list of TraceConfig instances used for client tracing"""
  937. return self._trace_configs
  938. def detach(self) -> None:
  939. """Detach connector from session without closing the former.
  940. Session is switched to closed state anyway.
  941. """
  942. self._connector = None
  943. def __enter__(self) -> None:
  944. raise TypeError("Use async with instead")
  945. def __exit__(
  946. self,
  947. exc_type: Optional[Type[BaseException]],
  948. exc_val: Optional[BaseException],
  949. exc_tb: Optional[TracebackType],
  950. ) -> None:
  951. # __exit__ should exist in pair with __enter__ but never executed
  952. pass # pragma: no cover
  953. async def __aenter__(self) -> "ClientSession":
  954. return self
  955. async def __aexit__(
  956. self,
  957. exc_type: Optional[Type[BaseException]],
  958. exc_val: Optional[BaseException],
  959. exc_tb: Optional[TracebackType],
  960. ) -> None:
  961. await self.close()
  962. class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]):
  963. __slots__ = ("_coro", "_resp")
  964. def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
  965. self._coro = coro
  966. def send(self, arg: None) -> "asyncio.Future[Any]":
  967. return self._coro.send(arg)
  968. def throw(self, arg: BaseException) -> None: # type: ignore
  969. self._coro.throw(arg)
  970. def close(self) -> None:
  971. return self._coro.close()
  972. def __await__(self) -> Generator[Any, None, _RetType]:
  973. ret = self._coro.__await__()
  974. return ret
  975. def __iter__(self) -> Generator[Any, None, _RetType]:
  976. return self.__await__()
  977. async def __aenter__(self) -> _RetType:
  978. self._resp = await self._coro
  979. return self._resp
  980. class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
  981. async def __aexit__(
  982. self,
  983. exc_type: Optional[Type[BaseException]],
  984. exc: Optional[BaseException],
  985. tb: Optional[TracebackType],
  986. ) -> None:
  987. # We're basing behavior on the exception as it can be caused by
  988. # user code unrelated to the status of the connection. If you
  989. # would like to close a connection you must do that
  990. # explicitly. Otherwise connection error handling should kick in
  991. # and close/recycle the connection as required.
  992. self._resp.release()
  993. class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
  994. async def __aexit__(
  995. self,
  996. exc_type: Optional[Type[BaseException]],
  997. exc: Optional[BaseException],
  998. tb: Optional[TracebackType],
  999. ) -> None:
  1000. await self._resp.close()
  1001. class _SessionRequestContextManager:
  1002. __slots__ = ("_coro", "_resp", "_session")
  1003. def __init__(
  1004. self,
  1005. coro: Coroutine["asyncio.Future[Any]", None, ClientResponse],
  1006. session: ClientSession,
  1007. ) -> None:
  1008. self._coro = coro
  1009. self._resp = None # type: Optional[ClientResponse]
  1010. self._session = session
  1011. async def __aenter__(self) -> ClientResponse:
  1012. try:
  1013. self._resp = await self._coro
  1014. except BaseException:
  1015. await self._session.close()
  1016. raise
  1017. else:
  1018. return self._resp
  1019. async def __aexit__(
  1020. self,
  1021. exc_type: Optional[Type[BaseException]],
  1022. exc: Optional[BaseException],
  1023. tb: Optional[TracebackType],
  1024. ) -> None:
  1025. assert self._resp is not None
  1026. self._resp.close()
  1027. await self._session.close()
  1028. def request(
  1029. method: str,
  1030. url: StrOrURL,
  1031. *,
  1032. params: Optional[Mapping[str, str]] = None,
  1033. data: Any = None,
  1034. json: Any = None,
  1035. headers: Optional[LooseHeaders] = None,
  1036. skip_auto_headers: Optional[Iterable[str]] = None,
  1037. auth: Optional[BasicAuth] = None,
  1038. allow_redirects: bool = True,
  1039. max_redirects: int = 10,
  1040. compress: Optional[str] = None,
  1041. chunked: Optional[bool] = None,
  1042. expect100: bool = False,
  1043. raise_for_status: Optional[bool] = None,
  1044. read_until_eof: bool = True,
  1045. proxy: Optional[StrOrURL] = None,
  1046. proxy_auth: Optional[BasicAuth] = None,
  1047. timeout: Union[ClientTimeout, object] = sentinel,
  1048. cookies: Optional[LooseCookies] = None,
  1049. version: HttpVersion = http.HttpVersion11,
  1050. connector: Optional[BaseConnector] = None,
  1051. read_bufsize: Optional[int] = None,
  1052. loop: Optional[asyncio.AbstractEventLoop] = None,
  1053. ) -> _SessionRequestContextManager:
  1054. """Constructs and sends a request. Returns response object.
  1055. method - HTTP method
  1056. url - request url
  1057. params - (optional) Dictionary or bytes to be sent in the query
  1058. string of the new request
  1059. data - (optional) Dictionary, bytes, or file-like object to
  1060. send in the body of the request
  1061. json - (optional) Any json compatible python object
  1062. headers - (optional) Dictionary of HTTP Headers to send with
  1063. the request
  1064. cookies - (optional) Dict object to send with the request
  1065. auth - (optional) BasicAuth named tuple represent HTTP Basic Auth
  1066. auth - aiohttp.helpers.BasicAuth
  1067. allow_redirects - (optional) If set to False, do not follow
  1068. redirects
  1069. version - Request HTTP version.
  1070. compress - Set to True if request has to be compressed
  1071. with deflate encoding.
  1072. chunked - Set to chunk size for chunked transfer encoding.
  1073. expect100 - Expect 100-continue response from server.
  1074. connector - BaseConnector sub-class instance to support
  1075. connection pooling.
  1076. read_until_eof - Read response until eof if response
  1077. does not have Content-Length header.
  1078. loop - Optional event loop.
  1079. timeout - Optional ClientTimeout settings structure, 5min
  1080. total timeout by default.
  1081. Usage::
  1082. >>> import aiohttp
  1083. >>> resp = await aiohttp.request('GET', 'http://python.org/')
  1084. >>> resp
  1085. <ClientResponse(python.org/) [200]>
  1086. >>> data = await resp.read()
  1087. """
  1088. connector_owner = False
  1089. if connector is None:
  1090. connector_owner = True
  1091. connector = TCPConnector(loop=loop, force_close=True)
  1092. session = ClientSession(
  1093. loop=loop,
  1094. cookies=cookies,
  1095. version=version,
  1096. timeout=timeout,
  1097. connector=connector,
  1098. connector_owner=connector_owner,
  1099. )
  1100. return _SessionRequestContextManager(
  1101. session._request(
  1102. method,
  1103. url,
  1104. params=params,
  1105. data=data,
  1106. json=json,
  1107. headers=headers,
  1108. skip_auto_headers=skip_auto_headers,
  1109. auth=auth,
  1110. allow_redirects=allow_redirects,
  1111. max_redirects=max_redirects,
  1112. compress=compress,
  1113. chunked=chunked,
  1114. expect100=expect100,
  1115. raise_for_status=raise_for_status,
  1116. read_until_eof=read_until_eof,
  1117. proxy=proxy,
  1118. proxy_auth=proxy_auth,
  1119. read_bufsize=read_bufsize,
  1120. ),
  1121. session,
  1122. )