Coverage for src/httpx/_pool.py: 95%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-06-09 10:38 +0100

1import contextlib 

2import ssl 

3import time 

4import typing 

5import types 

6 

7import h11 

8 

9from ._content import Content 

10from ._headers import Headers 

11from ._network import Lock, NetworkBackend, Semaphore, NetworkStream 

12from ._response import Response 

13from ._request import Request 

14from ._streams import IterByteStream, Stream 

15from ._urls import URL 

16 

17 

18__all__ = [ 

19 "Transport", 

20 "ConnectionPool", 

21 "Connection", 

22 "open_connection_pool", 

23 "open_connection", 

24] 

25 

26 

27class Transport: 

28 @contextlib.contextmanager 

29 def send(self, request: Request) -> typing.Iterator[Response]: 

30 raise NotImplementedError() 

31 yield 

32 

33 def close(self): 

34 pass 

35 

36 def request( 

37 self, 

38 method: str, 

39 url: URL | str, 

40 headers: Headers | dict[str, str] | None = None, 

41 content: Content | Stream | bytes | None = None, 

42 ) -> Response: 

43 request = Request(method, url, headers=headers, content=content) 

44 with self.send(request) as response: 

45 response.read() 

46 return response 

47 

48 @contextlib.contextmanager 

49 def stream( 

50 self, 

51 method: str, 

52 url: URL | str, 

53 headers: Headers | dict[str, str] | None = None, 

54 content: Content | Stream | bytes | None = None, 

55 ) -> typing.Iterator[Response]: 

56 request = Request(method, url, headers=headers, content=content) 

57 with self.send(request) as response: 

58 yield response 

59 

60 

61class ConnectionPool(Transport): 

62 def __init__(self, ssl_context: ssl.SSLContext, backend: NetworkBackend): 

63 self._connections: list[Connection] = [] 

64 self._ssl_context = ssl_context 

65 self._network_backend = backend 

66 self._limit_concurrency = Semaphore(100) 

67 self._closed = False 

68 

69 # Public API... 

70 @contextlib.contextmanager 

71 def send(self, request: Request) -> typing.Iterator[Response]: 

72 if self._closed: 

73 raise RuntimeError("ConnectionPool is closed.") 

74 

75 with self._limit_concurrency: 

76 try: 

77 connection = self._get_connection(request) 

78 with connection.send(request) as response: 

79 yield response 

80 finally: 

81 self._close_expired_connections() 

82 self._remove_closed_connections() 

83 

84 def close(self): 

85 self._closed = True 

86 closing = list(self._connections) 

87 self._connections = [] 

88 for conn in closing: 

89 conn.close() 

90 

91 # Create or reuse connections as required... 

92 def _get_connection(self, request: Request) -> "Connection": 

93 # Attempt to reuse an existing connection. 

94 url = request.url 

95 origin = URL(scheme=url.scheme, host=url.host, port=url.port) 

96 now = time.monotonic() 

97 for conn in self._connections: 

98 if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now): 

99 return conn 

100 

101 # Or else create a new connection. 

102 conn = open_connection( 

103 origin, 

104 hostname=request.headers["Host"], 

105 ssl_context=self._ssl_context, 

106 backend=self._network_backend 

107 ) 

108 self._connections.append(conn) 

109 return conn 

110 

111 # Connection pool management... 

112 def _close_expired_connections(self) -> None: 

113 now = time.monotonic() 

114 for conn in list(self._connections): 

115 if conn.is_expired(now): 

116 conn.close() 

117 

118 def _remove_closed_connections(self) -> None: 

119 for conn in list(self._connections): 

120 if conn.is_closed(): 

121 self._connections.remove(conn) 

122 

123 @property 

124 def connections(self) -> typing.List['Connection']: 

125 return [c for c in self._connections] 

126 

127 def description(self) -> str: 

128 counts = {"active": 0} 

129 for status in [c.description() for c in self._connections]: 

130 counts[status] = counts.get(status, 0) + 1 

131 return ", ".join(f"{count} {status}" for status, count in counts.items()) 

132 

133 # Builtins... 

134 def __repr__(self) -> str: 

135 return f"<ConnectionPool [{self.description()}]>" 

136 

137 def __del__(self): 

138 if not self._closed: 

139 import warnings 

140 warnings.warn("ConnectionPool was garbage collected without being closed.") 

141 

142 def __enter__(self) -> "ConnectionPool": 

143 return self 

144 

145 def __exit__( 

146 self, 

147 exc_type: type[BaseException] | None = None, 

148 exc_value: BaseException | None = None, 

149 traceback: types.TracebackType | None = None, 

150 ) -> None: 

151 self.close() 

152 

153 

154def open_connection_pool( 

155 ssl_context: ssl.SSLContext | None = None, 

156 backend: NetworkBackend | None = None 

157) -> ConnectionPool: 

158 if ssl_context is None: 

159 import truststore 

160 ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 

161 if backend is None: 

162 backend = NetworkBackend() 

163 

164 return ConnectionPool(ssl_context=ssl_context, backend=backend) 

165 

166 

167class Connection(Transport): 

168 def __init__(self, stream: "NetworkStream", origin: URL | str): 

169 self._stream = stream 

170 self._origin = URL(origin) 

171 self._state = h11.Connection(our_role=h11.CLIENT) 

172 self._keepalive_duration = 5.0 

173 self._idle_expiry = time.monotonic() + self._keepalive_duration 

174 self._request_lock = Lock() 

175 

176 # API for connection pool management... 

177 def origin(self) -> URL: 

178 return self._origin 

179 

180 def is_idle(self) -> bool: 

181 return self._state.our_state is h11.IDLE 

182 

