Coverage for src/httpx/_network.py: 95%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-05-28 11:39 +0100

1import contextlib 

2import contextvars 

3import socket 

4import ssl 

5import threading 

6import time 

7import types 

8import typing 

9 

10_timeout_stack: contextvars.ContextVar[list[float]] = contextvars.ContextVar("timeout_context", default=[]) 

11 

12__all__ = ["NetworkBackend", "NetworkStream", "timeout"] 

13 

14 

15@contextlib.contextmanager 

16def timeout(duration: float) -> typing.Iterator[None]: 

17 """ 

18 A context managed timeout API. 

19 

20 with timeout(1.0): 

21 ... 

22 """ 

23 now = time.monotonic() 

24 until = now + duration 

25 stack = typing.cast(list[float], _timeout_stack.get()) 

26 stack = [until] + stack 

27 token = _timeout_stack.set(stack) 

28 try: 

29 yield 

30 finally: 

31 _timeout_stack.reset(token) 

32 

33 

34def get_current_timeout() -> float | None: 

35 stack = _timeout_stack.get() 

36 if not stack: 

37 return None 

38 soonest = min(stack) 

39 now = time.monotonic() 

40 remaining = soonest - now 

41 if remaining <= 0.0: 

42 raise TimeoutError() 

43 return remaining 

44 

45 

46class NetworkStream: 

47 def __init__(self, sock: socket.socket) -> None: 

48 peername = sock.getpeername() 

49 

50 self._socket = sock 

51 self._address = f"{peername[0]}:{peername[1]}" 

52 self._is_tls = False 

53 self._is_closed = False 

54 

55 def read(self, max_bytes: int = 64 * 1024) -> bytes: 

56 timeout = get_current_timeout() 

57 self._socket.settimeout(timeout) 

58 content = self._socket.recv(max_bytes) 

59 return content 

60 

61 def write(self, buffer: bytes) -> None: 

62 while buffer: 

63 timeout = get_current_timeout() 

64 self._socket.settimeout(timeout) 

65 n = self._socket.send(buffer) 

66 buffer = buffer[n:] 

67 

68 def start_tls(self, ctx: ssl.SSLContext, hostname: str | None = None) -> None: 

69 self._socket = ctx.wrap_socket(self._socket, server_hostname=hostname) 

70 self._is_tls = True 

71 

72 def close(self) -> None: 

73 timeout = get_current_timeout() 

74 self._socket.settimeout(timeout) 

75 self._socket.close() 

76 self._is_closed = True 

77 

78 def __repr__(self): 

79 description = "" 

80 description += " TLS" if self._is_tls else "" 

81 description += " CLOSED" if self._is_closed else "" 

82 return f"<NetworkStream [{self._address!r}{description}]>" 

83 

84 def __del__(self): 

85 if not self._is_closed: 

86 import warnings 

87 warnings.warn("NetworkStream was garbage collected without being closed.") 

88 

89 def __enter__(self) -> "NetworkStream": 

90 return self 

91 

92 def __exit__( 

93 self, 

94 exc_type: type[BaseException] | None = None, 

95 exc_value: BaseException | None = None, 

96 traceback: types.TracebackType | None = None, 

97 ): 

98 self.close() 

99 

100 

101class NetworkBackend: 

102 def connect(self, host: str, port: int) -> NetworkStream: 

103 """ 

104 Connect to the given address, returning a NetworkStream instance. 

105 """ 

106 address = (host, port) 

107 timeout = get_current_timeout() 

108 sock = socket.create_connection(address, timeout=timeout) 

109 return NetworkStream(sock) 

110 

111 def __repr__(self): 

112 return "<NetworkBackend [threaded]>" 

113 

114 

115Semaphore = threading.Semaphore 

116Lock = threading.Lock