extmod/modssl: Add SSLContext class.
This commit adds the SSLContext class to the ssl module, and retains the
existing ssl.wrap_socket() function to maintain backwards compatibility.
CPython deprecated the ssl.wrap_socket() function since CPython 3.7 and
instead one should use ssl.SSLContext().wrap_socket(). This commit makes
that possible.
For the axtls implementation:
- ssl.SSLContext is added, although it doesn't hold much state because
axtls requires calling ssl_ctx_new() for each new socket
- ssl.SSLContext.wrap_socket() is added
- ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER are added
For the mbedtls implementation:
- ssl.SSLContext is added, and holds most of the mbedtls state
- ssl.verify_mode is added (getter and setter)
- ssl.SSLContext.wrap_socket() is added
- ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER are added
The signatures match CPython:
- SSLContext(protocol)
- SSLContext.wrap_socket(sock, *, server_side=False,
do_handshake_on_connect=True, server_hostname=None)
The existing ssl.wrap_socket() functions retain their existing signature.
Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
parent
c2ea8b2f98
commit
e8a4c1dd53
@ -4,6 +4,7 @@
|
|||||||
* The MIT License (MIT)
|
* The MIT License (MIT)
|
||||||
*
|
*
|
||||||
* Copyright (c) 2015-2019 Paul Sokolovsky
|
* Copyright (c) 2015-2019 Paul Sokolovsky
|
||||||
|
* Copyright (c) 2023 Damien P. George
|
||||||
*
|
*
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
* of this software and associated documentation files (the "Software"), to deal
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
@ -35,6 +36,17 @@
|
|||||||
|
|
||||||
#include "ssl.h"
|
#include "ssl.h"
|
||||||
|
|
||||||
|
#define PROTOCOL_TLS_CLIENT (0)
|
||||||
|
#define PROTOCOL_TLS_SERVER (1)
|
||||||
|
|
||||||
|
// This corresponds to an SSLContext object.
|
||||||
|
typedef struct _mp_obj_ssl_context_t {
|
||||||
|
mp_obj_base_t base;
|
||||||
|
mp_obj_t key;
|
||||||
|
mp_obj_t cert;
|
||||||
|
} mp_obj_ssl_context_t;
|
||||||
|
|
||||||
|
// This corresponds to an SSLSocket object.
|
||||||
typedef struct _mp_obj_ssl_socket_t {
|
typedef struct _mp_obj_ssl_socket_t {
|
||||||
mp_obj_base_t base;
|
mp_obj_base_t base;
|
||||||
mp_obj_t sock;
|
mp_obj_t sock;
|
||||||
@ -53,8 +65,15 @@ struct ssl_args {
|
|||||||
mp_arg_val_t do_handshake;
|
mp_arg_val_t do_handshake;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
STATIC const mp_obj_type_t ssl_context_type;
|
||||||
STATIC const mp_obj_type_t ssl_socket_type;
|
STATIC const mp_obj_type_t ssl_socket_type;
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
|
||||||
|
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// Helper functions.
|
||||||
|
|
||||||
// Table of error strings corresponding to SSL_xxx error codes.
|
// Table of error strings corresponding to SSL_xxx error codes.
|
||||||
STATIC const char *const ssl_error_tab1[] = {
|
STATIC const char *const ssl_error_tab1[] = {
|
||||||
"NOT_OK",
|
"NOT_OK",
|
||||||
@ -116,8 +135,71 @@ STATIC NORETURN void ssl_raise_error(int err) {
|
|||||||
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
|
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// SSLContext type.
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
|
||||||
|
mp_arg_check_num(n_args, n_kw, 1, 1, false);
|
||||||
|
|
||||||
|
// The "protocol" argument is ignored in this implementation.
|
||||||
|
|
||||||
|
// Create SSLContext object.
|
||||||
|
#if MICROPY_PY_SSL_FINALISER
|
||||||
|
mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t);
|
||||||
|
#else
|
||||||
|
mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t);
|
||||||
|
#endif
|
||||||
|
self->base.type = type_in;
|
||||||
|
self->key = mp_const_none;
|
||||||
|
self->cert = mp_const_none;
|
||||||
|
|
||||||
|
return MP_OBJ_FROM_PTR(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) {
|
||||||
|
self->key = key_obj;
|
||||||
|
self->cert = cert_obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||||
|
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
|
||||||
|
static const mp_arg_t allowed_args[] = {
|
||||||
|
{ MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
|
||||||
|
{ MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
||||||
|
{ MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse arguments.
|
||||||
|
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]);
|
||||||
|
mp_obj_t sock = pos_args[1];
|
||||||
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||||
|
|
||||||
|
// Create and return the new SSLSocket object.
|
||||||
|
return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
|
||||||
|
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
|
||||||
|
|
||||||
|
STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
|
||||||
|
};
|
||||||
|
STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);
|
||||||
|
|
||||||
|
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
||||||
|
ssl_context_type,
|
||||||
|
MP_QSTR_SSLContext,
|
||||||
|
MP_TYPE_FLAG_NONE,
|
||||||
|
make_new, ssl_context_make_new,
|
||||||
|
locals_dict, &ssl_context_locals_dict
|
||||||
|
);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// SSLSocket type.
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
|
||||||
|
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
|
||||||
|
|
||||||
STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|
||||||
#if MICROPY_PY_SSL_FINALISER
|
#if MICROPY_PY_SSL_FINALISER
|
||||||
mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t);
|
mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t);
|
||||||
#else
|
#else
|
||||||
@ -130,43 +212,43 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
|
|||||||
o->blocking = true;
|
o->blocking = true;
|
||||||
|
|
||||||
uint32_t options = SSL_SERVER_VERIFY_LATER;
|
uint32_t options = SSL_SERVER_VERIFY_LATER;
|
||||||
if (!args->do_handshake.u_bool) {
|
if (!do_handshake_on_connect) {
|
||||||
options |= SSL_CONNECT_IN_PARTS;
|
options |= SSL_CONNECT_IN_PARTS;
|
||||||
}
|
}
|
||||||
if (args->key.u_obj != mp_const_none) {
|
if (ssl_context->key != mp_const_none) {
|
||||||
options |= SSL_NO_DEFAULT_KEY;
|
options |= SSL_NO_DEFAULT_KEY;
|
||||||
}
|
}
|
||||||
if ((o->ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_CLNT_SESS)) == NULL) {
|
if ((o->ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_CLNT_SESS)) == NULL) {
|
||||||
mp_raise_OSError(MP_EINVAL);
|
mp_raise_OSError(MP_EINVAL);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args->key.u_obj != mp_const_none) {
|
if (ssl_context->key != mp_const_none) {
|
||||||
size_t len;
|
size_t len;
|
||||||
const byte *data = (const byte *)mp_obj_str_get_data(args->key.u_obj, &len);
|
const byte *data = (const byte *)mp_obj_str_get_data(ssl_context->key, &len);
|
||||||
int res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_RSA_KEY, data, len, NULL);
|
int res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_RSA_KEY, data, len, NULL);
|
||||||
if (res != SSL_OK) {
|
if (res != SSL_OK) {
|
||||||
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
||||||
}
|
}
|
||||||
|
|
||||||
data = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &len);
|
data = (const byte *)mp_obj_str_get_data(ssl_context->cert, &len);
|
||||||
res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_X509_CERT, data, len, NULL);
|
res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_X509_CERT, data, len, NULL);
|
||||||
if (res != SSL_OK) {
|
if (res != SSL_OK) {
|
||||||
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args->server_side.u_bool) {
|
if (server_side) {
|
||||||
o->ssl_sock = ssl_server_new(o->ssl_ctx, (long)sock);
|
o->ssl_sock = ssl_server_new(o->ssl_ctx, (long)sock);
|
||||||
} else {
|
} else {
|
||||||
SSL_EXTENSIONS *ext = ssl_ext_new();
|
SSL_EXTENSIONS *ext = ssl_ext_new();
|
||||||
|
|
||||||
if (args->server_hostname.u_obj != mp_const_none) {
|
if (server_hostname != mp_const_none) {
|
||||||
ext->host_name = (char *)mp_obj_str_get_str(args->server_hostname.u_obj);
|
ext->host_name = (char *)mp_obj_str_get_str(server_hostname);
|
||||||
}
|
}
|
||||||
|
|
||||||
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
|
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
|
||||||
|
|
||||||
if (args->do_handshake.u_bool) {
|
if (do_handshake_on_connect) {
|
||||||
int r = ssl_handshake_status(o->ssl_sock);
|
int r = ssl_handshake_status(o->ssl_sock);
|
||||||
|
|
||||||
if (r != SSL_OK) {
|
if (r != SSL_OK) {
|
||||||
@ -178,18 +260,11 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
|
|||||||
ssl_raise_error(r);
|
ssl_raise_error(r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return o;
|
return o;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATIC void ssl_socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
|
|
||||||
(void)kind;
|
|
||||||
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
|
|
||||||
mp_printf(print, "<_SSLSocket %p>", self->ssl_sock);
|
|
||||||
}
|
|
||||||
|
|
||||||
STATIC mp_uint_t ssl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
|
STATIC mp_uint_t ssl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
|
||||||
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
|
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
|
||||||
|
|
||||||
@ -305,7 +380,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
|
|||||||
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
|
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
|
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
|
||||||
|
|
||||||
STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
||||||
@ -316,16 +390,23 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
|||||||
|
|
||||||
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
||||||
ssl_socket_type,
|
ssl_socket_type,
|
||||||
// Save on qstr's, reuse same as for module
|
MP_QSTR_SSLSocket,
|
||||||
MP_QSTR_ssl,
|
|
||||||
MP_TYPE_FLAG_NONE,
|
MP_TYPE_FLAG_NONE,
|
||||||
print, ssl_socket_print,
|
|
||||||
protocol, &ssl_socket_stream_p,
|
protocol, &ssl_socket_stream_p,
|
||||||
locals_dict, &ssl_socket_locals_dict
|
locals_dict, &ssl_socket_locals_dict
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// ssl module.
|
||||||
|
|
||||||
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||||
// TODO: Implement more args
|
enum {
|
||||||
|
ARG_key,
|
||||||
|
ARG_cert,
|
||||||
|
ARG_server_side,
|
||||||
|
ARG_server_hostname,
|
||||||
|
ARG_do_handshake,
|
||||||
|
};
|
||||||
static const mp_arg_t allowed_args[] = {
|
static const mp_arg_t allowed_args[] = {
|
||||||
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
@ -334,22 +415,40 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
|
|||||||
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: Check that sock implements stream protocol
|
// Parse arguments.
|
||||||
mp_obj_t sock = pos_args[0];
|
mp_obj_t sock = pos_args[0];
|
||||||
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||||
|
|
||||||
struct ssl_args args;
|
// Create SSLContext.
|
||||||
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args,
|
mp_int_t protocol = args[ARG_server_side].u_bool ? PROTOCOL_TLS_SERVER : PROTOCOL_TLS_CLIENT;
|
||||||
MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args);
|
mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) };
|
||||||
|
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args));
|
||||||
|
|
||||||
return MP_OBJ_FROM_PTR(ssl_socket_new(sock, &args));
|
// Load key and cert if given.
|
||||||
|
if (args[ARG_key].u_obj != mp_const_none) {
|
||||||
|
ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and return the new SSLSocket object.
|
||||||
|
return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool,
|
||||||
|
args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj);
|
||||||
}
|
}
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
|
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
|
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
|
||||||
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
|
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
|
||||||
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
|
|
||||||
};
|
|
||||||
|
|
||||||
|
// Functions.
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
|
||||||
|
|
||||||
|
// Classes.
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
|
||||||
|
|
||||||
|
// Constants.
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(PROTOCOL_TLS_CLIENT) },
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(PROTOCOL_TLS_SERVER) },
|
||||||
|
};
|
||||||
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
|
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
|
||||||
|
|
||||||
const mp_obj_module_t mp_module_ssl = {
|
const mp_obj_module_t mp_module_ssl = {
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
*
|
*
|
||||||
* Copyright (c) 2016 Linaro Ltd.
|
* Copyright (c) 2016 Linaro Ltd.
|
||||||
* Copyright (c) 2019 Paul Sokolovsky
|
* Copyright (c) 2019 Paul Sokolovsky
|
||||||
|
* Copyright (c) 2023 Damien P. George
|
||||||
*
|
*
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
* of this software and associated documentation files (the "Software"), to deal
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
@ -48,33 +49,38 @@
|
|||||||
|
|
||||||
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
|
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
|
||||||
|
|
||||||
typedef struct _mp_obj_ssl_socket_t {
|
// This corresponds to an SSLContext object.
|
||||||
|
typedef struct _mp_obj_ssl_context_t {
|
||||||
mp_obj_base_t base;
|
mp_obj_base_t base;
|
||||||
mp_obj_t sock;
|
|
||||||
mbedtls_entropy_context entropy;
|
mbedtls_entropy_context entropy;
|
||||||
mbedtls_ctr_drbg_context ctr_drbg;
|
mbedtls_ctr_drbg_context ctr_drbg;
|
||||||
mbedtls_ssl_context ssl;
|
|
||||||
mbedtls_ssl_config conf;
|
mbedtls_ssl_config conf;
|
||||||
mbedtls_x509_crt cacert;
|
mbedtls_x509_crt cacert;
|
||||||
mbedtls_x509_crt cert;
|
mbedtls_x509_crt cert;
|
||||||
mbedtls_pk_context pkey;
|
mbedtls_pk_context pkey;
|
||||||
|
int authmode;
|
||||||
|
} mp_obj_ssl_context_t;
|
||||||
|
|
||||||
|
// This corresponds to an SSLSocket object.
|
||||||
|
typedef struct _mp_obj_ssl_socket_t {
|
||||||
|
mp_obj_base_t base;
|
||||||
|
mp_obj_ssl_context_t *ssl_context;
|
||||||
|
mp_obj_t sock;
|
||||||
|
mbedtls_ssl_context ssl;
|
||||||
|
|
||||||
uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
|
uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
|
||||||
int last_error; // The last error code, if any
|
int last_error; // The last error code, if any
|
||||||
} mp_obj_ssl_socket_t;
|
} mp_obj_ssl_socket_t;
|
||||||
|
|
||||||
struct ssl_args {
|
STATIC const mp_obj_type_t ssl_context_type;
|
||||||
mp_arg_val_t key;
|
|
||||||
mp_arg_val_t cert;
|
|
||||||
mp_arg_val_t server_side;
|
|
||||||
mp_arg_val_t server_hostname;
|
|
||||||
mp_arg_val_t cert_reqs;
|
|
||||||
mp_arg_val_t cadata;
|
|
||||||
mp_arg_val_t do_handshake;
|
|
||||||
};
|
|
||||||
|
|
||||||
STATIC const mp_obj_type_t ssl_socket_type;
|
STATIC const mp_obj_type_t ssl_socket_type;
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
|
||||||
|
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// Helper functions.
|
||||||
|
|
||||||
#ifdef MBEDTLS_DEBUG_C
|
#ifdef MBEDTLS_DEBUG_C
|
||||||
STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) {
|
STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) {
|
||||||
(void)ctx;
|
(void)ctx;
|
||||||
@ -84,6 +90,15 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
STATIC NORETURN void mbedtls_raise_error(int err) {
|
STATIC NORETURN void mbedtls_raise_error(int err) {
|
||||||
|
// Handle special cases.
|
||||||
|
if (err == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
|
||||||
|
mp_raise_OSError(MP_ENOMEM);
|
||||||
|
} else if (err == MBEDTLS_ERR_PK_BAD_INPUT_DATA) {
|
||||||
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
||||||
|
} else if (err == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
|
||||||
|
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
||||||
|
}
|
||||||
|
|
||||||
// _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
|
// _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
|
||||||
// underlying socket into negative codes to pass them through mbedtls. Here we turn them
|
// underlying socket into negative codes to pass them through mbedtls. Here we turn them
|
||||||
// positive again so they get interpreted as the OSError they really are. The
|
// positive again so they get interpreted as the OSError they really are. The
|
||||||
@ -123,6 +138,178 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// SSLContext type.
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
|
||||||
|
mp_arg_check_num(n_args, n_kw, 1, 1, false);
|
||||||
|
|
||||||
|
// This is the "protocol" argument.
|
||||||
|
mp_int_t endpoint = mp_obj_get_int(args[0]);
|
||||||
|
|
||||||
|
// Create SSLContext object.
|
||||||
|
#if MICROPY_PY_SSL_FINALISER
|
||||||
|
mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t);
|
||||||
|
#else
|
||||||
|
mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t);
|
||||||
|
#endif
|
||||||
|
self->base.type = type_in;
|
||||||
|
|
||||||
|
// Initialise mbedTLS state.
|
||||||
|
mbedtls_ssl_config_init(&self->conf);
|
||||||
|
mbedtls_entropy_init(&self->entropy);
|
||||||
|
mbedtls_ctr_drbg_init(&self->ctr_drbg);
|
||||||
|
mbedtls_x509_crt_init(&self->cacert);
|
||||||
|
mbedtls_x509_crt_init(&self->cert);
|
||||||
|
mbedtls_pk_init(&self->pkey);
|
||||||
|
|
||||||
|
#ifdef MBEDTLS_DEBUG_C
|
||||||
|
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
|
||||||
|
mbedtls_debug_set_threshold(3);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const byte seed[] = "upy";
|
||||||
|
int ret = mbedtls_ctr_drbg_seed(&self->ctr_drbg, mbedtls_entropy_func, &self->entropy, seed, sizeof(seed));
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = mbedtls_ssl_config_defaults(&self->conf, endpoint,
|
||||||
|
MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (endpoint == MBEDTLS_SSL_IS_CLIENT) {
|
||||||
|
// The CPython default is MBEDTLS_SSL_VERIFY_REQUIRED, but to maintain
|
||||||
|
// backwards compatibility we use MBEDTLS_SSL_VERIFY_NONE for now.
|
||||||
|
self->authmode = MBEDTLS_SSL_VERIFY_NONE;
|
||||||
|
} else {
|
||||||
|
self->authmode = MBEDTLS_SSL_VERIFY_NONE;
|
||||||
|
}
|
||||||
|
mbedtls_ssl_conf_authmode(&self->conf, self->authmode);
|
||||||
|
mbedtls_ssl_conf_rng(&self->conf, mbedtls_ctr_drbg_random, &self->ctr_drbg);
|
||||||
|
#ifdef MBEDTLS_DEBUG_C
|
||||||
|
mbedtls_ssl_conf_dbg(&self->conf, mbedtls_debug, NULL);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return MP_OBJ_FROM_PTR(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
STATIC void ssl_context_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
|
||||||
|
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in);
|
||||||
|
if (dest[0] == MP_OBJ_NULL) {
|
||||||
|
// Load attribute.
|
||||||
|
if (attr == MP_QSTR_verify_mode) {
|
||||||
|
dest[0] = MP_OBJ_NEW_SMALL_INT(self->authmode);
|
||||||
|
} else {
|
||||||
|
// Continue lookup in locals_dict.
|
||||||
|
dest[1] = MP_OBJ_SENTINEL;
|
||||||
|
}
|
||||||
|
} else if (dest[1] != MP_OBJ_NULL) {
|
||||||
|
// Store attribute.
|
||||||
|
if (attr == MP_QSTR_verify_mode) {
|
||||||
|
self->authmode = mp_obj_get_int(dest[1]);
|
||||||
|
dest[0] = MP_OBJ_NULL;
|
||||||
|
mbedtls_ssl_conf_authmode(&self->conf, self->authmode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if MICROPY_PY_SSL_FINALISER
|
||||||
|
STATIC mp_obj_t ssl_context___del__(mp_obj_t self_in) {
|
||||||
|
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in);
|
||||||
|
mbedtls_pk_free(&self->pkey);
|
||||||
|
mbedtls_x509_crt_free(&self->cert);
|
||||||
|
mbedtls_x509_crt_free(&self->cacert);
|
||||||
|
mbedtls_ctr_drbg_free(&self->ctr_drbg);
|
||||||
|
mbedtls_entropy_free(&self->entropy);
|
||||||
|
mbedtls_ssl_config_free(&self->conf);
|
||||||
|
return mp_const_none;
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_context___del___obj, ssl_context___del__);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) {
|
||||||
|
size_t key_len;
|
||||||
|
const byte *key = (const byte *)mp_obj_str_get_data(key_obj, &key_len);
|
||||||
|
// len should include terminating null
|
||||||
|
int ret;
|
||||||
|
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
|
||||||
|
ret = mbedtls_pk_parse_key(&self->pkey, key, key_len + 1, NULL, 0, mbedtls_ctr_drbg_random, &self->ctr_drbg);
|
||||||
|
#else
|
||||||
|
ret = mbedtls_pk_parse_key(&self->pkey, key, key_len + 1, NULL, 0);
|
||||||
|
#endif
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(MBEDTLS_ERR_PK_BAD_INPUT_DATA); // use general error for all key errors
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cert_len;
|
||||||
|
const byte *cert = (const byte *)mp_obj_str_get_data(cert_obj, &cert_len);
|
||||||
|
// len should include terminating null
|
||||||
|
ret = mbedtls_x509_crt_parse(&self->cert, cert, cert_len + 1);
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(MBEDTLS_ERR_X509_BAD_INPUT_DATA); // use general error for all cert errors
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = mbedtls_ssl_conf_own_cert(&self->conf, &self->cert, &self->pkey);
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
STATIC void ssl_context_load_cadata(mp_obj_ssl_context_t *self, mp_obj_t cadata_obj) {
|
||||||
|
size_t cacert_len;
|
||||||
|
const byte *cacert = (const byte *)mp_obj_str_get_data(cadata_obj, &cacert_len);
|
||||||
|
// len should include terminating null
|
||||||
|
int ret = mbedtls_x509_crt_parse(&self->cacert, cacert, cacert_len + 1);
|
||||||
|
if (ret != 0) {
|
||||||
|
mbedtls_raise_error(MBEDTLS_ERR_X509_BAD_INPUT_DATA); // use general error for all cert errors
|
||||||
|
}
|
||||||
|
|
||||||
|
mbedtls_ssl_conf_ca_chain(&self->conf, &self->cacert, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||||
|
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
|
||||||
|
static const mp_arg_t allowed_args[] = {
|
||||||
|
{ MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
|
||||||
|
{ MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
||||||
|
{ MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse arguments.
|
||||||
|
mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]);
|
||||||
|
mp_obj_t sock = pos_args[1];
|
||||||
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||||
|
|
||||||
|
// Create and return the new SSLSocket object.
|
||||||
|
return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
|
||||||
|
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
|
||||||
|
|
||||||
|
STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
|
||||||
|
#if MICROPY_PY_SSL_FINALISER
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&ssl_context___del___obj) },
|
||||||
|
#endif
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
|
||||||
|
};
|
||||||
|
STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);
|
||||||
|
|
||||||
|
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
||||||
|
ssl_context_type,
|
||||||
|
MP_QSTR_SSLContext,
|
||||||
|
MP_TYPE_FLAG_NONE,
|
||||||
|
make_new, ssl_context_make_new,
|
||||||
|
attr, ssl_context_attr,
|
||||||
|
locals_dict, &ssl_context_locals_dict
|
||||||
|
);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// SSLSocket type.
|
||||||
|
|
||||||
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
|
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
|
||||||
mp_obj_t sock = *(mp_obj_t *)ctx;
|
mp_obj_t sock = *(mp_obj_t *)ctx;
|
||||||
|
|
||||||
@ -158,8 +345,9 @@ STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
|
||||||
|
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
|
||||||
|
|
||||||
STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|
||||||
// Verify the socket object has the full stream protocol
|
// Verify the socket object has the full stream protocol
|
||||||
mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL);
|
mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL);
|
||||||
|
|
||||||
@ -175,44 +363,14 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|||||||
|
|
||||||
int ret;
|
int ret;
|
||||||
mbedtls_ssl_init(&o->ssl);
|
mbedtls_ssl_init(&o->ssl);
|
||||||
mbedtls_ssl_config_init(&o->conf);
|
|
||||||
mbedtls_x509_crt_init(&o->cacert);
|
|
||||||
mbedtls_x509_crt_init(&o->cert);
|
|
||||||
mbedtls_pk_init(&o->pkey);
|
|
||||||
mbedtls_ctr_drbg_init(&o->ctr_drbg);
|
|
||||||
#ifdef MBEDTLS_DEBUG_C
|
|
||||||
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
|
|
||||||
mbedtls_debug_set_threshold(3);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
mbedtls_entropy_init(&o->entropy);
|
ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf);
|
||||||
const byte seed[] = "upy";
|
|
||||||
ret = mbedtls_ctr_drbg_seed(&o->ctr_drbg, mbedtls_entropy_func, &o->entropy, seed, sizeof(seed));
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = mbedtls_ssl_config_defaults(&o->conf,
|
if (server_hostname != mp_const_none) {
|
||||||
args->server_side.u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT,
|
const char *sni = mp_obj_str_get_str(server_hostname);
|
||||||
MBEDTLS_SSL_TRANSPORT_STREAM,
|
|
||||||
MBEDTLS_SSL_PRESET_DEFAULT);
|
|
||||||
if (ret != 0) {
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
|
|
||||||
mbedtls_ssl_conf_authmode(&o->conf, args->cert_reqs.u_int);
|
|
||||||
mbedtls_ssl_conf_rng(&o->conf, mbedtls_ctr_drbg_random, &o->ctr_drbg);
|
|
||||||
#ifdef MBEDTLS_DEBUG_C
|
|
||||||
mbedtls_ssl_conf_dbg(&o->conf, mbedtls_debug, NULL);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ret = mbedtls_ssl_setup(&o->ssl, &o->conf);
|
|
||||||
if (ret != 0) {
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args->server_hostname.u_obj != mp_const_none) {
|
|
||||||
const char *sni = mp_obj_str_get_str(args->server_hostname.u_obj);
|
|
||||||
ret = mbedtls_ssl_set_hostname(&o->ssl, sni);
|
ret = mbedtls_ssl_set_hostname(&o->ssl, sni);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
@ -221,49 +379,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|||||||
|
|
||||||
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
|
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
|
||||||
|
|
||||||
if (args->key.u_obj != mp_const_none) {
|
if (do_handshake_on_connect) {
|
||||||
size_t key_len;
|
|
||||||
const byte *key = (const byte *)mp_obj_str_get_data(args->key.u_obj, &key_len);
|
|
||||||
// len should include terminating null
|
|
||||||
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
|
|
||||||
ret = mbedtls_pk_parse_key(&o->pkey, key, key_len + 1, NULL, 0, mbedtls_ctr_drbg_random, &o->ctr_drbg);
|
|
||||||
#else
|
|
||||||
ret = mbedtls_pk_parse_key(&o->pkey, key, key_len + 1, NULL, 0);
|
|
||||||
#endif
|
|
||||||
if (ret != 0) {
|
|
||||||
ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA; // use general error for all key errors
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t cert_len;
|
|
||||||
const byte *cert = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &cert_len);
|
|
||||||
// len should include terminating null
|
|
||||||
ret = mbedtls_x509_crt_parse(&o->cert, cert, cert_len + 1);
|
|
||||||
if (ret != 0) {
|
|
||||||
ret = MBEDTLS_ERR_X509_BAD_INPUT_DATA; // use general error for all cert errors
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = mbedtls_ssl_conf_own_cert(&o->conf, &o->cert, &o->pkey);
|
|
||||||
if (ret != 0) {
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args->cadata.u_obj != mp_const_none) {
|
|
||||||
size_t cacert_len;
|
|
||||||
const byte *cacert = (const byte *)mp_obj_str_get_data(args->cadata.u_obj, &cacert_len);
|
|
||||||
// len should include terminating null
|
|
||||||
ret = mbedtls_x509_crt_parse(&o->cacert, cacert, cacert_len + 1);
|
|
||||||
if (ret != 0) {
|
|
||||||
ret = MBEDTLS_ERR_X509_BAD_INPUT_DATA; // use general error for all cert errors
|
|
||||||
goto cleanup;
|
|
||||||
}
|
|
||||||
|
|
||||||
mbedtls_ssl_conf_ca_chain(&o->conf, &o->cacert, NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args->do_handshake.u_bool) {
|
|
||||||
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
|
while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
|
||||||
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||||
goto cleanup;
|
goto cleanup;
|
||||||
@ -274,26 +390,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return o;
|
return MP_OBJ_FROM_PTR(o);
|
||||||
|
|
||||||
cleanup:
|
cleanup:
|
||||||
mbedtls_pk_free(&o->pkey);
|
|
||||||
mbedtls_x509_crt_free(&o->cert);
|
|
||||||
mbedtls_x509_crt_free(&o->cacert);
|
|
||||||
mbedtls_ssl_free(&o->ssl);
|
mbedtls_ssl_free(&o->ssl);
|
||||||
mbedtls_ssl_config_free(&o->conf);
|
mbedtls_raise_error(ret);
|
||||||
mbedtls_ctr_drbg_free(&o->ctr_drbg);
|
|
||||||
mbedtls_entropy_free(&o->entropy);
|
|
||||||
|
|
||||||
if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) {
|
|
||||||
mp_raise_OSError(MP_ENOMEM);
|
|
||||||
} else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA) {
|
|
||||||
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
|
|
||||||
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) {
|
|
||||||
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
|
|
||||||
} else {
|
|
||||||
mbedtls_raise_error(ret);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
|
STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
|
||||||
@ -309,12 +410,6 @@ STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
|
|||||||
}
|
}
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert);
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert);
|
||||||
|
|
||||||
STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
|
|
||||||
(void)kind;
|
|
||||||
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
|
|
||||||
mp_printf(print, "<_SSLSocket %p>", self);
|
|
||||||
}
|
|
||||||
|
|
||||||
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
|
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
|
||||||
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
|
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
|
||||||
o->poll_mask = 0;
|
o->poll_mask = 0;
|
||||||
@ -397,13 +492,7 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
|
|||||||
|
|
||||||
if (request == MP_STREAM_CLOSE) {
|
if (request == MP_STREAM_CLOSE) {
|
||||||
self->sock = MP_OBJ_NULL;
|
self->sock = MP_OBJ_NULL;
|
||||||
mbedtls_pk_free(&self->pkey);
|
|
||||||
mbedtls_x509_crt_free(&self->cert);
|
|
||||||
mbedtls_x509_crt_free(&self->cacert);
|
|
||||||
mbedtls_ssl_free(&self->ssl);
|
mbedtls_ssl_free(&self->ssl);
|
||||||
mbedtls_ssl_config_free(&self->conf);
|
|
||||||
mbedtls_ctr_drbg_free(&self->ctr_drbg);
|
|
||||||
mbedtls_entropy_free(&self->entropy);
|
|
||||||
} else if (request == MP_STREAM_POLL) {
|
} else if (request == MP_STREAM_POLL) {
|
||||||
// If the library signaled us that it needs reading or writing, only check that direction,
|
// If the library signaled us that it needs reading or writing, only check that direction,
|
||||||
// but save what the caller asked because we need to restore it later
|
// but save what the caller asked because we need to restore it later
|
||||||
@ -454,7 +543,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
|
|||||||
#endif
|
#endif
|
||||||
{ MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
|
{ MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
|
||||||
};
|
};
|
||||||
|
|
||||||
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
|
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
|
||||||
|
|
||||||
STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
||||||
@ -465,16 +553,25 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = {
|
|||||||
|
|
||||||
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
STATIC MP_DEFINE_CONST_OBJ_TYPE(
|
||||||
ssl_socket_type,
|
ssl_socket_type,
|
||||||
// Save on qstr's, reuse same as for module
|
MP_QSTR_SSLSocket,
|
||||||
MP_QSTR_ssl,
|
|
||||||
MP_TYPE_FLAG_NONE,
|
MP_TYPE_FLAG_NONE,
|
||||||
print, socket_print,
|
|
||||||
protocol, &ssl_socket_stream_p,
|
protocol, &ssl_socket_stream_p,
|
||||||
locals_dict, &ssl_socket_locals_dict
|
locals_dict, &ssl_socket_locals_dict
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
// ssl module.
|
||||||
|
|
||||||
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||||
// TODO: Implement more args
|
enum {
|
||||||
|
ARG_key,
|
||||||
|
ARG_cert,
|
||||||
|
ARG_server_side,
|
||||||
|
ARG_server_hostname,
|
||||||
|
ARG_cert_reqs,
|
||||||
|
ARG_cadata,
|
||||||
|
ARG_do_handshake,
|
||||||
|
};
|
||||||
static const mp_arg_t allowed_args[] = {
|
static const mp_arg_t allowed_args[] = {
|
||||||
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
|
||||||
@ -485,25 +582,52 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
|
|||||||
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: Check that sock implements stream protocol
|
// Parse arguments.
|
||||||
mp_obj_t sock = pos_args[0];
|
mp_obj_t sock = pos_args[0];
|
||||||
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||||
|
|
||||||
struct ssl_args args;
|
// Create SSLContext.
|
||||||
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args,
|
mp_int_t protocol = args[ARG_server_side].u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
|
||||||
MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args);
|
mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) };
|
||||||
|
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args));
|
||||||
|
|
||||||
return MP_OBJ_FROM_PTR(socket_new(sock, &args));
|
// Load key and cert if given.
|
||||||
|
if (args[ARG_key].u_obj != mp_const_none) {
|
||||||
|
ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the verify_mode.
|
||||||
|
mp_obj_t dest[2] = { MP_OBJ_SENTINEL, MP_OBJ_NEW_SMALL_INT(args[ARG_cert_reqs].u_int) };
|
||||||
|
ssl_context_attr(MP_OBJ_FROM_PTR(ssl_context), MP_QSTR_verify_mode, dest);
|
||||||
|
|
||||||
|
// Load cadata if given.
|
||||||
|
if (args[ARG_cadata].u_obj != mp_const_none) {
|
||||||
|
ssl_context_load_cadata(ssl_context, args[ARG_cadata].u_obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and return the new SSLSocket object.
|
||||||
|
return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool,
|
||||||
|
args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj);
|
||||||
}
|
}
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
|
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
|
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
|
||||||
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
|
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
|
||||||
|
|
||||||
|
// Functions.
|
||||||
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
|
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
|
||||||
|
|
||||||
|
// Classes.
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
|
||||||
|
|
||||||
|
// Constants.
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MBEDTLS_SSL_IS_CLIENT) },
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MBEDTLS_SSL_IS_SERVER) },
|
||||||
{ MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) },
|
{ MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) },
|
||||||
{ MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) },
|
{ MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) },
|
||||||
{ MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) },
|
{ MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) },
|
||||||
};
|
};
|
||||||
|
|
||||||
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
|
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
|
||||||
|
|
||||||
const mp_obj_module_t mp_module_ssl = {
|
const mp_obj_module_t mp_module_ssl = {
|
||||||
|
|||||||
@ -33,7 +33,7 @@ except OSError as er:
|
|||||||
ss = ssl.wrap_socket(TestSocket(), server_side=1, do_handshake=0)
|
ss = ssl.wrap_socket(TestSocket(), server_side=1, do_handshake=0)
|
||||||
|
|
||||||
# print
|
# print
|
||||||
print(repr(ss)[:12])
|
print(ss)
|
||||||
|
|
||||||
# setblocking() propagates call to the underlying stream object
|
# setblocking() propagates call to the underlying stream object
|
||||||
ss.setblocking(False)
|
ss.setblocking(False)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
OSError: client
|
OSError: client
|
||||||
<_SSLSocket
|
<SSLSocket>
|
||||||
TestSocket.setblocking(False)
|
TestSocket.setblocking(False)
|
||||||
TestSocket.setblocking(True)
|
TestSocket.setblocking(True)
|
||||||
TestSocket.ioctl 4 0
|
TestSocket.ioctl 4 0
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import ssl
|
|||||||
|
|
||||||
# CPython only supports server_hostname with SSLContext
|
# CPython only supports server_hostname with SSLContext
|
||||||
if hasattr(ssl, "SSLContext"):
|
if hasattr(ssl, "SSLContext"):
|
||||||
ssl = ssl.SSLContext()
|
ssl = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
def test_one(site, opts):
|
def test_one(site, opts):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user