Skip to content

Commit

Permalink
feat(tls_cxx): Add support for DTLS
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cermak committed Feb 9, 2024
1 parent a637bfc commit 5cd7cb0
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 4 deletions.
11 changes: 11 additions & 0 deletions components/mbedtls_cxx/examples/udp_mutual_auth/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# For more information about build system see
# https://docs.espressif.com/projects/esp-idf/en/latest/api-guides/build-system.html
# The following five lines of boilerplate have to be in your project's
# CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16)

include($ENV{IDF_PATH}/tools/cmake/project.cmake)
if("${IDF_TARGET}" STREQUAL "linux")
list(APPEND EXTRA_COMPONENT_DIRS "$ENV{IDF_PATH}/tools/mocks/freertos/")
endif()
project(udp_mutual)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
idf_component_register(SRCS "udp_mutual.cpp"
INCLUDE_DIRS ".")
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dependencies:
idf: ">=5.0"
espressif/mbedtls_cxx:
version: "*"
override_path: "../../.."
protocol_examples_common:
path: ${IDF_PATH}/examples/common_components/protocol_examples_common
253 changes: 253 additions & 0 deletions components/mbedtls_cxx/examples/udp_mutual_auth/main/udp_mutual.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
/*
* SPDX-FileCopyrightText: 2024 Espressif Systems (Shanghai) CO LTD
*
* SPDX-License-Identifier: Unlicense OR CC0-1.0
*/
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#include "esp_log.h"
#include "mbedtls_wrap.hpp"

static auto const *TAG = "simple_udp_example";

const unsigned char servercert[] = "-----BEGIN CERTIFICATE-----\n"
"MIIDKzCCAhOgAwIBAgIUBxM3WJf2bP12kAfqhmhhjZWv0ukwDQYJKoZIhvcNAQEL\n"
"BQAwJTEjMCEGA1UEAwwaRVNQMzIgSFRUUFMgc2VydmVyIGV4YW1wbGUwHhcNMTgx\n"
"MDE3MTEzMjU3WhcNMjgxMDE0MTEzMjU3WjAlMSMwIQYDVQQDDBpFU1AzMiBIVFRQ\n"
"UyBzZXJ2ZXIgZXhhbXBsZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB\n"
"ALBint6nP77RCQcmKgwPtTsGK0uClxg+LwKJ3WXuye3oqnnjqJCwMEneXzGdG09T\n"
"sA0SyNPwrEgebLCH80an3gWU4pHDdqGHfJQa2jBL290e/5L5MB+6PTs2NKcojK/k\n"
"qcZkn58MWXhDW1NpAnJtjVniK2Ksvr/YIYSbyD+JiEs0MGxEx+kOl9d7hRHJaIzd\n"
"GF/vO2pl295v1qXekAlkgNMtYIVAjUy9CMpqaQBCQRL+BmPSJRkXBsYk8GPnieS4\n"
"sUsp53DsNvCCtWDT6fd9D1v+BB6nDk/FCPKhtjYOwOAZlX4wWNSZpRNr5dfrxKsb\n"
"jAn4PCuR2akdF4G8WLUeDWECAwEAAaNTMFEwHQYDVR0OBBYEFMnmdJKOEepXrHI/\n"
"ivM6mVqJgAX8MB8GA1UdIwQYMBaAFMnmdJKOEepXrHI/ivM6mVqJgAX8MA8GA1Ud\n"
"EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADiXIGEkSsN0SLSfCF1VNWO3\n"
"emBurfOcDq4EGEaxRKAU0814VEmU87btIDx80+z5Dbf+GGHCPrY7odIkxGNn0DJY\n"
"W1WcF+DOcbiWoUN6DTkAML0SMnp8aGj9ffx3x+qoggT+vGdWVVA4pgwqZT7Ybntx\n"
"bkzcNFW0sqmCv4IN1t4w6L0A87ZwsNwVpre/j6uyBw7s8YoJHDLRFT6g7qgn0tcN\n"
"ZufhNISvgWCVJQy/SZjNBHSpnIdCUSJAeTY2mkM4sGxY0Widk8LnjydxZUSxC3Nl\n"
"hb6pnMh3jRq4h0+5CZielA4/a+TdrNPv/qok67ot/XJdY3qHCCd8O2b14OVq9jo=\n"
"-----END CERTIFICATE-----";
const unsigned char prvtkey[] = "-----BEGIN PRIVATE KEY-----\n"
"MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCwYp7epz++0QkH\n"
"JioMD7U7BitLgpcYPi8Cid1l7snt6Kp546iQsDBJ3l8xnRtPU7ANEsjT8KxIHmyw\n"
"h/NGp94FlOKRw3ahh3yUGtowS9vdHv+S+TAfuj07NjSnKIyv5KnGZJ+fDFl4Q1tT\n"
"aQJybY1Z4itirL6/2CGEm8g/iYhLNDBsRMfpDpfXe4URyWiM3Rhf7ztqZdveb9al\n"
"3pAJZIDTLWCFQI1MvQjKamkAQkES/gZj0iUZFwbGJPBj54nkuLFLKedw7DbwgrVg\n"
"0+n3fQ9b/gQepw5PxQjyobY2DsDgGZV+MFjUmaUTa+XX68SrG4wJ+DwrkdmpHReB\n"
"vFi1Hg1hAgMBAAECggEAaTCnZkl/7qBjLexIryC/CBBJyaJ70W1kQ7NMYfniWwui\n"
"f0aRxJgOdD81rjTvkINsPp+xPRQO6oOadjzdjImYEuQTqrJTEUnntbu924eh+2D9\n"
"Mf2CAanj0mglRnscS9mmljZ0KzoGMX6Z/EhnuS40WiJTlWlH6MlQU/FDnwC6U34y\n"
"JKy6/jGryfsx+kGU/NRvKSru6JYJWt5v7sOrymHWD62IT59h3blOiP8GMtYKeQlX\n"
"49om9Mo1VTIFASY3lrxmexbY+6FG8YO+tfIe0tTAiGrkb9Pz6tYbaj9FjEWOv4Vc\n"
"+3VMBUVdGJjgqvE8fx+/+mHo4Rg69BUPfPSrpEg7sQKBgQDlL85G04VZgrNZgOx6\n"
"pTlCCl/NkfNb1OYa0BELqWINoWaWQHnm6lX8YjrUjwRpBF5s7mFhguFjUjp/NW6D\n"
"0EEg5BmO0ePJ3dLKSeOA7gMo7y7kAcD/YGToqAaGljkBI+IAWK5Su5yldrECTQKG\n"
"YnMKyQ1MWUfCYEwHtPvFvE5aPwKBgQDFBWXekpxHIvt/B41Cl/TftAzE7/f58JjV\n"
"MFo/JCh9TDcH6N5TMTRS1/iQrv5M6kJSSrHnq8pqDXOwfHLwxetpk9tr937VRzoL\n"
"CuG1Ar7c1AO6ujNnAEmUVC2DppL/ck5mRPWK/kgLwZSaNcZf8sydRgphsW1ogJin\n"
"7g0nGbFwXwKBgQCPoZY07Pr1TeP4g8OwWTu5F6dSvdU2CAbtZthH5q98u1n/cAj1\n"
"noak1Srpa3foGMTUn9CHu+5kwHPIpUPNeAZZBpq91uxa5pnkDMp3UrLIRJ2uZyr8\n"
"4PxcknEEh8DR5hsM/IbDcrCJQglM19ZtQeW3LKkY4BsIxjDf45ymH407IQKBgE/g\n"
"Ul6cPfOxQRlNLH4VMVgInSyyxWx1mODFy7DRrgCuh5kTVh+QUVBM8x9lcwAn8V9/\n"
"nQT55wR8E603pznqY/jX0xvAqZE6YVPcw4kpZcwNwL1RhEl8GliikBlRzUL3SsW3\n"
"q30AfqEViHPE3XpE66PPo6Hb1ymJCVr77iUuC3wtAoGBAIBrOGunv1qZMfqmwAY2\n"
"lxlzRgxgSiaev0lTNxDzZkmU/u3dgdTwJ5DDANqPwJc6b8SGYTp9rQ0mbgVHnhIB\n"
"jcJQBQkTfq6Z0H6OoTVi7dPs3ibQJFrtkoyvYAbyk36quBmNRjVh6rc8468bhXYr\n"
"v/t+MeGJP/0Zw8v/X2CFll96\n"
"-----END PRIVATE KEY-----";


