cookiejar.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import asyncio
  2. import datetime
  3. import os # noqa
  4. import pathlib
  5. import pickle
  6. import re
  7. from collections import defaultdict
  8. from http.cookies import BaseCookie, Morsel, SimpleCookie
  9. from typing import ( # noqa
  10. DefaultDict,
  11. Dict,
  12. Iterable,
  13. Iterator,
  14. Mapping,
  15. Optional,
  16. Set,
  17. Tuple,
  18. Union,
  19. cast,
  20. )
  21. from yarl import URL
  22. from .abc import AbstractCookieJar
  23. from .helpers import is_ip_address, next_whole_second
  24. from .typedefs import LooseCookies, PathLike
  25. __all__ = ("CookieJar", "DummyCookieJar")
  26. CookieItem = Union[str, "Morsel[str]"]
  27. class CookieJar(AbstractCookieJar):
  28. """Implements cookie storage adhering to RFC 6265."""
  29. DATE_TOKENS_RE = re.compile(
  30. r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
  31. r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
  32. )
  33. DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
  34. DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
  35. DATE_MONTH_RE = re.compile(
  36. "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)",
  37. re.I,
  38. )
  39. DATE_YEAR_RE = re.compile(r"(\d{2,4})")
  40. MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)
  41. MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2 ** 31 - 1)
  42. def __init__(
  43. self,
  44. *,
  45. unsafe: bool = False,
  46. quote_cookie: bool = True,
  47. loop: Optional[asyncio.AbstractEventLoop] = None
  48. ) -> None:
  49. super().__init__(loop=loop)
  50. self._cookies = defaultdict(
  51. SimpleCookie
  52. ) # type: DefaultDict[str, SimpleCookie[str]]
  53. self._host_only_cookies = set() # type: Set[Tuple[str, str]]
  54. self._unsafe = unsafe
  55. self._quote_cookie = quote_cookie
  56. self._next_expiration = next_whole_second()
  57. self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime]
  58. # #4515: datetime.max may not be representable on 32-bit platforms
  59. self._max_time = self.MAX_TIME
  60. try:
  61. self._max_time.timestamp()
  62. except OverflowError:
  63. self._max_time = self.MAX_32BIT_TIME
  64. def save(self, file_path: PathLike) -> None:
  65. file_path = pathlib.Path(file_path)
  66. with file_path.open(mode="wb") as f:
  67. pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
  68. def load(self, file_path: PathLike) -> None:
  69. file_path = pathlib.Path(file_path)
  70. with file_path.open(mode="rb") as f:
  71. self._cookies = pickle.load(f)
  72. def clear(self) -> None:
  73. self._cookies.clear()
  74. self._host_only_cookies.clear()
  75. self._next_expiration = next_whole_second()
  76. self._expirations.clear()
  77. def __iter__(self) -> "Iterator[Morsel[str]]":
  78. self._do_expiration()
  79. for val in self._cookies.values():
  80. yield from val.values()
  81. def __len__(self) -> int:
  82. return sum(1 for i in self)
  83. def _do_expiration(self) -> None:
  84. now = datetime.datetime.now(datetime.timezone.utc)
  85. if self._next_expiration > now:
  86. return
  87. if not self._expirations:
  88. return
  89. next_expiration = self._max_time
  90. to_del = []
  91. cookies = self._cookies
  92. expirations = self._expirations
  93. for (domain, name), when in expirations.items():
  94. if when <= now:
  95. cookies[domain].pop(name, None)
  96. to_del.append((domain, name))
  97. self._host_only_cookies.discard((domain, name))
  98. else:
  99. next_expiration = min(next_expiration, when)
  100. for key in to_del:
  101. del expirations[key]
  102. try:
  103. self._next_expiration = next_expiration.replace(
  104. microsecond=0
  105. ) + datetime.timedelta(seconds=1)
  106. except OverflowError:
  107. self._next_expiration = self._max_time
  108. def _expire_cookie(self, when: datetime.datetime, domain: str, name: str) -> None:
  109. self._next_expiration = min(self._next_expiration, when)
  110. self._expirations[(domain, name)] = when
  111. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  112. """Update cookies."""
  113. hostname = response_url.raw_host
  114. if not self._unsafe and is_ip_address(hostname):
  115. # Don't accept cookies from IPs
  116. return
  117. if isinstance(cookies, Mapping):
  118. cookies = cookies.items()
  119. for name, cookie in cookies:
  120. if not isinstance(cookie, Morsel):
  121. tmp = SimpleCookie() # type: SimpleCookie[str]
  122. tmp[name] = cookie # type: ignore
  123. cookie = tmp[name]
  124. domain = cookie["domain"]
  125. # ignore domains with trailing dots
  126. if domain.endswith("."):
  127. domain = ""
  128. del cookie["domain"]
  129. if not domain and hostname is not None:
  130. # Set the cookie's domain to the response hostname
  131. # and set its host-only-flag
  132. self._host_only_cookies.add((hostname, name))
  133. domain = cookie["domain"] = hostname
  134. if domain.startswith("."):
  135. # Remove leading dot
  136. domain = domain[1:]
  137. cookie["domain"] = domain
  138. if hostname and not self._is_domain_match(domain, hostname):
  139. # Setting cookies for different domains is not allowed
  140. continue
  141. path = cookie["path"]
  142. if not path or not path.startswith("/"):
  143. # Set the cookie's path to the response path
  144. path = response_url.path
  145. if not path.startswith("/"):
  146. path = "/"
  147. else:
  148. # Cut everything from the last slash to the end
  149. path = "/" + path[1 : path.rfind("/")]
  150. cookie["path"] = path
  151. max_age = cookie["max-age"]
  152. if max_age:
  153. try:
  154. delta_seconds = int(max_age)
  155. try:
  156. max_age_expiration = datetime.datetime.now(
  157. datetime.timezone.utc
  158. ) + datetime.timedelta(seconds=delta_seconds)
  159. except OverflowError:
  160. max_age_expiration = self._max_time
  161. self._expire_cookie(max_age_expiration, domain, name)
  162. except ValueError:
  163. cookie["max-age"] = ""
  164. else:
  165. expires = cookie["expires"]
  166. if expires:
  167. expire_time = self._parse_date(expires)
  168. if expire_time:
  169. self._expire_cookie(expire_time, domain, name)
  170. else:
  171. cookie["expires"] = ""
  172. self._cookies[domain][name] = cookie
  173. self._do_expiration()
  174. def filter_cookies(
  175. self, request_url: URL = URL()
  176. ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]:
  177. """Returns this jar's cookies filtered by their attributes."""
  178. self._do_expiration()
  179. request_url = URL(request_url)
  180. filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = (
  181. SimpleCookie() if self._quote_cookie else BaseCookie()
  182. )
  183. hostname = request_url.raw_host or ""
  184. is_not_secure = request_url.scheme not in ("https", "wss")
  185. for cookie in self:
  186. name = cookie.key
  187. domain = cookie["domain"]
  188. # Send shared cookies
  189. if not domain:
  190. filtered[name] = cookie.value
  191. continue
  192. if not self._unsafe and is_ip_address(hostname):
  193. continue
  194. if (domain, name) in self._host_only_cookies:
  195. if domain != hostname:
  196. continue
  197. elif not self._is_domain_match(domain, hostname):
  198. continue
  199. if not self._is_path_match(request_url.path, cookie["path"]):
  200. continue
  201. if is_not_secure and cookie["secure"]:
  202. continue
  203. # It's critical we use the Morsel so the coded_value
  204. # (based on cookie version) is preserved
  205. mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
  206. mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
  207. filtered[name] = mrsl_val
  208. return filtered
  209. @staticmethod
  210. def _is_domain_match(domain: str, hostname: str) -> bool:
  211. """Implements domain matching adhering to RFC 6265."""
  212. if hostname == domain:
  213. return True
  214. if not hostname.endswith(domain):
  215. return False
  216. non_matching = hostname[: -len(domain)]
  217. if not non_matching.endswith("."):
  218. return False
  219. return not is_ip_address(hostname)
  220. @staticmethod
  221. def _is_path_match(req_path: str, cookie_path: str) -> bool:
  222. """Implements path matching adhering to RFC 6265."""
  223. if not req_path.startswith("/"):
  224. req_path = "/"
  225. if req_path == cookie_path:
  226. return True
  227. if not req_path.startswith(cookie_path):
  228. return False
  229. if cookie_path.endswith("/"):
  230. return True
  231. non_matching = req_path[len(cookie_path) :]
  232. return non_matching.startswith("/")
  233. @classmethod
  234. def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
  235. """Implements date string parsing adhering to RFC 6265."""
  236. if not date_str:
  237. return None
  238. found_time = False
  239. found_day = False
  240. found_month = False
  241. found_year = False
  242. hour = minute = second = 0
  243. day = 0
  244. month = 0
  245. year = 0
  246. for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
  247. token = token_match.group("token")
  248. if not found_time:
  249. time_match = cls.DATE_HMS_TIME_RE.match(token)
  250. if time_match:
  251. found_time = True
  252. hour, minute, second = [int(s) for s in time_match.groups()]
  253. continue
  254. if not found_day:
  255. day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
  256. if day_match:
  257. found_day = True
  258. day = int(day_match.group())
  259. continue
  260. if not found_month:
  261. month_match = cls.DATE_MONTH_RE.match(token)
  262. if month_match:
  263. found_month = True
  264. assert month_match.lastindex is not None
  265. month = month_match.lastindex
  266. continue
  267. if not found_year:
  268. year_match = cls.DATE_YEAR_RE.match(token)
  269. if year_match:
  270. found_year = True
  271. year = int(year_match.group())
  272. if 70 <= year <= 99:
  273. year += 1900
  274. elif 0 <= year <= 69:
  275. year += 2000
  276. if False in (found_day, found_month, found_year, found_time):
  277. return None
  278. if not 1 <= day <= 31:
  279. return None
  280. if year < 1601 or hour > 23 or minute > 59 or second > 59:
  281. return None
  282. return datetime.datetime(
  283. year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc
  284. )
  285. class DummyCookieJar(AbstractCookieJar):
  286. """Implements a dummy cookie storage.
  287. It can be used with the ClientSession when no cookie processing is needed.
  288. """
  289. def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  290. super().__init__(loop=loop)
  291. def __iter__(self) -> "Iterator[Morsel[str]]":
  292. while False:
  293. yield None
  294. def __len__(self) -> int:
  295. return 0
  296. def clear(self) -> None:
  297. pass
  298. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  299. pass
  300. def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
  301. return SimpleCookie()