diff --git a/Lib/socket.py b/Lib/socket.py index 42ee1307732..35d87eff34d 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -306,7 +306,8 @@ def makefile(self, mode="r", buffering=None, *, """makefile(...) -> an I/O stream connected to the socket The arguments are as for io.open() after the filename, except the only - supported mode values are 'r' (default), 'w' and 'b'. + supported mode values are 'r' (default), 'w', 'b', or a combination of + those. """ # XXX refactor to share code? if not set(mode) <= {"r", "w", "b"}: @@ -591,16 +592,65 @@ def fromshare(info): return socket(0, 0, 0, info) __all__.append("fromshare") -if hasattr(_socket, "socketpair"): +# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +# This is used if _socket doesn't natively provide socketpair. It's +# always defined so that it can be patched in for testing purposes. +def _fallback_socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): + if family == AF_INET: + host = _LOCALHOST + elif family == AF_INET6: + host = _LOCALHOST_V6 + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with + # setblocking(False) that prevents us from having to create a thread. + lsock = socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen() + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() - def socketpair(family=None, type=SOCK_STREAM, proto=0): - """socketpair([family[, type[, proto]]]) -> (socket object, socket object) + # Authenticating avoids using a connection from something else + # able to connect to {host}:{port} instead of us. + # We expect only AF_INET and AF_INET6 families. + try: + if ( + ssock.getsockname() != csock.getpeername() + or csock.getsockname() != ssock.getpeername() + ): + raise ConnectionError("Unexpected peer connection") + except: + # getsockname() and getpeername() can fail + # if either socket isn't connected. + ssock.close() + csock.close() + raise - Create a pair of socket objects from the sockets returned by the platform - socketpair() function. - The arguments are the same as for socket() except the default family is - AF_UNIX if defined on the platform; otherwise, the default is AF_INET. - """ + return (ssock, csock) + +if hasattr(_socket, "socketpair"): + def socketpair(family=None, type=SOCK_STREAM, proto=0): if family is None: try: family = AF_UNIX @@ -612,44 +662,7 @@ def socketpair(family=None, type=SOCK_STREAM, proto=0): return a, b else: - - # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - def socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): - if family == AF_INET: - host = _LOCALHOST - elif family == AF_INET6: - host = _LOCALHOST_V6 - else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") - if type != SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with - # setblocking(False) that prevents us from having to create a thread. - lsock = socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen() - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) + socketpair = _fallback_socketpair __all__.append("socketpair") socketpair.__doc__ = """socketpair([family[, type[, proto]]]) -> (socket object, socket object) @@ -702,16 +715,15 @@ def readinto(self, b): self._checkReadable() if self._timeout_occurred: raise OSError("cannot read from timed out object") - while True: - try: - return self._sock.recv_into(b) - except timeout: - self._timeout_occurred = True - raise - except error as e: - if e.errno in _blocking_errnos: - return None - raise + try: + return self._sock.recv_into(b) + except timeout: + self._timeout_occurred = True + raise + except error as e: + if e.errno in _blocking_errnos: + return None + raise def write(self, b): """Write the given bytes or bytearray object *b* to the socket @@ -919,7 +931,9 @@ def create_server(address, *, family=AF_INET, backlog=None, reuse_port=False, # Fail later on bind(), for platforms which may not # support this option. pass - if reuse_port: + # Since Linux 6.12.9, SO_REUSEPORT is not allowed + # on other address families than AF_INET/AF_INET6. + if reuse_port and family in (AF_INET, AF_INET6): sock.setsockopt(SOL_SOCKET, SO_REUSEPORT, 1) if has_ipv6 and family == AF_INET6: if dualstack_ipv6: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index d5a35a3253e..87479171b5c 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1,9 +1,9 @@ import unittest +from unittest import mock from test import support -from test.support import os_helper -from test.support import socket_helper -from test.support import threading_helper - +from test.support import ( + is_apple, os_helper, refleak_helper, socket_helper, threading_helper +) import _thread as thread import array import contextlib @@ -28,6 +28,7 @@ import threading import time import traceback +import warnings from weakref import proxy try: import multiprocessing @@ -37,6 +38,10 @@ import fcntl except ImportError: fcntl = None +try: + import _testcapi +except ImportError: + _testcapi = None support.requires_working_socket(module=True) @@ -44,6 +49,7 @@ # test unicode string and carriage return MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') +VMADDR_CID_LOCAL = 1 VSOCKPORT = 1234 AIX = platform.system() == "AIX" WSL = "microsoft-standard-WSL" in platform.release() @@ -53,6 +59,35 @@ except ImportError: _socket = None +def skipForRefleakHuntinIf(condition, issueref): + if not condition: + def decorator(f): + f.client_skip = lambda f: f + return f + + else: + def decorator(f): + @contextlib.wraps(f) + def wrapper(*args, **kwds): + if refleak_helper.hunting_for_refleaks(): + raise unittest.SkipTest(f"ignore while hunting for refleaks, see {issueref}") + + return f(*args, **kwds) + + def client_skip(f): + @contextlib.wraps(f) + def wrapper(*args, **kwds): + if refleak_helper.hunting_for_refleaks(): + return + + return f(*args, **kwds) + + return wrapper + wrapper.client_skip = client_skip + return wrapper + + return decorator + def get_cid(): if fcntl is None: return None @@ -128,8 +163,8 @@ def _have_socket_qipcrtr(): def _have_socket_vsock(): """Check whether AF_VSOCK sockets are supported on this host.""" - ret = get_cid() is not None - return ret + cid = get_cid() + return (cid is not None) def _have_socket_bluetooth(): @@ -145,6 +180,17 @@ def _have_socket_bluetooth(): return True +def _have_socket_bluetooth_l2cap(): + """Check whether BTPROTO_L2CAP sockets are supported on this host.""" + try: + s = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) + except (AttributeError, OSError): + return False + else: + s.close() + return True + + def _have_socket_hyperv(): """Check whether AF_HYPERV sockets are supported on this host.""" try: @@ -166,6 +212,24 @@ def socket_setdefaulttimeout(timeout): socket.setdefaulttimeout(old_timeout) +@contextlib.contextmanager +def downgrade_malformed_data_warning(): + # This warning happens on macos and win, but does not always happen on linux. + if sys.platform not in {"win32", "darwin"}: + yield + return + + with warnings.catch_warnings(): + # TODO: gh-110012, we should investigate why this warning is happening + # and fix it properly. + warnings.filterwarnings( + action="always", + message="received malformed or improperly-truncated ancillary data", + category=RuntimeWarning, + ) + yield + + HAVE_SOCKET_CAN = _have_socket_can() HAVE_SOCKET_CAN_ISOTP = _have_socket_can_isotp() @@ -180,10 +244,15 @@ def socket_setdefaulttimeout(timeout): HAVE_SOCKET_VSOCK = _have_socket_vsock() -HAVE_SOCKET_UDPLITE = hasattr(socket, "IPPROTO_UDPLITE") +# Older Android versions block UDPLITE with SELinux. +HAVE_SOCKET_UDPLITE = ( + hasattr(socket, "IPPROTO_UDPLITE") + and not (support.is_android and platform.android_ver().api_level < 29)) HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth() +HAVE_SOCKET_BLUETOOTH_L2CAP = _have_socket_bluetooth_l2cap() + HAVE_SOCKET_HYPERV = _have_socket_hyperv() # Size in bytes of the int type @@ -485,8 +554,8 @@ def clientTearDown(self): @unittest.skipIf(WSL, 'VSOCK does not work on Microsoft WSL') @unittest.skipUnless(HAVE_SOCKET_VSOCK, 'VSOCK sockets required for this test.') -@unittest.skipUnless(get_cid() != 2, - "This test can only be run on a virtual guest.") +@unittest.skipUnless(get_cid() != 2, # VMADDR_CID_HOST + "This test can only be run on a virtual guest.") class ThreadedVSOCKSocketStreamTest(unittest.TestCase, ThreadableTest): def __init__(self, methodName='runTest'): @@ -508,10 +577,16 @@ def clientSetUp(self): self.cli = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM) self.addCleanup(self.cli.close) cid = get_cid() + if cid in (socket.VMADDR_CID_HOST, socket.VMADDR_CID_ANY): + # gh-119461: Use the local communication address (loopback) + cid = VMADDR_CID_LOCAL self.cli.connect((cid, VSOCKPORT)) def testStream(self): - msg = self.conn.recv(1024) + try: + msg = self.conn.recv(1024) + except PermissionError as exc: + self.skipTest(repr(exc)) self.assertEqual(msg, MSG) def _testStream(self): @@ -556,19 +631,27 @@ class SocketPairTest(unittest.TestCase, ThreadableTest): def __init__(self, methodName='runTest'): unittest.TestCase.__init__(self, methodName=methodName) ThreadableTest.__init__(self) + self.cli = None + self.serv = None + + def socketpair(self): + # To be overridden by some child classes. + return socket.socketpair() def setUp(self): - self.serv, self.cli = socket.socketpair() + self.serv, self.cli = self.socketpair() def tearDown(self): - self.serv.close() + if self.serv: + self.serv.close() self.serv = None def clientSetUp(self): pass def clientTearDown(self): - self.cli.close() + if self.cli: + self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) @@ -821,9 +904,9 @@ def requireSocket(*args): class GeneralModuleTests(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(_socket is not None, 'need _socket module') + # TODO: RUSTPYTHON; gc.is_tracked not implemented + @unittest.expectedFailure def test_socket_type(self): self.assertTrue(gc.is_tracked(_socket.socket)) with self.assertRaisesRegex(TypeError, "immutable"): @@ -886,7 +969,7 @@ def testSocketError(self): with self.assertRaises(OSError, msg=msg % 'socket.gaierror'): raise socket.gaierror - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON; error message format differs @unittest.expectedFailure def testSendtoErrors(self): # Testing that sendto doesn't mask failures. See #10169. @@ -973,8 +1056,6 @@ def test_socket_methods(self): if not hasattr(socket.socket, name): self.fail(f"socket method {name} is missing") - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test') @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') def test3542SocketOptions(self): @@ -1090,7 +1171,10 @@ def testInterfaceNameIndex(self): @unittest.skipUnless(hasattr(socket, 'if_indextoname'), 'socket.if_indextoname() not available.') def testInvalidInterfaceIndexToName(self): - self.assertRaises(OSError, socket.if_indextoname, 0) + with self.assertRaises(OSError) as cm: + socket.if_indextoname(0) + self.assertIsNotNone(cm.exception.errno) + self.assertRaises(OverflowError, socket.if_indextoname, -1) self.assertRaises(OverflowError, socket.if_indextoname, 2**1000) self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF') @@ -1109,8 +1193,11 @@ def testInvalidInterfaceIndexToName(self): @unittest.skipUnless(hasattr(socket, 'if_nametoindex'), 'socket.if_nametoindex() not available.') def testInvalidInterfaceNameToIndex(self): + with self.assertRaises(OSError) as cm: + socket.if_nametoindex("_DEADBEEF") + self.assertIsNotNone(cm.exception.errno) + self.assertRaises(TypeError, socket.if_nametoindex, 0) - self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF') @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()') @@ -1147,6 +1234,7 @@ def testNtoH(self): self.assertRaises(OverflowError, func, 1<<34) @support.cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") def testNtoHErrors(self): import _testcapi s_good_values = [0, 1, 2, 0xffff] @@ -1175,8 +1263,11 @@ def testGetServBy(self): # Find one service that exists, then check all the related interfaces. # I've ordered this by protocols that have both a tcp and udp # protocol, at least for modern Linuxes. - if (sys.platform.startswith(('freebsd', 'netbsd', 'gnukfreebsd')) - or sys.platform in ('linux', 'darwin')): + if ( + sys.platform.startswith( + ('linux', 'android', 'freebsd', 'netbsd', 'gnukfreebsd')) + or is_apple + ): # avoid the 'echo' service on this platform, as there is an # assumption breaking non-standard port/protocol entry services = ('daytime', 'qotd', 'domain') @@ -1191,9 +1282,8 @@ def testGetServBy(self): else: raise OSError # Try same call with optional protocol omitted - # Issue #26936: Android getservbyname() was broken before API 23. - if (not hasattr(sys, 'getandroidapilevel') or - sys.getandroidapilevel() >= 23): + # Issue gh-71123: this fails on Android before API level 23. + if not (support.is_android and platform.android_ver().api_level < 23): port2 = socket.getservbyname(service) eq(port, port2) # Try udp, but don't barf if it doesn't exist @@ -1204,8 +1294,9 @@ def testGetServBy(self): else: eq(udpport, port) # Now make sure the lookup by port returns the same service name - # Issue #26936: Android getservbyport() is broken. - if not support.is_android: + # Issue #26936: when the protocol is omitted, this fails on Android + # before API level 28. + if not (support.is_android and platform.android_ver().api_level < 28): eq(socket.getservbyport(port2), service) eq(socket.getservbyport(port, 'tcp'), service) if udpport is not None: @@ -1503,8 +1594,6 @@ def test_getsockaddrarg(self): break @unittest.skipUnless(os.name == "nt", "Windows specific") - # TODO: RUSTPYTHON, windows ioctls - @unittest.expectedFailure def test_sock_ioctl(self): self.assertTrue(hasattr(socket.socket, 'ioctl')) self.assertTrue(hasattr(socket, 'SIO_RCVALL')) @@ -1519,8 +1608,6 @@ def test_sock_ioctl(self): @unittest.skipUnless(os.name == "nt", "Windows specific") @unittest.skipUnless(hasattr(socket, 'SIO_LOOPBACK_FAST_PATH'), 'Loopback fast path support required for this test') - # TODO: RUSTPYTHON, AttributeError: 'socket' object has no attribute 'ioctl' - @unittest.expectedFailure def test_sio_loopback_fast_path(self): s = socket.socket() self.addCleanup(s.close) @@ -1554,9 +1641,8 @@ def testGetaddrinfo(self): socket.getaddrinfo('::1', 80) # port can be a string service name such as "http", a numeric # port number or None - # Issue #26936: Android getaddrinfo() was broken before API level 23. - if (not hasattr(sys, 'getandroidapilevel') or - sys.getandroidapilevel() >= 23): + # Issue #26936: this fails on Android before API level 23. + if not (support.is_android and platform.android_ver().api_level < 23): socket.getaddrinfo(HOST, "http") socket.getaddrinfo(HOST, 80) socket.getaddrinfo(HOST, None) @@ -1602,8 +1688,7 @@ def testGetaddrinfo(self): flags=socket.AI_PASSIVE) self.assertEqual(a, b) # Issue #6697. - # XXX RUSTPYTHON TODO: surrogates in str - # self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800') + self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800') # Issue 17269: test workaround for OS X platform bug segfault if hasattr(socket, 'AI_NUMERICSERV'): @@ -1615,8 +1700,7 @@ def testGetaddrinfo(self): except socket.gaierror: pass - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.skipIf(_testcapi is None, "requires _testcapi") def test_getaddrinfo_int_port_overflow(self): # gh-74895: Test that getaddrinfo does not raise OverflowError on port. # @@ -1634,7 +1718,7 @@ def test_getaddrinfo_int_port_overflow(self): try: socket.getaddrinfo(None, ULONG_MAX + 1, type=socket.SOCK_STREAM) except OverflowError: - # Platforms differ as to what values consitute a getaddrinfo() error + # Platforms differ as to what values constitute a getaddrinfo() error # return. Some fail for LONG_MAX+1, others ULONG_MAX+1, and Windows # silently accepts such huge "port" aka "service" numeric values. self.fail("Either no error or socket.gaierror expected.") @@ -1669,7 +1753,6 @@ def test_getnameinfo(self): # only IP addresses are allowed self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) - @unittest.skip("TODO: RUSTPYTHON: flaky on CI?") @unittest.skipUnless(support.is_resource_enabled('network'), 'network is not enabled') def test_idna(self): @@ -1724,8 +1807,6 @@ def test_sendall_interrupted(self): def test_sendall_interrupted_with_timeout(self): self.check_sendall_interrupted(True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dealloc_warn(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) r = repr(sock) @@ -1813,6 +1894,7 @@ def test_listen_backlog(self): srv.listen() @support.cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") def test_listen_backlog_overflow(self): # Issue 15989 import _testcapi @@ -1905,7 +1987,6 @@ def test_str_for_enums(self): self.assertEqual(str(s.family), str(s.family.value)) self.assertEqual(str(s.type), str(s.type.value)) - @unittest.expectedFailureIf(sys.platform.startswith("linux"), "TODO: RUSTPYTHON, AssertionError: 526337 != ") def test_socket_consistent_sock_type(self): SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0) @@ -2098,8 +2179,6 @@ def testCrucialConstants(self): @unittest.skipUnless(hasattr(socket, "CAN_BCM"), 'socket.CAN_BCM required for this test.') - # TODO: RUSTPYTHON, AttributeError: module 'socket' has no attribute 'CAN_BCM_TX_SETUP' - @unittest.expectedFailure def testBCMConstants(self): socket.CAN_BCM @@ -2140,16 +2219,12 @@ def testCreateBCMSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s: pass - # TODO: RUSTPYTHON, OSError: bind(): bad family - @unittest.expectedFailure def testBindAny(self): with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: address = ('', ) s.bind(address) self.assertEqual(s.getsockname(), address) - # TODO: RUSTPYTHON, AssertionError: "interface name too long" does not match "bind(): bad family" - @unittest.expectedFailure def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: @@ -2292,16 +2367,12 @@ def testCreateISOTPSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: pass - # TODO: RUSTPYTHON, OSError: bind(): bad family - @unittest.expectedFailure def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: with self.assertRaisesRegex(OSError, 'interface name too long'): s.bind(('x' * 1024, 1, 2)) - # TODO: RUSTPYTHON, OSError: bind(): bad family - @unittest.expectedFailure def testBind(self): try: with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: @@ -2323,7 +2394,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.interface = "vcan0" - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON - J1939 constants not fully implemented @unittest.expectedFailure @unittest.skipUnless(hasattr(socket, "CAN_J1939"), 'socket.CAN_J1939 required for this test.') @@ -2366,7 +2437,7 @@ def testCreateJ1939Socket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_J1939) as s: pass - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON - AF_CAN J1939 address format not fully implemented @unittest.expectedFailure def testBind(self): try: @@ -2579,6 +2650,143 @@ def testCreateScoSocket(self): with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s: pass + @unittest.skipUnless(HAVE_SOCKET_BLUETOOTH_L2CAP, 'Bluetooth L2CAP sockets required for this test') + def testBindBrEdrL2capSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as f: + # First user PSM in BR/EDR L2CAP + psm = 0x1001 + f.bind((socket.BDADDR_ANY, psm)) + addr = f.getsockname() + self.assertEqual(addr, (socket.BDADDR_ANY, psm)) + + @unittest.skipUnless(HAVE_SOCKET_BLUETOOTH_L2CAP, 'Bluetooth L2CAP sockets required for this test') + def testBadL2capAddr(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_L2CAP) as f: + with self.assertRaises(OSError): + f.bind((socket.BDADDR_ANY, 0, 0)) + with self.assertRaises(OSError): + f.bind((socket.BDADDR_ANY,)) + with self.assertRaises(OSError): + f.bind(socket.BDADDR_ANY) + with self.assertRaises(OSError): + f.bind((socket.BDADDR_ANY.encode(), 0x1001)) + with self.assertRaises(OSError): + f.bind(('\ud812', 0x1001)) + + def testBindRfcommSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s: + channel = 0 + try: + s.bind((socket.BDADDR_ANY, channel)) + except OSError as err: + if sys.platform == 'win32' and err.winerror == 10050: + self.skipTest(str(err)) + raise + addr = s.getsockname() + self.assertEqual(addr, (mock.ANY, channel)) + self.assertRegex(addr[0], r'(?i)[0-9a-f]{2}(?::[0-9a-f]{2}){4}') + if sys.platform != 'win32': + self.assertEqual(addr, (socket.BDADDR_ANY, channel)) + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s: + s.bind(addr) + addr2 = s.getsockname() + self.assertEqual(addr2, addr) + + def testBadRfcommAddr(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_STREAM, socket.BTPROTO_RFCOMM) as s: + channel = 0 + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY.encode(), channel)) + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY,)) + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY, channel, 0)) + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY + '\0', channel)) + with self.assertRaises(OSError): + s.bind('\ud812') + with self.assertRaises(OSError): + s.bind(('invalid', channel)) + + @unittest.skipUnless(hasattr(socket, 'BTPROTO_HCI'), 'Bluetooth HCI sockets required for this test') + def testBindHciSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s: + if sys.platform.startswith(('netbsd', 'dragonfly', 'freebsd')): + s.bind(socket.BDADDR_ANY) + addr = s.getsockname() + self.assertEqual(addr, socket.BDADDR_ANY) + else: + dev = 0 + try: + s.bind((dev,)) + except OSError as err: + if err.errno in (errno.EINVAL, errno.ENODEV): + self.skipTest(str(err)) + raise + addr = s.getsockname() + self.assertEqual(addr, dev) + + @unittest.skipUnless(hasattr(socket, 'BTPROTO_HCI'), 'Bluetooth HCI sockets required for this test') + def testBadHciAddr(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW, socket.BTPROTO_HCI) as s: + if sys.platform.startswith(('netbsd', 'dragonfly', 'freebsd')): + with self.assertRaises(OSError): + s.bind(socket.BDADDR_ANY.encode()) + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY,)) + with self.assertRaises(OSError): + s.bind(socket.BDADDR_ANY + '\0') + with self.assertRaises((ValueError, OSError)): + s.bind(socket.BDADDR_ANY + ' '*100) + with self.assertRaises(OSError): + s.bind('\ud812') + with self.assertRaises(OSError): + s.bind('invalid') + with self.assertRaises(OSError): + s.bind(b'invalid') + else: + dev = 0 + with self.assertRaises(OSError): + s.bind(()) + with self.assertRaises(OSError): + s.bind((dev, 0)) + with self.assertRaises(OSError): + s.bind(dev) + with self.assertRaises(OSError): + s.bind(socket.BDADDR_ANY) + with self.assertRaises(OSError): + s.bind(socket.BDADDR_ANY.encode()) + + @unittest.skipUnless(hasattr(socket, 'BTPROTO_SCO'), 'Bluetooth SCO sockets required for this test') + def testBindScoSocket(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s: + s.bind(socket.BDADDR_ANY) + addr = s.getsockname() + self.assertEqual(addr, socket.BDADDR_ANY) + + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s: + s.bind(socket.BDADDR_ANY.encode()) + addr = s.getsockname() + self.assertEqual(addr, socket.BDADDR_ANY) + + @unittest.skipUnless(hasattr(socket, 'BTPROTO_SCO'), 'Bluetooth SCO sockets required for this test') + def testBadScoAddr(self): + with socket.socket(socket.AF_BLUETOOTH, socket.SOCK_SEQPACKET, socket.BTPROTO_SCO) as s: + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY,)) + with self.assertRaises(OSError): + s.bind((socket.BDADDR_ANY.encode(),)) + with self.assertRaises(ValueError): + s.bind(socket.BDADDR_ANY + '\0') + with self.assertRaises(ValueError): + s.bind(socket.BDADDR_ANY.encode() + b'\0') + with self.assertRaises(UnicodeEncodeError): + s.bind('\ud812') + with self.assertRaises(OSError): + s.bind('invalid') + with self.assertRaises(OSError): + s.bind(b'invalid') + @unittest.skipUnless(HAVE_SOCKET_HYPERV, 'Hyper-V sockets required for this test.') @@ -2709,22 +2917,29 @@ def testDup(self): def _testDup(self): self.serv_conn.send(MSG) - def testShutdown(self): - # Testing shutdown() + def check_shutdown(self): + # Test shutdown() helper msg = self.cli_conn.recv(1024) self.assertEqual(msg, MSG) - # wait for _testShutdown to finish: on OS X, when the server + # wait for _testShutdown[_overflow] to finish: on OS X, when the server # closes the connection the client also becomes disconnected, # and the client's shutdown call will fail. (Issue #4397.) self.done.wait() + def testShutdown(self): + self.check_shutdown() + def _testShutdown(self): self.serv_conn.send(MSG) self.serv_conn.shutdown(2) - testShutdown_overflow = support.cpython_only(testShutdown) + @support.cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") + def testShutdown_overflow(self): + self.check_shutdown() @support.cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") def _testShutdown_overflow(self): import _testcapi self.serv_conn.send(MSG) @@ -3195,7 +3410,7 @@ def _testSendmsgTimeout(self): # Linux supports MSG_DONTWAIT when sending, but in general, it # only works when receiving. Could add other platforms if they # support it too. - @skipWithClientIf(sys.platform not in {"linux"}, + @skipWithClientIf(sys.platform not in {"linux", "android"}, "MSG_DONTWAIT not known to work on this platform when " "sending") def testSendmsgDontWait(self): @@ -3712,7 +3927,7 @@ def testFDPassCMSG_LEN(self): def _testFDPassCMSG_LEN(self): self.createAndSendFDs(1) - @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958") + @unittest.skipIf(is_apple, "skipping, see issue #12958") @unittest.skipIf(AIX, "skipping, see issue #22397") @requireAttrs(socket, "CMSG_SPACE") def testFDPassSeparate(self): @@ -3723,7 +3938,7 @@ def testFDPassSeparate(self): maxcmsgs=2) @testFDPassSeparate.client_skip - @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958") + @unittest.skipIf(is_apple, "skipping, see issue #12958") @unittest.skipIf(AIX, "skipping, see issue #22397") def _testFDPassSeparate(self): fd0, fd1 = self.newFDs(2) @@ -3736,7 +3951,7 @@ def _testFDPassSeparate(self): array.array("i", [fd1]))]), len(MSG)) - @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958") + @unittest.skipIf(is_apple, "skipping, see issue #12958") @unittest.skipIf(AIX, "skipping, see issue #22397") @requireAttrs(socket, "CMSG_SPACE") def testFDPassSeparateMinSpace(self): @@ -3750,7 +3965,7 @@ def testFDPassSeparateMinSpace(self): maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC) @testFDPassSeparateMinSpace.client_skip - @unittest.skipIf(sys.platform == "darwin", "skipping, see issue #12958") + @unittest.skipIf(is_apple, "skipping, see issue #12958") @unittest.skipIf(AIX, "skipping, see issue #22397") def _testFDPassSeparateMinSpace(self): fd0, fd1 = self.newFDs(2) @@ -3774,7 +3989,7 @@ def sendAncillaryIfPossible(self, msg, ancdata): nbytes = self.sendmsgToServer([msg]) self.assertEqual(nbytes, len(msg)) - @unittest.skipIf(sys.platform == "darwin", "see issue #24725") + @unittest.skipIf(is_apple, "skipping, see issue #12958") def testFDPassEmpty(self): # Try to pass an empty FD array. Can receive either no array # or an empty array. @@ -3848,6 +4063,7 @@ def checkTruncatedHeader(self, result, ignoreflags=0): self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC, ignore=ignoreflags) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncNoBufSize(self): # Check that no ancillary data is received when no buffer size # is specified. @@ -3857,26 +4073,32 @@ def testCmsgTruncNoBufSize(self): # received. ignoreflags=socket.MSG_CTRUNC) + @testCmsgTruncNoBufSize.client_skip def _testCmsgTruncNoBufSize(self): self.createAndSendFDs(1) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTrunc0(self): # Check that no ancillary data is received when buffer size is 0. self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0), ignoreflags=socket.MSG_CTRUNC) + @testCmsgTrunc0.client_skip def _testCmsgTrunc0(self): self.createAndSendFDs(1) # Check that no ancillary data is returned for various non-zero # (but still too small) buffer sizes. + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTrunc1(self): self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1)) + @testCmsgTrunc1.client_skip def _testCmsgTrunc1(self): self.createAndSendFDs(1) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTrunc2Int(self): # The cmsghdr structure has at least three members, two of # which are ints, so we still shouldn't see any ancillary @@ -3884,13 +4106,16 @@ def testCmsgTrunc2Int(self): self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), SIZEOF_INT * 2)) + @testCmsgTrunc2Int.client_skip def _testCmsgTrunc2Int(self): self.createAndSendFDs(1) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncLen0Minus1(self): self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), socket.CMSG_LEN(0) - 1)) + @testCmsgTruncLen0Minus1.client_skip def _testCmsgTruncLen0Minus1(self): self.createAndSendFDs(1) @@ -3902,8 +4127,9 @@ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0): # mindata and maxdata bytes when received with buffer size # ancbuf, and that any complete file descriptor numbers are # valid. - msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, - len(MSG), ancbuf) + with downgrade_malformed_data_warning(): # TODO: gh-110012 + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbuf) self.assertEqual(msg, MSG) self.checkRecvmsgAddress(addr, self.cli_addr) self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) @@ -3921,29 +4147,38 @@ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0): len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) self.checkFDs(fds) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncLen0(self): self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0) + @testCmsgTruncLen0.client_skip def _testCmsgTruncLen0(self): self.createAndSendFDs(1) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncLen0Plus1(self): self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1) + @testCmsgTruncLen0Plus1.client_skip def _testCmsgTruncLen0Plus1(self): self.createAndSendFDs(2) + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncLen1(self): self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT), maxdata=SIZEOF_INT) + @testCmsgTruncLen1.client_skip def _testCmsgTruncLen1(self): self.createAndSendFDs(2) + + @skipForRefleakHuntinIf(sys.platform == "darwin", "#80931") def testCmsgTruncLen2Minus1(self): self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1, maxdata=(2 * SIZEOF_INT) - 1) + @testCmsgTruncLen2Minus1.client_skip def _testCmsgTruncLen2Minus1(self): self.createAndSendFDs(2) @@ -4245,8 +4480,9 @@ def testSingleCmsgTruncInData(self): self.serv_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_RECVHOPLIMIT, 1) self.misc_event.set() - msg, ancdata, flags, addr = self.doRecvmsg( - self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1) + with downgrade_malformed_data_warning(): # TODO: gh-110012 + msg, ancdata, flags, addr = self.doRecvmsg( + self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1) self.assertEqual(msg, MSG) self.checkRecvmsgAddress(addr, self.cli_addr) @@ -4349,9 +4585,10 @@ def testSecondCmsgTruncInData(self): self.serv_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_RECVTCLASS, 1) self.misc_event.set() - msg, ancdata, flags, addr = self.doRecvmsg( - self.serv_sock, len(MSG), - socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1) + with downgrade_malformed_data_warning(): # TODO: gh-110012 + msg, ancdata, flags, addr = self.doRecvmsg( + self.serv_sock, len(MSG), + socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1) self.assertEqual(msg, MSG) self.checkRecvmsgAddress(addr, self.cli_addr) @@ -4593,14 +4830,15 @@ class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase, ConnectedStreamTestMixin, UnixStreamBase): pass -@unittest.skipIf(sys.platform == "darwin", "flaky on macOS") @requireAttrs(socket.socket, "sendmsg") @requireAttrs(socket, "AF_UNIX") +@unittest.skip("TODO: RUSTPYTHON; accept() on Unix sockets returns EINVAL") class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg") @requireAttrs(socket, "AF_UNIX") +@unittest.skip("TODO: RUSTPYTHON; intermittent accept() EINVAL on Unix sockets") class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgUnixStreamTestBase): pass @@ -4613,6 +4851,7 @@ class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, @requireAttrs(socket.socket, "sendmsg", "recvmsg") @requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") +@unittest.skip("TODO: RUSTPYTHON; intermittent accept() EINVAL on Unix sockets") class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase): pass @@ -4816,6 +5055,112 @@ def _testSend(self): self.assertEqual(msg, MSG) +class PurePythonSocketPairTest(SocketPairTest): + # Explicitly use socketpair AF_INET or AF_INET6 to ensure that is the + # code path we're using regardless platform is the pure python one where + # `_socket.socketpair` does not exist. (AF_INET does not work with + # _socket.socketpair on many platforms). + def socketpair(self): + # called by super().setUp(). + try: + return socket.socketpair(socket.AF_INET6) + except OSError: + return socket.socketpair(socket.AF_INET) + + # Local imports in this class make for easy security fix backporting. + + def setUp(self): + if hasattr(_socket, "socketpair"): + self._orig_sp = socket.socketpair + # This forces the version using the non-OS provided socketpair + # emulation via an AF_INET socket in Lib/socket.py. + socket.socketpair = socket._fallback_socketpair + else: + # This platform already uses the non-OS provided version. + self._orig_sp = None + super().setUp() + + def tearDown(self): + super().tearDown() + if self._orig_sp is not None: + # Restore the default socket.socketpair definition. + socket.socketpair = self._orig_sp + + def test_recv(self): + msg = self.serv.recv(1024) + self.assertEqual(msg, MSG) + + def _test_recv(self): + self.cli.send(MSG) + + def test_send(self): + self.serv.send(MSG) + + def _test_send(self): + msg = self.cli.recv(1024) + self.assertEqual(msg, MSG) + + def test_ipv4(self): + cli, srv = socket.socketpair(socket.AF_INET) + cli.close() + srv.close() + + def _test_ipv4(self): + pass + + @unittest.skipIf(not hasattr(_socket, 'IPPROTO_IPV6') or + not hasattr(_socket, 'IPV6_V6ONLY'), + "IPV6_V6ONLY option not supported") + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test') + def test_ipv6(self): + cli, srv = socket.socketpair(socket.AF_INET6) + cli.close() + srv.close() + + def _test_ipv6(self): + pass + + def test_injected_authentication_failure(self): + orig_getsockname = socket.socket.getsockname + inject_sock = None + + def inject_getsocketname(self): + nonlocal inject_sock + sockname = orig_getsockname(self) + # Connect to the listening socket ahead of the + # client socket. + if inject_sock is None: + inject_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + inject_sock.setblocking(False) + try: + inject_sock.connect(sockname[:2]) + except (BlockingIOError, InterruptedError): + pass + inject_sock.setblocking(True) + return sockname + + sock1 = sock2 = None + try: + socket.socket.getsockname = inject_getsocketname + with self.assertRaises(OSError): + sock1, sock2 = socket.socketpair() + finally: + socket.socket.getsockname = orig_getsockname + if inject_sock: + inject_sock.close() + if sock1: # This cleanup isn't needed on a successful test. + sock1.close() + if sock2: + sock2.close() + + def _test_injected_authentication_failure(self): + # No-op. Exists for base class threading infrastructure to call. + # We could refactor this test into its own lesser class along with the + # setUp and tearDown code to construct an ideal; it is simpler to keep + # it here and live with extra overhead one this _one_ failure test. + pass + + class NonBlockingTCPTests(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): @@ -4863,6 +5208,7 @@ def _testSetBlocking(self): pass @support.cpython_only + @unittest.skipIf(_testcapi is None, "requires _testcapi") def testSetBlocking_overflow(self): # Issue 15989 import _testcapi @@ -4880,8 +5226,6 @@ def testSetBlocking_overflow(self): @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), 'test needs socket.SOCK_NONBLOCK') @support.requires_linux_version(2, 6, 28) - # TODO: RUSTPYTHON, AssertionError: None != 0 - @unittest.expectedFailure def testInitNonBlocking(self): # create a socket with SOCK_NONBLOCK self.serv.close() @@ -4977,6 +5321,39 @@ def _testRecv(self): # send data: recv() will no longer block self.cli.sendall(MSG) + def testLargeTimeout(self): + # gh-126876: Check that a timeout larger than INT_MAX is replaced with + # INT_MAX in the poll() code path. The following assertion must not + # fail: assert(INT_MIN <= ms && ms <= INT_MAX). + if _testcapi is not None: + large_timeout = _testcapi.INT_MAX + 1 + else: + large_timeout = 2147483648 + + # test recv() with large timeout + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + try: + conn.settimeout(large_timeout) + except OverflowError: + # On Windows, settimeout() fails with OverflowError, whereas + # we want to test recv(). Just give up silently. + return + msg = conn.recv(len(MSG)) + + def _testLargeTimeout(self): + # test sendall() with large timeout + if _testcapi is not None: + large_timeout = _testcapi.INT_MAX + 1 + else: + large_timeout = 2147483648 + self.cli.connect((HOST, self.port)) + try: + self.cli.settimeout(large_timeout) + except OverflowError: + return + self.cli.sendall(MSG) + class FileObjectClassTestCase(SocketConnectedTest): """Unit tests for the object returned by socket.makefile() @@ -5179,6 +5556,8 @@ def _testMakefileClose(self): self.write_file.write(self.write_msg) self.write_file.flush() + @unittest.skipUnless(hasattr(sys, 'getrefcount'), + 'test needs sys.getrefcount()') def testMakefileCloseSocketDestroy(self): refcount_before = sys.getrefcount(self.cli_conn) self.read_file.close() @@ -5617,7 +5996,7 @@ def test_setblocking_invalidfd(self): sock.setblocking(False) -@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test') +@unittest.skipUnless(sys.platform in ('linux', 'android'), 'Linux specific test') class TestLinuxAbstractNamespace(unittest.TestCase): UNIX_PATH_MAX = 108 @@ -5742,7 +6121,8 @@ def testUnencodableAddr(self): self.addCleanup(os_helper.unlink, path) self.assertEqual(self.sock.getsockname(), path) - @unittest.skipIf(sys.platform == 'linux', 'Linux specific test') + @unittest.skipIf(sys.platform in ('linux', 'android'), + 'Linux behavior is tested by TestLinuxAbstractNamespace') def testEmptyAddress(self): # Test that binding empty address fails. self.assertRaises(OSError, self.sock.bind, "") @@ -5968,8 +6348,6 @@ class InheritanceTest(unittest.TestCase): @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), "SOCK_CLOEXEC not defined") @support.requires_linux_version(2, 6, 28) - # TODO: RUSTPYTHON, AssertionError: 524289 != - @unittest.expectedFailure def test_SOCK_CLOEXEC(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: @@ -6062,8 +6440,6 @@ def checkNonblock(self, s, nonblock=True, timeout=0.0): self.assertTrue(s.getblocking()) @support.requires_linux_version(2, 6, 28) - # TODO: RUSTPYTHON, AssertionError: 2049 != - @unittest.expectedFailure def test_SOCK_NONBLOCK(self): # a lot of it seems silly and redundant, but I wanted to test that # changing back and forth worked ok @@ -6099,7 +6475,6 @@ def test_SOCK_NONBLOCK(self): @unittest.skipUnless(os.name == "nt", "Windows specific") @unittest.skipUnless(multiprocessing, "need multiprocessing") -@unittest.skip("TODO: RUSTPYTHON, socket sharing") class TestSocketSharing(SocketTCPTest): # This must be classmethod and not staticmethod or multiprocessing # won't be able to bootstrap it. @@ -6117,6 +6492,8 @@ def remoteProcessServer(cls, q): s2.close() s.close() + # TODO: RUSTPYTHON; multiprocessing.SemLock not implemented + @unittest.expectedFailure def testShare(self): # Transfer the listening server socket to another process # and service it from there. @@ -6340,7 +6717,6 @@ def _testCount(self): self.assertEqual(sent, count) self.assertEqual(file.tell(), count) - @unittest.skipIf(sys.platform == "darwin", "TODO: RUSTPYTHON, killed (for OOM?)") def testCount(self): count = 5000007 conn = self.accept_conn() @@ -6416,7 +6792,6 @@ def _testWithTimeout(self): sent = meth(file) self.assertEqual(sent, self.FILESIZE) - @unittest.skip("TODO: RUSTPYTHON") def testWithTimeout(self): conn = self.accept_conn() data = self.recv_data(conn) @@ -6470,6 +6845,7 @@ def test_errors(self): @unittest.skipUnless(hasattr(os, "sendfile"), 'os.sendfile() required for this test.') +@unittest.skip("TODO: RUSTPYTHON; os.sendfile count parameter not handled correctly") class SendfileUsingSendfileTest(SendfileUsingSendTest): """ Test the sendfile() implementation of socket.sendfile(). @@ -6494,9 +6870,9 @@ def create_alg(self, typ, name): # bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY, # at least on ppc64le architecture - @support.requires_linux_version(4, 5) - # TODO: RUSTPYTHON, OSError: bind(): bad family + # TODO: RUSTPYTHON - AF_ALG not fully implemented @unittest.expectedFailure + @support.requires_linux_version(4, 5) def test_sha256(self): expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396" "177a9cb410ff61f20015ad") @@ -6514,7 +6890,7 @@ def test_sha256(self): op.send(b'') self.assertEqual(op.recv(512), expected) - # TODO: RUSTPYTHON, OSError: bind(): bad family + # TODO: RUSTPYTHON - AF_ALG not fully implemented @unittest.expectedFailure def test_hmac_sha1(self): # gh-109396: In FIPS mode, Linux 6.5 requires a key @@ -6531,9 +6907,9 @@ def test_hmac_sha1(self): # Although it should work with 3.19 and newer the test blocks on # Ubuntu 15.10 with Kernel 4.2.0-19. - @support.requires_linux_version(4, 3) - # TODO: RUSTPYTHON, OSError: bind(): bad family + # TODO: RUSTPYTHON - AF_ALG not fully implemented @unittest.expectedFailure + @support.requires_linux_version(4, 3) def test_aes_cbc(self): key = bytes.fromhex('06a9214036b8a15b512e03d534120006') iv = bytes.fromhex('3dafba429d9eb430b422da802c9fac41') @@ -6574,10 +6950,16 @@ def test_aes_cbc(self): self.assertEqual(len(dec), msglen * multiplier) self.assertEqual(dec, msg * multiplier) - @support.requires_linux_version(4, 9) # see issue29324 - # TODO: RUSTPYTHON, OSError: bind(): bad family + # TODO: RUSTPYTHON - AF_ALG not fully implemented @unittest.expectedFailure + @support.requires_linux_version(4, 9) # see gh-73510 def test_aead_aes_gcm(self): + kernel_version = support._get_kernel_version("Linux") + if kernel_version is not None: + if kernel_version >= (6, 16) and kernel_version < (6, 18): + # See https://github.com/python/cpython/issues/139310. + self.skipTest("upstream Linux kernel issue") + key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c') iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2') plain = bytes.fromhex('c3b3c41f113a31b73d9a5cd432103069') @@ -6639,9 +7021,9 @@ def test_aead_aes_gcm(self): res = op.recv(len(msg) - taglen) self.assertEqual(plain, res[assoclen:]) - @support.requires_linux_version(4, 3) # see test_aes_cbc - # TODO: RUSTPYTHON, OSError: bind(): bad family + # TODO: RUSTPYTHON - AF_ALG not fully implemented @unittest.expectedFailure + @support.requires_linux_version(4, 3) # see test_aes_cbc def test_drbg_pr_sha256(self): # deterministic random bit generator, prediction resistance, sha256 with self.create_alg('rng', 'drbg_pr_sha256') as algo: @@ -6652,8 +7034,6 @@ def test_drbg_pr_sha256(self): rn = op.recv(32) self.assertEqual(len(rn), 32) - # TODO: RUSTPYTHON, AttributeError: 'socket' object has no attribute 'sendmsg_afalg' - @unittest.expectedFailure def test_sendmsg_afalg_args(self): sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) with sock: @@ -6672,8 +7052,6 @@ def test_sendmsg_afalg_args(self): with self.assertRaises(TypeError): sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1) - # TODO: RUSTPYTHON, OSError: bind(): bad family - @unittest.expectedFailure def test_length_restriction(self): # bpo-35050, off-by-one error in length check sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index 08b05b56aa8..d6250e472bb 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -63,6 +63,10 @@ mod _socket { SOL_SOCKET, SOMAXCONN, TCP_NODELAY, WSAEBADF, WSAECONNRESET, WSAENOTSOCK, WSAEWOULDBLOCK, }; + pub use windows_sys::Win32::Networking::WinSock::{ + INVALID_SOCKET, SOCKET_ERROR, WSA_FLAG_OVERLAPPED, WSADuplicateSocketW, + WSAGetLastError, WSAIoctl, WSAPROTOCOL_INFOW, WSASocketW, + }; pub use windows_sys::Win32::Networking::WinSock::{ SO_REUSEADDR as SO_EXCLUSIVEADDRUSE, getprotobyname, getservbyname, getservbyport, getsockopt, setsockopt, @@ -82,6 +86,7 @@ mod _socket { pub const AI_PASSIVE: i32 = windows_sys::Win32::Networking::WinSock::AI_PASSIVE as _; pub const AI_NUMERICHOST: i32 = windows_sys::Win32::Networking::WinSock::AI_NUMERICHOST as _; + pub const FROM_PROTOCOL_INFO: i32 = -1; } // constants #[pyattr(name = "has_ipv6")] @@ -138,6 +143,82 @@ mod _socket { SOL_CAN_RAW, }; + // CAN BCM opcodes + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_SETUP: i32 = 1; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_DELETE: i32 = 2; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_READ: i32 = 3; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_SEND: i32 = 4; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_SETUP: i32 = 5; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_DELETE: i32 = 6; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_READ: i32 = 7; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_STATUS: i32 = 8; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_EXPIRED: i32 = 9; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_STATUS: i32 = 10; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_TIMEOUT: i32 = 11; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_CHANGED: i32 = 12; + + // CAN BCM flags (linux/can/bcm.h) + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_SETTIMER: i32 = 0x0001; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_STARTTIMER: i32 = 0x0002; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_COUNTEVT: i32 = 0x0004; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_ANNOUNCE: i32 = 0x0008; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_CP_CAN_ID: i32 = 0x0010; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_FILTER_ID: i32 = 0x0020; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_CHECK_DLC: i32 = 0x0040; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_NO_AUTOTIMER: i32 = 0x0080; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_ANNOUNCE_RESUME: i32 = 0x0100; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_TX_RESET_MULTI_IDX: i32 = 0x0200; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_RX_RTR_FRAME: i32 = 0x0400; + #[cfg(target_os = "linux")] + #[pyattr] + const CAN_BCM_CAN_FD_FRAME: i32 = 0x0800; + #[cfg(all(target_os = "linux", target_env = "gnu"))] #[pyattr] use c::SOL_RDS; @@ -150,6 +231,48 @@ mod _socket { #[pyattr] use c::{AF_SYSTEM, PF_SYSTEM, SYSPROTO_CONTROL, TCP_KEEPALIVE}; + // RFC3542 IPv6 socket options for macOS (netinet6/in6.h) + // Not available in libc, define manually + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RECVHOPLIMIT: i32 = 37; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RECVRTHDR: i32 = 38; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RECVHOPOPTS: i32 = 39; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RECVDSTOPTS: i32 = 40; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_USE_MIN_MTU: i32 = 42; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RECVPATHMTU: i32 = 43; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_PATHMTU: i32 = 44; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_NEXTHOP: i32 = 48; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_HOPOPTS: i32 = 49; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_DSTOPTS: i32 = 50; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RTHDR: i32 = 51; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RTHDRDSTOPTS: i32 = 57; + #[cfg(target_vendor = "apple")] + #[pyattr] + const IPV6_RTHDR_TYPE_0: i32 = 0; + #[cfg(windows)] #[pyattr] use c::{ @@ -415,6 +538,7 @@ mod _socket { target_os = "dragonfly", target_os = "freebsd", target_os = "linux", + target_vendor = "apple", windows ))] #[pyattr] @@ -783,6 +907,21 @@ mod _socket { .map(|sock| sock as RawSocket) } + #[cfg(target_os = "linux")] + #[derive(FromArgs)] + struct SendmsgAfalgArgs { + #[pyarg(any, default)] + msg: Vec, + #[pyarg(named)] + op: u32, + #[pyarg(named, default)] + iv: Option, + #[pyarg(named, default)] + assoclen: OptionalArg, + #[pyarg(named, default)] + flags: i32, + } + #[pyattr(name = "socket")] #[pyattr(name = "SocketType")] #[pyclass(name = "socket")] @@ -849,10 +988,64 @@ mod _socket { sock: Socket, ) -> io::Result<()> { self.family.store(family); - self.kind.store(socket_kind); + // Mask out SOCK_NONBLOCK and SOCK_CLOEXEC flags from stored type + // to ensure consistent cross-platform behavior + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "redox" + ))] + let masked_kind = socket_kind & !(c::SOCK_NONBLOCK | c::SOCK_CLOEXEC); + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "redox" + )))] + let masked_kind = socket_kind; + self.kind.store(masked_kind); self.proto.store(proto); let mut s = self.sock.write(); let sock = s.insert(sock); + // If SOCK_NONBLOCK is set, use timeout 0 (non-blocking) + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "redox" + ))] + let timeout = if socket_kind & c::SOCK_NONBLOCK != 0 { + 0.0 + } else { + DEFAULT_TIMEOUT.load() + }; + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "redox" + )))] let timeout = DEFAULT_TIMEOUT.load(); self.timeout.store(timeout); if timeout >= 0.0 { @@ -996,6 +1189,129 @@ mod _socket { } Ok(addr6.into()) } + #[cfg(target_os = "linux")] + c::AF_CAN => { + let tuple: PyTupleRef = addr.downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_CAN address must be tuple, not {}", + caller, + obj.class().name() + )) + })?; + if tuple.is_empty() || tuple.len() > 2 { + return Err(vm + .new_type_error( + "AF_CAN address must be a tuple (interface,) or (interface, addr)", + ) + .into()); + } + let interface: PyStrRef = tuple[0].clone().downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_CAN interface must be str, not {}", + caller, + obj.class().name() + )) + })?; + let ifname = interface.as_str(); + + // Get interface index + let ifindex = if ifname.is_empty() { + 0 // Bind to all CAN interfaces + } else { + // Check interface name length (IFNAMSIZ is typically 16) + if ifname.len() >= 16 { + return Err(vm + .new_os_error("interface name too long".to_owned()) + .into()); + } + let cstr = std::ffi::CString::new(ifname) + .map_err(|_| vm.new_os_error("invalid interface name".to_owned()))?; + let idx = unsafe { libc::if_nametoindex(cstr.as_ptr()) }; + if idx == 0 { + return Err(io::Error::last_os_error().into()); + } + idx as i32 + }; + + // Create sockaddr_can + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + let can_addr = + &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_can; + unsafe { + (*can_addr).can_family = libc::AF_CAN as libc::sa_family_t; + (*can_addr).can_ifindex = ifindex; + } + let storage: socket2::SockAddrStorage = unsafe { std::mem::transmute(storage) }; + Ok(unsafe { + socket2::SockAddr::new( + storage, + std::mem::size_of::() as libc::socklen_t, + ) + }) + } + #[cfg(target_os = "linux")] + c::AF_ALG => { + let tuple: PyTupleRef = addr.downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_ALG address must be tuple, not {}", + caller, + obj.class().name() + )) + })?; + if tuple.len() != 2 { + return Err(vm + .new_type_error("AF_ALG address must be a tuple (type, name)") + .into()); + } + let alg_type: PyStrRef = tuple[0].clone().downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_ALG type must be str, not {}", + caller, + obj.class().name() + )) + })?; + let alg_name: PyStrRef = tuple[1].clone().downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_ALG name must be str, not {}", + caller, + obj.class().name() + )) + })?; + + let type_str = alg_type.as_str(); + let name_str = alg_name.as_str(); + + // salg_type is 14 bytes, salg_name is 64 bytes + if type_str.len() >= 14 { + return Err(vm.new_value_error("type too long".to_owned()).into()); + } + if name_str.len() >= 64 { + return Err(vm.new_value_error("name too long".to_owned()).into()); + } + + // Create sockaddr_alg + let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + let alg_addr = + &mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_alg; + unsafe { + (*alg_addr).salg_family = libc::AF_ALG as libc::sa_family_t; + // Copy type string + for (i, b) in type_str.bytes().enumerate() { + (*alg_addr).salg_type[i] = b; + } + // Copy name string + for (i, b) in name_str.bytes().enumerate() { + (*alg_addr).salg_name[i] = b; + } + } + let storage: socket2::SockAddrStorage = unsafe { std::mem::transmute(storage) }; + Ok(unsafe { + socket2::SockAddr::new( + storage, + std::mem::size_of::() as libc::socklen_t, + ) + }) + } _ => Err(vm.new_os_error(format!("{caller}(): bad family")).into()), } } @@ -1083,11 +1399,90 @@ mod _socket { let mut family = family.unwrap_or(-1); let mut socket_kind = socket_kind.unwrap_or(-1); let mut proto = proto.unwrap_or(-1); + + let sock; + + // On Windows, fileno can be bytes from socket.share() for fromshare() + #[cfg(windows)] + if let Some(fileno_obj) = fileno.flatten() { + use crate::vm::builtins::PyBytes; + if let Ok(bytes) = fileno_obj.clone().downcast::() { + let bytes_data = bytes.as_bytes(); + let expected_size = std::mem::size_of::(); + + if bytes_data.len() != expected_size { + return Err(vm + .new_value_error(format!( + "socket descriptor string has wrong size, should be {} bytes", + expected_size + )) + .into()); + } + + let mut info: c::WSAPROTOCOL_INFOW = unsafe { std::mem::zeroed() }; + unsafe { + std::ptr::copy_nonoverlapping( + bytes_data.as_ptr(), + &mut info as *mut c::WSAPROTOCOL_INFOW as *mut u8, + expected_size, + ); + } + + let fd = unsafe { + c::WSASocketW( + c::FROM_PROTOCOL_INFO, + c::FROM_PROTOCOL_INFO, + c::FROM_PROTOCOL_INFO, + &info, + 0, + c::WSA_FLAG_OVERLAPPED, + ) + }; + + if fd == c::INVALID_SOCKET { + return Err(Self::wsa_error().into()); + } + + crate::vm::stdlib::nt::raw_set_handle_inheritable(fd as _, false)?; + + family = info.iAddressFamily; + socket_kind = info.iSocketType; + proto = info.iProtocol; + + sock = unsafe { sock_from_raw_unchecked(fd as RawSocket) }; + return Ok(zelf.init_inner(family, socket_kind, proto, sock)?); + } + + // Not bytes, treat as regular fileno + let fileno = get_raw_sock(fileno_obj, vm)?; + sock = sock_from_raw(fileno, vm)?; + match sock.local_addr() { + Ok(addr) if family == -1 => family = addr.family() as i32, + Err(e) + if family == -1 + || matches!( + e.raw_os_error(), + Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF)) + ) => + { + std::mem::forget(sock); + return Err(e.into()); + } + _ => {} + } + if socket_kind == -1 { + socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into(); + } + proto = 0; + return Ok(zelf.init_inner(family, socket_kind, proto, sock)?); + } + + #[cfg(not(windows))] let fileno = fileno .flatten() .map(|obj| get_raw_sock(obj, vm)) .transpose()?; - let sock; + #[cfg(not(windows))] if let Some(fileno) = fileno { sock = sock_from_raw(fileno, vm)?; match sock.local_addr() { @@ -1121,7 +1516,11 @@ mod _socket { proto = 0; } } - } else { + return Ok(zelf.init_inner(family, socket_kind, proto, sock)?); + } + + // No fileno provided, create new socket + { if family == -1 { family = c::AF_INET as _ } @@ -1403,6 +1802,248 @@ mod _socket { .map_err(|e| e.into_pyexception(vm)) } + /// sendmsg_afalg([msg], *, op[, iv[, assoclen[, flags]]]) -> int + /// + /// Set operation mode and target IV for an AF_ALG socket. + #[cfg(target_os = "linux")] + #[pymethod] + fn sendmsg_afalg(&self, args: SendmsgAfalgArgs, vm: &VirtualMachine) -> PyResult { + let msg = args.msg; + let op = args.op; + let iv = args.iv; + let flags = args.flags; + + // Validate assoclen - must be non-negative if provided + let assoclen: Option = match args.assoclen { + OptionalArg::Present(val) if val < 0 => { + return Err(vm.new_type_error("assoclen must be non-negative".to_owned())); + } + OptionalArg::Present(val) => Some(val as u32), + OptionalArg::Missing => None, + }; + + // Build control messages for AF_ALG + let mut control_buf = Vec::new(); + + // Add ALG_SET_OP control message + { + let op_bytes = op.to_ne_bytes(); + let space = unsafe { libc::CMSG_SPACE(std::mem::size_of::() as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of::() as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_OP; + let data = libc::CMSG_DATA(cmsg); + std::ptr::copy_nonoverlapping(op_bytes.as_ptr(), data, op_bytes.len()); + } + } + + // Add ALG_SET_IV control message if iv is provided + if let Some(iv_data) = iv { + let iv_bytes = iv_data.borrow_buf(); + // struct af_alg_iv { __u32 ivlen; __u8 iv[]; } + let iv_struct_size = 4 + iv_bytes.len(); + let space = unsafe { libc::CMSG_SPACE(iv_struct_size as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(iv_struct_size as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_IV; + let data = libc::CMSG_DATA(cmsg); + // Write ivlen + let ivlen = (iv_bytes.len() as u32).to_ne_bytes(); + std::ptr::copy_nonoverlapping(ivlen.as_ptr(), data, 4); + // Write iv + std::ptr::copy_nonoverlapping(iv_bytes.as_ptr(), data.add(4), iv_bytes.len()); + } + } + + // Add ALG_SET_AEAD_ASSOCLEN control message if assoclen is provided + if let Some(assoclen_val) = assoclen { + let assoclen_bytes = assoclen_val.to_ne_bytes(); + let space = unsafe { libc::CMSG_SPACE(std::mem::size_of::() as u32) } as usize; + let old_len = control_buf.len(); + control_buf.resize(old_len + space, 0u8); + + let cmsg = control_buf[old_len..].as_mut_ptr() as *mut libc::cmsghdr; + unsafe { + (*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of::() as u32) as _; + (*cmsg).cmsg_level = libc::SOL_ALG; + (*cmsg).cmsg_type = libc::ALG_SET_AEAD_ASSOCLEN; + let data = libc::CMSG_DATA(cmsg); + std::ptr::copy_nonoverlapping( + assoclen_bytes.as_ptr(), + data, + assoclen_bytes.len(), + ); + } + } + + // Build buffers + let buffers = msg.iter().map(|buf| buf.borrow_buf()).collect::>(); + let iovecs: Vec = buffers + .iter() + .map(|buf| libc::iovec { + iov_base: buf.as_ptr() as *mut _, + iov_len: buf.len(), + }) + .collect(); + + // Set up msghdr + let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() }; + msghdr.msg_iov = iovecs.as_ptr() as *mut _; + msghdr.msg_iovlen = iovecs.len() as _; + if !control_buf.is_empty() { + msghdr.msg_control = control_buf.as_mut_ptr() as *mut _; + msghdr.msg_controllen = control_buf.len() as _; + } + + self.sock_op(vm, SelectKind::Write, || { + let sock = self.sock()?; + let fd = sock_fileno(&sock); + let ret = unsafe { libc::sendmsg(fd as libc::c_int, &msghdr, flags) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret as usize) + } + }) + .map_err(|e| e.into_pyexception(vm)) + } + + /// recvmsg(bufsize[, ancbufsize[, flags]]) -> (data, ancdata, msg_flags, address) + /// + /// Receive normal data and ancillary data from the socket. + #[cfg(all(unix, not(target_os = "redox")))] + #[pymethod] + fn recvmsg( + &self, + bufsize: isize, + ancbufsize: OptionalArg, + flags: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + use std::mem::MaybeUninit; + + if bufsize < 0 { + return Err(vm.new_value_error("negative buffer size in recvmsg".to_owned())); + } + let bufsize = bufsize as usize; + + let ancbufsize = ancbufsize.unwrap_or(0); + if ancbufsize < 0 { + return Err( + vm.new_value_error("negative ancillary buffer size in recvmsg".to_owned()) + ); + } + let ancbufsize = ancbufsize as usize; + let flags = flags.unwrap_or(0); + + // Allocate buffers + let mut data_buf: Vec> = vec![MaybeUninit::uninit(); bufsize]; + let mut anc_buf: Vec> = vec![MaybeUninit::uninit(); ancbufsize]; + let mut addr_storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; + + // Set up iovec + let mut iov = [libc::iovec { + iov_base: data_buf.as_mut_ptr().cast(), + iov_len: bufsize, + }]; + + // Set up msghdr + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + msg.msg_name = (&mut addr_storage as *mut libc::sockaddr_storage).cast(); + msg.msg_namelen = std::mem::size_of::() as libc::socklen_t; + msg.msg_iov = iov.as_mut_ptr(); + msg.msg_iovlen = 1; + if ancbufsize > 0 { + msg.msg_control = anc_buf.as_mut_ptr().cast(); + msg.msg_controllen = ancbufsize as _; + } + + let n = self + .sock_op(vm, SelectKind::Read, || { + let sock = self.sock()?; + let fd = sock_fileno(&sock); + let ret = unsafe { libc::recvmsg(fd as libc::c_int, &mut msg, flags) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret as usize) + } + }) + .map_err(|e| e.into_pyexception(vm))?; + + // Build data bytes + let data = unsafe { + data_buf.set_len(n); + std::mem::transmute::>, Vec>(data_buf) + }; + + // Build ancdata list + let ancdata = Self::parse_ancillary_data(&msg, vm)?; + + // Build address tuple + let address = if msg.msg_namelen > 0 { + let storage: socket2::SockAddrStorage = + unsafe { std::mem::transmute(addr_storage) }; + let addr = unsafe { socket2::SockAddr::new(storage, msg.msg_namelen) }; + get_addr_tuple(&addr, vm) + } else { + vm.ctx.none() + }; + + Ok(vm.ctx.new_tuple(vec![ + vm.ctx.new_bytes(data).into(), + ancdata, + vm.ctx.new_int(msg.msg_flags).into(), + address, + ])) + } + + /// Parse ancillary data from a received message header + #[cfg(all(unix, not(target_os = "redox")))] + fn parse_ancillary_data(msg: &libc::msghdr, vm: &VirtualMachine) -> PyResult { + let mut result = Vec::new(); + + // Calculate buffer end for truncation handling + let ctrl_buf = msg.msg_control as *const u8; + let ctrl_end = unsafe { ctrl_buf.add(msg.msg_controllen as _) }; + + let mut cmsg: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(msg) }; + while !cmsg.is_null() { + let cmsg_ref = unsafe { &*cmsg }; + let data_ptr = unsafe { libc::CMSG_DATA(cmsg) }; + + // Calculate data length, respecting buffer truncation + let data_len_from_cmsg = + cmsg_ref.cmsg_len as usize - (data_ptr as usize - cmsg as usize); + let available = ctrl_end as usize - data_ptr as usize; + let data_len = data_len_from_cmsg.min(available); + + let data = unsafe { std::slice::from_raw_parts(data_ptr, data_len) }; + + let tuple = vm.ctx.new_tuple(vec![ + vm.ctx.new_int(cmsg_ref.cmsg_level).into(), + vm.ctx.new_int(cmsg_ref.cmsg_type).into(), + vm.ctx.new_bytes(data.to_vec()).into(), + ]); + + result.push(tuple.into()); + + cmsg = unsafe { libc::CMSG_NXTHDR(msg, cmsg) }; + } + + Ok(vm.ctx.new_list(result).into()) + } + // based on nix's implementation #[cfg(all(unix, not(target_os = "redox")))] fn pack_cmsgs_to_send( @@ -1447,7 +2088,7 @@ mod _socket { unsafe { (*pmhdr).cmsg_level = *lvl; (*pmhdr).cmsg_type = *typ; - (*pmhdr).cmsg_len = data.len() as _; + (*pmhdr).cmsg_len = libc::CMSG_LEN(data.len() as _) as _; ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len()); } @@ -1467,6 +2108,38 @@ mod _socket { Ok(()) } + #[pymethod] + fn __del__(&self, vm: &VirtualMachine) { + // Emit ResourceWarning if socket is still open + if self.sock.read().is_some() { + let laddr = if let Ok(sock) = self.sock() + && let Ok(addr) = sock.local_addr() + && let Ok(repr) = get_addr_tuple(&addr, vm).repr(vm) + { + format!(", laddr={}", repr.as_str()) + } else { + String::new() + }; + + let msg = format!( + "unclosed ", + self.fileno(), + self.family.load(), + self.kind.load(), + self.proto.load(), + laddr + ); + let _ = crate::vm::warn::warn( + vm.ctx.new_str(msg), + Some(vm.ctx.exceptions.resource_warning.to_owned()), + 1, + None, + vm, + ); + } + let _ = self.close(); + } + #[pymethod] #[inline] fn detach(&self) -> i64 { @@ -1629,6 +2302,136 @@ mod _socket { Ok(self.sock()?.shutdown(how)?) } + #[cfg(windows)] + fn wsa_error() -> io::Error { + io::Error::from_raw_os_error(unsafe { c::WSAGetLastError() }) + } + + #[cfg(windows)] + #[pymethod] + fn ioctl( + &self, + cmd: PyObjectRef, + option: PyObjectRef, + vm: &VirtualMachine, + ) -> Result { + use crate::vm::builtins::PyInt; + use crate::vm::convert::TryFromObject; + + let sock = self.sock()?; + let fd = sock_fileno(&sock); + let mut recv: u32 = 0; + + // Convert cmd to u32, returning ValueError for invalid/negative values + let cmd_int = cmd + .downcast::() + .map_err(|_| vm.new_type_error("an integer is required"))?; + let cmd_val = cmd_int.as_bigint(); + let cmd: u32 = cmd_val + .to_u32() + .ok_or_else(|| vm.new_value_error(format!("invalid ioctl command {}", cmd_val)))?; + + match cmd { + c::SIO_RCVALL | c::SIO_LOOPBACK_FAST_PATH => { + // Option must be an integer, not None + if vm.is_none(&option) { + return Err(vm + .new_type_error("an integer is required (got type NoneType)") + .into()); + } + let option_val: u32 = TryFromObject::try_from_object(vm, option)?; + let ret = unsafe { + c::WSAIoctl( + fd as _, + cmd, + &option_val as *const u32 as *const _, + std::mem::size_of::() as u32, + std::ptr::null_mut(), + 0, + &mut recv, + std::ptr::null_mut(), + None, + ) + }; + if ret == c::SOCKET_ERROR { + return Err(Self::wsa_error().into()); + } + Ok(recv) + } + c::SIO_KEEPALIVE_VALS => { + let tuple: PyTupleRef = option + .downcast() + .map_err(|_| vm.new_type_error("SIO_KEEPALIVE_VALS requires a tuple"))?; + if tuple.len() != 3 { + return Err(vm + .new_type_error( + "SIO_KEEPALIVE_VALS requires (onoff, keepalivetime, keepaliveinterval)", + ) + .into()); + } + + #[repr(C)] + struct TcpKeepalive { + onoff: u32, + keepalivetime: u32, + keepaliveinterval: u32, + } + + let ka = TcpKeepalive { + onoff: TryFromObject::try_from_object(vm, tuple[0].clone())?, + keepalivetime: TryFromObject::try_from_object(vm, tuple[1].clone())?, + keepaliveinterval: TryFromObject::try_from_object(vm, tuple[2].clone())?, + }; + + let ret = unsafe { + c::WSAIoctl( + fd as _, + cmd, + &ka as *const TcpKeepalive as *const _, + std::mem::size_of::() as u32, + std::ptr::null_mut(), + 0, + &mut recv, + std::ptr::null_mut(), + None, + ) + }; + if ret == c::SOCKET_ERROR { + return Err(Self::wsa_error().into()); + } + Ok(recv) + } + _ => Err(vm + .new_value_error(format!("invalid ioctl command {}", cmd)) + .into()), + } + } + + #[cfg(windows)] + #[pymethod] + fn share(&self, process_id: u32, _vm: &VirtualMachine) -> Result, IoOrPyException> { + let sock = self.sock()?; + let fd = sock_fileno(&sock); + + let mut info: MaybeUninit = MaybeUninit::uninit(); + + let ret = unsafe { c::WSADuplicateSocketW(fd as _, process_id, info.as_mut_ptr()) }; + + if ret == c::SOCKET_ERROR { + return Err(Self::wsa_error().into()); + } + + let info = unsafe { info.assume_init() }; + let bytes = unsafe { + std::slice::from_raw_parts( + &info as *const c::WSAPROTOCOL_INFOW as *const u8, + std::mem::size_of::(), + ) + }; + + Ok(bytes.to_vec()) + } + #[pygetset(name = "type")] fn kind(&self) -> i32 { self.kind.load() @@ -1729,6 +2532,50 @@ mod _socket { let path = ffi::OsStr::from_bytes(&path[..nul_pos]); return vm.fsdecode(path).into(); } + #[cfg(target_os = "linux")] + { + let family = addr.family(); + if family == libc::AF_CAN as libc::sa_family_t { + // AF_CAN address: (interface_name,) or (interface_name, can_id) + let can_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_can) }; + let ifindex = can_addr.can_ifindex; + let ifname = if ifindex == 0 { + String::new() + } else { + let mut buf = [0u8; libc::IF_NAMESIZE]; + let ret = unsafe { + libc::if_indextoname( + ifindex as libc::c_uint, + buf.as_mut_ptr() as *mut libc::c_char, + ) + }; + if ret.is_null() { + String::new() + } else { + let nul_pos = memchr::memchr(b'\0', &buf).unwrap_or(buf.len()); + String::from_utf8_lossy(&buf[..nul_pos]).into_owned() + } + }; + return vm.ctx.new_tuple(vec![vm.ctx.new_str(ifname).into()]).into(); + } + if family == libc::AF_ALG as libc::sa_family_t { + // AF_ALG address: (type, name) + let alg_addr = unsafe { &*(addr.as_ptr() as *const libc::sockaddr_alg) }; + let type_bytes = &alg_addr.salg_type; + let name_bytes = &alg_addr.salg_name; + let type_nul = memchr::memchr(b'\0', type_bytes).unwrap_or(type_bytes.len()); + let name_nul = memchr::memchr(b'\0', name_bytes).unwrap_or(name_bytes.len()); + let type_str = String::from_utf8_lossy(&type_bytes[..type_nul]).into_owned(); + let name_str = String::from_utf8_lossy(&name_bytes[..name_nul]).into_owned(); + return vm + .ctx + .new_tuple(vec![ + vm.ctx.new_str(type_str).into(), + vm.ctx.new_str(name_str).into(), + ]) + .into(); + } + } // TODO: support more address families (String::new(), 0).to_pyobject(vm) } @@ -1951,13 +2798,31 @@ mod _socket { flags: opts.flags, }; - let host = opts.host.as_ref().map(|s| s.as_str()); - let port = opts.port.as_ref().map(|p| -> std::borrow::Cow<'_, str> { - match p { - Either::A(s) => s.as_str().into(), - Either::B(i) => i.to_string().into(), + // Encode host using IDNA encoding + let host_encoded: Option = match opts.host.as_ref() { + Some(s) => { + let encoded = + vm.state + .codec_registry + .encode_text(s.to_owned(), "idna", None, vm)?; + let host_str = std::str::from_utf8(encoded.as_bytes()) + .map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?; + Some(host_str.to_owned()) } - }); + None => None, + }; + let host = host_encoded.as_deref(); + + // Encode port using UTF-8 + let port: Option> = match opts.port.as_ref() { + Some(Either::A(s)) => { + Some(std::borrow::Cow::Borrowed(s.to_str().ok_or_else(|| { + vm.new_unicode_encode_error("surrogates not allowed".to_owned()) + })?)) + } + Some(Either::B(i)) => Some(std::borrow::Cow::Owned(i.to_string())), + None => None, + }; let port = port.as_ref().map(|p| p.as_ref()); let addrs = dns_lookup::getaddrinfo(host, port, Some(hints))