Source code for aiohttp.cookiejar

import asyncio
import contextlib
import datetime
import os  # noqa
import pathlib
import pickle
import re
import warnings
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import (  # noqa

from yarl import URL

from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address, next_whole_second
from .typedefs import LooseCookies, PathLike, StrOrURL

__all__ = ("CookieJar", "DummyCookieJar")

CookieItem = Union[str, "Morsel[str]"]

[docs]class CookieJar(AbstractCookieJar): """Implements cookie storage adhering to RFC 6265.""" DATE_TOKENS_RE = re.compile( r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" ) DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") DATE_MONTH_RE = re.compile( "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", re.I, ) DATE_YEAR_RE = re.compile(r"(\d{2,4})") MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2 ** 31 - 1) def __init__( self, *, unsafe: bool = False, quote_cookie: bool = True, treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None ) -> None: self._loop = asyncio.get_running_loop() self._cookies = defaultdict( SimpleCookie ) # type: DefaultDict[str, SimpleCookie[str]] self._host_only_cookies = set() # type: Set[Tuple[str, str]] self._unsafe = unsafe self._quote_cookie = quote_cookie if treat_as_secure_origin is None: treat_as_secure_origin = [] elif isinstance(treat_as_secure_origin, URL): treat_as_secure_origin = [treat_as_secure_origin.origin()] elif isinstance(treat_as_secure_origin, str): treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] else: treat_as_secure_origin = [ URL(url).origin() if isinstance(url, str) else url.origin() for url in treat_as_secure_origin ] self._treat_as_secure_origin = treat_as_secure_origin self._next_expiration = next_whole_second() self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime] # #4515: datetime.max may not be representable on 32-bit platforms self._max_time = self.MAX_TIME try: self._max_time.timestamp() except OverflowError: self._max_time = self.MAX_32BIT_TIME
[docs] def save(self, file_path: PathLike) -> None: file_path = pathlib.Path(file_path) with"wb") as f: pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
[docs] def load(self, file_path: PathLike) -> None: file_path = pathlib.Path(file_path) with"rb") as f: self._cookies = pickle.load(f)
[docs] def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: if predicate is None: self._next_expiration = next_whole_second() self._cookies.clear() self._host_only_cookies.clear() self._expirations.clear() return to_del = [] now = for domain, cookie in self._cookies.items(): for name, morsel in cookie.items(): key = (domain, name) if ( key in self._expirations and self._expirations[key] <= now ) or predicate(morsel): to_del.append(key) for domain, name in to_del: key = (domain, name) self._host_only_cookies.discard(key) if key in self._expirations: del self._expirations[(domain, name)] self._cookies[domain].pop(name, None) next_expiration = min(self._expirations.values(), default=self._max_time) try: self._next_expiration = next_expiration.replace( microsecond=0 ) + datetime.timedelta(seconds=1) except OverflowError: self._next_expiration = self._max_time
[docs] def clear_domain(self, domain: str) -> None: self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
def __iter__(self) -> "Iterator[Morsel[str]]": self._do_expiration() for val in self._cookies.values(): yield from val.values() def __len__(self) -> int: return sum(1 for i in self) def _do_expiration(self) -> None: self.clear(lambda x: False) def _expire_cookie(self, when: datetime.datetime, domain: str, name: str) -> None: self._next_expiration = min(self._next_expiration, when) self._expirations[(domain, name)] = when
[docs] def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" hostname = response_url.raw_host if not self._unsafe and is_ip_address(hostname): # Don't accept cookies from IPs return if isinstance(cookies, Mapping): cookies = cookies.items() for name, cookie in cookies: if not isinstance(cookie, Morsel): tmp = SimpleCookie() # type: SimpleCookie[str] tmp[name] = cookie # type: ignore[assignment] cookie = tmp[name] domain = cookie["domain"] # ignore domains with trailing dots if domain.endswith("."): domain = "" del cookie["domain"] if not domain and hostname is not None: # Set the cookie's domain to the response hostname # and set its host-only-flag self._host_only_cookies.add((hostname, name)) domain = cookie["domain"] = hostname if domain.startswith("."): # Remove leading dot domain = domain[1:] cookie["domain"] = domain if hostname and not self._is_domain_match(domain, hostname): # Setting cookies for different domains is not allowed continue path = cookie["path"] if not path or not path.startswith("/"): # Set the cookie's path to the response path path = response_url.path if not path.startswith("/"): path = "/" else: # Cut everything from the last slash to the end path = "/" + path[1 : path.rfind("/")] cookie["path"] = path max_age = cookie["max-age"] if max_age: try: delta_seconds = int(max_age) try: max_age_expiration = datetime.timezone.utc ) + datetime.timedelta(seconds=delta_seconds) except OverflowError: max_age_expiration = self._max_time self._expire_cookie(max_age_expiration, domain, name) except ValueError: cookie["max-age"] = "" else: expires = cookie["expires"] if expires: expire_time = self._parse_date(expires) if expire_time: self._expire_cookie(expire_time, domain, name) else: cookie["expires"] = "" self._cookies[domain][name] = cookie self._do_expiration()
[docs] def filter_cookies( self, request_url: URL = URL() ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]: """Returns this jar's cookies filtered by their attributes.""" self._do_expiration() if not isinstance(request_url, URL): warnings.warn( "The method accepts yarl.URL instances only, got {}".format( type(request_url) ), DeprecationWarning, ) request_url = URL(request_url) filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = ( SimpleCookie() if self._quote_cookie else BaseCookie() ) hostname = request_url.raw_host or "" request_origin = URL() with contextlib.suppress(ValueError): request_origin = request_url.origin() is_not_secure = ( request_url.scheme not in ("https", "wss") and request_origin not in self._treat_as_secure_origin ) for cookie in self: name = cookie.key domain = cookie["domain"] # Send shared cookies if not domain: filtered[name] = cookie.value continue if not self._unsafe and is_ip_address(hostname): continue if (domain, name) in self._host_only_cookies: if domain != hostname: continue elif not self._is_domain_match(domain, hostname): continue if not self._is_path_match(request_url.path, cookie["path"]): continue if is_not_secure and cookie["secure"]: continue # It's critical we use the Morsel so the coded_value # (based on cookie version) is preserved mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) filtered[name] = mrsl_val return filtered
@staticmethod def _is_domain_match(domain: str, hostname: str) -> bool: """Implements domain matching adhering to RFC 6265.""" if hostname == domain: return True if not hostname.endswith(domain): return False non_matching = hostname[: -len(domain)] if not non_matching.endswith("."): return False return not is_ip_address(hostname) @staticmethod def _is_path_match(req_path: str, cookie_path: str) -> bool: """Implements path matching adhering to RFC 6265.""" if not req_path.startswith("/"): req_path = "/" if req_path == cookie_path: return True if not req_path.startswith(cookie_path): return False if cookie_path.endswith("/"): return True non_matching = req_path[len(cookie_path) :] return non_matching.startswith("/") @classmethod def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: """Implements date string parsing adhering to RFC 6265.""" if not date_str: return None found_time = False found_day = False found_month = False found_year = False hour = minute = second = 0 day = 0 month = 0 year = 0 for token_match in cls.DATE_TOKENS_RE.finditer(date_str): token ="token") if not found_time: time_match = cls.DATE_HMS_TIME_RE.match(token) if time_match: found_time = True hour, minute, second = (int(s) for s in time_match.groups()) continue if not found_day: day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) if day_match: found_day = True day = int( continue if not found_month: month_match = cls.DATE_MONTH_RE.match(token) if month_match: found_month = True assert month_match.lastindex is not None month = month_match.lastindex continue if not found_year: year_match = cls.DATE_YEAR_RE.match(token) if year_match: found_year = True year = int( if 70 <= year <= 99: year += 1900 elif 0 <= year <= 69: year += 2000 if False in (found_day, found_month, found_year, found_time): return None if not 1 <= day <= 31: return None if year < 1601 or hour > 23 or minute > 59 or second > 59: return None return datetime.datetime( year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc )
[docs]class DummyCookieJar(AbstractCookieJar): """Implements a dummy cookie storage. It can be used with the ClientSession when no cookie processing is needed. """ def __iter__(self) -> "Iterator[Morsel[str]]": while False: yield None def __len__(self) -> int: return 0 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: pass def clear_domain(self, domain: str) -> None: pass def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: pass def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": return SimpleCookie()