Coverage for src/httpx/_server.py: 89%

153 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-06-16 18:32 +0100

1import concurrent.futures 

2import contextlib 

3import logging 

4import select 

5import socket 

6import threading 

7import time 

8 

9import h11 

10import httpx 

11 

12from ._streams import IterByteStream 

13 

14__all__ = [ 

15 "serve_http", 

16 "serve_tcp", 

17] 

18 

19logger = logging.getLogger("httpx.server") 

20 

21 

22class ConnectionClosed(Exception): 

23 pass 

24 

25 

26class HTTPConnection: 

27 def __init__(self, stream, endpoint): 

28 self._stream = stream 

29 self._endpoint = endpoint 

30 self._state = h11.Connection(our_role=h11.SERVER) 

31 self._keepalive_duration = 5.0 

32 self._idle_expiry = time.monotonic() + self._keepalive_duration 

33 

34 # API entry points... 

35 def handle_requests(self): 

36 try: 

37 method, url, headers = self._recv_head() 

38 stream = IterByteStream(self._recv_body()) 

39 # TODO: Handle endpoint exceptions 

40 try: 

41 request = httpx.Request(method, url, headers=headers, content=stream) 

42 response = self._endpoint(request) 

43 except Exception as exc: 

44 logger.error("Internal Server Error", exc_info=True) 

45 content = httpx.Text("Internal Server Error") 

46 response = httpx.Response(code=500, content=content) 

47 self._send_head(response) 

48 self._send_body(response) 

49 else: 

50 try: 

51 self._send_head(response) 

52 self._send_body(response) 

53 except Exception as exc: 

54 logger.error("Internal Server Error", exc_info=True) 

55 finally: 

56 status_line = f"{request.method} {request.url.target} [{response.code} {response.reason_phrase}]" 

57 logger.info(status_line) 

58 except ConnectionClosed: 

59 pass 

60 finally: 

61 self._cycle_complete() 

62 

63 def close(self): 

64 if self._state.our_state in (h11.DONE, h11.IDLE, h11.MUST_CLOSE): 

65 event = h11.ConnectionClosed() 

66 self._state.send(event) 

67 

68 self._stream.close() 

69 

70 # Receive the request... 

71 def _recv_head(self) -> tuple[str, str, list[tuple[str, str]]]: 

72 while True: 

73 event = self._recv_event() 

74 if isinstance(event, h11.Request): 

75 method = event.method.decode('ascii') 

76 target = event.target.decode('ascii') 

77 headers = [ 

78 (k.decode('latin-1'), v.decode('latin-1')) 

79 for k, v in event.headers.raw_items() 

80 ] 

81 return (method, target, headers) 

82 elif isinstance(event, h11.ConnectionClosed): 

83 raise ConnectionClosed() 

84 

85 def _recv_body(self): 

86 while True: 

87 event = self._recv_event() 

88 if isinstance(event, h11.Data): 

89 yield bytes(event.data) 

90 elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): 

91 break 

92 

93 def _recv_event(self) -> h11.Event | type[h11.PAUSED]: 

94 while True: 

95 event = self._state.next_event() 

96 

97 if event is h11.NEED_DATA: 

98 data = self._stream.read() 

99 self._state.receive_data(data) 

100 else: 

101 return event # type: ignore[return-value] 

102 

103 # Return the response... 

104 def _send_head(self, response: httpx.Response): 

105 event = h11.Response( 

106 status_code=response.code, 

107 headers=list(response.headers.items()) 

108 ) 

109 self._send_event(event) 

110 

111 def _send_body(self, response: httpx.Response): 

112 for data in response.stream: 

113 self._send_event(h11.Data(data=data)) 

114 self._send_event(h11.EndOfMessage()) 

115 

116 def _send_event(self, event: h11.Event) -> None: 

117 data = self._state.send(event) 

118 if data is not None: 

119 self._stream.write(data) 

120 

121 # Start it all over again... 

122 def _cycle_complete(self): 

123 if self._state.our_state is h11.DONE and self._state.their_state is h11.DONE: 

124 self._state.start_next_cycle() 

125 self._idle_expiry = time.monotonic() + self._keepalive_duration 

126 else: 

127 self.close() 

128 

129 

130class HTTPServer: 

131 def __init__(self, host, port): 

132 self.url = f"http://{host}:{port}/" 

133 

134 def wait(self): 

135 while(True): 

136 time.sleep(1) 

137 

138 

139class TCPServer: 

140 def __init__(self, handler, host: str = "127.0.0.1", port: int = 8080): 

141 self.handler = handler 

142 self.host = host 

143 self.port = port 

144 self._max_workers = 5 

145 self._server_socket = None 

146 self._client_sockets: list[socket.socket] = [] 

147 self._executor = None 

148 self._thread = None 

149 self._shutdown = threading.Event() 

150 

151 def __enter__(self): 

152 self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

153 self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

154 self._server_socket.bind((self.host, self.port)) 

155 self._server_socket.listen(5) 

156 self._server_socket.setblocking(False) 

157 

158 self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) 

159 self._thread = threading.Thread(target=self._serve_loop, daemon=True) 

160 self._thread.start() 

161 return self 

162 

163 def __exit__(self, exc_type, exc_val, exc_tb): 

164 self._shutdown.set() 

165 self._thread.join() 

166 self._server_socket.close() 

167 for client_socket in list(self._client_sockets): 

168 client_socket.close() 

169 self._executor.shutdown(wait=True) 

170 

171 def _serve_loop(self): 

172 while not self._shutdown.is_set(): 

173 readable, _, _ = select.select([self._server_socket], [], [], 0.1) 

174 if readable: 

175 try: 

176 client_socket, _ = self._server_socket.accept() 

177 self._executor.submit(self._handler, client_socket) 

178 except socket.error as e: 

179 pass 

180 

181 def _handler(self, socket): 

182 self._client_sockets.append(socket) 

183 try: 

184 stream = httpx.NetworkStream(socket) 

185 self.handler(stream) 

186 finally: 

187 self._client_sockets.remove(socket) 

188 stream.close() 

189 

190 

191@contextlib.contextmanager 

192def serve_http(endpoint): 

193 def handler(stream): 

194 connection = HTTPConnection(stream, endpoint) 

195 connection.handle_requests() 

196 

197 logging.basicConfig( 

198 format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", 

199 datefmt="%Y-%m-%d %H:%M:%S", 

200 level=logging.DEBUG 

201 ) 

202 

203 with TCPServer(handler) as server: 

204 server = HTTPServer(server.host, server.port) 

205 logger.info(f"Serving on {server.url}") 

206 yield server 

207 

208 

209@contextlib.contextmanager 

210def serve_tcp(handler): 

211 with TCPServer(handler) as server: 

212 yield server