From 77c1b81e95313a0aebf334b1b4f929195aa738b8 Mon Sep 17 00:00:00 2001 From: sha512sum Date: Thu, 26 Sep 2024 18:11:30 +0000 Subject: [PATCH] Fix scram auth and add tests --- examples/src/connect.cpp | 12 ++++-- library/include/larra/client/client.hpp | 55 +++++++++++++++++++++--- library/include/larra/encryption.hpp | 12 ++++-- library/include/larra/printer_stream.hpp | 26 +++++------ library/src/encryption.cpp | 47 +++++++++++++++----- tests/scram.cpp | 33 ++++++++++++++ 6 files changed, 149 insertions(+), 36 deletions(-) create mode 100644 tests/scram.cpp diff --git a/examples/src/connect.cpp b/examples/src/connect.cpp index 013806d..ab24ba2 100644 --- a/examples/src/connect.cpp +++ b/examples/src/connect.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -5,18 +7,20 @@ #include auto Coroutine() -> boost::asio::awaitable { - std::println("Connecting client..."); + SPDLOG_INFO("Connecting client..."); + try { auto client = co_await larra::xmpp::client::CreateClient>( - larra::xmpp::EncryptionUserAccount{{"sha512sum", "localhost"}, "12345"}, {.useTls = larra::xmpp::client::Options::kNever}); + larra::xmpp::EncryptionUserAccount{{"test1", "localhost"}, "test1"}, {.useTls = larra::xmpp::client::Options::kNever}); } catch(const std::exception& err) { - std::println("Err: {}", err.what()); + SPDLOG_ERROR("{}", err.what()); co_return; } - std::println("Done!"); + SPDLOG_INFO("Done connecting client!"); } auto main() -> int { + spdlog::set_level(spdlog::level::trace); boost::asio::io_context io_context; boost::asio::co_spawn(io_context, Coroutine(), boost::asio::detached); io_context.run(); diff --git a/library/include/larra/client/client.hpp b/library/include/larra/client/client.hpp index e49c9db..6d524f6 100644 --- a/library/include/larra/client/client.hpp +++ b/library/include/larra/client/client.hpp @@ -1,6 +1,10 @@ #pragma once +#include + #include +#include #include +#include #include #include #include @@ -23,11 +27,44 @@ namespace larra::xmpp::client { template struct Client { - Client(BareJid jid, Connection connection) : jid(std::move(jid)), connection(std::move(connection)) {}; + constexpr Client(BareJid jid, Connection connection) : jid(std::move(jid)), connection(std::move(connection)) {}; + template Token = boost::asio::use_awaitable_t<>> + constexpr auto Close(Token token = {}) { + this->active = false; + return boost::asio::async_initiate( + [](Handler&& h, Connection connection) { + boost::asio::co_spawn( + connection.get_executor(), + [](auto h, Connection connection) -> boost::asio::awaitable { + co_await boost::asio::async_write(connection, boost::asio::buffer(""), boost::asio::use_awaitable); + std::string response; + co_await boost::asio::async_read_until( + connection, boost::asio::dynamic_buffer(response), "", boost::asio::use_awaitable); + std::move(h)(); + }(std::move(h), std::move(connection)), + boost::asio::detached); + }, + token, + std::move(this->connection)); + } + constexpr Client(const Client&) = delete; + constexpr Client(Client&& client) : connection(std::move(client.connection)), jid(std::move(client.jid)) { + client.active = false; + } + constexpr ~Client() { + if(this->active) { + this->Close([] {}); + } + } + + [[nodiscard]] constexpr auto Jid() const -> const BareJid& { + return this->jid; + } - const BareJid jid; // NOLINT: const private: + bool active = true; Connection connection; + BareJid jid; }; struct StartTlsNegotiationError : std::runtime_error { @@ -113,6 +150,7 @@ struct ClientCreateVisitor { auth.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; auth.append_attribute("mechanism") = methodName.data(); auto nonce = GenerateNonce(); + SPDLOG_DEBUG("nonce: {}", nonce); auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce); auto data = EncodeBase64(initialMessage); auth.text().set(data.c_str()); @@ -136,11 +174,14 @@ struct ClientCreateVisitor { doc = pugi::xml_document{}; auto success = doc.append_child("response"); success.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; - success.text().set( - EncodeBase64( - GenerateScramAuthMessage( - account.password, DecodeBase64(params["s"]), serverNonce, decoded, initialMessage, ToInt(params["i"]).value(), tag)) - .c_str()); + success.text().set(EncodeBase64(GenerateScramAuthMessage(account.password, + DecodeBase64(params["s"]), + serverNonce, + decoded, + std::string_view{initialMessage}.substr(3), + ToInt(params["i"]).value(), + tag)) + .c_str()); std::ostringstream strstream2; doc.print(strstream2, "", diff --git a/library/include/larra/encryption.hpp b/library/include/larra/encryption.hpp index 70e869d..567b01c 100644 --- a/library/include/larra/encryption.hpp +++ b/library/include/larra/encryption.hpp @@ -40,7 +40,9 @@ auto GenerateNonce(std::size_t length = 24) -> std::string; // NOLINT namespace sha512sum { -struct EncryptionTag {}; +struct EncryptionTag { + friend auto ToString(const EncryptionTag&) -> std::string; +}; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; @@ -60,7 +62,9 @@ auto GenerateScramAuthMessage(std::string_view password, namespace sha256sum { -struct EncryptionTag {}; +struct EncryptionTag { + friend auto ToString(const EncryptionTag&) -> std::string; +}; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; @@ -80,7 +84,9 @@ auto GenerateScramAuthMessage(std::string_view password, namespace sha1sum { -struct EncryptionTag {}; +struct EncryptionTag { + friend auto ToString(const EncryptionTag&) -> std::string; +}; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; diff --git a/library/include/larra/printer_stream.hpp b/library/include/larra/printer_stream.hpp index 3a1bad4..e1eea57 100644 --- a/library/include/larra/printer_stream.hpp +++ b/library/include/larra/printer_stream.hpp @@ -1,4 +1,6 @@ #pragma once +#include + #include #include #include @@ -20,11 +22,11 @@ struct PrintStream : Socket { for(boost::asio::const_buffer buf : buffers) { stream << std::string_view{static_cast(buf.data()), buf.size()}; } - std::println("{}", stream.str()); + SPDLOG_INFO("{}", stream.str()); return boost::asio::async_initiate( [this](Handler&& token, const ConstBufferSequence& buffers) { Socket::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - std::println("Data writing completed"); + SPDLOG_INFO("Data writing completed"); token(err, s); }); }, @@ -35,7 +37,7 @@ struct PrintStream : Socket { BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t)) ReadToken = boost::asio::default_completion_token_t> auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { - std::println("Reading data from stream"); + SPDLOG_INFO("Reading data from stream"); return boost::asio::async_initiate( [this](ReadToken&& token, const MutableBufferSequence& buffers) { Socket::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { @@ -45,7 +47,7 @@ struct PrintStream : Socket { for(boost::asio::mutable_buffer buf : buffers) { stream << std::string_view{static_cast(buf.data()), buf.size()}; } - std::println("{}", stream.str()); + SPDLOG_INFO("{}", stream.str()); token(err, s); }); }, @@ -82,12 +84,12 @@ struct boost::asio::ssl::stream> : public boost for(boost::asio::const_buffer buf : buffers) { stream << std::string_view{static_cast(buf.data()), buf.size()}; } - std::println("{}", stream.str()); + SPDLOG_INFO("{}", stream.str()); return boost::asio::async_initiate( [this](Handler&& token, const ConstBufferSequence& buffers) { Base::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - std::println("Data writing completed(SSL)"); - token(err, s); + SPDLOG_INFO("Data writing completed(SSL)"); + std::move(token)(err, s); }); }, token, @@ -97,7 +99,7 @@ struct boost::asio::ssl::stream> : public boost BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t)) ReadToken = boost::asio::default_completion_token_t> auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { - std::println("Reading data from stream(SSL)"); + SPDLOG_INFO("Reading data from stream(SSL)"); return boost::asio::async_initiate( [this](Handler&& token, const MutableBufferSequence& buffers) { Base::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { @@ -107,8 +109,8 @@ struct boost::asio::ssl::stream> : public boost for(boost::asio::mutable_buffer buf : buffers) { stream << std::string_view{static_cast(buf.data()), buf.size()}; } - std::println("{}", stream.str()); - token(err, s); + SPDLOG_INFO("{}", stream.str()); + std::move(token)(err, s); }); }, token, @@ -121,8 +123,8 @@ struct boost::asio::ssl::stream> : public boost return boost::asio::async_initiate( [this](Handler&& token, Base::handshake_type type) { Base::async_handshake(type, [token = std::move(token)](boost::system::error_code error) mutable { - std::println("SSL Handshake completed"); - token(error); + SPDLOG_INFO("SSL Handshake completed"); + std::move(token)(error); }); }, token, diff --git a/library/src/encryption.cpp b/library/src/encryption.cpp index e01c239..5a7b2fd 100644 --- a/library/src/encryption.cpp +++ b/library/src/encryption.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -129,26 +130,44 @@ inline auto GenerateAuthScramMessageImpl(std::string_view password, std::string_view initialMessage, int iterations, TagType tag) -> std::string { + SPDLOG_TRACE( + "Call function GenerateScramAuthMessage with params: [password: ({}), salt: Base64({}), serverNonce: ({}), firstServerMessage: ({}), " + "initialMessage: ({}), iterations: ({}), tag: ({})]", + password, + EncodeBase64(salt), + serverNonce, + firstServerMessage, + initialMessage, + iterations, + ToString(tag)); auto clientFinalMessageBare = std::format("c=biws,r={}", serverNonce); - auto saltedPassword = Pbkdf2Impl(password, ToUnsignedCharStringView(salt), iterations, tag); + SPDLOG_DEBUG("clientFinalMessageBare: ({})", clientFinalMessageBare); + auto saltedPassword = Pbdkf2(password, ToUnsignedCharStringView(salt), iterations, tag); + SPDLOG_DEBUG("saltedPassword: Base64({})", EncodeBase64(ToCharStringView(saltedPassword))); std::string clientKeyStr = "Client Key"; // NOLINT - auto clientKey = HmacImpl(ToCharStringView(saltedPassword), ToUnsignedCharStringView(clientKeyStr), tag); - auto storedKey = HashImpl(clientKey, tag); + auto clientKey = Hmac(ToCharStringView(saltedPassword), ToUnsignedCharStringView(clientKeyStr), tag); + SPDLOG_DEBUG("clientKey: Base64({})", EncodeBase64(ToCharStringView(clientKey))); + auto storedKey = Hash(ToUnsignedCharStringView(clientKey), tag); + SPDLOG_DEBUG("storedKey: Base64({})", EncodeBase64(ToCharStringView(storedKey))); auto authMessage = std::format("{},{},{}", initialMessage, firstServerMessage, clientFinalMessageBare); - auto clientSignature = HmacImpl(ToCharStringView(storedKey), ToUnsignedCharStringView(authMessage), tag); - auto clientProof = std::views::zip(clientKey, clientSignature) | // No std::views::enumerate in libc++ - std::views::transform([&](auto arg) { + SPDLOG_DEBUG("authMessage: ({})", authMessage); + auto clientSignature = Hmac(ToCharStringView(storedKey), ToUnsignedCharStringView(authMessage), tag); + SPDLOG_DEBUG("clientSignature: Base64({})", EncodeBase64(ToCharStringView(clientSignature))); + auto clientProof = std::views::zip(clientKey, ToUnsignedCharStringView(clientSignature)) | std::views::transform([&](auto arg) { return std::get<0>(arg) ^ std::get<1>(arg); }) | std::ranges::to(); - std::string serverKeyStr = "Server Key"; - auto serverKey = HmacImpl(ToCharStringView(saltedPassword), ToUnsignedCharStringView(serverKeyStr), tag); - auto serverSignature = HmacImpl(ToCharStringView(serverKey), ToUnsignedCharStringView(authMessage), tag); - return std::format("{},p={}", clientFinalMessageBare, EncodeBase64(ToCharStringView(clientProof))); + auto clientProofBase64 = EncodeBase64(clientProof); + SPDLOG_DEBUG("clientProof: Base64({})", clientProofBase64); + return std::format("{},p={}", clientFinalMessageBare, clientProofBase64); } namespace sha512sum { +auto ToString(const EncryptionTag&) -> std::string { + return "sha512sum"; +} + auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { return Pbkdf2Impl(password, salt, iterations, tag); } @@ -175,6 +194,10 @@ auto GenerateScramAuthMessage(std::string_view password, namespace sha256sum { +auto ToString(const EncryptionTag&) -> std::string { + return "sha256sum"; +} + auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { return Pbkdf2Impl(password, salt, iterations, tag); } @@ -201,6 +224,10 @@ auto GenerateScramAuthMessage(std::string_view password, namespace sha1sum { +auto ToString(const EncryptionTag&) -> std::string { + return "sha1sum"; +} + auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { return Pbkdf2Impl(password, salt, iterations, tag); } diff --git a/tests/scram.cpp b/tests/scram.cpp new file mode 100644 index 0000000..5bbe979 --- /dev/null +++ b/tests/scram.cpp @@ -0,0 +1,33 @@ +#include + +#include + +namespace larra::xmpp::client { + +constexpr std::string_view kUsername = "test1"; +constexpr std::string_view kPassword = "test1"; +constexpr std::string_view kClientNonce = "jzabxxqWYZEmuM9pRq7JR4VQ"; +constexpr std::string_view kServerNonce = "jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ=="; +constexpr std::string_view kServerFirst = + "r=jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ==," + "s=cTJBDYBtebLYSyDKwIut5w==,i=4096"; +const std::string kSalt = DecodeBase64("cTJBDYBtebLYSyDKwIut5w=="); +constexpr std::string_view kClientFirst = "n=test1,r=jzabxxqWYZEmuM9pRq7JR4VQ"; +constexpr int kIterations = 4096; + +constexpr std::string_view kExpectedDataNoBase64 = + "c=biws,r=jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ==,p=Xpco7kbX/I0OQ7ubCScmCdG1Nml2QBIJw4dp2jdJl9bNq2Uny43CQ88zrCvfBnJuLdXM/" + "kw8VBzb6oy6BkRdog=="; + +constexpr sha512sum::EncryptionTag kTag{}; + +auto ToUnsignedCharStringView(auto& range) -> UnsignedStringView { + return {new(range.data()) unsigned char[range.size()], range.size()}; +}; + +TEST(SCRAM, SHA512) { + auto data = GenerateScramAuthMessage(kUsername, kSalt, kServerNonce, kServerFirst, kClientFirst, kIterations, kTag); + EXPECT_EQ(data, kExpectedDataNoBase64); +} + +} // namespace larra::xmpp::client