extmod/modlwip: Make socket.connect raise ETIMEDOUT on non-zero timeout.

If the socket timeout is 0 then a failed socket.connect() raises
EINPROGRESS (which is what the lwIP bindings already did), but if the
socket timeout is non-zero then a failed socket.connect() should raise
ETIMEDOUT.  The latter is fixed in this commit.

A test is added for these timeout cases.

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George 2024-06-07 13:33:47 +10:00
parent 80a4f632ee
commit df0d7e9429
2 changed files with 40 additions and 12 deletions

View File

@ -326,6 +326,10 @@ typedef struct _lwip_socket_obj_t {
int8_t state; int8_t state;
} lwip_socket_obj_t; } lwip_socket_obj_t;
static inline bool socket_is_timedout(lwip_socket_obj_t *socket, mp_uint_t ticks_start) {
return socket->timeout != -1 && (mp_uint_t)(mp_hal_ticks_ms() - ticks_start) >= socket->timeout;
}
static inline void poll_sockets(void) { static inline void poll_sockets(void) {
mp_event_wait_ms(1); mp_event_wait_ms(1);
} }
@ -1130,21 +1134,21 @@ static mp_obj_t lwip_socket_connect(mp_obj_t self_in, mp_obj_t addr_in) {
MICROPY_PY_LWIP_EXIT MICROPY_PY_LWIP_EXIT
// And now we wait... // And now we wait...
if (socket->timeout != -1) { mp_uint_t ticks_start = mp_hal_ticks_ms();
for (mp_uint_t retries = socket->timeout / 100; retries--;) { for (;;) {
mp_hal_delay_ms(100); poll_sockets();
if (socket->state != STATE_CONNECTING) { if (socket->state != STATE_CONNECTING) {
break; break;
}
if (socket_is_timedout(socket, ticks_start)) {
if (socket->timeout == 0) {
mp_raise_OSError(MP_EINPROGRESS);
} else {
mp_raise_OSError(MP_ETIMEDOUT);
} }
} }
if (socket->state == STATE_CONNECTING) {
mp_raise_OSError(MP_EINPROGRESS);
}
} else {
while (socket->state == STATE_CONNECTING) {
poll_sockets();
}
} }
if (socket->state == STATE_CONNECTED) { if (socket->state == STATE_CONNECTED) {
err = ERR_OK; err = ERR_OK;
} else { } else {

View File

@ -0,0 +1,24 @@
# Test that socket.connect() on a socket with timeout raises EINPROGRESS or ETIMEDOUT appropriately.
import errno
import socket
def test(peer_addr, timeout, expected_exc):
s = socket.socket()
s.settimeout(timeout)
try:
s.connect(peer_addr)
print("OK")
except OSError as er:
print(er.args[0] in expected_exc)
s.close()
if __name__ == "__main__":
# This test needs an address that doesn't respond to TCP connections.
# 1.1.1.1:8000 seem to reliably timeout, so use that.
addr = socket.getaddrinfo("1.1.1.1", 8000)[0][-1]
test(addr, 0, (errno.EINPROGRESS,))
test(addr, 1, (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno