Coverage for src/httpx/_pool.py: 95%
206 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-06-09 10:38 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-06-09 10:38 +0100
1import contextlib
2import ssl
3import time
4import typing
5import types
7import h11
9from ._content import Content
10from ._headers import Headers
11from ._network import Lock, NetworkBackend, Semaphore, NetworkStream
12from ._response import Response
13from ._request import Request
14from ._streams import IterByteStream, Stream
15from ._urls import URL
18__all__ = [
19 "Transport",
20 "ConnectionPool",
21 "Connection",
22 "open_connection_pool",
23 "open_connection",
24]
27class Transport:
28 @contextlib.contextmanager
29 def send(self, request: Request) -> typing.Iterator[Response]:
30 raise NotImplementedError()
31 yield
33 def close(self):
34 pass
36 def request(
37 self,
38 method: str,
39 url: URL | str,
40 headers: Headers | dict[str, str] | None = None,
41 content: Content | Stream | bytes | None = None,
42 ) -> Response:
43 request = Request(method, url, headers=headers, content=content)
44 with self.send(request) as response:
45 response.read()
46 return response
48 @contextlib.contextmanager
49 def stream(
50 self,
51 method: str,
52 url: URL | str,
53 headers: Headers | dict[str, str] | None = None,
54 content: Content | Stream | bytes | None = None,
55 ) -> typing.Iterator[Response]:
56 request = Request(method, url, headers=headers, content=content)
57 with self.send(request) as response:
58 yield response
61class ConnectionPool(Transport):
62 def __init__(self, ssl_context: ssl.SSLContext, backend: NetworkBackend):
63 self._connections: list[Connection] = []
64 self._ssl_context = ssl_context
65 self._network_backend = backend
66 self._limit_concurrency = Semaphore(100)
67 self._closed = False
69 # Public API...
70 @contextlib.contextmanager
71 def send(self, request: Request) -> typing.Iterator[Response]:
72 if self._closed:
73 raise RuntimeError("ConnectionPool is closed.")
75 with self._limit_concurrency:
76 try:
77 connection = self._get_connection(request)
78 with connection.send(request) as response:
79 yield response
80 finally:
81 self._close_expired_connections()
82 self._remove_closed_connections()
84 def close(self):
85 self._closed = True
86 closing = list(self._connections)
87 self._connections = []
88 for conn in closing:
89 conn.close()
91 # Create or reuse connections as required...
92 def _get_connection(self, request: Request) -> "Connection":
93 # Attempt to reuse an existing connection.
94 url = request.url
95 origin = URL(scheme=url.scheme, host=url.host, port=url.port)
96 now = time.monotonic()
97 for conn in self._connections:
98 if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now):
99 return conn
101 # Or else create a new connection.
102 conn = open_connection(
103 origin,
104 hostname=request.headers["Host"],
105 ssl_context=self._ssl_context,
106 backend=self._network_backend
107 )
108 self._connections.append(conn)
109 return conn
111 # Connection pool management...
112 def _close_expired_connections(self) -> None:
113 now = time.monotonic()
114 for conn in list(self._connections):
115 if conn.is_expired(now):
116 conn.close()
118 def _remove_closed_connections(self) -> None:
119 for conn in list(self._connections):
120 if conn.is_closed():
121 self._connections.remove(conn)
123 @property
124 def connections(self) -> typing.List['Connection']:
125 return [c for c in self._connections]
127 def description(self) -> str:
128 counts = {"active": 0}
129 for status in [c.description() for c in self._connections]:
130 counts[status] = counts.get(status, 0) + 1
131 return ", ".join(f"{count} {status}" for status, count in counts.items())
133 # Builtins...
134 def __repr__(self) -> str:
135 return f"<ConnectionPool [{self.description()}]>"
137 def __del__(self):
138 if not self._closed:
139 import warnings
140 warnings.warn("ConnectionPool was garbage collected without being closed.")
142 def __enter__(self) -> "ConnectionPool":
143 return self
145 def __exit__(
146 self,
147 exc_type: type[BaseException] | None = None,
148 exc_value: BaseException | None = None,
149 traceback: types.TracebackType | None = None,
150 ) -> None:
151 self.close()
154def open_connection_pool(
155 ssl_context: ssl.SSLContext | None = None,
156 backend: NetworkBackend | None = None
157) -> ConnectionPool:
158 if ssl_context is None:
159 import truststore
160 ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
161 if backend is None:
162 backend = NetworkBackend()
164 return ConnectionPool(ssl_context=ssl_context, backend=backend)
167class Connection(Transport):
168 def __init__(self, stream: "NetworkStream", origin: URL | str):
169 self._stream = stream
170 self._origin = URL(origin)
171 self._state = h11.Connection(our_role=h11.CLIENT)
172 self._keepalive_duration = 5.0
173 self._idle_expiry = time.monotonic() + self._keepalive_duration
174 self._request_lock = Lock()
176 # API for connection pool management...
177 def origin(self) -> URL:
178 return self._origin
180 def is_idle(self) -> bool:
181 return self._state.our_state is h11.IDLE
183 def is_expired(self, when: float) -> bool:
184 return self._state.our_state is h11.IDLE and when > self._idle_expiry
186 def is_closed(self) -> bool:
187 return self._state.our_state in (h11.CLOSED, h11.ERROR)
189 def description(self) -> str:
190 return {
191 h11.IDLE: "idle",
192 h11.SEND_BODY: "active",
193 h11.DONE: "active",
194 h11.MUST_CLOSE: "closing",
195 h11.CLOSED: "closed",
196 h11.ERROR: "error",
197 h11.MIGHT_SWITCH_PROTOCOL: "upgrading",
198 h11.SWITCHED_PROTOCOL: "upgraded",
199 }[self._state.our_state]
201 # API entry points...
202 @contextlib.contextmanager
203 def send(self, request: Request) -> typing.Iterator[Response]:
204 with self._request_lock:
205 try:
206 self._send_head(request)
207 self._send_body(request)
208 code, headers = self._recv_head()
209 stream = IterByteStream(self._recv_body())
210 yield Response(code, headers=headers, content=stream)
211 finally:
212 self._cycle_complete()
214 def close(self) -> None:
215 with self._request_lock:
216 self._close()
218 # Top-level API for working directly with a connection.
219 def request(
220 self,
221 method: str,
222 url: URL | str,
223 headers: Headers | typing.Mapping[str, str] | None = None,
224 content: Content | Stream | bytes | None = None,
225 ) -> Response:
226 url = self._origin.join(url)
227 request = Request(method, url, headers=headers, content=content)
228 with self.send(request) as response:
229 response.read()
230 return response
232 @contextlib.contextmanager
233 def stream(
234 self,
235 method: str,
236 url: URL | str,
237 headers: Headers | typing.Mapping[str, str] | None = None,
238 content: Content | Stream | bytes | None = None,
239 ) -> typing.Iterator[Response]:
240 url = self._origin.join(url)
241 request = Request(method, url, headers=headers, content=content)
242 with self.send(request) as response:
243 yield response
245 # Send the request...
246 def _send_head(self, request: Request) -> None:
247 event = h11.Request(
248 method=request.method,
249 target=request.url.target,
250 headers=list(request.headers.items()),
251 )
252 self._send_event(event)
254 def _send_body(self, request: Request) -> None:
255 for data in request.stream:
256 self._send_event(h11.Data(data=data))
257 self._send_event(h11.EndOfMessage())
259 def _send_event(self, event: h11.Event) -> None:
260 data = self._state.send(event)
261 if data is not None:
262 self._stream.write(data)
264 # Receive the response...
265 def _recv_head(self) -> tuple[int, Headers]:
266 while True:
267 event = self._recv_event()
268 if isinstance(event, h11.Response):
269 code = event.status_code
270 headers = Headers([
271 (k.decode("latin-1"), v.decode("latin-1")) for k, v in event.headers
272 ])
273 return (code, headers)
275 def _recv_body(self) -> typing.Iterator[bytes]:
276 while True:
277 event = self._recv_event()
278 if isinstance(event, h11.Data):
279 yield bytes(event.data)
280 elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
281 break
283 def _recv_event(self) -> h11.Event | type[h11.PAUSED]:
284 while True:
285 event = self._state.next_event()
287 if event is h11.NEED_DATA:
288 data = self._stream.read()
289 self._state.receive_data(data)
290 else:
291 return event # type: ignore[return-value]
293 # Request / response cycle complete...
294 def _cycle_complete(self) -> None:
295 if self._state.our_state is h11.DONE and self._state.their_state is h11.DONE:
296 self._state.start_next_cycle()
297 self._idle_expiry = time.monotonic() + self._keepalive_duration
298 else:
299 self._close()
301 def _close(self) -> None:
302 if self._state.our_state in (h11.DONE, h11.IDLE, h11.MUST_CLOSE):
303 event = h11.ConnectionClosed()
304 self._state.send(event)
306 self._stream.close()
308 # Builtins...
309 def __repr__(self) -> str:
310 return f"<Connection [{self._origin} {self.description()}]>"
312 def __enter__(self) -> "Connection":
313 return self
315 def __exit__(
316 self,
317 exc_type: type[BaseException] | None = None,
318 exc_value: BaseException | None = None,
319 traceback: types.TracebackType | None = None,
320 ):
321 self.close()
324def open_connection(
325 url: URL | str,
326 hostname: str = '',
327 ssl_context: ssl.SSLContext | None = None,
328 backend: NetworkBackend | None = None,
329 ) -> Connection:
331 if isinstance(url, str):
332 url = URL(url)
334 if url.scheme not in ("http", "https"):
335 raise ValueError("URL scheme must be 'http://' or 'https://'.")
336 if backend is None:
337 backend = NetworkBackend()
339 host = url.host
340 port = url.port or {"http": 80, "https": 443}[url.scheme]
341 hostname = hostname or url.host
343 stream = backend.connect(host, port)
344 if url.scheme == "https":
345 if ssl_context is None:
346 import truststore
347 ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
348 stream.start_tls(ssl_context, hostname=hostname)
350 return Connection(stream, url)