mirror of https://github.com/yuzu-mirror/yuzu
Implement SSL service
This implements some missing network APIs including a large chunk of the SSL service, enough for Mario Maker (with an appropriate mod applied) to connect to the fan server [Open Course World](https://opencourse.world/). Connecting to first-party servers is out of scope of this PR and is a minefield I'd rather not step into. ## TLS TLS is implemented with multiple backends depending on the system's 'native' TLS library. Currently there are two backends: Schannel for Windows, and OpenSSL for Linux. (In reality Linux is a bit of a free-for-all where there's no one 'native' library, but OpenSSL is the closest it gets.) On macOS the 'native' library is SecureTransport but that isn't implemented in this PR. (Instead, all non-Windows OSes will use OpenSSL unless disabled with `-DENABLE_OPENSSL=OFF`.) Why have multiple backends instead of just using a single library, especially given that Yuzu already embeds mbedtls for cryptographic algorithms? Well, I tried implementing this on mbedtls first, but the problem is TLS policies - mainly trusted certificate policies, and to a lesser extent trusted algorithms, SSL versions, etc. ...In practice, the chance that someone is going to conduct a man-in-the-middle attack on a third-party game server is pretty low, but I'm a security nerd so I like to do the right security things. My base assumption is that we want to use the host system's TLS policies. An alternative would be to more closely emulate the Switch's TLS implementation (which is based on NSS). But for one thing, I don't feel like reverse engineering it. And I'd argue that for third-party servers such as Open Course World, it's theoretically preferable to use the system's policies rather than the Switch's, for two reasons 1. Someday the Switch will stop being updated, and the trusted cert list, algorithms, etc. will start to go stale, but users will still want to connect to third-party servers, and there's no reason they shouldn't have up-to-date security when doing so. At that point, homebrew users on actual hardware may patch the TLS implementation, but for emulators it's simpler to just use the host's stack. 2. Also, it's good to respect any custom certificate policies the user may have added systemwide. For example, they may have added custom trusted CAs in order to use TLS debugging tools or pass through corporate MitM middleboxes. Or they may have removed some CAs that are normally trusted out of paranoia. Note that this policy wouldn't work as-is for connecting to first-party servers, because some of them serve certificates based on Nintendo's own CA rather than a publicly trusted one. However, this could probably be solved easily by using appropriate APIs to adding Nintendo's CA as an alternate trusted cert for Yuzu's connections. That is not implemented in this PR because, again, first-party servers are out of scope. (If anything I'd rather have an option to _block_ connections to Nintendo servers, but that's not implemented here.) To use the host's TLS policies, there are three theoretical options: a) Import the host's trusted certificate list into a cross-platform TLS library (presumably mbedtls). b) Use the native TLS library to verify certificates but use a cross-platform TLS library for everything else. c) Use the native TLS library for everything. Two problems with option a). First, importing the trusted certificate list at minimum requires a bunch of platform-specific code, which mbedtls does not have built in. Interestingly, OpenSSL recently gained the ability to import the Windows certificate trust store... but that leads to the second problem, which is that a list of trusted certificates is [not expressive enough](https://bugs.archlinux.org/task/41909) to express a modern certificate trust policy. For example, Windows has the concept of [explicitly distrusted certificates](https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/dn265983(v=ws.11)), and macOS requires Certificate Transparency validation for some certificates with complex rules for when it's required. Option b) (using native library just to verify certs) is probably feasible, but it would miss aspects of TLS policy other than trusted certs (like allowed algorithms), and in any case it might well require writing more code, not less, compared to using the native library for everything. So I ended up at option c), using the native library for everything. What I'd *really* prefer would be to use a third-party library that does option c) for me. Rust has a good library for this, [native-tls](https://docs.rs/native-tls/latest/native_tls/). I did search, but I couldn't find a good option in the C or C++ ecosystem, at least not any that wasn't part of some much larger framework. I was surprised - isn't this a pretty common use case? Well, many applications only need TLS for HTTPS, and they can use libcurl, which has a TLS abstraction layer internally but doesn't expose it. Other applications only support a single TLS library, or use one of the aforementioned larger frameworks, or are platform-specific to begin with, or of course are written in a non-C/C++ language, most of which have some canonical choice for TLS. But there are also many applications that have a set of TLS backends just like this; it's just that nobody has gone ahead and abstracted the pattern into a library, at least not a widespread one. Amusingly, there is one TLS abstraction layer that Yuzu already bundles: the one in ffmpeg. But it is missing some features that would be needed to use it here (like reusing an existing socket rather than managing the socket itself). Though, that does mean that the wiki's build instructions for Linux (and macOS for some reason?) already recommend installing OpenSSL, so no need to update those. ## Other APIs implemented - Sockets: - GetSockOpt(`SO_ERROR`) - SetSockOpt(`SO_NOSIGPIPE`) (stub, I have no idea what this does on Switch) - `DuplicateSocket` (because the SSL sysmodule calls it internally) - More `PollEvents` values - NSD: - `Resolve` and `ResolveEx` (stub, good enough for Open Course World and probably most third-party servers, but not first-party) - SFDNSRES: - `GetHostByNameRequest` and `GetHostByNameRequestWithOptions` - `ResolverSetOptionRequest` (stub) ## Fixes - Parts of the socket code were previously allocating a `sockaddr` object on the stack when calling functions that take a `sockaddr*` (e.g. `accept`). This might seem like the right thing to do to avoid illegal aliasing, but in fact `sockaddr` is not guaranteed to be large enough to hold any particular type of address, only the header. This worked in practice because in practice `sockaddr` is the same size as `sockaddr_in`, but it's not how the API is meant to be used. I changed this to allocate an `sockaddr_in` on the stack and `reinterpret_cast` it. I could try to do something cleverer with `aligned_storage`, but casting is the idiomatic way to use these particular APIs, so it's really the system's responsibility to avoid any aliasing issues. - I rewrote most of the `GetAddrInfoRequest[WithOptions]` implementation. The old implementation invoked the host's getaddrinfo directly from sfdnsres.cpp, and directly passed through the host's socket type, protocol, etc. values rather than looking up the corresponding constants on the Switch. To be fair, these constants don't tend to actually vary across systems, but still... I added a wrapper for `getaddrinfo` in `internal_network/network.cpp` similar to the ones for other socket APIs, and changed the `GetAddrInfoRequest` implementation to use it. While I was at it, I rewrote the serialization to use the same approach I used to implement `GetHostByNameRequest`, because it reduces the number of size calculations. While doing so I removed `AF_INET6` support because the Switch doesn't support IPv6; it might be nice to support IPv6 anyway, but that would have to apply to all of the socket APIs. I also corrected the IPC wrappers for `GetAddrInfoRequest` and `GetAddrInfoRequestWithOptions` based on reverse engineering and hardware testing. Every call to `GetAddrInfoRequestWithOptions` returns *four* different error codes (IPC status, getaddrinfo error code, netdb error code, and errno), and `GetAddrInfoRequest` returns three of those but in a different order, and it doesn't really matter but the existing implementation was a bit off, as I discovered while testing `GetHostByNameRequest`. - The new serialization code is based on two simple helper functions: ```cpp template <typename T> static void Append(std::vector<u8>& vec, T t); void AppendNulTerminated(std::vector<u8>& vec, std::string_view str); ``` I was thinking there must be existing functions somewhere that assist with serialization/deserialization of binary data, but all I could find was the helper methods in `IOFile` and `HLERequestContext`, not anything that could be used with a generic byte buffer. If I'm not missing something, then maybe I should move the above functions to a new header in `common`... right now they're just sitting in `sfdnsres.cpp` where they're used. - Not a fix, but `SocketBase::Recv`/`Send` is changed to use `std::span<u8>` rather than `std::vector<u8>&` to avoid needing to copy the data to/from a vector when those methods are called from the TLS implementation.pull/8/head
parent
ce191ba32b
commit
8e703e08df
@ -0,0 +1,44 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/hle/result.h"
|
||||
|
||||
#include "common/common_types.h"
|
||||
|
||||
#include <memory>
|
||||
#include <span>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace Network {
|
||||
class SocketBase;
|
||||
}
|
||||
|
||||
namespace Service::SSL {
|
||||
|
||||
constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
|
||||
constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
|
||||
constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
|
||||
constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
|
||||
|
||||
constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
|
||||
// ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
|
||||
// with no way in the latter case to distinguish whether the client should poll
|
||||
// for read or write. The one official client I've seen handles this by always
|
||||
// polling for read (with a timeout).
|
||||
|
||||
class SSLConnectionBackend {
|
||||
public:
|
||||
virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
|
||||
virtual Result SetHostName(const std::string& hostname) = 0;
|
||||
virtual Result DoHandshake() = 0;
|
||||
virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
|
||||
virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
|
||||
virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
|
||||
};
|
||||
|
||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
|
||||
|
||||
} // namespace Service::SSL
|
@ -0,0 +1,15 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
|
||||
#include "common/logging/log.h"
|
||||
|
||||
namespace Service::SSL {
|
||||
|
||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
||||
LOG_ERROR(Service_SSL, "No SSL backend on this platform");
|
||||
return ResultInternalError;
|
||||
}
|
||||
|
||||
} // namespace Service::SSL
|
@ -0,0 +1,342 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
#include "core/internal_network/network.h"
|
||||
#include "core/internal_network/sockets.h"
|
||||
|
||||
#include "common/fs/file.h"
|
||||
#include "common/hex_util.h"
|
||||
#include "common/string_util.h"
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/x509.h>
|
||||
|
||||
using namespace Common::FS;
|
||||
|
||||
namespace Service::SSL {
|
||||
|
||||
// Import OpenSSL's `SSL` type into the namespace. This is needed because the
|
||||
// namespace is also named `SSL`.
|
||||
using ::SSL;
|
||||
|
||||
namespace {
|
||||
|
||||
std::once_flag one_time_init_flag;
|
||||
bool one_time_init_success = false;
|
||||
|
||||
SSL_CTX* ssl_ctx;
|
||||
IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
|
||||
BIO_METHOD* bio_meth;
|
||||
|
||||
Result CheckOpenSSLErrors();
|
||||
void OneTimeInit();
|
||||
void OneTimeInitLogFile();
|
||||
bool OneTimeInitBIO();
|
||||
|
||||
} // namespace
|
||||
|
||||
class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
|
||||
public:
|
||||
Result Init() {
|
||||
std::call_once(one_time_init_flag, OneTimeInit);
|
||||
|
||||
if (!one_time_init_success) {
|
||||
LOG_ERROR(Service_SSL,
|
||||
"Can't create SSL connection because OpenSSL one-time initialization failed");
|
||||
return ResultInternalError;
|
||||
}
|
||||
|
||||
ssl_ = SSL_new(ssl_ctx);
|
||||
if (!ssl_) {
|
||||
LOG_ERROR(Service_SSL, "SSL_new failed");
|
||||
return CheckOpenSSLErrors();
|
||||
}
|
||||
|
||||
SSL_set_connect_state(ssl_);
|
||||
|
||||
bio_ = BIO_new(bio_meth);
|
||||
if (!bio_) {
|
||||
LOG_ERROR(Service_SSL, "BIO_new failed");
|
||||
return CheckOpenSSLErrors();
|
||||
}
|
||||
|
||||
BIO_set_data(bio_, this);
|
||||
BIO_set_init(bio_, 1);
|
||||
SSL_set_bio(ssl_, bio_, bio_);
|
||||
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
|
||||
socket_ = socket;
|
||||
}
|
||||
|
||||
Result SetHostName(const std::string& hostname) override {
|
||||
if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification
|
||||
LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
|
||||
return CheckOpenSSLErrors();
|
||||
}
|
||||
if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI
|
||||
LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
|
||||
return CheckOpenSSLErrors();
|
||||
}
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
Result DoHandshake() override {
|
||||
SSL_set_verify_result(ssl_, X509_V_OK);
|
||||
int ret = SSL_do_handshake(ssl_);
|
||||
long verify_result = SSL_get_verify_result(ssl_);
|
||||
if (verify_result != X509_V_OK) {
|
||||
LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
|
||||
X509_verify_cert_error_string(verify_result));
|
||||
return CheckOpenSSLErrors();
|
||||
}
|
||||
if (ret <= 0) {
|
||||
int ssl_err = SSL_get_error(ssl_, ret);
|
||||
if (ssl_err == SSL_ERROR_ZERO_RETURN ||
|
||||
(ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) {
|
||||
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
return HandleReturn("SSL_do_handshake", 0, ret).Code();
|
||||
}
|
||||
|
||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
||||
size_t actual;
|
||||
int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual);
|
||||
return HandleReturn("SSL_read_ex", actual, ret);
|
||||
}
|
||||
|
||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
||||
size_t actual;
|
||||
int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual);
|
||||
return HandleReturn("SSL_write_ex", actual, ret);
|
||||
}
|
||||
|
||||
ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
|
||||
int ssl_err = SSL_get_error(ssl_, ret);
|
||||
CheckOpenSSLErrors();
|
||||
switch (ssl_err) {
|
||||
case SSL_ERROR_NONE:
|
||||
return actual;
|
||||
case SSL_ERROR_ZERO_RETURN:
|
||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
|
||||
// DoHandshake special-cases this, but for Read and Write:
|
||||
return size_t(0);
|
||||
case SSL_ERROR_WANT_READ:
|
||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
|
||||
return ResultWouldBlock;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
|
||||
return ResultWouldBlock;
|
||||
default:
|
||||
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) {
|
||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
|
||||
return size_t(0);
|
||||
}
|
||||
LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
||||
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
|
||||
if (!chain) {
|
||||
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
|
||||
return ResultInternalError;
|
||||
}
|
||||
std::vector<std::vector<u8>> ret;
|
||||
int count = sk_X509_num(chain);
|
||||
ASSERT(count >= 0);
|
||||
for (int i = 0; i < count; i++) {
|
||||
X509* x509 = sk_X509_value(chain, i);
|
||||
ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
|
||||
unsigned char* buf = nullptr;
|
||||
int len = i2d_X509(x509, &buf);
|
||||
ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
|
||||
ret.emplace_back(buf, buf + len);
|
||||
OPENSSL_free(buf);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
~SSLConnectionBackendOpenSSL() {
|
||||
// these are null-tolerant:
|
||||
SSL_free(ssl_);
|
||||
BIO_free(bio_);
|
||||
}
|
||||
|
||||
static void KeyLogCallback(const SSL* ssl, const char* line) {
|
||||
std::string str(line);
|
||||
str.push_back('\n');
|
||||
// Do this in a single WriteString for atomicity if multiple instances
|
||||
// are running on different threads (though that can't currently
|
||||
// happen).
|
||||
if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
|
||||
LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
|
||||
}
|
||||
LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
|
||||
}
|
||||
|
||||
static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
|
||||
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
||||
ASSERT_OR_EXECUTE_MSG(
|
||||
self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket");
|
||||
BIO_clear_retry_flags(bio);
|
||||
auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0);
|
||||
switch (err) {
|
||||
case Network::Errno::SUCCESS:
|
||||
*actual_p = actual;
|
||||
return 1;
|
||||
case Network::Errno::AGAIN:
|
||||
BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
|
||||
return 0;
|
||||
default:
|
||||
LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
|
||||
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
|
||||
ASSERT_OR_EXECUTE_MSG(
|
||||
self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket");
|
||||
BIO_clear_retry_flags(bio);
|
||||
auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len});
|
||||
switch (err) {
|
||||
case Network::Errno::SUCCESS:
|
||||
*actual_p = actual;
|
||||
if (actual == 0) {
|
||||
self->got_read_eof_ = true;
|
||||
}
|
||||
return actual ? 1 : 0;
|
||||
case Network::Errno::AGAIN:
|
||||
BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
|
||||
return 0;
|
||||
default:
|
||||
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) {
|
||||
switch (cmd) {
|
||||
case BIO_CTRL_FLUSH:
|
||||
// Nothing to flush.
|
||||
return 1;
|
||||
case BIO_CTRL_PUSH:
|
||||
case BIO_CTRL_POP:
|
||||
case BIO_CTRL_GET_KTLS_SEND:
|
||||
case BIO_CTRL_GET_KTLS_RECV:
|
||||
// We don't support these operations, but don't bother logging them
|
||||
// as they're nothing unusual.
|
||||
return 0;
|
||||
default:
|
||||
LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
SSL* ssl_ = nullptr;
|
||||
BIO* bio_ = nullptr;
|
||||
bool got_read_eof_ = false;
|
||||
|
||||
std::shared_ptr<Network::SocketBase> socket_;
|
||||
};
|
||||
|
||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
||||
auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
|
||||
Result res = conn->Init();
|
||||
if (res.IsFailure()) {
|
||||
return res;
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Result CheckOpenSSLErrors() {
|
||||
unsigned long rc;
|
||||
const char* file;
|
||||
int line;
|
||||
const char* func;
|
||||
const char* data;
|
||||
int flags;
|
||||
while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) {
|
||||
std::string msg;
|
||||
msg.resize(1024, '\0');
|
||||
ERR_error_string_n(rc, msg.data(), msg.size());
|
||||
msg.resize(strlen(msg.data()), '\0');
|
||||
if (flags & ERR_TXT_STRING) {
|
||||
msg.append(" | ");
|
||||
msg.append(data);
|
||||
}
|
||||
Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
|
||||
Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
|
||||
msg);
|
||||
}
|
||||
return ResultInternalError;
|
||||
}
|
||||
|
||||
void OneTimeInit() {
|
||||
ssl_ctx = SSL_CTX_new(TLS_client_method());
|
||||
if (!ssl_ctx) {
|
||||
LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
|
||||
CheckOpenSSLErrors();
|
||||
return;
|
||||
}
|
||||
|
||||
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
|
||||
|
||||
if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
|
||||
LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
|
||||
CheckOpenSSLErrors();
|
||||
return;
|
||||
}
|
||||
|
||||
OneTimeInitLogFile();
|
||||
|
||||
if (!OneTimeInitBIO()) {
|
||||
return;
|
||||
}
|
||||
|
||||
one_time_init_success = true;
|
||||
}
|
||||
|
||||
void OneTimeInitLogFile() {
|
||||
const char* logfile = getenv("SSLKEYLOGFILE");
|
||||
if (logfile) {
|
||||
key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
|
||||
FileShareFlag::ShareWriteOnly);
|
||||
if (key_log_file.IsOpen()) {
|
||||
SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
|
||||
} else {
|
||||
LOG_CRITICAL(Service_SSL,
|
||||
"SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool OneTimeInitBIO() {
|
||||
bio_meth =
|
||||
BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
|
||||
if (!bio_meth ||
|
||||
!BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
|
||||
!BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
|
||||
!BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
|
||||
LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace Service::SSL
|
@ -0,0 +1,529 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
#include "core/internal_network/network.h"
|
||||
#include "core/internal_network/sockets.h"
|
||||
|
||||
#include "common/error.h"
|
||||
#include "common/fs/file.h"
|
||||
#include "common/hex_util.h"
|
||||
#include "common/string_util.h"
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#define SECURITY_WIN32
|
||||
#include <Security.h>
|
||||
#include <schnlsp.h>
|
||||
|
||||
namespace {
|
||||
|
||||
std::once_flag one_time_init_flag;
|
||||
bool one_time_init_success = false;
|
||||
|
||||
SCHANNEL_CRED schannel_cred{
|
||||
.dwVersion = SCHANNEL_CRED_VERSION,
|
||||
.dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
|
||||
SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
|
||||
SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate
|
||||
// ^ I'm assuming that nobody would want to connect Yuzu to a
|
||||
// service that requires some OS-provided corporate client
|
||||
// certificate, and presenting one to some arbitrary server
|
||||
// might be a privacy concern? Who knows, though.
|
||||
};
|
||||
|
||||
CredHandle cred_handle;
|
||||
|
||||
static void OneTimeInit() {
|
||||
SECURITY_STATUS ret =
|
||||
AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
|
||||
nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
|
||||
if (ret != SEC_E_OK) {
|
||||
// SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
|
||||
LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
|
||||
Common::NativeErrorToString(ret));
|
||||
return;
|
||||
}
|
||||
|
||||
one_time_init_success = true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace Service::SSL {
|
||||
|
||||
class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
|
||||
public:
|
||||
Result Init() {
|
||||
std::call_once(one_time_init_flag, OneTimeInit);
|
||||
|
||||
if (!one_time_init_success) {
|
||||
LOG_ERROR(
|
||||
Service_SSL,
|
||||
"Can't create SSL connection because Schannel one-time initialization failed");
|
||||
return ResultInternalError;
|
||||
}
|
||||
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
|
||||
socket_ = socket;
|
||||
}
|
||||
|
||||
Result SetHostName(const std::string& hostname) override {
|
||||
hostname_ = hostname;
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
Result DoHandshake() override {
|
||||
while (1) {
|
||||
Result r;
|
||||
switch (handshake_state_) {
|
||||
case HandshakeState::Initial:
|
||||
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
|
||||
(r = CallInitializeSecurityContext()) != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
// CallInitializeSecurityContext updated `handshake_state_`.
|
||||
continue;
|
||||
case HandshakeState::ContinueNeeded:
|
||||
case HandshakeState::IncompleteMessage:
|
||||
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
|
||||
(r = FillCiphertextReadBuf()) != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
if (ciphertext_read_buf_.empty()) {
|
||||
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
|
||||
return ResultInternalError;
|
||||
}
|
||||
if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
// CallInitializeSecurityContext updated `handshake_state_`.
|
||||
continue;
|
||||
case HandshakeState::DoneAfterFlush:
|
||||
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
handshake_state_ = HandshakeState::Connected;
|
||||
return ResultSuccess;
|
||||
case HandshakeState::Connected:
|
||||
LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
|
||||
return ResultInternalError;
|
||||
case HandshakeState::Error:
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Result FillCiphertextReadBuf() {
|
||||
size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
|
||||
read_buf_fill_size_ = 0;
|
||||
// This unnecessarily zeroes the buffer; oh well.
|
||||
size_t offset = ciphertext_read_buf_.size();
|
||||
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
|
||||
ciphertext_read_buf_.resize(offset + fill_size, 0);
|
||||
auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
|
||||
auto [actual, err] = socket_->Recv(0, read_span);
|
||||
switch (err) {
|
||||
case Network::Errno::SUCCESS:
|
||||
ASSERT(static_cast<size_t>(actual) <= fill_size);
|
||||
ciphertext_read_buf_.resize(offset + actual);
|
||||
return ResultSuccess;
|
||||
case Network::Errno::AGAIN:
|
||||
ciphertext_read_buf_.resize(offset);
|
||||
return ResultWouldBlock;
|
||||
default:
|
||||
ciphertext_read_buf_.resize(offset);
|
||||
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns success if the write buffer has been completely emptied.
|
||||
Result FlushCiphertextWriteBuf() {
|
||||
while (!ciphertext_write_buf_.empty()) {
|
||||
auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
|
||||
switch (err) {
|
||||
case Network::Errno::SUCCESS:
|
||||
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
|
||||
ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(),
|
||||
ciphertext_write_buf_.begin() + actual);
|
||||
break;
|
||||
case Network::Errno::AGAIN:
|
||||
return ResultWouldBlock;
|
||||
default:
|
||||
LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
Result CallInitializeSecurityContext() {
|
||||
unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY |
|
||||
ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
|
||||
ISC_REQ_USE_SUPPLIED_CREDS;
|
||||
unsigned long attr;
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
|
||||
std::array<SecBuffer, 2> input_buffers{{
|
||||
// only used if `initial_call_done`
|
||||
{
|
||||
// [0]
|
||||
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
|
||||
.BufferType = SECBUFFER_TOKEN,
|
||||
.pvBuffer = ciphertext_read_buf_.data(),
|
||||
},
|
||||
{
|
||||
// [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
|
||||
// returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
|
||||
// whole buffer wasn't used)
|
||||
.BufferType = SECBUFFER_EMPTY,
|
||||
},
|
||||
}};
|
||||
std::array<SecBuffer, 2> output_buffers{{
|
||||
{
|
||||
.BufferType = SECBUFFER_TOKEN,
|
||||
}, // [0]
|
||||
{
|
||||
.BufferType = SECBUFFER_ALERT,
|
||||
}, // [1]
|
||||
}};
|
||||
SecBufferDesc input_desc{
|
||||
.ulVersion = SECBUFFER_VERSION,
|
||||
.cBuffers = static_cast<unsigned long>(input_buffers.size()),
|
||||
.pBuffers = input_buffers.data(),
|
||||
};
|
||||
SecBufferDesc output_desc{
|
||||
.ulVersion = SECBUFFER_VERSION,
|
||||
.cBuffers = static_cast<unsigned long>(output_buffers.size()),
|
||||
.pBuffers = output_buffers.data(),
|
||||
};
|
||||
ASSERT_OR_EXECUTE_MSG(
|
||||
input_buffers[0].cbBuffer == ciphertext_read_buf_.size(),
|
||||
{ return ResultInternalError; }, "read buffer too large");
|
||||
|
||||
bool initial_call_done = handshake_state_ != HandshakeState::Initial;
|
||||
if (initial_call_done) {
|
||||
LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
|
||||
ciphertext_read_buf_.size());
|
||||
}
|
||||
|
||||
SECURITY_STATUS ret =
|
||||
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
|
||||
// Caller ensured we have set a hostname:
|
||||
const_cast<char*>(hostname_.value().c_str()), req,
|
||||
0, // Reserved1
|
||||
0, // TargetDataRep not used with Schannel
|
||||
initial_call_done ? &input_desc : nullptr,
|
||||
0, // Reserved2
|
||||
initial_call_done ? nullptr : &ctxt_, &output_desc, &attr,
|
||||
nullptr); // ptsExpiry
|
||||
|
||||
if (output_buffers[0].pvBuffer) {
|
||||
std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
|
||||
output_buffers[0].cbBuffer);
|
||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
|
||||
FreeContextBuffer(output_buffers[0].pvBuffer);
|
||||
}
|
||||
|
||||
if (output_buffers[1].pvBuffer) {
|
||||
std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
|
||||
output_buffers[1].cbBuffer);
|
||||
// The documentation doesn't explain what format this data is in.
|
||||
LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
|
||||
Common::HexToString(span));
|
||||
}
|
||||
|
||||
switch (ret) {
|
||||
case SEC_I_CONTINUE_NEEDED:
|
||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
|
||||
if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
|
||||
LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
|
||||
ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size());
|
||||
ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
|
||||
ciphertext_read_buf_.end() - input_buffers[1].cbBuffer);
|
||||
} else {
|
||||
ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
|
||||
ciphertext_read_buf_.clear();
|
||||
}
|
||||
handshake_state_ = HandshakeState::ContinueNeeded;
|
||||
return ResultSuccess;
|
||||
case SEC_E_INCOMPLETE_MESSAGE:
|
||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
|
||||
ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
|
||||
read_buf_fill_size_ = input_buffers[1].cbBuffer;
|
||||
handshake_state_ = HandshakeState::IncompleteMessage;
|
||||
return ResultSuccess;
|
||||
case SEC_E_OK:
|
||||
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
|
||||
ciphertext_read_buf_.clear();
|
||||
handshake_state_ = HandshakeState::DoneAfterFlush;
|
||||
return GrabStreamSizes();
|
||||
default:
|
||||
LOG_ERROR(Service_SSL,
|
||||
"InitializeSecurityContext failed (probably certificate/protocol issue): {}",
|
||||
Common::NativeErrorToString(ret));
|
||||
handshake_state_ = HandshakeState::Error;
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
|
||||
Result GrabStreamSizes() {
|
||||
SECURITY_STATUS ret =
|
||||
QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
|
||||
if (ret != SEC_E_OK) {
|
||||
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
|
||||
Common::NativeErrorToString(ret));
|
||||
handshake_state_ = HandshakeState::Error;
|
||||
return ResultInternalError;
|
||||
}
|
||||
return ResultSuccess;
|
||||
}
|
||||
|
||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
||||
if (handshake_state_ != HandshakeState::Connected) {
|
||||
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
|
||||
return ResultInternalError;
|
||||
}
|
||||
if (data.size() == 0 || got_read_eof_) {
|
||||
return size_t(0);
|
||||
}
|
||||
while (1) {
|
||||
if (!cleartext_read_buf_.empty()) {
|
||||
size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
|
||||
std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
|
||||
cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
|
||||
cleartext_read_buf_.begin() + read_size);
|
||||
return read_size;
|
||||
}
|
||||
if (!ciphertext_read_buf_.empty()) {
|
||||
std::array<SecBuffer, 5> buffers{{
|
||||
{
|
||||
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
|
||||
.BufferType = SECBUFFER_DATA,
|
||||
.pvBuffer = ciphertext_read_buf_.data(),
|
||||
},
|
||||
{
|
||||
.BufferType = SECBUFFER_EMPTY,
|
||||
},
|
||||
{
|
||||
.BufferType = SECBUFFER_EMPTY,
|
||||
},
|
||||
{
|
||||
.BufferType = SECBUFFER_EMPTY,
|
||||
},
|
||||
}};
|
||||
ASSERT_OR_EXECUTE_MSG(
|
||||
buffers[0].cbBuffer == ciphertext_read_buf_.size(),
|
||||
{ return ResultInternalError; }, "read buffer too large");
|
||||
SecBufferDesc desc{
|
||||
.ulVersion = SECBUFFER_VERSION,
|
||||
.cBuffers = static_cast<unsigned long>(buffers.size()),
|
||||
.pBuffers = buffers.data(),
|
||||
};
|
||||
SECURITY_STATUS ret =
|
||||
DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
|
||||
switch (ret) {
|
||||
case SEC_E_OK:
|
||||
ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
|
||||
{ return ResultInternalError; });
|
||||
ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
|
||||
{ return ResultInternalError; });
|
||||
ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
|
||||
{ return ResultInternalError; });
|
||||
cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer),
|
||||
static_cast<u8*>(buffers[1].pvBuffer) +
|
||||
buffers[1].cbBuffer);
|
||||
if (buffers[3].BufferType == SECBUFFER_EXTRA) {
|
||||
ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size());
|
||||
ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
|
||||
ciphertext_read_buf_.end() -
|
||||
buffers[3].cbBuffer);
|
||||
} else {
|
||||
ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
|
||||
ciphertext_read_buf_.clear();
|
||||
}
|
||||
continue;
|
||||
case SEC_E_INCOMPLETE_MESSAGE:
|
||||
break;
|
||||
case SEC_I_CONTEXT_EXPIRED:
|
||||
// Server hung up by sending close_notify.
|
||||
got_read_eof_ = true;
|
||||
return size_t(0);
|
||||
default:
|
||||
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
|
||||
Common::NativeErrorToString(ret));
|
||||
return ResultInternalError;
|
||||
}
|
||||
}
|
||||
Result r = FillCiphertextReadBuf();
|
||||
if (r != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
if (ciphertext_read_buf_.empty()) {
|
||||
got_read_eof_ = true;
|
||||
return size_t(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
||||
if (handshake_state_ != HandshakeState::Connected) {
|
||||
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
|
||||
return ResultInternalError;
|
||||
}
|
||||
if (data.size() == 0) {
|
||||
return size_t(0);
|
||||
}
|
||||
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage));
|
||||
if (!cleartext_write_buf_.empty()) {
|
||||
// Already in the middle of a write. It wouldn't make sense to not
|
||||
// finish sending the entire buffer since TLS has
|
||||
// header/MAC/padding/etc.
|
||||
if (data.size() != cleartext_write_buf_.size() ||
|
||||
std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) {
|
||||
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
|
||||
return ResultInternalError;
|
||||
}
|
||||
return WriteAlreadyEncryptedData();
|
||||
} else {
|
||||
cleartext_write_buf_.assign(data.begin(), data.end());
|
||||
}
|
||||
|
||||
std::vector<u8> header_buf(stream_sizes_.cbHeader, 0);
|
||||
std::vector<u8> tmp_data_buf = cleartext_write_buf_;
|
||||
std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0);
|
||||
|
||||
std::array<SecBuffer, 3> buffers{{
|
||||
{
|
||||
.cbBuffer = stream_sizes_.cbHeader,
|
||||
.BufferType = SECBUFFER_STREAM_HEADER,
|
||||
.pvBuffer = header_buf.data(),
|
||||
},
|
||||
{
|
||||
.cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
|
||||
.BufferType = SECBUFFER_DATA,
|
||||
.pvBuffer = tmp_data_buf.data(),
|
||||
},
|
||||
{
|
||||
.cbBuffer = stream_sizes_.cbTrailer,
|
||||
.BufferType = SECBUFFER_STREAM_TRAILER,
|
||||
.pvBuffer = trailer_buf.data(),
|
||||
},
|
||||
}};
|
||||
ASSERT_OR_EXECUTE_MSG(
|
||||
buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
|
||||
"temp buffer too large");
|
||||
SecBufferDesc desc{
|
||||
.ulVersion = SECBUFFER_VERSION,
|
||||
.cBuffers = static_cast<unsigned long>(buffers.size()),
|
||||
.pBuffers = buffers.data(),
|
||||
};
|
||||
|
||||
SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
|
||||
if (ret != SEC_E_OK) {
|
||||
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
|
||||
return ResultInternalError;
|
||||
}
|
||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(),
|
||||
header_buf.end());
|
||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(),
|
||||
tmp_data_buf.end());
|
||||
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(),
|
||||
trailer_buf.end());
|
||||
return WriteAlreadyEncryptedData();
|
||||
}
|
||||
|
||||
ResultVal<size_t> WriteAlreadyEncryptedData() {
|
||||
Result r = FlushCiphertextWriteBuf();
|
||||
if (r != ResultSuccess) {
|
||||
return r;
|
||||
}
|
||||
// write buf is empty
|
||||
size_t cleartext_bytes_written = cleartext_write_buf_.size();
|
||||
cleartext_write_buf_.clear();
|
||||
return cleartext_bytes_written;
|
||||
}
|
||||
|
||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
||||
PCCERT_CONTEXT returned_cert = nullptr;
|
||||
SECURITY_STATUS ret =
|
||||
QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
||||
if (ret != SEC_E_OK) {
|
||||
LOG_ERROR(Service_SSL,
|
||||
"QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
|
||||
Common::NativeErrorToString(ret));
|
||||
return ResultInternalError;
|
||||
}
|
||||
PCCERT_CONTEXT some_cert = nullptr;
|
||||
std::vector<std::vector<u8>> certs;
|
||||
while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
|
||||
certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
|
||||
static_cast<u8*>(some_cert->pbCertEncoded) +
|
||||
some_cert->cbCertEncoded);
|
||||
}
|
||||
std::reverse(certs.begin(),
|
||||
certs.end()); // Windows returns certs in reverse order from what we want
|
||||
CertFreeCertificateContext(returned_cert);
|
||||
return certs;
|
||||
}
|
||||
|
||||
~SSLConnectionBackendSchannel() {
|
||||
if (handshake_state_ != HandshakeState::Initial) {
|
||||
DeleteSecurityContext(&ctxt_);
|
||||
}
|
||||
}
|
||||
|
||||
enum class HandshakeState {
|
||||
// Haven't called anything yet.
|
||||
Initial,
|
||||
// `SEC_I_CONTINUE_NEEDED` was returned by
|
||||
// `InitializeSecurityContext`; must finish sending data (if any) in
|
||||
// the write buffer, then read at least one byte before calling
|
||||
// `InitializeSecurityContext` again.
|
||||
ContinueNeeded,
|
||||
// `SEC_E_INCOMPLETE_MESSAGE` was returned by
|
||||
// `InitializeSecurityContext`; hopefully the write buffer is empty;
|
||||
// must read at least one byte before calling
|
||||
// `InitializeSecurityContext` again.
|
||||
IncompleteMessage,
|
||||
// `SEC_E_OK` was returned by `InitializeSecurityContext`; must
|
||||
// finish sending data in the write buffer before having `DoHandshake`
|
||||
// report success.
|
||||
DoneAfterFlush,
|
||||
// We finished the above and are now connected. At this point, writing
|
||||
// and reading are separate 'state machines' represented by the
|
||||
// nonemptiness of the ciphertext and cleartext read and write buffers.
|
||||
Connected,
|
||||
// Another error was returned and we shouldn't allow initialization
|
||||
// to continue.
|
||||
Error,
|
||||
} handshake_state_ = HandshakeState::Initial;
|
||||
|
||||
CtxtHandle ctxt_;
|
||||
SecPkgContext_StreamSizes stream_sizes_;
|
||||
|
||||
std::shared_ptr<Network::SocketBase> socket_;
|
||||
std::optional<std::string> hostname_;
|
||||
|
||||
std::vector<u8> ciphertext_read_buf_;
|
||||
std::vector<u8> ciphertext_write_buf_;
|
||||
std::vector<u8> cleartext_read_buf_;
|
||||
std::vector<u8> cleartext_write_buf_;
|
||||
|
||||
bool got_read_eof_ = false;
|
||||
size_t read_buf_fill_size_ = 0;
|
||||
};
|
||||
|
||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
||||
auto conn = std::make_unique<SSLConnectionBackendSchannel>();
|
||||
Result res = conn->Init();
|
||||
if (res.IsFailure()) {
|
||||
return res;
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
} // namespace Service::SSL
|
Loading…
Reference in New Issue