123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- import asyncio
- import datetime
- import os # noqa
- import pathlib
- import pickle
- import re
- from collections import defaultdict
- from http.cookies import BaseCookie, Morsel, SimpleCookie
- from typing import ( # noqa
- DefaultDict,
- Dict,
- Iterable,
- Iterator,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
- cast,
- )
- from yarl import URL
- from .abc import AbstractCookieJar
- from .helpers import is_ip_address, next_whole_second
- from .typedefs import LooseCookies, PathLike
- __all__ = ("CookieJar", "DummyCookieJar")
- CookieItem = Union[str, "Morsel[str]"]
- class CookieJar(AbstractCookieJar):
- """Implements cookie storage adhering to RFC 6265."""
- DATE_TOKENS_RE = re.compile(
- r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
- r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
- )
- DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
- DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
- DATE_MONTH_RE = re.compile(
- "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)",
- re.I,
- )
- DATE_YEAR_RE = re.compile(r"(\d{2,4})")
- MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)
- MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2 ** 31 - 1)
- def __init__(
- self,
- *,
- unsafe: bool = False,
- quote_cookie: bool = True,
- loop: Optional[asyncio.AbstractEventLoop] = None
- ) -> None:
- super().__init__(loop=loop)
- self._cookies = defaultdict(
- SimpleCookie
- ) # type: DefaultDict[str, SimpleCookie[str]]
- self._host_only_cookies = set() # type: Set[Tuple[str, str]]
- self._unsafe = unsafe
- self._quote_cookie = quote_cookie
- self._next_expiration = next_whole_second()
- self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime]
- # #4515: datetime.max may not be representable on 32-bit platforms
- self._max_time = self.MAX_TIME
- try:
- self._max_time.timestamp()
- except OverflowError:
- self._max_time = self.MAX_32BIT_TIME
- def save(self, file_path: PathLike) -> None:
- file_path = pathlib.Path(file_path)
- with file_path.open(mode="wb") as f:
- pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
- def load(self, file_path: PathLike) -> None:
- file_path = pathlib.Path(file_path)
- with file_path.open(mode="rb") as f:
- self._cookies = pickle.load(f)
- def clear(self) -> None:
- self._cookies.clear()
- self._host_only_cookies.clear()
- self._next_expiration = next_whole_second()
- self._expirations.clear()
- def __iter__(self) -> "Iterator[Morsel[str]]":
- self._do_expiration()
- for val in self._cookies.values():
- yield from val.values()
- def __len__(self) -> int:
- return sum(1 for i in self)
- def _do_expiration(self) -> None:
- now = datetime.datetime.now(datetime.timezone.utc)
- if self._next_expiration > now:
- return
- if not self._expirations:
- return
- next_expiration = self._max_time
- to_del = []
- cookies = self._cookies
- expirations = self._expirations
- for (domain, name), when in expirations.items():
- if when <= now:
- cookies[domain].pop(name, None)
- to_del.append((domain, name))
- self._host_only_cookies.discard((domain, name))
- else:
- next_expiration = min(next_expiration, when)
- for key in to_del:
- del expirations[key]
- try:
- self._next_expiration = next_expiration.replace(
- microsecond=0
- ) + datetime.timedelta(seconds=1)
- except OverflowError:
- self._next_expiration = self._max_time
- def _expire_cookie(self, when: datetime.datetime, domain: str, name: str) -> None:
- self._next_expiration = min(self._next_expiration, when)
- self._expirations[(domain, name)] = when
- def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
- """Update cookies."""
- hostname = response_url.raw_host
- if not self._unsafe and is_ip_address(hostname):
- # Don't accept cookies from IPs
- return
- if isinstance(cookies, Mapping):
- cookies = cookies.items()
- for name, cookie in cookies:
- if not isinstance(cookie, Morsel):
- tmp = SimpleCookie() # type: SimpleCookie[str]
- tmp[name] = cookie # type: ignore
- cookie = tmp[name]
- domain = cookie["domain"]
- # ignore domains with trailing dots
- if domain.endswith("."):
- domain = ""
- del cookie["domain"]
- if not domain and hostname is not None:
- # Set the cookie's domain to the response hostname
- # and set its host-only-flag
- self._host_only_cookies.add((hostname, name))
- domain = cookie["domain"] = hostname
- if domain.startswith("."):
- # Remove leading dot
- domain = domain[1:]
- cookie["domain"] = domain
- if hostname and not self._is_domain_match(domain, hostname):
- # Setting cookies for different domains is not allowed
- continue
- path = cookie["path"]
- if not path or not path.startswith("/"):
- # Set the cookie's path to the response path
- path = response_url.path
- if not path.startswith("/"):
- path = "/"
- else:
- # Cut everything from the last slash to the end
- path = "/" + path[1 : path.rfind("/")]
- cookie["path"] = path
- max_age = cookie["max-age"]
- if max_age:
- try:
- delta_seconds = int(max_age)
- try:
- max_age_expiration = datetime.datetime.now(
- datetime.timezone.utc
- ) + datetime.timedelta(seconds=delta_seconds)
- except OverflowError:
- max_age_expiration = self._max_time
- self._expire_cookie(max_age_expiration, domain, name)
- except ValueError:
- cookie["max-age"] = ""
- else:
- expires = cookie["expires"]
- if expires:
- expire_time = self._parse_date(expires)
- if expire_time:
- self._expire_cookie(expire_time, domain, name)
- else:
- cookie["expires"] = ""
- self._cookies[domain][name] = cookie
- self._do_expiration()
- def filter_cookies(
- self, request_url: URL = URL()
- ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]:
- """Returns this jar's cookies filtered by their attributes."""
- self._do_expiration()
- request_url = URL(request_url)
- filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = (
- SimpleCookie() if self._quote_cookie else BaseCookie()
- )
- hostname = request_url.raw_host or ""
- is_not_secure = request_url.scheme not in ("https", "wss")
- for cookie in self:
- name = cookie.key
- domain = cookie["domain"]
- # Send shared cookies
- if not domain:
- filtered[name] = cookie.value
- continue
- if not self._unsafe and is_ip_address(hostname):
- continue
- if (domain, name) in self._host_only_cookies:
- if domain != hostname:
- continue
- elif not self._is_domain_match(domain, hostname):
- continue
- if not self._is_path_match(request_url.path, cookie["path"]):
- continue
- if is_not_secure and cookie["secure"]:
- continue
- # It's critical we use the Morsel so the coded_value
- # (based on cookie version) is preserved
- mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
- mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
- filtered[name] = mrsl_val
- return filtered
- @staticmethod
- def _is_domain_match(domain: str, hostname: str) -> bool:
- """Implements domain matching adhering to RFC 6265."""
- if hostname == domain:
- return True
- if not hostname.endswith(domain):
- return False
- non_matching = hostname[: -len(domain)]
- if not non_matching.endswith("."):
- return False
- return not is_ip_address(hostname)
- @staticmethod
- def _is_path_match(req_path: str, cookie_path: str) -> bool:
- """Implements path matching adhering to RFC 6265."""
- if not req_path.startswith("/"):
- req_path = "/"
- if req_path == cookie_path:
- return True
- if not req_path.startswith(cookie_path):
- return False
- if cookie_path.endswith("/"):
- return True
- non_matching = req_path[len(cookie_path) :]
- return non_matching.startswith("/")
- @classmethod
- def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
- """Implements date string parsing adhering to RFC 6265."""
- if not date_str:
- return None
- found_time = False
- found_day = False
- found_month = False
- found_year = False
- hour = minute = second = 0
- day = 0
- month = 0
- year = 0
- for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
- token = token_match.group("token")
- if not found_time:
- time_match = cls.DATE_HMS_TIME_RE.match(token)
- if time_match:
- found_time = True
- hour, minute, second = [int(s) for s in time_match.groups()]
- continue
- if not found_day:
- day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
- if day_match:
- found_day = True
- day = int(day_match.group())
- continue
- if not found_month:
- month_match = cls.DATE_MONTH_RE.match(token)
- if month_match:
- found_month = True
- assert month_match.lastindex is not None
- month = month_match.lastindex
- continue
- if not found_year:
- year_match = cls.DATE_YEAR_RE.match(token)
- if year_match:
- found_year = True
- year = int(year_match.group())
- if 70 <= year <= 99:
- year += 1900
- elif 0 <= year <= 69:
- year += 2000
- if False in (found_day, found_month, found_year, found_time):
- return None
- if not 1 <= day <= 31:
- return None
- if year < 1601 or hour > 23 or minute > 59 or second > 59:
- return None
- return datetime.datetime(
- year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc
- )
- class DummyCookieJar(AbstractCookieJar):
- """Implements a dummy cookie storage.
- It can be used with the ClientSession when no cookie processing is needed.
- """
- def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- super().__init__(loop=loop)
- def __iter__(self) -> "Iterator[Morsel[str]]":
- while False:
- yield None
- def __len__(self) -> int:
- return 0
- def clear(self) -> None:
- pass
- def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
- pass
- def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
- return SimpleCookie()
|