mirror of https://git.suyu.dev/suyu/suyu
				
				
				
			
						commit
						2461c78e3f
					
				@ -0,0 +1,45 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "core/hle/result.h"
 | 
			
		||||
 | 
			
		||||
#include "common/common_types.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <span>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace Network {
 | 
			
		||||
class SocketBase;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
 | 
			
		||||
constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
 | 
			
		||||
constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
 | 
			
		||||
constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
 | 
			
		||||
 | 
			
		||||
// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
 | 
			
		||||
// with no way in the latter case to distinguish whether the client should poll
 | 
			
		||||
// for read or write.  The one official client I've seen handles this by always
 | 
			
		||||
// polling for read (with a timeout).
 | 
			
		||||
constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    virtual ~SSLConnectionBackend() {}
 | 
			
		||||
    virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
 | 
			
		||||
    virtual Result SetHostName(const std::string& hostname) = 0;
 | 
			
		||||
    virtual Result DoHandshake() = 0;
 | 
			
		||||
    virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
 | 
			
		||||
    virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
 | 
			
		||||
    virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
@ -0,0 +1,16 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
 | 
			
		||||
#include "common/logging/log.h"
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    LOG_ERROR(Service_SSL,
 | 
			
		||||
              "Can't create SSL connection because no SSL backend is available on this platform");
 | 
			
		||||
    return ResultInternalError;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
@ -0,0 +1,351 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include "common/fs/file.h"
 | 
			
		||||
#include "common/hex_util.h"
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <openssl/bio.h>
 | 
			
		||||
#include <openssl/err.h>
 | 
			
		||||
#include <openssl/ssl.h>
 | 
			
		||||
#include <openssl/x509.h>
 | 
			
		||||
 | 
			
		||||
using namespace Common::FS;
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
// Import OpenSSL's `SSL` type into the namespace.  This is needed because the
 | 
			
		||||
// namespace is also named `SSL`.
 | 
			
		||||
using ::SSL;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::once_flag one_time_init_flag;
 | 
			
		||||
bool one_time_init_success = false;
 | 
			
		||||
 | 
			
		||||
SSL_CTX* ssl_ctx;
 | 
			
		||||
IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
 | 
			
		||||
BIO_METHOD* bio_meth;
 | 
			
		||||
 | 
			
		||||
Result CheckOpenSSLErrors();
 | 
			
		||||
void OneTimeInit();
 | 
			
		||||
void OneTimeInitLogFile();
 | 
			
		||||
