All checks were successful
PR Check / on-push-commit-check (push) Successful in 11m21s
390 lines
18 KiB
C++
390 lines
18 KiB
C++
#pragma once
|
|
#include <spdlog/spdlog.h>
|
|
|
|
#include <boost/asio/awaitable.hpp>
|
|
#include <boost/asio/co_spawn.hpp>
|
|
#include <boost/asio/connect.hpp>
|
|
#include <boost/asio/detached.hpp>
|
|
#include <boost/asio/ip/tcp.hpp>
|
|
#include <boost/asio/read.hpp>
|
|
#include <boost/asio/read_until.hpp>
|
|
#include <boost/asio/ssl.hpp>
|
|
#include <boost/asio/use_awaitable.hpp>
|
|
#include <charconv>
|
|
#include <larra/bind.hpp>
|
|
#include <larra/client/challenge_response.hpp>
|
|
#include <larra/client/options.hpp>
|
|
#include <larra/client/starttls_response.hpp>
|
|
#include <larra/client/xmpp_client_stream_features.hpp>
|
|
#include <larra/encryption.hpp>
|
|
#include <larra/features.hpp>
|
|
#include <larra/roster.hpp>
|
|
#include <larra/stream.hpp>
|
|
#include <larra/user_account.hpp>
|
|
#include <larra/xml_stream.hpp>
|
|
#include <ranges>
|
|
#include <utility>
|
|
|
|
namespace rng = std::ranges;
|
|
namespace views = std::views;
|
|
namespace iq = larra::xmpp::iq;
|
|
|
|
namespace larra::xmpp {
|
|
|
|
constexpr auto kDefaultXmppPort = 5222;
|
|
|
|
} // namespace larra::xmpp
|
|
|
|
namespace larra::xmpp::client {
|
|
|
|
template <typename Connection>
|
|
struct Client {
|
|
constexpr Client(BareJid jid, XmlStream<Connection> connection) : jid(std::move(jid)), connection(std::move(connection)) {
|
|
}
|
|
template <boost::asio::completion_token_for<void()> Token = boost::asio::use_awaitable_t<>>
|
|
constexpr auto Close(Token token = {}) {
|
|
this->active = false;
|
|
return boost::asio::async_initiate<Token, void()>(
|
|
[]<typename Handler>(Handler&& h, XmlStream<Connection> connection) { // NOLINT
|
|
boost::asio::co_spawn(
|
|
connection.next_layer().get_executor(),
|
|
[](auto h, XmlStream<Connection> connection) -> boost::asio::awaitable<void> {
|
|
co_await boost::asio::async_write(
|
|
connection.next_layer(), boost::asio::buffer("</stream:stream>"), boost::asio::use_awaitable);
|
|
std::string response;
|
|
co_await boost::asio::async_read_until(
|
|
connection.next_layer(), boost::asio::dynamic_buffer(response), "</stream:stream>", boost::asio::use_awaitable);
|
|
h();
|
|
}(std::move(h), std::move(connection)),
|
|
boost::asio::detached);
|
|
},
|
|
token,
|
|
std::move(this->connection));
|
|
}
|
|
constexpr Client(const Client&) = delete;
|
|
constexpr Client(Client&& client) noexcept : connection(std::move(client.connection)), jid(std::move(client.jid)) {
|
|
client.active = false;
|
|
}
|
|
constexpr ~Client() {
|
|
if(this->active) {
|
|
this->Close([] {});
|
|
}
|
|
}
|
|
template <typename T>
|
|
auto Send(const T& object) -> boost::asio::awaitable<void> {
|
|
co_await this->connection.Send(object);
|
|
}
|
|
[[nodiscard]] constexpr auto Jid() const -> const FullJid& {
|
|
return this->jid;
|
|
}
|
|
|
|
auto CreateResourceBind() -> boost::asio::awaitable<void> {
|
|
SPDLOG_INFO("Send IQ: Set::Bind");
|
|
co_await this->Send(::iq::SetBind{.id = "1", .payload = {}});
|
|
|
|
auto set_bind_response = co_await connection.template Read<Iq<::iq::Bind>>();
|
|
std::visit(utempl::Overloaded(
|
|
[](auto error) {
|
|
throw "Error response on IQ: Set::Bind: ''"; // TODO(unknown): Add exact error parsing
|
|
},
|
|
[&](::iq::ResultBind r) {
|
|
jid.resource = std::move(r.payload.jid->resource);
|
|
SPDLOG_INFO("Allocated resource: {}", jid.resource);
|
|
}),
|
|
set_bind_response);
|
|
co_return;
|
|
}
|
|
|
|
auto UpdateListOfContacts() -> boost::asio::awaitable<void> {
|
|
SPDLOG_INFO("Send IQ: Get::Roster");
|
|
co_await this->Send(::iq::GetRoster{.id = "1", .from = std::format("{}@{}", "invalid_user", jid.server), .payload = {}});
|
|
|
|
const auto get_roster_response = co_await connection.template Read<Iq<::iq::Roster>>();
|
|
std::visit(utempl::Overloaded(
|
|
[](auto error) {
|
|
throw "Error response on IQ: Get::Roster: ''"; // TODO(unknown): Add exact error parsing
|
|
},
|
|
[&](::iq::ResultRoster r) {
|
|
roster = std::move(r.payload);
|
|
SPDLOG_INFO("New roster: {}", ToString(roster));
|
|
}),
|
|
get_roster_response);
|
|
co_return;
|
|
}
|
|
|
|
private:
|
|
bool active = true;
|
|
XmlStream<Connection> connection{};
|
|
FullJid jid;
|
|
::iq::Roster roster;
|
|
};
|
|
|
|
struct StartTlsNegotiationError : std::runtime_error {
|
|
inline StartTlsNegotiationError(std::string_view error) : std::runtime_error(std::format("STARTTLS negotiation error: {}", error)) {};
|
|
};
|
|
|
|
struct ServerRequiresStartTls : std::exception {
|
|
[[nodiscard]] auto what() const noexcept -> const char* override {
|
|
return "XMPP Server requires STARTTLS";
|
|
};
|
|
};
|
|
|
|
namespace impl {
|
|
|
|
template <std::ranges::range Range, typename... Args>
|
|
auto Contains(Range&& range, Args&&... values) { // NOLINT
|
|
for(auto& value : range) {
|
|
if(((value == std::forward<Args>(values)) || ...)) { // NOLINT
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
template <typename T = int>
|
|
inline auto ToInt(std::string_view input) -> std::optional<T> {
|
|
T out{};
|
|
const std::from_chars_result result = std::from_chars(input.data(), input.data() + input.size(), out);
|
|
return result.ec == std::errc::invalid_argument || result.ec == std::errc::result_out_of_range ? std::nullopt : std::optional{out};
|
|
}
|
|
|
|
struct Challenge {
|
|
std::string body;
|
|
std::string_view serverNonce;
|
|
std::string salt;
|
|
int iterations;
|
|
[[nodiscard]] inline static auto Parse(const xmlpp::Element* node) -> Challenge {
|
|
if(node->get_name() != "challenge") {
|
|
throw std::runtime_error(std::format("Invalid name {} for challenge", node->get_name()));
|
|
}
|
|
std::string decoded = DecodeBase64(node->get_first_child_text()->get_content());
|
|
auto params = std::views::split(decoded, ',') //
|
|
| std::views::transform([](auto param) { //
|
|
return std::string_view{param}; //
|
|
}) //
|
|
| std::views::transform([](std::string_view param) -> std::pair<std::string_view, std::string_view> { //
|
|
auto v = param.find("="); //
|
|
return {param.substr(0, v), param.substr(v + 1)}; //
|
|
}) //
|
|
| std::ranges::to<std::unordered_map<std::string_view, std::string_view>>();
|
|
return {.body = std::move(decoded),
|
|
.serverNonce = params.at("r"),
|
|
.salt = DecodeBase64(params.at("s")),
|
|
.iterations = ToInt(params.at("i")).value()};
|
|
}
|
|
};
|
|
|
|
template <typename Tag>
|
|
struct ChallengeResponse {
|
|
static constexpr auto kDefaultName = "response";
|
|
static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-sasl";
|
|
std::string_view password;
|
|
std::string& salt;
|
|
std::string_view serverNonce;
|
|
std::string_view firstServerMessage;
|
|
std::string_view initialMessage;
|
|
int iterations{};
|
|
Tag tag;
|
|
friend constexpr auto operator<<(xmlpp::Element* element, const ChallengeResponse& self) {
|
|
auto text = EncodeBase64(GenerateScramAuthMessage(
|
|
self.password, std::move(self.salt), self.serverNonce, self.firstServerMessage, self.initialMessage, self.iterations, self.tag));
|
|
element->add_child_text(text);
|
|
}
|
|
};
|
|
|
|
struct StartTlsRequest {
|
|
static constexpr auto kDefaultName = "starttls";
|
|
static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-tls";
|
|
friend constexpr auto operator<<(xmlpp::Element*, const StartTlsRequest&) {
|
|
}
|
|
};
|
|
|
|
struct ClientCreateVisitor {
|
|
UserAccount account;
|
|
const Options& options;
|
|
|
|
template <typename Socket>
|
|
auto Auth(PlainUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
|
|
-> boost::asio::awaitable<void> {
|
|
SPDLOG_DEBUG("Start Plain Auth");
|
|
if(!std::ranges::contains(features.saslMechanisms.mechanisms, "PLAIN")) {
|
|
throw std::runtime_error("Server not support PLAIN auth");
|
|
}
|
|
const features::PlainAuthData data{.username = account.jid.username, .password = account.password};
|
|
co_await stream.Send(data);
|
|
sasl::Response response = co_await stream.template Read<sasl::Response>();
|
|
std::visit(utempl::Overloaded(
|
|
[](auto error) {
|
|
throw std::move(error);
|
|
},
|
|
[](sasl::Success) {}),
|
|
response);
|
|
}
|
|
|
|
template <typename Socket, typename Tag>
|
|
auto ScramAuth(std::string methodName, EncryptionUserAccount account, XmlStream<Socket>& stream, Tag tag)
|
|
-> boost::asio::awaitable<void> {
|
|
SPDLOG_DEBUG("Start Scram Auth using '{}'", methodName);
|
|
const auto nonce = GenerateNonce();
|
|
SPDLOG_DEBUG("nonce: {}", nonce);
|
|
const auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce);
|
|
const features::ScramAuthData authData{.mechanism = methodName, .initialMessage = initialMessage, .tag = tag};
|
|
co_await stream.Send(authData);
|
|
Challenge challenge = co_await stream.template Read<Challenge>();
|
|
const std::string_view serverNonce = challenge.serverNonce;
|
|
if(serverNonce.substr(0, nonce.size()) != nonce) {
|
|
throw std::runtime_error("XMPP Server SCRAM nonce not started with client nonce");
|
|
}
|
|
const ChallengeResponse challengeResponse{.password = account.password,
|
|
.salt = challenge.salt, // Mutable reference
|
|
.serverNonce = serverNonce,
|
|
.firstServerMessage = challenge.body,
|
|
.initialMessage = std::string_view{initialMessage}.substr(3),
|
|
.iterations = challenge.iterations,
|
|
.tag = tag};
|
|
co_await stream.Send(challengeResponse);
|
|
sasl::Response response = co_await stream.template Read<sasl::Response>();
|
|
std::visit(utempl::Overloaded(
|
|
[](auto error) {
|
|
throw std::move(error);
|
|
},
|
|
[](sasl::Success) {}),
|
|
response);
|
|
SPDLOG_DEBUG("Success auth for JID {}", ToString(account.jid));
|
|
}
|
|
|
|
template <typename Socket>
|
|
auto Auth(EncryptionRequiredUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
|
|
-> boost::asio::awaitable<void> {
|
|
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-512")) {
|
|
co_return co_await ScramAuth("SCRAM-SHA-512", std::move(account), stream, sha512sum::EncryptionTag{});
|
|
}
|
|
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-256")) {
|
|
co_return co_await ScramAuth("SCRAM-SHA-256", std::move(account), stream, sha256sum::EncryptionTag{});
|
|
}
|
|
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1")) {
|
|
co_return co_await ScramAuth("SCRAM-SHA-1", std::move(account), stream, sha1sum::EncryptionTag{});
|
|
}
|
|
throw std::runtime_error("Server not support SCRAM SHA 1 or SCRAM SHA 256 or SCRAM SHA 512 auth");
|
|
}
|
|
|
|
template <typename Socket>
|
|
auto Auth(EncryptionUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
|
|
-> boost::asio::awaitable<void> {
|
|
Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512")
|
|
? co_await this->Auth(EncryptionRequiredUserAccount{std::move(account)}, stream, std::move(streamHeader), std::move(features))
|
|
: co_await this->Auth(static_cast<PlainUserAccount>(std::move(account)), stream, std::move(streamHeader), std::move(features));
|
|
}
|
|
|
|
template <typename Socket>
|
|
auto Auth(XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable<void> {
|
|
co_return co_await std::visit(
|
|
[&](auto& account) -> boost::asio::awaitable<void> {
|
|
return this->Auth(account, stream, std::move(streamHeader), std::move(features));
|
|
},
|
|
this->account);
|
|
}
|
|
|
|
auto Resolve() -> boost::asio::awaitable<boost::asio::ip::tcp::resolver::results_type> {
|
|
auto executor = co_await boost::asio::this_coro::executor;
|
|
boost::asio::ip::tcp::resolver resolver(executor);
|
|
co_return co_await resolver.async_resolve(this->options.hostname.value_or(account.Jid().server),
|
|
std::to_string(this->options.port.value_or(kDefaultXmppPort)),
|
|
boost::asio::use_awaitable);
|
|
}
|
|
|
|
auto Connect(auto& socket, boost::asio::ip::tcp::resolver::results_type resolveResults) -> boost::asio::awaitable<void> {
|
|
co_await boost::asio::async_connect(socket, resolveResults, boost::asio::use_awaitable);
|
|
}
|
|
|
|
template <typename Socket>
|
|
auto ProcessTls(XmlStream<boost::asio::ssl::stream<Socket>>& stream) -> boost::asio::awaitable<void> {
|
|
const StartTlsRequest request;
|
|
co_await stream.Send(request);
|
|
starttls::Response response = co_await stream.template Read<starttls::Response>();
|
|
std::visit(utempl::Overloaded(
|
|
[](starttls::Failure error) {
|
|
throw error;
|
|
},
|
|
[](starttls::Success) {}),
|
|
response);
|
|
auto& socket = stream.next_layer();
|
|
SSL_set_tlsext_host_name(socket.native_handle(), this->account.Jid().server.c_str());
|
|
try {
|
|
co_await socket.async_handshake(boost::asio::ssl::stream<Socket>::handshake_type::client, boost::asio::use_awaitable);
|
|
} catch(const std::exception& e) {
|
|
throw StartTlsNegotiationError{e.what()};
|
|
}
|
|
}
|
|
|
|
template <typename Socket>
|
|
auto ReadStartStream(XmlStream<Socket>& stream) -> boost::asio::awaitable<ServerToUserStream> {
|
|
co_return (co_await stream.ReadOne(), co_await stream.template ReadOne<ServerToUserStream>());
|
|
}
|
|
|
|
template <typename Socket>
|
|
inline auto operator()(XmlStream<Socket> stream)
|
|
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
|
|
co_await this->Connect(stream.next_layer(), co_await this->Resolve());
|
|
|
|
co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"});
|
|
SPDLOG_DEBUG("UserStream sended");
|
|
ServerToUserStream sToUStream = co_await ReadStartStream(stream);
|
|
StreamFeatures features = co_await stream.template Read<StreamFeatures>();
|
|
SPDLOG_DEBUG("features parsed");
|
|
|
|
if(features.startTls && features.startTls->required == Required::kRequired) {
|
|
throw ServerRequiresStartTls{};
|
|
}
|
|
|
|
co_await this->Auth(stream, std::move(sToUStream), std::move(features));
|
|
co_await stream.Send(UserStream{.from = this->account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"});
|
|
ServerToUserStream secondSToUStream = co_await ReadStartStream(stream);
|
|
StreamFeatures secondFeatures = co_await stream.template Read<StreamFeatures>();
|
|
co_return Client{std::move(this->account.Jid()), std::move(stream)};
|
|
}
|
|
|
|
template <typename Socket>
|
|
inline auto operator()(XmlStream<boost::asio::ssl::stream<Socket>> stream)
|
|
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
|
|
auto& socket = stream.next_layer();
|
|
co_await this->Connect(socket.next_layer(), co_await this->Resolve());
|
|
co_await stream.Send(
|
|
UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"},
|
|
socket.next_layer());
|
|
SPDLOG_DEBUG("UserStream sended");
|
|
auto streamHeader = co_await this->ReadStartStream(stream);
|
|
StreamFeatures features = co_await stream.template Read<StreamFeatures>();
|
|
SPDLOG_DEBUG("features parsed(SSL)");
|
|
if(!features.startTls) {
|
|
if(this->options.useTls == Options::kRequire) {
|
|
throw std::runtime_error("XMPP server not support STARTTLS");
|
|
}
|
|
socket.next_layer().close();
|
|
co_return co_await (*this)(XmlStream<Socket>{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)});
|
|
}
|
|
co_await this->ProcessTls(stream);
|
|
co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer());
|
|
auto newStreamHeader = co_await this->ReadStartStream(stream);
|
|
auto newFeatures = co_await stream.template Read<StreamFeatures>();
|
|
co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures));
|
|
ServerToUserStream secondSToUStream = co_await ReadStartStream(stream);
|
|
StreamFeatures secondFeatures = co_await stream.template Read<StreamFeatures>();
|
|
co_return Client{std::move(this->account.Jid()), std::move(stream)};
|
|
}
|
|
};
|
|
|
|
} // namespace impl
|
|
|
|
template <typename Socket = boost::asio::ip::tcp::socket>
|
|
inline auto CreateClient(UserAccount account, Options options = {})
|
|
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
|
|
auto executor = co_await boost::asio::this_coro::executor;
|
|
boost::asio::ssl::context ctx(boost::asio::ssl::context::sslv23);
|
|
co_return co_await std::visit(
|
|
impl::ClientCreateVisitor{.account = std::move(account), .options = options},
|
|
options.useTls == Options::kNever
|
|
? std::variant<XmlStream<Socket>, XmlStream<boost::asio::ssl::stream<Socket>>>{XmlStream{Socket{executor}}}
|
|
: XmlStream{boost::asio::ssl::stream<Socket>(executor, ctx)});
|
|
}
|
|
|
|
} // namespace larra::xmpp::client
|