extmod/modtls_mbedtls: Add a thread-global ptr for current SSL context.

This is necessary for mbedTLS callbacks that do not carry any user state,
so those callbacks can be customised per SSL context.

Signed-off-by: iabdalkader <i.abdalkader@gmail.com>
This commit is contained in:
iabdalkader 2024-10-16 14:08:43 +02:00 committed by Damien George
parent 09ea901317
commit 2644f577f1
3 changed files with 28 additions and 0 deletions

View File

@ -166,6 +166,13 @@ static NORETURN void mbedtls_raise_error(int err) {
#endif #endif
} }
// Stores the current SSLContext for use in mbedtls callbacks where the current state is not passed.
static inline void store_active_context(mp_obj_ssl_context_t *ssl_context) {
#if MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT
MP_STATE_THREAD(tls_ssl_context) = ssl_context;
#endif
}
static void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int *errcode) { static void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int *errcode) {
if ( if (
#if MBEDTLS_VERSION_NUMBER >= 0x03000000 #if MBEDTLS_VERSION_NUMBER >= 0x03000000
@ -497,6 +504,9 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
// Store the current SSL context.
store_active_context(ssl_context);
// Verify the socket object has the full stream protocol // Verify the socket object has the full stream protocol
mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL);
@ -602,6 +612,9 @@ static mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
return MP_STREAM_ERROR; return MP_STREAM_ERROR;
} }
// Store the current SSL context.
store_active_context(o->ssl_context);
int ret = mbedtls_ssl_read(&o->ssl, buf, size); int ret = mbedtls_ssl_read(&o->ssl, buf, size);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
// end of stream // end of stream
@ -643,6 +656,9 @@ static mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
return MP_STREAM_ERROR; return MP_STREAM_ERROR;
} }
// Store the current SSL context.
store_active_context(o->ssl_context);
int ret = mbedtls_ssl_write(&o->ssl, buf, size); int ret = mbedtls_ssl_write(&o->ssl, buf, size);
if (ret >= 0) { if (ret >= 0) {
return ret; return ret;
@ -680,6 +696,9 @@ static mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
mp_obj_t sock = self->sock; mp_obj_t sock = self->sock;
if (request == MP_STREAM_CLOSE) { if (request == MP_STREAM_CLOSE) {
// Clear the SSL context.
store_active_context(NULL);
if (sock == MP_OBJ_NULL) { if (sock == MP_OBJ_NULL) {
// Already closed socket, do nothing. // Already closed socket, do nothing.
return 0; return 0;

View File

@ -1814,6 +1814,11 @@ typedef double mp_float_t;
#define MICROPY_PY_SSL_FINALISER (MICROPY_ENABLE_FINALISER) #define MICROPY_PY_SSL_FINALISER (MICROPY_ENABLE_FINALISER)
#endif #endif
// Whether to add a root pointer for the current ssl object
#ifndef MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT
#define MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT (MICROPY_PY_SSL_ECDSA_SIGN_ALT)
#endif
// Whether to provide the "vfs" module // Whether to provide the "vfs" module
#ifndef MICROPY_PY_VFS #ifndef MICROPY_PY_VFS
#define MICROPY_PY_VFS (MICROPY_CONFIG_ROM_LEVEL_AT_LEAST_CORE_FEATURES && MICROPY_VFS) #define MICROPY_PY_VFS (MICROPY_CONFIG_ROM_LEVEL_AT_LEAST_CORE_FEATURES && MICROPY_VFS)

View File

@ -293,6 +293,10 @@ typedef struct _mp_state_thread_t {
bool prof_callback_is_executing; bool prof_callback_is_executing;
struct _mp_code_state_t *current_code_state; struct _mp_code_state_t *current_code_state;
#endif #endif
#if MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT
struct _mp_obj_ssl_context_t *tls_ssl_context;
#endif
} mp_state_thread_t; } mp_state_thread_t;
// This structure combines the above 3 structures. // This structure combines the above 3 structures.