class SecureLink: public Tls {
public:
explicit SecureLink() : Tls(), addr("localhost", 3333, AF_INET, SOCK_DGRAM) {}
~SecureLink() override
{
if (sock >= 0) {
::close(sock);
}
}
int send(const unsigned char *buf, size_t len) override
{
return sendto(sock, buf, len, 0, addr, ai_size);
}
int recv(unsigned char *buf, size_t len) override
{
socklen_t socklen = sizeof(sockaddr);
return recvfrom(sock, buf, len, 0, addr, &socklen);
}
int recv_tout(unsigned char *buf, size_t len, int timeout) override
{
struct timeval tv {
timeout / 1000, (timeout % 1000 ) * 1000
};
fd_set read_fds;
FD_ZERO( &read_fds );
FD_SET( sock, &read_fds );

int ret = select(sock + 1, &read_fds, nullptr, nullptr, timeout == 0 ? nullptr : &tv);
if (ret == 0) {
return MBEDTLS_ERR_SSL_TIMEOUT;
}
if (ret < 0) {
if (errno == EINTR) {
return MBEDTLS_ERR_SSL_WANT_READ;
}
return ret;
}
return recv(buf, len);
}
bool open(bool server_not_client)
{
if (!addr) {
ESP_LOGE(TAG, "Failed to resolve endpoint");
return false;
}
sock = addr.get_sock();
if (sock < 0) {
ESP_LOGE(TAG, "Failed to create socket");
return false;
}
if (server_not_client) {
int err = bind(sock, addr, ai_size);
if (err < 0) {
ESP_LOGE(TAG, "Socket unable to bind: errno %d", errno);
return false;
}
}
if (!init(is_server{server_not_client}, do_verify{false})) {
return false;
}

return handshake() == 0;
}

private:
int sock{-1};
/**
* RAII wrapper of the address_info
*/
struct addr_info {
struct addrinfo *ai = nullptr;
explicit addr_info(const char *host, int port, int family, int type)
{
struct addrinfo hints {};
hints.ai_family = family;
hints.ai_socktype = type;
if (getaddrinfo(host, nullptr, &hints, &ai) < 0) {
freeaddrinfo(ai);
ai = nullptr;
}
auto *p = (struct sockaddr_in *)ai->ai_addr;
p->sin_port = htons(port);
}
~addr_info()
{
freeaddrinfo(ai);
}
explicit operator bool() const
{
return ai != nullptr;
}
operator sockaddr *() const
{
auto *p = (struct sockaddr_in *)ai->ai_addr;
return (struct sockaddr *)p;
}

int get_sock() const
{
return socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
}
} addr;
const int ai_size{sizeof(struct sockaddr_in)};
};

