Fix scram auth and add tests

This commit is contained in:
sha512sum 2024-09-26 18:11:30 +00:00
parent 0a38c55c19
commit 77c1b81e95
6 changed files with 149 additions and 36 deletions

View file

@ -1,3 +1,5 @@
#include <spdlog/spdlog.h>
#include <boost/asio/co_spawn.hpp> #include <boost/asio/co_spawn.hpp>
#include <boost/asio/detached.hpp> #include <boost/asio/detached.hpp>
#include <larra/client/client.hpp> #include <larra/client/client.hpp>
@ -5,18 +7,20 @@
#include <print> #include <print>
auto Coroutine() -> boost::asio::awaitable<void> { auto Coroutine() -> boost::asio::awaitable<void> {
std::println("Connecting client..."); SPDLOG_INFO("Connecting client...");
try { try {
auto client = co_await larra::xmpp::client::CreateClient<larra::xmpp::PrintStream<boost::asio::ip::tcp::socket>>( auto client = co_await larra::xmpp::client::CreateClient<larra::xmpp::PrintStream<boost::asio::ip::tcp::socket>>(
larra::xmpp::EncryptionUserAccount{{"sha512sum", "localhost"}, "12345"}, {.useTls = larra::xmpp::client::Options::kNever}); larra::xmpp::EncryptionUserAccount{{"test1", "localhost"}, "test1"}, {.useTls = larra::xmpp::client::Options::kNever});
} catch(const std::exception& err) { } catch(const std::exception& err) {
std::println("Err: {}", err.what()); SPDLOG_ERROR("{}", err.what());
co_return; co_return;
} }
std::println("Done!"); SPDLOG_INFO("Done connecting client!");
} }
auto main() -> int { auto main() -> int {
spdlog::set_level(spdlog::level::trace);
boost::asio::io_context io_context; boost::asio::io_context io_context;
boost::asio::co_spawn(io_context, Coroutine(), boost::asio::detached); boost::asio::co_spawn(io_context, Coroutine(), boost::asio::detached);
io_context.run(); io_context.run();

View file

@ -1,6 +1,10 @@
#pragma once #pragma once
#include <spdlog/spdlog.h>
#include <boost/asio/awaitable.hpp> #include <boost/asio/awaitable.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/connect.hpp> #include <boost/asio/connect.hpp>
#include <boost/asio/detached.hpp>
#include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/tcp.hpp>
#include <boost/asio/read.hpp> #include <boost/asio/read.hpp>
#include <boost/asio/read_until.hpp> #include <boost/asio/read_until.hpp>
@ -23,11 +27,44 @@ namespace larra::xmpp::client {
template <typename Connection> template <typename Connection>
struct Client { struct Client {
Client(BareJid jid, Connection connection) : jid(std::move(jid)), connection(std::move(connection)) {}; constexpr Client(BareJid jid, Connection connection) : jid(std::move(jid)), connection(std::move(connection)) {};
template <boost::asio::completion_token_for<void()> Token = boost::asio::use_awaitable_t<>>
constexpr auto Close(Token token = {}) {
this->active = false;
return boost::asio::async_initiate<Token, void()>(
[]<typename Handler>(Handler&& h, Connection connection) {
boost::asio::co_spawn(
connection.get_executor(),
[](auto h, Connection connection) -> boost::asio::awaitable<void> {
co_await boost::asio::async_write(connection, boost::asio::buffer("</stream:stream>"), boost::asio::use_awaitable);
std::string response;
co_await boost::asio::async_read_until(
connection, boost::asio::dynamic_buffer(response), "</stream:stream>", boost::asio::use_awaitable);
std::move(h)();
}(std::move(h), std::move(connection)),
boost::asio::detached);
},
token,
std::move(this->connection));
}
constexpr Client(const Client&) = delete;
constexpr Client(Client&& client) : connection(std::move(client.connection)), jid(std::move(client.jid)) {
client.active = false;
}
constexpr ~Client() {
if(this->active) {
this->Close([] {});
}
}
[[nodiscard]] constexpr auto Jid() const -> const BareJid& {
return this->jid;
}
const BareJid jid; // NOLINT: const
private: private:
bool active = true;
Connection connection; Connection connection;
BareJid jid;
}; };
struct StartTlsNegotiationError : std::runtime_error { struct StartTlsNegotiationError : std::runtime_error {
@ -113,6 +150,7 @@ struct ClientCreateVisitor {
auth.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; auth.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl";
auth.append_attribute("mechanism") = methodName.data(); auth.append_attribute("mechanism") = methodName.data();
auto nonce = GenerateNonce(); auto nonce = GenerateNonce();
SPDLOG_DEBUG("nonce: {}", nonce);
auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce); auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce);
auto data = EncodeBase64(initialMessage); auto data = EncodeBase64(initialMessage);
auth.text().set(data.c_str()); auth.text().set(data.c_str());
@ -136,10 +174,13 @@ struct ClientCreateVisitor {
doc = pugi::xml_document{}; doc = pugi::xml_document{};
auto success = doc.append_child("response"); auto success = doc.append_child("response");
success.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; success.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl";
success.text().set( success.text().set(EncodeBase64(GenerateScramAuthMessage(account.password,
EncodeBase64( DecodeBase64(params["s"]),
GenerateScramAuthMessage( serverNonce,
account.password, DecodeBase64(params["s"]), serverNonce, decoded, initialMessage, ToInt(params["i"]).value(), tag)) decoded,
std::string_view{initialMessage}.substr(3),
ToInt(params["i"]).value(),
tag))
.c_str()); .c_str());
std::ostringstream strstream2; std::ostringstream strstream2;
doc.print(strstream2, doc.print(strstream2,

View file

@ -40,7 +40,9 @@ auto GenerateNonce(std::size_t length = 24) -> std::string; // NOLINT
namespace sha512sum { namespace sha512sum {
struct EncryptionTag {}; struct EncryptionTag {
friend auto ToString(const EncryptionTag&) -> std::string;
};
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString;
@ -60,7 +62,9 @@ auto GenerateScramAuthMessage(std::string_view password,
namespace sha256sum { namespace sha256sum {
struct EncryptionTag {}; struct EncryptionTag {
friend auto ToString(const EncryptionTag&) -> std::string;
};
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString;
@ -80,7 +84,9 @@ auto GenerateScramAuthMessage(std::string_view password,
namespace sha1sum { namespace sha1sum {
struct EncryptionTag {}; struct EncryptionTag {
friend auto ToString(const EncryptionTag&) -> std::string;
};
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString; auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag = {}) -> UnsignedString;

View file

@ -1,4 +1,6 @@
#pragma once #pragma once
#include <spdlog/spdlog.h>
#include <boost/asio/ssl.hpp> #include <boost/asio/ssl.hpp>
#include <boost/asio/write.hpp> #include <boost/asio/write.hpp>
#include <print> #include <print>
@ -20,11 +22,11 @@ struct PrintStream : Socket {
for(boost::asio::const_buffer buf : buffers) { for(boost::asio::const_buffer buf : buffers) {
stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()}; stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()};
} }
std::println("{}", stream.str()); SPDLOG_INFO("{}", stream.str());
return boost::asio::async_initiate<WriteToken, void(boost::system::error_code, std::size_t)>( return boost::asio::async_initiate<WriteToken, void(boost::system::error_code, std::size_t)>(
[this]<typename Handler>(Handler&& token, const ConstBufferSequence& buffers) { [this]<typename Handler>(Handler&& token, const ConstBufferSequence& buffers) {
Socket::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { Socket::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable {
std::println("Data writing completed"); SPDLOG_INFO("Data writing completed");
token(err, s); token(err, s);
}); });
}, },
@ -35,7 +37,7 @@ struct PrintStream : Socket {
BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t)) BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t))
ReadToken = boost::asio::default_completion_token_t<Executor>> ReadToken = boost::asio::default_completion_token_t<Executor>>
auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) {
std::println("Reading data from stream"); SPDLOG_INFO("Reading data from stream");
return boost::asio::async_initiate<ReadToken, void(boost::system::error_code, std::size_t)>( return boost::asio::async_initiate<ReadToken, void(boost::system::error_code, std::size_t)>(
[this](ReadToken&& token, const MutableBufferSequence& buffers) { [this](ReadToken&& token, const MutableBufferSequence& buffers) {
Socket::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { Socket::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable {
@ -45,7 +47,7 @@ struct PrintStream : Socket {
for(boost::asio::mutable_buffer buf : buffers) { for(boost::asio::mutable_buffer buf : buffers) {
stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()}; stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()};
} }
std::println("{}", stream.str()); SPDLOG_INFO("{}", stream.str());
token(err, s); token(err, s);
}); });
}, },
@ -82,12 +84,12 @@ struct boost::asio::ssl::stream<larra::xmpp::PrintStream<Socket>> : public boost
for(boost::asio::const_buffer buf : buffers) { for(boost::asio::const_buffer buf : buffers) {
stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()}; stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()};
} }
std::println("{}", stream.str()); SPDLOG_INFO("{}", stream.str());
return boost::asio::async_initiate<WriteToken, void(boost::system::error_code, std::size_t)>( return boost::asio::async_initiate<WriteToken, void(boost::system::error_code, std::size_t)>(
[this]<typename Handler>(Handler&& token, const ConstBufferSequence& buffers) { [this]<typename Handler>(Handler&& token, const ConstBufferSequence& buffers) {
Base::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { Base::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable {
std::println("Data writing completed(SSL)"); SPDLOG_INFO("Data writing completed(SSL)");
token(err, s); std::move(token)(err, s);
}); });
}, },
token, token,
@ -97,7 +99,7 @@ struct boost::asio::ssl::stream<larra::xmpp::PrintStream<Socket>> : public boost
BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t)) BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, std::size_t))
ReadToken = boost::asio::default_completion_token_t<Executor>> ReadToken = boost::asio::default_completion_token_t<Executor>>
auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) {
std::println("Reading data from stream(SSL)"); SPDLOG_INFO("Reading data from stream(SSL)");
return boost::asio::async_initiate<ReadToken, void(boost::system::error_code, std::size_t)>( return boost::asio::async_initiate<ReadToken, void(boost::system::error_code, std::size_t)>(
[this]<typename Handler>(Handler&& token, const MutableBufferSequence& buffers) { [this]<typename Handler>(Handler&& token, const MutableBufferSequence& buffers) {
Base::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { Base::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable {
@ -107,8 +109,8 @@ struct boost::asio::ssl::stream<larra::xmpp::PrintStream<Socket>> : public boost
for(boost::asio::mutable_buffer buf : buffers) { for(boost::asio::mutable_buffer buf : buffers) {
stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()}; stream << std::string_view{static_cast<const char*>(buf.data()), buf.size()};
} }
std::println("{}", stream.str()); SPDLOG_INFO("{}", stream.str());
token(err, s); std::move(token)(err, s);
}); });
}, },
token, token,
@ -121,8 +123,8 @@ struct boost::asio::ssl::stream<larra::xmpp::PrintStream<Socket>> : public boost
return boost::asio::async_initiate<HandshakeToken, void(boost::system::error_code)>( return boost::asio::async_initiate<HandshakeToken, void(boost::system::error_code)>(
[this]<typename Handler>(Handler&& token, Base::handshake_type type) { [this]<typename Handler>(Handler&& token, Base::handshake_type type) {
Base::async_handshake(type, [token = std::move(token)](boost::system::error_code error) mutable { Base::async_handshake(type, [token = std::move(token)](boost::system::error_code error) mutable {
std::println("SSL Handshake completed"); SPDLOG_INFO("SSL Handshake completed");
token(error); std::move(token)(error);
}); });
}, },
token, token,

View file

@ -2,6 +2,7 @@
#include <openssl/hmac.h> #include <openssl/hmac.h>
#include <openssl/rand.h> #include <openssl/rand.h>
#include <openssl/sha.h> #include <openssl/sha.h>
#include <spdlog/spdlog.h>
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
#include <boost/archive/iterators/base64_from_binary.hpp> #include <boost/archive/iterators/base64_from_binary.hpp>
@ -129,26 +130,44 @@ inline auto GenerateAuthScramMessageImpl(std::string_view password,
std::string_view initialMessage, std::string_view initialMessage,
int iterations, int iterations,
TagType tag) -> std::string { TagType tag) -> std::string {
SPDLOG_TRACE(
"Call function GenerateScramAuthMessage with params: [password: ({}), salt: Base64({}), serverNonce: ({}), firstServerMessage: ({}), "
"initialMessage: ({}), iterations: ({}), tag: ({})]",
password,
EncodeBase64(salt),
serverNonce,
firstServerMessage,
initialMessage,
iterations,
ToString(tag));
auto clientFinalMessageBare = std::format("c=biws,r={}", serverNonce); auto clientFinalMessageBare = std::format("c=biws,r={}", serverNonce);
auto saltedPassword = Pbkdf2Impl(password, ToUnsignedCharStringView(salt), iterations, tag); SPDLOG_DEBUG("clientFinalMessageBare: ({})", clientFinalMessageBare);
auto saltedPassword = Pbdkf2(password, ToUnsignedCharStringView(salt), iterations, tag);
SPDLOG_DEBUG("saltedPassword: Base64({})", EncodeBase64(ToCharStringView(saltedPassword)));
std::string clientKeyStr = "Client Key"; // NOLINT std::string clientKeyStr = "Client Key"; // NOLINT
auto clientKey = HmacImpl(ToCharStringView(saltedPassword), ToUnsignedCharStringView(clientKeyStr), tag); auto clientKey = Hmac(ToCharStringView(saltedPassword), ToUnsignedCharStringView(clientKeyStr), tag);
auto storedKey = HashImpl(clientKey, tag); SPDLOG_DEBUG("clientKey: Base64({})", EncodeBase64(ToCharStringView(clientKey)));
auto storedKey = Hash(ToUnsignedCharStringView(clientKey), tag);
SPDLOG_DEBUG("storedKey: Base64({})", EncodeBase64(ToCharStringView(storedKey)));
auto authMessage = std::format("{},{},{}", initialMessage, firstServerMessage, clientFinalMessageBare); auto authMessage = std::format("{},{},{}", initialMessage, firstServerMessage, clientFinalMessageBare);
auto clientSignature = HmacImpl(ToCharStringView(storedKey), ToUnsignedCharStringView(authMessage), tag); SPDLOG_DEBUG("authMessage: ({})", authMessage);
auto clientProof = std::views::zip(clientKey, clientSignature) | // No std::views::enumerate in libc++ auto clientSignature = Hmac(ToCharStringView(storedKey), ToUnsignedCharStringView(authMessage), tag);
std::views::transform([&](auto arg) { SPDLOG_DEBUG("clientSignature: Base64({})", EncodeBase64(ToCharStringView(clientSignature)));
auto clientProof = std::views::zip(clientKey, ToUnsignedCharStringView(clientSignature)) | std::views::transform([&](auto arg) {
return std::get<0>(arg) ^ std::get<1>(arg); return std::get<0>(arg) ^ std::get<1>(arg);
}) | }) |
std::ranges::to<std::string>(); std::ranges::to<std::string>();
std::string serverKeyStr = "Server Key"; auto clientProofBase64 = EncodeBase64(clientProof);
auto serverKey = HmacImpl(ToCharStringView(saltedPassword), ToUnsignedCharStringView(serverKeyStr), tag); SPDLOG_DEBUG("clientProof: Base64({})", clientProofBase64);
auto serverSignature = HmacImpl(ToCharStringView(serverKey), ToUnsignedCharStringView(authMessage), tag); return std::format("{},p={}", clientFinalMessageBare, clientProofBase64);
return std::format("{},p={}", clientFinalMessageBare, EncodeBase64(ToCharStringView(clientProof)));
} }
namespace sha512sum { namespace sha512sum {
auto ToString(const EncryptionTag&) -> std::string {
return "sha512sum";
}
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString {
return Pbkdf2Impl(password, salt, iterations, tag); return Pbkdf2Impl(password, salt, iterations, tag);
} }
@ -175,6 +194,10 @@ auto GenerateScramAuthMessage(std::string_view password,
namespace sha256sum { namespace sha256sum {
auto ToString(const EncryptionTag&) -> std::string {
return "sha256sum";
}
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString {
return Pbkdf2Impl(password, salt, iterations, tag); return Pbkdf2Impl(password, salt, iterations, tag);
} }
@ -201,6 +224,10 @@ auto GenerateScramAuthMessage(std::string_view password,
namespace sha1sum { namespace sha1sum {
auto ToString(const EncryptionTag&) -> std::string {
return "sha1sum";
}
auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString { auto Pbdkf2(std::string_view password, UnsignedStringView salt, int iterations, EncryptionTag tag) -> UnsignedString {
return Pbkdf2Impl(password, salt, iterations, tag); return Pbkdf2Impl(password, salt, iterations, tag);
} }

33
tests/scram.cpp Normal file
View file

@ -0,0 +1,33 @@
#include <gtest/gtest.h>
#include <larra/client/client.hpp>
namespace larra::xmpp::client {
constexpr std::string_view kUsername = "test1";
constexpr std::string_view kPassword = "test1";
constexpr std::string_view kClientNonce = "jzabxxqWYZEmuM9pRq7JR4VQ";
constexpr std::string_view kServerNonce = "jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ==";
constexpr std::string_view kServerFirst =
"r=jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ==,"
"s=cTJBDYBtebLYSyDKwIut5w==,i=4096";
const std::string kSalt = DecodeBase64("cTJBDYBtebLYSyDKwIut5w==");
constexpr std::string_view kClientFirst = "n=test1,r=jzabxxqWYZEmuM9pRq7JR4VQ";
constexpr int kIterations = 4096;
constexpr std::string_view kExpectedDataNoBase64 =
"c=biws,r=jzabxxqWYZEmuM9pRq7JR4VQ65tCij07kpOM2/+obhuYEQ==,p=Xpco7kbX/I0OQ7ubCScmCdG1Nml2QBIJw4dp2jdJl9bNq2Uny43CQ88zrCvfBnJuLdXM/"
"kw8VBzb6oy6BkRdog==";
constexpr sha512sum::EncryptionTag kTag{};
auto ToUnsignedCharStringView(auto& range) -> UnsignedStringView {
return {new(range.data()) unsigned char[range.size()], range.size()};
};
TEST(SCRAM, SHA512) {
auto data = GenerateScramAuthMessage(kUsername, kSalt, kServerNonce, kServerFirst, kClientFirst, kIterations, kTag);
EXPECT_EQ(data, kExpectedDataNoBase64);
}
} // namespace larra::xmpp::client