183 def is_expired(self, when: float) -> bool: 

184 return self._state.our_state is h11.IDLE and when > self._idle_expiry 

185 

186 def is_closed(self) -> bool: 

187 return self._state.our_state in (h11.CLOSED, h11.ERROR) 

188 

189 def description(self) -> str: 

190 return { 

191 h11.IDLE: "idle", 

192 h11.SEND_BODY: "active", 

193 h11.DONE: "active", 

194 h11.MUST_CLOSE: "closing", 

195 h11.CLOSED: "closed", 

196 h11.ERROR: "error", 

197 h11.MIGHT_SWITCH_PROTOCOL: "upgrading", 

198 h11.SWITCHED_PROTOCOL: "upgraded", 

199 }[self._state.our_state] 

200 

201 # API entry points... 

202 @contextlib.contextmanager 

203 def send(self, request: Request) -> typing.Iterator[Response]: 

204 with self._request_lock: 

205 try: 

206 self._send_head(request) 

207 self._send_body(request) 

208 code, headers = self._recv_head() 

209 stream = IterByteStream(self._recv_body()) 

210 yield Response(code, headers=headers, content=stream) 

211 finally: 

212 self._cycle_complete() 

213 

214 def close(self) -> None: 

215 with self._request_lock: 

216 self._close() 

217 

218 # Top-level API for working directly with a connection. 

219 def request( 

220 self, 

221 method: str, 

222 url: URL | str, 

223 headers: Headers | typing.Mapping[str, str] | None = None, 

224 content: Content | Stream | bytes | None = None, 

225 ) -> Response: 

226 url = self._origin.join(url) 

227 request = Request(method, url, headers=headers, content=content) 

228 with self.send(request) as response: 

229 response.read() 

230 return response 

231 

232 @contextlib.contextmanager 

233 def stream( 

234 self, 

235 method: str, 

236 url: URL | str, 

237 headers: Headers | typing.Mapping[str, str] | None = None, 

238 content: Content | Stream | bytes | None = None, 

239 ) -> typing.Iterator[Response]: 

240 url = self._origin.join(url) 

241 request = Request(method, url, headers=headers, content=content) 

242 with self.send(request) as response: 

243 yield response 

244 

245 # Send the request... 

246 def _send_head(self, request: Request) -> None: 

247 event = h11.Request( 

248 method=request.method, 

249 target=request.url.target, 

250 headers=list(request.headers.items()), 

251 ) 

252 self._send_event(event) 

253 

254 def _send_body(self, request: Request) -> None: 

255 for data in request.stream: 

256 self._send_event(h11.Data(data=data)) 

257 self._send_event(h11.EndOfMessage()) 

258 

259 def _send_event(self, event: h11.Event) -> None: 

260 data = self._state.send(event) 

261 if data is not None: 

262 self._stream.write(data) 

263 

264 # Receive the response... 

265 def _recv_head(self) -> tuple[int, Headers]: 

266 while True: 

267 event = self._recv_event() 

268 if isinstance(event, h11.Response): 

269 code = event.status_code 

270 headers = Headers([ 

271 (k.decode("latin-1"), v.decode("latin-1")) for k, v in event.headers 

272 ]) 

273 return (code, headers) 

274 

275 def _recv_body(self) -> typing.Iterator[bytes]: 

276 while True: 

277 event = self._recv_event() 

278 if isinstance(event, h11.Data): 

279 yield bytes(event.data) 

280 elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): 

281 break 

282 

283 def _recv_event(self) -> h11.Event | type[h11.PAUSED]: 

284 while True: 

285 event = self._state.next_event() 

286 

287 if event is h11.NEED_DATA: 

288 data = self._stream.read() 

289 self._state.receive_data(data) 

290 else: 

291 return event # type: ignore[return-value] 

292 

293 # Request / response cycle complete... 

294 def _cycle_complete(self) -> None: 

295 if self._state.our_state is h11.DONE and self._state.their_state is h11.DONE: 

296 self._state.start_next_cycle() 

297 self._idle_expiry = time.monotonic() + self._keepalive_duration 

298 else: 

299 self._close() 

300 

301 def _close(self) -> None: 

302 if self._state.our_state in (h11.DONE, h11.IDLE, h11.MUST_CLOSE): 

303 event = h11.ConnectionClosed() 

304 self._state.send(event) 

305 

306 self._stream.close() 

307 

308 # Builtins... 

309 def __repr__(self) -> str: 

310 return f"<Connection [{self._origin} {self.description()}]>" 

311 

312 def __enter__(self) -> "Connection": 

313 return self 

314 

315 def __exit__( 

316 self, 

317 exc_type: type[BaseException] | None = None, 

318 exc_value: BaseException | None = None, 

319 traceback: types.TracebackType | None = None, 

320 ): 

321 self.close() 

322 

323 

324def open_connection( 

325 url: URL | str, 

326 hostname: str = '', 

327 ssl_context: ssl.SSLContext | None = None, 

328 backend: NetworkBackend | None = None, 

329 ) -> Connection: 

330 

331 if isinstance(url, str): 

332 url = URL(url) 

333 

334 if url.scheme not in ("http", "https"): 

335 raise ValueError("URL scheme must be 'http://' or 'https://'.") 

336 if backend is None: 

337 backend = NetworkBackend() 

338 

339 host = url.host 

340 port = url.port or {"http": 80, "https": 443}[url.scheme] 

341 hostname = hostname or url.host 

342 

343 stream = backend.connect(host, port) 

344 if url.scheme == "https": 

345 if ssl_context is None: 

346 import truststore 

347 ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 

348 stream.start_tls(ssl_context, hostname=hostname) 

349 

350 return Connection(stream, url)