static void tls_client()
{
const unsigned char message[] = "Hello\n";
unsigned char reply[128];
SecureLink client;
if (!client.open(false)) {
ESP_LOGE(TAG, "Failed to CONNECT! %d", errno);
return;
}
ESP_LOGI(TAG, "client opened...");
if (client.write(message, sizeof(message)) < 0) {
ESP_LOGE(TAG, "Failed to write!");
return;
}
int len = client.read(reply, sizeof(reply));
if (len < 0) {
ESP_LOGE(TAG, "Failed to read!");
return;
}
ESP_LOGI(TAG, "Successfully received: %.*s", len, reply);
}

static void tls_server()
{
unsigned char message[128];
SecureLink server;
const_buf cert{servercert, sizeof(servercert)};
const_buf key{prvtkey, sizeof(prvtkey)};
if (!server.set_own_cert(cert, key)) {
ESP_LOGE(TAG, "Failed to set own cert");
return;
}
ESP_LOGI(TAG, "openning...");
if (!server.open(true)) {
ESP_LOGE(TAG, "Failed to OPEN! %d", errno);
return;
}
int len = server.read(message, sizeof(message));
if (len < 0) {
ESP_LOGE(TAG, "Failed to read!");
return;
}
ESP_LOGI(TAG, "Received from client: %.*s", len, message);
if (server.write(message, len) < 0) {
ESP_LOGE(TAG, "Failed to write!");
return;
}
ESP_LOGI(TAG, "Written back");
}


#if CONFIG_IDF_TARGET_LINUX
/**
* Linux target: We're already connected, just run the client
*/
#include <thread>
int main()
{
std::thread t2(tls_server);
usleep(1000);
std::thread t1(tls_client);
t1.join();
t2.join();
return 0;
}
#else
/**
* ESP32 chipsets: Need to initialize system components
* and connect to network
*/

#include "nvs_flash.h"
#include "esp_event.h"
#include "protocol_examples_common.h"
#include "esp_netif.h"

extern "C" void app_main()
{
ESP_ERROR_CHECK(nvs_flash_init());
ESP_ERROR_CHECK(esp_netif_init());
ESP_ERROR_CHECK(esp_event_loop_create_default());
ESP_ERROR_CHECK(example_connect());

tls_client();
}
#endif
10 changes: 9 additions & 1 deletion components/mbedtls_cxx/include/mbedtls_wrap.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
/*
* SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
* SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD
*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include <utility>
#include <memory>
#include <mbedtls/timing.h>
#include <mbedtls/ssl_cookie.h>
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
Expand All @@ -23,6 +25,7 @@ class Tls {
Tls();
virtual ~Tls();
bool init(is_server server, do_verify verify);
bool init_dtls();
bool deinit();
int handshake();
int write(const unsigned char *buf, size_t len);
Expand All @@ -32,6 +35,7 @@ class Tls {
bool set_hostname(const char *name);
virtual int send(const unsigned char *buf, size_t len) = 0;
virtual int recv(unsigned char *buf, size_t len) = 0;
virtual int recv_tout(unsigned char *buf, size_t len, int timeout) = 0;
size_t get_available_bytes();

protected:
Expand All @@ -42,7 +46,10 @@ class Tls {
mbedtls_ssl_config conf_{};
mbedtls_ctr_drbg_context ctr_drbg_{};
mbedtls_entropy_context entropy_{};
mbedtls_timing_delay_context timer_{};
mbedtls_ssl_cookie_ctx cookie_{};
virtual void delay() {}
bool is_server_{false};

bool set_session();
bool get_session();
Expand All @@ -53,6 +60,7 @@ class Tls {
static void print_error(const char *function, int error_code);
static int bio_write(void *ctx, const unsigned char *buf, size_t len);
static int bio_read(void *ctx, unsigned char *buf, size_t len);
static int bio_read_tout(void *ctx, unsigned char *buf, size_t len, uint32_t timeout);
int mbedtls_pk_parse_key( mbedtls_pk_context *ctx,
const unsigned char *key, size_t keylen,
const unsigned char *pwd, size_t pwdlen);
Expand Down
Loading

0 comments on commit 5cd7cb0

Please sign in to comment.