bool OneTimeInitBIO();
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        std::call_once(one_time_init_flag, OneTimeInit);
 | 
			
		||||
 | 
			
		||||
        if (!one_time_init_success) {
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "Can't create SSL connection because OpenSSL one-time initialization failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ssl = SSL_new(ssl_ctx);
 | 
			
		||||
        if (!ssl) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_new failed");
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SSL_set_connect_state(ssl);
 | 
			
		||||
 | 
			
		||||
        bio = BIO_new(bio_meth);
 | 
			
		||||
        if (!bio) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "BIO_new failed");
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        BIO_set_data(bio, this);
 | 
			
		||||
        BIO_set_init(bio, 1);
 | 
			
		||||
        SSL_set_bio(ssl, bio, bio);
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
 | 
			
		||||
        socket = std::move(socket_in);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname) override {
 | 
			
		||||
        if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        SSL_set_verify_result(ssl, X509_V_OK);
 | 
			
		||||
        const int ret = SSL_do_handshake(ssl);
 | 
			
		||||
        const long verify_result = SSL_get_verify_result(ssl);
 | 
			
		||||
        if (verify_result != X509_V_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
 | 
			
		||||
                      X509_verify_cert_error_string(verify_result));
 | 
			
		||||
            return CheckOpenSSLErrors();
 | 
			
		||||
        }
 | 
			
		||||
        if (ret <= 0) {
 | 
			
		||||
            const int ssl_err = SSL_get_error(ssl, ret);
 | 
			
		||||
            if (ssl_err == SSL_ERROR_ZERO_RETURN ||
 | 
			
		||||
                (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
 | 
			
		||||
                LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return HandleReturn("SSL_do_handshake", 0, ret).Code();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
 | 
			
		||||
        return HandleReturn("SSL_read_ex", actual, ret);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
 | 
			
		||||
        return HandleReturn("SSL_write_ex", actual, ret);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
 | 
			
		||||
        const int ssl_err = SSL_get_error(ssl, ret);
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        switch (ssl_err) {
 | 
			
		||||
        case SSL_ERROR_NONE:
 | 
			
		||||
            return actual;
 | 
			
		||||
        case SSL_ERROR_ZERO_RETURN:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
 | 
			
		||||
            // DoHandshake special-cases this, but for Read and Write:
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        case SSL_ERROR_WANT_READ:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        case SSL_ERROR_WANT_WRITE:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default:
 | 
			
		||||
            if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
 | 
			
		||||
                LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
 | 
			
		||||
                return size_t(0);
 | 
			
		||||
            }
 | 
			
		||||
            LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
 | 
			
		||||
        if (!chain) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<std::vector<u8>> ret;
 | 
			
		||||
        int count = sk_X509_num(chain);
 | 
			
		||||
        ASSERT(count >= 0);
 | 
			
		||||
        for (int i = 0; i < count; i++) {
 | 
			
		||||
            X509* x509 = sk_X509_value(chain, i);
 | 
			
		||||
            ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
 | 
			
		||||
            unsigned char* buf = nullptr;
 | 
			
		||||
            int len = i2d_X509(x509, &buf);
 | 
			
		||||
            ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
 | 
			
		||||
            ret.emplace_back(buf, buf + len);
 | 
			
		||||
            OPENSSL_free(buf);
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~SSLConnectionBackendOpenSSL() {
 | 
			
		||||
        // these are null-tolerant:
 | 
			
		||||
        SSL_free(ssl);
 | 
			
		||||
        BIO_free(bio);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static void KeyLogCallback(const SSL* ssl, const char* line) {
 | 
			
		||||
        std::string str(line);
 | 
			
		||||
        str.push_back('\n');
 | 
			
		||||
        // Do this in a single WriteString for atomicity if multiple instances
 | 
			
		||||
        // are running on different threads (though that can't currently
 | 
			
		||||
        // happen).
 | 
			
		||||
        if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
 | 
			
		||||
            LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
 | 
			
		||||
        }
 | 
			
		||||
        LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
 | 
			
		||||
        auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
 | 
			
		||||
        BIO_clear_retry_flags(bio);
 | 
			
		||||
        auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            *actual_p = actual;
 | 
			
		||||
            return 1;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
 | 
			
		||||
            return -1;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
 | 
			
		||||
        auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
 | 
			
		||||
        BIO_clear_retry_flags(bio);
 | 
			
		||||
        auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            *actual_p = actual;
 | 
			
		||||
            if (actual == 0) {
 | 
			
		||||
                self->got_read_eof = true;
 | 
			
		||||
            }
 | 
			
		||||
            return actual ? 1 : 0;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
 | 
			
		||||
            return -1;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
 | 
			
		||||
        switch (cmd) {
 | 
			
		||||
        case BIO_CTRL_FLUSH:
 | 
			
		||||
            // Nothing to flush.
 | 
			
		||||
            return 1;
 | 
			
		||||
        case BIO_CTRL_PUSH:
 | 
			
		||||
        case BIO_CTRL_POP:
 | 
			
		||||
#ifdef BIO_CTRL_GET_KTLS_SEND
 | 
			
		||||
        case BIO_CTRL_GET_KTLS_SEND:
 | 
			
		||||
        case BIO_CTRL_GET_KTLS_RECV:
 | 
			
		||||
#endif
 | 
			
		||||
            // We don't support these operations, but don't bother logging them
 | 
			
		||||
            // as they're nothing unusual.
 | 
			
		||||
            return 0;
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
 | 
			
		||||
            return 0;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SSL* ssl = nullptr;
 | 
			
		||||
    BIO* bio = nullptr;
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
Result CheckOpenSSLErrors() {
 | 
			
		||||
    unsigned long rc;
 | 
			
		||||
    const char* file;
 | 
			
		||||
    int line;
 | 
			
		||||
    const char* func;
 | 
			
		||||
    const char* data;
 | 
			
		||||
    int flags;
 | 
			
		||||
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
 | 
			
		||||
    while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
 | 
			
		||||
#else
 | 
			
		||||
    // Can't get function names from OpenSSL on this version, so use mine:
 | 
			
		||||
    func = __func__;
 | 
			
		||||
    while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
 | 
			
		||||
#endif
 | 
			
		||||
    {
 | 
			
		||||
        std::string msg;
 | 
			
		||||
        msg.resize(1024, '\0');
 | 
			
		||||
        ERR_error_string_n(rc, msg.data(), msg.size());
 | 
			
		||||
        msg.resize(strlen(msg.data()), '\0');
 | 
			
		||||
        if (flags & ERR_TXT_STRING) {
 | 
			
		||||
            msg.append(" | ");
 | 
			
		||||
            msg.append(data);
 | 
			
		||||
        }
 | 
			
		||||
        Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
 | 
			
		||||
                                   Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
 | 
			
		||||
                                   msg);
 | 
			
		||||
    }
 | 
			
		||||
    return ResultInternalError;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OneTimeInit() {
 | 
			
		||||
    ssl_ctx = SSL_CTX_new(TLS_client_method());
 | 
			
		||||
    if (!ssl_ctx) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
 | 
			
		||||
 | 
			
		||||
    if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
 | 
			
		||||
        CheckOpenSSLErrors();
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    OneTimeInitLogFile();
 | 
			
		||||
 | 
			
		||||
    if (!OneTimeInitBIO()) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    one_time_init_success = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void OneTimeInitLogFile() {
 | 
			
		||||
    const char* logfile = getenv("SSLKEYLOGFILE");
 | 
			
		||||
    if (logfile) {
 | 
			
		||||
        key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
 | 
			
		||||
                          FileShareFlag::ShareWriteOnly);
 | 
			
		||||
        if (key_log_file.IsOpen()) {
 | 
			
		||||
            SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
 | 
			
		||||
        } else {
 | 
			
		||||
            LOG_CRITICAL(Service_SSL,
 | 
			
		||||
                         "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool OneTimeInitBIO() {
 | 
			
		||||
    bio_meth =
 | 
			
		||||
        BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
 | 
			
		||||
    if (!bio_meth ||
 | 
			
		||||
        !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
 | 
			
		||||
        !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
 | 
			
		||||
        !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
 | 
			
		||||
        LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
@ -0,0 +1,543 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include "common/error.h"
 | 
			
		||||
#include "common/fs/file.h"
 | 
			
		||||
#include "common/hex_util.h"
 | 
			
		||||
#include "common/string_util.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// These includes are inside the namespace to avoid a conflict on MinGW where
 | 
			
		||||
// the headers define an enum containing Network and Service as enumerators
 | 
			
		||||
// (which clash with the correspondingly named namespaces).
 | 
			
		||||
#define SECURITY_WIN32
 | 
			
		||||
#include <schnlsp.h>
 | 
			
		||||
#include <security.h>
 | 
			
		||||
 | 
			
		||||
std::once_flag one_time_init_flag;
 | 
			
		||||
bool one_time_init_success = false;
 | 
			
		||||
 | 
			
		||||
SCHANNEL_CRED schannel_cred{};
 | 
			
		||||
CredHandle cred_handle;
 | 
			
		||||
 | 
			
		||||
static void OneTimeInit() {
 | 
			
		||||
    schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
 | 
			
		||||
    schannel_cred.dwFlags =
 | 
			
		||||
        SCH_USE_STRONG_CRYPTO |         // don't allow insecure protocols
 | 
			
		||||
        SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
 | 
			
		||||
        SCH_CRED_NO_DEFAULT_CREDS;      // don't automatically present a client certificate
 | 
			
		||||
    // ^ I'm assuming that nobody would want to connect Yuzu to a
 | 
			
		||||
    // service that requires some OS-provided corporate client
 | 
			
		||||
    // certificate, and presenting one to some arbitrary server
 | 
			
		||||
    // might be a privacy concern?  Who knows, though.
 | 
			
		||||
 | 
			
		||||
    const SECURITY_STATUS ret =
 | 
			
		||||
        AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
 | 
			
		||||
                                 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
 | 
			
		||||
    if (ret != SEC_E_OK) {
 | 
			
		||||
        // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
 | 
			
		||||
        LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
 | 
			
		||||
                  Common::NativeErrorToString(ret));
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (getenv("SSLKEYLOGFILE")) {
 | 
			
		||||
        LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
 | 
			
		||||
                                  "keys; not logging keys!");
 | 
			
		||||
        // Not fatal.
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    one_time_init_success = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        std::call_once(one_time_init_flag, OneTimeInit);
 | 
			
		||||
 | 
			
		||||
        if (!one_time_init_success) {
 | 
			
		||||
            LOG_ERROR(
 | 
			
		||||
                Service_SSL,
 | 
			
		||||
                "Can't create SSL connection because Schannel one-time initialization failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
 | 
			
		||||
        socket = std::move(socket_in);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname_in) override {
 | 
			
		||||
        hostname = hostname_in;
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        while (1) {
 | 
			
		||||
            Result r;
 | 
			
		||||
            switch (handshake_state) {
 | 
			
		||||
            case HandshakeState::Initial:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
 | 
			
		||||
                    (r = CallInitializeSecurityContext()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                // CallInitializeSecurityContext updated `handshake_state`.
 | 
			
		||||
                continue;
 | 
			
		||||
            case HandshakeState::ContinueNeeded:
 | 
			
		||||
            case HandshakeState::IncompleteMessage:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
 | 
			
		||||
                    (r = FillCiphertextReadBuf()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                if (ciphertext_read_buf.empty()) {
 | 
			
		||||
                    LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
 | 
			
		||||
                    return ResultInternalError;
 | 
			
		||||
                }
 | 
			
		||||
                if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                // CallInitializeSecurityContext updated `handshake_state`.
 | 
			
		||||
                continue;
 | 
			
		||||
            case HandshakeState::DoneAfterFlush:
 | 
			
		||||
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
 | 
			
		||||
                    return r;
 | 
			
		||||
                }
 | 
			
		||||
                handshake_state = HandshakeState::Connected;
 | 
			
		||||
                return ResultSuccess;
 | 
			
		||||
            case HandshakeState::Connected:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            case HandshakeState::Error:
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result FillCiphertextReadBuf() {
 | 
			
		||||
        const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
 | 
			
		||||
        read_buf_fill_size = 0;
 | 
			
		||||
        // This unnecessarily zeroes the buffer; oh well.
 | 
			
		||||
        const size_t offset = ciphertext_read_buf.size();
 | 
			
		||||
        ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
 | 
			
		||||
        ciphertext_read_buf.resize(offset + fill_size, 0);
 | 
			
		||||
        const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
 | 
			
		||||
        const auto [actual, err] = socket->Recv(0, read_span);
 | 
			
		||||
        switch (err) {
 | 
			
		||||
        case Network::Errno::SUCCESS:
 | 
			
		||||
            ASSERT(static_cast<size_t>(actual) <= fill_size);
 | 
			
		||||
            ciphertext_read_buf.resize(offset + actual);
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case Network::Errno::AGAIN:
 | 
			
		||||
            ciphertext_read_buf.resize(offset);
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default:
 | 
			
		||||
            ciphertext_read_buf.resize(offset);
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Returns success if the write buffer has been completely emptied.
 | 
			
		||||
    Result FlushCiphertextWriteBuf() {
 | 
			
		||||
        while (!ciphertext_write_buf.empty()) {
 | 
			
		||||
            const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
 | 
			
		||||
            switch (err) {
 | 
			
		||||
            case Network::Errno::SUCCESS:
 | 
			
		||||
                ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
 | 
			
		||||
                ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
 | 
			
		||||
                                           ciphertext_write_buf.begin() + actual);
 | 
			
		||||
                break;
 | 
			
		||||
            case Network::Errno::AGAIN:
 | 
			
		||||
                return ResultWouldBlock;
 | 
			
		||||
            default:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result CallInitializeSecurityContext() {
 | 
			
		||||
        const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
 | 
			
		||||
                                  ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
 | 
			
		||||
                                  ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
 | 
			
		||||
                                  ISC_REQ_USE_SUPPLIED_CREDS;
 | 
			
		||||
        unsigned long attr;
 | 
			
		||||
        // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
 | 
			
		||||
        std::array<SecBuffer, 2> input_buffers{{
 | 
			
		||||
            // only used if `initial_call_done`
 | 
			
		||||
            {
 | 
			
		||||
                // [0]
 | 
			
		||||
                .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
 | 
			
		||||
                .BufferType = SECBUFFER_TOKEN,
 | 
			
		||||
                .pvBuffer = ciphertext_read_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
 | 
			
		||||
                //     returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
 | 
			
		||||
                //     whole buffer wasn't used)
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_EMPTY,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            },
 | 
			
		||||
        }};
 | 
			
		||||
        std::array<SecBuffer, 2> output_buffers{{
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_TOKEN,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            }, // [0]
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = 0,
 | 
			
		||||
                .BufferType = SECBUFFER_ALERT,
 | 
			
		||||
                .pvBuffer = nullptr,
 | 
			
		||||
            }, // [1]
 | 
			
		||||
        }};
 | 
			
		||||
        SecBufferDesc input_desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(input_buffers.size()),
 | 
			
		||||
            .pBuffers = input_buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
        SecBufferDesc output_desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(output_buffers.size()),
 | 
			
		||||
            .pBuffers = output_buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
 | 
			
		||||
            { return ResultInternalError; }, "read buffer too large");
 | 
			
		||||
 | 
			
		||||
        bool initial_call_done = handshake_state != HandshakeState::Initial;
 | 
			
		||||
        if (initial_call_done) {
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
 | 
			
		||||
                      ciphertext_read_buf.size());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
 | 
			
		||||
                                       // Caller ensured we have set a hostname:
 | 
			
		||||
                                       const_cast<char*>(hostname.value().c_str()), req,
 | 
			
		||||
                                       0, // Reserved1
 | 
			
		||||
                                       0, // TargetDataRep not used with Schannel
 | 
			
		||||
                                       initial_call_done ? &input_desc : nullptr,
 | 
			
		||||
                                       0, // Reserved2
 | 
			
		||||
                                       initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
 | 
			
		||||
                                       nullptr); // ptsExpiry
 | 
			
		||||
 | 
			
		||||
        if (output_buffers[0].pvBuffer) {
 | 
			
		||||
            const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
 | 
			
		||||
                                 output_buffers[0].cbBuffer);
 | 
			
		||||
            ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
 | 
			
		||||
            FreeContextBuffer(output_buffers[0].pvBuffer);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (output_buffers[1].pvBuffer) {
 | 
			
		||||
            const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
 | 
			
		||||
                                 output_buffers[1].cbBuffer);
 | 
			
		||||
            // The documentation doesn't explain what format this data is in.
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
 | 
			
		||||
                      Common::HexToString(span));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        switch (ret) {
 | 
			
		||||
        case SEC_I_CONTINUE_NEEDED:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
 | 
			
		||||
            if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
 | 
			
		||||
                LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
 | 
			
		||||
                ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
 | 
			
		||||
                ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
 | 
			
		||||
                                          ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
 | 
			
		||||
            } else {
 | 
			
		||||
                ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
 | 
			
		||||
                ciphertext_read_buf.clear();
 | 
			
		||||
            }
 | 
			
		||||
            handshake_state = HandshakeState::ContinueNeeded;
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case SEC_E_INCOMPLETE_MESSAGE:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
 | 
			
		||||
            ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
 | 
			
		||||
            read_buf_fill_size = input_buffers[1].cbBuffer;
 | 
			
		||||
            handshake_state = HandshakeState::IncompleteMessage;
 | 
			
		||||
            return ResultSuccess;
 | 
			
		||||
        case SEC_E_OK:
 | 
			
		||||
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
 | 
			
		||||
            ciphertext_read_buf.clear();
 | 
			
		||||
            handshake_state = HandshakeState::DoneAfterFlush;
 | 
			
		||||
            return GrabStreamSizes();
 | 
			
		||||
        default:
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            handshake_state = HandshakeState::Error;
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result GrabStreamSizes() {
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            handshake_state = HandshakeState::Error;
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        if (handshake_state != HandshakeState::Connected) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        if (data.size() == 0 || got_read_eof) {
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        }
 | 
			
		||||
        while (1) {
 | 
			
		||||
            if (!cleartext_read_buf.empty()) {
 | 
			
		||||
                const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
 | 
			
		||||
                std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
 | 
			
		||||
                cleartext_read_buf.erase(cleartext_read_buf.begin(),
 | 
			
		||||
                                         cleartext_read_buf.begin() + read_size);
 | 
			
		||||
                return read_size;
 | 
			
		||||
            }
 | 
			
		||||
            if (!ciphertext_read_buf.empty()) {
 | 
			
		||||
                SecBuffer empty{
 | 
			
		||||
                    .cbBuffer = 0,
 | 
			
		||||
                    .BufferType = SECBUFFER_EMPTY,
 | 
			
		||||
                    .pvBuffer = nullptr,
 | 
			
		||||
                };
 | 
			
		||||
                std::array<SecBuffer, 5> buffers{{
 | 
			
		||||
                    {
 | 
			
		||||
                        .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
 | 
			
		||||
                        .BufferType = SECBUFFER_DATA,
 | 
			
		||||
                        .pvBuffer = ciphertext_read_buf.data(),
 | 
			
		||||
                    },
 | 
			
		||||
                    empty,
 | 
			
		||||
                    empty,
 | 
			
		||||
                    empty,
 | 
			
		||||
                }};
 | 
			
		||||
                ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
                    buffers[0].cbBuffer == ciphertext_read_buf.size(),
 | 
			
		||||
                    { return ResultInternalError; }, "read buffer too large");
 | 
			
		||||
                SecBufferDesc desc{
 | 
			
		||||
                    .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
                    .cBuffers = static_cast<unsigned long>(buffers.size()),
 | 
			
		||||
                    .pBuffers = buffers.data(),
 | 
			
		||||
                };
 | 
			
		||||
                SECURITY_STATUS ret =
 | 
			
		||||
                    DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
 | 
			
		||||
                switch (ret) {
 | 
			
		||||
                case SEC_E_OK:
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
 | 
			
		||||
                                      { return ResultInternalError; });
 | 
			
		||||
                    cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
 | 
			
		||||
                                              static_cast<u8*>(buffers[1].pvBuffer) +
 | 
			
		||||
                                                  buffers[1].cbBuffer);
 | 
			
		||||
                    if (buffers[3].BufferType == SECBUFFER_EXTRA) {
 | 
			
		||||
                        ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
 | 
			
		||||
                        ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
 | 
			
		||||
                                                  ciphertext_read_buf.end() - buffers[3].cbBuffer);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
 | 
			
		||||
                        ciphertext_read_buf.clear();
 | 
			
		||||
                    }
 | 
			
		||||
                    continue;
 | 
			
		||||
                case SEC_E_INCOMPLETE_MESSAGE:
 | 
			
		||||
                    break;
 | 
			
		||||
                case SEC_I_CONTEXT_EXPIRED:
 | 
			
		||||
                    // Server hung up by sending close_notify.
 | 
			
		||||
                    got_read_eof = true;
 | 
			
		||||
                    return size_t(0);
 | 
			
		||||
                default:
 | 
			
		||||
                    LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
 | 
			
		||||
                              Common::NativeErrorToString(ret));
 | 
			
		||||
                    return ResultInternalError;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            const Result r = FillCiphertextReadBuf();
 | 
			
		||||
            if (r != ResultSuccess) {
 | 
			
		||||
                return r;
 | 
			
		||||
            }
 | 
			
		||||
            if (ciphertext_read_buf.empty()) {
 | 
			
		||||
                got_read_eof = true;
 | 
			
		||||
                return size_t(0);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        if (handshake_state != HandshakeState::Connected) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        if (data.size() == 0) {
 | 
			
		||||
            return size_t(0);
 | 
			
		||||
        }
 | 
			
		||||
        data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
 | 
			
		||||
        if (!cleartext_write_buf.empty()) {
 | 
			
		||||
            // Already in the middle of a write.  It wouldn't make sense to not
 | 
			
		||||
            // finish sending the entire buffer since TLS has
 | 
			
		||||
            // header/MAC/padding/etc.
 | 
			
		||||
            if (data.size() != cleartext_write_buf.size() ||
 | 
			
		||||
                std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
 | 
			
		||||
                return ResultInternalError;
 | 
			
		||||
            }
 | 
			
		||||
            return WriteAlreadyEncryptedData();
 | 
			
		||||
        } else {
 | 
			
		||||
            cleartext_write_buf.assign(data.begin(), data.end());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
 | 
			
		||||
        std::vector<u8> tmp_data_buf = cleartext_write_buf;
 | 
			
		||||
        std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
 | 
			
		||||
 | 
			
		||||
        std::array<SecBuffer, 3> buffers{{
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = stream_sizes.cbHeader,
 | 
			
		||||
                .BufferType = SECBUFFER_STREAM_HEADER,
 | 
			
		||||
                .pvBuffer = header_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
 | 
			
		||||
                .BufferType = SECBUFFER_DATA,
 | 
			
		||||
                .pvBuffer = tmp_data_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                .cbBuffer = stream_sizes.cbTrailer,
 | 
			
		||||
                .BufferType = SECBUFFER_STREAM_TRAILER,
 | 
			
		||||
                .pvBuffer = trailer_buf.data(),
 | 
			
		||||
            },
 | 
			
		||||
        }};
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
 | 
			
		||||
            "temp buffer too large");
 | 
			
		||||
        SecBufferDesc desc{
 | 
			
		||||
            .ulVersion = SECBUFFER_VERSION,
 | 
			
		||||
            .cBuffers = static_cast<unsigned long>(buffers.size()),
 | 
			
		||||
            .pBuffers = buffers.data(),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
 | 
			
		||||
                                    header_buf.end());
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
 | 
			
		||||
                                    tmp_data_buf.end());
 | 
			
		||||
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
 | 
			
		||||
                                    trailer_buf.end());
 | 
			
		||||
        return WriteAlreadyEncryptedData();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> WriteAlreadyEncryptedData() {
 | 
			
		||||
        const Result r = FlushCiphertextWriteBuf();
 | 
			
		||||
        if (r != ResultSuccess) {
 | 
			
		||||
            return r;
 | 
			
		||||
        }
 | 
			
		||||
        // write buf is empty
 | 
			
		||||
        const size_t cleartext_bytes_written = cleartext_write_buf.size();
 | 
			
		||||
        cleartext_write_buf.clear();
 | 
			
		||||
        return cleartext_bytes_written;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        PCCERT_CONTEXT returned_cert = nullptr;
 | 
			
		||||
        const SECURITY_STATUS ret =
 | 
			
		||||
            QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
 | 
			
		||||
        if (ret != SEC_E_OK) {
 | 
			
		||||
            LOG_ERROR(Service_SSL,
 | 
			
		||||
                      "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
 | 
			
		||||
                      Common::NativeErrorToString(ret));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        PCCERT_CONTEXT some_cert = nullptr;
 | 
			
		||||
        std::vector<std::vector<u8>> certs;
 | 
			
		||||
        while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
 | 
			
		||||
            certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
 | 
			
		||||
                               static_cast<u8*>(some_cert->pbCertEncoded) +
 | 
			
		||||
                                   some_cert->cbCertEncoded);
 | 
			
		||||
        }
 | 
			
		||||
        std::reverse(certs.begin(),
 | 
			
		||||
                     certs.end()); // Windows returns certs in reverse order from what we want
 | 
			
		||||
        CertFreeCertificateContext(returned_cert);
 | 
			
		||||
        return certs;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~SSLConnectionBackendSchannel() {
 | 
			
		||||
        if (handshake_state != HandshakeState::Initial) {
 | 
			
		||||
            DeleteSecurityContext(&ctxt);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    enum class HandshakeState {
 | 
			
		||||
        // Haven't called anything yet.
 | 
			
		||||
        Initial,
 | 
			
		||||
        // `SEC_I_CONTINUE_NEEDED` was returned by
 | 
			
		||||
        // `InitializeSecurityContext`; must finish sending data (if any) in
 | 
			
		||||
        // the write buffer, then read at least one byte before calling
 | 
			
		||||
        // `InitializeSecurityContext` again.
 | 
			
		||||
        ContinueNeeded,
 | 
			
		||||
        // `SEC_E_INCOMPLETE_MESSAGE` was returned by
 | 
			
		||||
        // `InitializeSecurityContext`; hopefully the write buffer is empty;
 | 
			
		||||
        // must read at least one byte before calling
 | 
			
		||||
        // `InitializeSecurityContext` again.
 | 
			
		||||
        IncompleteMessage,
 | 
			
		||||
        // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
 | 
			
		||||
        // finish sending data in the write buffer before having `DoHandshake`
 | 
			
		||||
        // report success.
 | 
			
		||||
        DoneAfterFlush,
 | 
			
		||||
        // We finished the above and are now connected.  At this point, writing
 | 
			
		||||
        // and reading are separate 'state machines' represented by the
 | 
			
		||||
        // nonemptiness of the ciphertext and cleartext read and write buffers.
 | 
			
		||||
        Connected,
 | 
			
		||||
        // Another error was returned and we shouldn't allow initialization
 | 
			
		||||
        // to continue.
 | 
			
		||||
        Error,
 | 
			
		||||
    } handshake_state = HandshakeState::Initial;
 | 
			
		||||
 | 
			
		||||
    CtxtHandle ctxt;
 | 
			
		||||
    SecPkgContext_StreamSizes stream_sizes;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
    std::optional<std::string> hostname;
 | 
			
		||||
 | 
			
		||||
    std::vector<u8> ciphertext_read_buf;
 | 
			
		||||
    std::vector<u8> ciphertext_write_buf;
 | 
			
		||||
    std::vector<u8> cleartext_read_buf;
 | 
			
		||||
    std::vector<u8> cleartext_write_buf;
 | 
			
		||||
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
    size_t read_buf_fill_size = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendSchannel>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
@ -0,0 +1,219 @@
 | 
			
		||||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
 | 
			
		||||
// SPDX-License-Identifier: GPL-2.0-or-later
 | 
			
		||||
 | 
			
		||||
#include "core/hle/service/ssl/ssl_backend.h"
 | 
			
		||||
#include "core/internal_network/network.h"
 | 
			
		||||
#include "core/internal_network/sockets.h"
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
#include <Security/SecureTransport.h>
 | 
			
		||||
 | 
			
		||||
// SecureTransport has been deprecated in its entirety in favor of
 | 
			
		||||
// Network.framework, but that does not allow layering TLS on top of an
 | 
			
		||||
// arbitrary socket.
 | 
			
		||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct CFReleaser {
 | 
			
		||||
    T ptr;
 | 
			
		||||
 | 
			
		||||
    YUZU_NON_COPYABLE(CFReleaser);
 | 
			
		||||
    constexpr CFReleaser() : ptr(nullptr) {}
 | 
			
		||||
    constexpr CFReleaser(T ptr) : ptr(ptr) {}
 | 
			
		||||
    constexpr operator T() {
 | 
			
		||||
        return ptr;
 | 
			
		||||
    }
 | 
			
		||||
    ~CFReleaser() {
 | 
			
		||||
        if (ptr) {
 | 
			
		||||
            CFRelease(ptr);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::string CFStringToString(CFStringRef cfstr) {
 | 
			
		||||
    CFReleaser<CFDataRef> cfdata(
 | 
			
		||||
        CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
 | 
			
		||||
    ASSERT_OR_EXECUTE(cfdata, { return "???"; });
 | 
			
		||||
    return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
 | 
			
		||||
                       CFDataGetLength(cfdata));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string OSStatusToString(OSStatus status) {
 | 
			
		||||
    CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
 | 
			
		||||
    if (!cfstr) {
 | 
			
		||||
        return "[unknown error]";
 | 
			
		||||
    }
 | 
			
		||||
    return CFStringToString(cfstr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace Service::SSL {
 | 
			
		||||
 | 
			
		||||
class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
 | 
			
		||||
public:
 | 
			
		||||
    Result Init() {
 | 
			
		||||
        static std::once_flag once_flag;
 | 
			
		||||
        std::call_once(once_flag, []() {
 | 
			
		||||
            if (getenv("SSLKEYLOGFILE")) {
 | 
			
		||||
                LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
 | 
			
		||||
                                          "support exporting keys; not logging keys!");
 | 
			
		||||
                // Not fatal.
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
 | 
			
		||||
        if (!context) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLCreateContext failed");
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        OSStatus status;
 | 
			
		||||
        if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
 | 
			
		||||
            (status = SSLSetConnection(context, this))) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
 | 
			
		||||
                      OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
 | 
			
		||||
        socket = std::move(in_socket);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result SetHostName(const std::string& hostname) override {
 | 
			
		||||
        OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
 | 
			
		||||
        if (status) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        return ResultSuccess;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Result DoHandshake() override {
 | 
			
		||||
        OSStatus status = SSLHandshake(context);
 | 
			
		||||
        return HandleReturn("SSLHandshake", 0, status).Code();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Read(std::span<u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
 | 
			
		||||
        ;
 | 
			
		||||
        return HandleReturn("SSLRead", actual, status);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> Write(std::span<const u8> data) override {
 | 
			
		||||
        size_t actual;
 | 
			
		||||
        OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
 | 
			
		||||
        ;
 | 
			
		||||
        return HandleReturn("SSLWrite", actual, status);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
 | 
			
		||||
        switch (status) {
 | 
			
		||||
        case 0:
 | 
			
		||||
            return actual;
 | 
			
		||||
        case errSSLWouldBlock:
 | 
			
		||||
            return ResultWouldBlock;
 | 
			
		||||
        default: {
 | 
			
		||||
            std::string reason;
 | 
			
		||||
            if (got_read_eof) {
 | 
			
		||||
                reason = "server hung up";
 | 
			
		||||
            } else {
 | 
			
		||||
                reason = OSStatusToString(status);
 | 
			
		||||
            }
 | 
			
		||||
            LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
 | 
			
		||||
        CFReleaser<SecTrustRef> trust;
 | 
			
		||||
        OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
 | 
			
		||||
        if (status) {
 | 
			
		||||
            LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
 | 
			
		||||
            return ResultInternalError;
 | 
			
		||||
        }
 | 
			
		||||
        std::vector<std::vector<u8>> ret;
 | 
			
		||||
        for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
 | 
			
		||||
            SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
 | 
			
		||||
            CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
 | 
			
		||||
            ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
 | 
			
		||||
            const u8* ptr = CFDataGetBytePtr(data);
 | 
			
		||||
            ret.emplace_back(ptr, ptr + CFDataGetLength(data));
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
 | 
			
		||||
        return ReadOrWriteCallback(connection, data, dataLength, true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
 | 
			
		||||
                                  size_t* dataLength) {
 | 
			
		||||
        return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
 | 
			
		||||
                                        bool is_read) {
 | 
			
		||||
        auto self =
 | 
			
		||||
            static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
 | 
			
		||||
        ASSERT_OR_EXECUTE_MSG(
 | 
			
		||||
            self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
 | 
			
		||||
            is_read ? "read" : "write");
 | 
			
		||||
 | 
			
		||||
        // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
 | 
			
		||||
        // expected to read/write the full requested dataLength or return an
 | 
			
		||||
        // error, so we have to add a loop ourselves.
 | 
			
		||||
        size_t requested_len = *dataLength;
 | 
			
		||||
        size_t offset = 0;
 | 
			
		||||
        while (offset < requested_len) {
 | 
			
		||||
            std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
 | 
			
		||||
            auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
 | 
			
		||||
            LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
 | 
			
		||||
                         actual, cur.size(), static_cast<s32>(err));
 | 
			
		||||
            switch (err) {
 | 
			
		||||
            case Network::Errno::SUCCESS:
 | 
			
		||||
                offset += actual;
 | 
			
		||||
                if (actual == 0) {
 | 
			
		||||
                    ASSERT(is_read);
 | 
			
		||||
                    self->got_read_eof = true;
 | 
			
		||||
                    return errSecEndOfData;
 | 
			
		||||
                }
 | 
			
		||||
                break;
 | 
			
		||||
            case Network::Errno::AGAIN:
 | 
			
		||||
                *dataLength = offset;
 | 
			
		||||
                return errSSLWouldBlock;
 | 
			
		||||
            default:
 | 
			
		||||
                LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
 | 
			
		||||
                          is_read ? "recv" : "send", err);
 | 
			
		||||
                return errSecIO;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ASSERT(offset == requested_len);
 | 
			
		||||
        return 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    CFReleaser<SSLContextRef> context = nullptr;
 | 
			
		||||
    bool got_read_eof = false;
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<Network::SocketBase> socket;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
 | 
			
		||||
    auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
 | 
			
		||||
    const Result res = conn->Init();
 | 
			
		||||
    if (res.IsFailure()) {
 | 
			
		||||
        return res;
 | 
			
		||||
    }
 | 
			
		||||
    return conn;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace Service::SSL
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue