larra/library/include/larra/client/client.hpp
Ivan-lis 30a5e69d14
All checks were successful
PR Check / on-push-commit-check (push) Successful in 11m21s
Fixed errors and improve roster tests
2024-11-21 23:22:03 +00:00

390 lines
18 KiB
C++

#pragma once
#include <spdlog/spdlog.h>
#include <boost/asio/awaitable.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/asio/connect.hpp>
#include <boost/asio/detached.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/read_until.hpp>
#include <boost/asio/ssl.hpp>
#include <boost/asio/use_awaitable.hpp>
#include <charconv>
#include <larra/bind.hpp>
#include <larra/client/challenge_response.hpp>
#include <larra/client/options.hpp>
#include <larra/client/starttls_response.hpp>
#include <larra/client/xmpp_client_stream_features.hpp>
#include <larra/encryption.hpp>
#include <larra/features.hpp>
#include <larra/roster.hpp>
#include <larra/stream.hpp>
#include <larra/user_account.hpp>
#include <larra/xml_stream.hpp>
#include <ranges>
#include <utility>
namespace rng = std::ranges;
namespace views = std::views;
namespace iq = larra::xmpp::iq;
namespace larra::xmpp {
constexpr auto kDefaultXmppPort = 5222;
} // namespace larra::xmpp
namespace larra::xmpp::client {
template <typename Connection>
struct Client {
constexpr Client(BareJid jid, XmlStream<Connection> connection) : jid(std::move(jid)), connection(std::move(connection)) {
}
template <boost::asio::completion_token_for<void()> Token = boost::asio::use_awaitable_t<>>
constexpr auto Close(Token token = {}) {
this->active = false;
return boost::asio::async_initiate<Token, void()>(
[]<typename Handler>(Handler&& h, XmlStream<Connection> connection) { // NOLINT
boost::asio::co_spawn(
connection.next_layer().get_executor(),
[](auto h, XmlStream<Connection> connection) -> boost::asio::awaitable<void> {
co_await boost::asio::async_write(
connection.next_layer(), boost::asio::buffer("</stream:stream>"), boost::asio::use_awaitable);
std::string response;
co_await boost::asio::async_read_until(
connection.next_layer(), boost::asio::dynamic_buffer(response), "</stream:stream>", boost::asio::use_awaitable);
h();
}(std::move(h), std::move(connection)),
boost::asio::detached);
},
token,
std::move(this->connection));
}
constexpr Client(const Client&) = delete;
constexpr Client(Client&& client) noexcept : connection(std::move(client.connection)), jid(std::move(client.jid)) {
client.active = false;
}
constexpr ~Client() {
if(this->active) {
this->Close([] {});
}
}
template <typename T>
auto Send(const T& object) -> boost::asio::awaitable<void> {
co_await this->connection.Send(object);
}
[[nodiscard]] constexpr auto Jid() const -> const FullJid& {
return this->jid;
}
auto CreateResourceBind() -> boost::asio::awaitable<void> {
SPDLOG_INFO("Send IQ: Set::Bind");
co_await this->Send(::iq::SetBind{.id = "1", .payload = {}});
auto set_bind_response = co_await connection.template Read<Iq<::iq::Bind>>();
std::visit(utempl::Overloaded(
[](auto error) {
throw "Error response on IQ: Set::Bind: ''"; // TODO(unknown): Add exact error parsing
},
[&](::iq::ResultBind r) {
jid.resource = std::move(r.payload.jid->resource);
SPDLOG_INFO("Allocated resource: {}", jid.resource);
}),
set_bind_response);
co_return;
}
auto UpdateListOfContacts() -> boost::asio::awaitable<void> {
SPDLOG_INFO("Send IQ: Get::Roster");
co_await this->Send(::iq::GetRoster{.id = "1", .from = std::format("{}@{}", "invalid_user", jid.server), .payload = {}});
const auto get_roster_response = co_await connection.template Read<Iq<::iq::Roster>>();
std::visit(utempl::Overloaded(
[](auto error) {
throw "Error response on IQ: Get::Roster: ''"; // TODO(unknown): Add exact error parsing
},
[&](::iq::ResultRoster r) {
roster = std::move(r.payload);
SPDLOG_INFO("New roster: {}", ToString(roster));
}),
get_roster_response);
co_return;
}
private:
bool active = true;
XmlStream<Connection> connection{};
FullJid jid;
::iq::Roster roster;
};
struct StartTlsNegotiationError : std::runtime_error {
inline StartTlsNegotiationError(std::string_view error) : std::runtime_error(std::format("STARTTLS negotiation error: {}", error)) {};
};
struct ServerRequiresStartTls : std::exception {
[[nodiscard]] auto what() const noexcept -> const char* override {
return "XMPP Server requires STARTTLS";
};
};
namespace impl {
template <std::ranges::range Range, typename... Args>
auto Contains(Range&& range, Args&&... values) { // NOLINT
for(auto& value : range) {
if(((value == std::forward<Args>(values)) || ...)) { // NOLINT
return true;
}
}
return false;
}
template <typename T = int>
inline auto ToInt(std::string_view input) -> std::optional<T> {
T out{};
const std::from_chars_result result = std::from_chars(input.data(), input.data() + input.size(), out);
return result.ec == std::errc::invalid_argument || result.ec == std::errc::result_out_of_range ? std::nullopt : std::optional{out};
}
struct Challenge {
std::string body;
std::string_view serverNonce;
std::string salt;
int iterations;
[[nodiscard]] inline static auto Parse(const xmlpp::Element* node) -> Challenge {
if(node->get_name() != "challenge") {
throw std::runtime_error(std::format("Invalid name {} for challenge", node->get_name()));
}
std::string decoded = DecodeBase64(node->get_first_child_text()->get_content());
auto params = std::views::split(decoded, ',') //
| std::views::transform([](auto param) { //
return std::string_view{param}; //
}) //
| std::views::transform([](std::string_view param) -> std::pair<std::string_view, std::string_view> { //
auto v = param.find("="); //
return {param.substr(0, v), param.substr(v + 1)}; //
}) //
| std::ranges::to<std::unordered_map<std::string_view, std::string_view>>();
return {.body = std::move(decoded),
.serverNonce = params.at("r"),
.salt = DecodeBase64(params.at("s")),
.iterations = ToInt(params.at("i")).value()};
}
};
template <typename Tag>
struct ChallengeResponse {
static constexpr auto kDefaultName = "response";
static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-sasl";
std::string_view password;
std::string& salt;
std::string_view serverNonce;
std::string_view firstServerMessage;
std::string_view initialMessage;
int iterations{};
Tag tag;
friend constexpr auto operator<<(xmlpp::Element* element, const ChallengeResponse& self) {
auto text = EncodeBase64(GenerateScramAuthMessage(
self.password, std::move(self.salt), self.serverNonce, self.firstServerMessage, self.initialMessage, self.iterations, self.tag));
element->add_child_text(text);
}
};
struct StartTlsRequest {
static constexpr auto kDefaultName = "starttls";
static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-tls";
friend constexpr auto operator<<(xmlpp::Element*, const StartTlsRequest&) {
}
};
struct ClientCreateVisitor {
UserAccount account;
const Options& options;
template <typename Socket>
auto Auth(PlainUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
SPDLOG_DEBUG("Start Plain Auth");
if(!std::ranges::contains(features.saslMechanisms.mechanisms, "PLAIN")) {
throw std::runtime_error("Server not support PLAIN auth");
}
const features::PlainAuthData data{.username = account.jid.username, .password = account.password};
co_await stream.Send(data);
sasl::Response response = co_await stream.template Read<sasl::Response>();
std::visit(utempl::Overloaded(
[](auto error) {
throw std::move(error);
},
[](sasl::Success) {}),
response);
}
template <typename Socket, typename Tag>
auto ScramAuth(std::string methodName, EncryptionUserAccount account, XmlStream<Socket>& stream, Tag tag)
-> boost::asio::awaitable<void> {
SPDLOG_DEBUG("Start Scram Auth using '{}'", methodName);
const auto nonce = GenerateNonce();
SPDLOG_DEBUG("nonce: {}", nonce);
const auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce);
const features::ScramAuthData authData{.mechanism = methodName, .initialMessage = initialMessage, .tag = tag};
co_await stream.Send(authData);
Challenge challenge = co_await stream.template Read<Challenge>();
const std::string_view serverNonce = challenge.serverNonce;
if(serverNonce.substr(0, nonce.size()) != nonce) {
throw std::runtime_error("XMPP Server SCRAM nonce not started with client nonce");
}
const ChallengeResponse challengeResponse{.password = account.password,
.salt = challenge.salt, // Mutable reference
.serverNonce = serverNonce,
.firstServerMessage = challenge.body,
.initialMessage = std::string_view{initialMessage}.substr(3),
.iterations = challenge.iterations,
.tag = tag};
co_await stream.Send(challengeResponse);
sasl::Response response = co_await stream.template Read<sasl::Response>();
std::visit(utempl::Overloaded(
[](auto error) {
throw std::move(error);
},
[](sasl::Success) {}),
response);
SPDLOG_DEBUG("Success auth for JID {}", ToString(account.jid));
}
template <typename Socket>
auto Auth(EncryptionRequiredUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-512")) {
co_return co_await ScramAuth("SCRAM-SHA-512", std::move(account), stream, sha512sum::EncryptionTag{});
}
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-256")) {
co_return co_await ScramAuth("SCRAM-SHA-256", std::move(account), stream, sha256sum::EncryptionTag{});
}
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1")) {
co_return co_await ScramAuth("SCRAM-SHA-1", std::move(account), stream, sha1sum::EncryptionTag{});
}
throw std::runtime_error("Server not support SCRAM SHA 1 or SCRAM SHA 256 or SCRAM SHA 512 auth");
}
template <typename Socket>
auto Auth(EncryptionUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512")
? co_await this->Auth(EncryptionRequiredUserAccount{std::move(account)}, stream, std::move(streamHeader), std::move(features))
: co_await this->Auth(static_cast<PlainUserAccount>(std::move(account)), stream, std::move(streamHeader), std::move(features));
}
template <typename Socket>
auto Auth(XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable<void> {
co_return co_await std::visit(
[&](auto& account) -> boost::asio::awaitable<void> {
return this->Auth(account, stream, std::move(streamHeader), std::move(features));
},
this->account);
}
auto Resolve() -> boost::asio::awaitable<boost::asio::ip::tcp::resolver::results_type> {
auto executor = co_await boost::asio::this_coro::executor;
boost::asio::ip::tcp::resolver resolver(executor);
co_return co_await resolver.async_resolve(this->options.hostname.value_or(account.Jid().server),
std::to_string(this->options.port.value_or(kDefaultXmppPort)),
boost::asio::use_awaitable);
}
auto Connect(auto& socket, boost::asio::ip::tcp::resolver::results_type resolveResults) -> boost::asio::awaitable<void> {
co_await boost::asio::async_connect(socket, resolveResults, boost::asio::use_awaitable);
}
template <typename Socket>
auto ProcessTls(XmlStream<boost::asio::ssl::stream<Socket>>& stream) -> boost::asio::awaitable<void> {
const StartTlsRequest request;
co_await stream.Send(request);
starttls::Response response = co_await stream.template Read<starttls::Response>();
std::visit(utempl::Overloaded(
[](starttls::Failure error) {
throw error;
},
[](starttls::Success) {}),
response);
auto& socket = stream.next_layer();
SSL_set_tlsext_host_name(socket.native_handle(), this->account.Jid().server.c_str());
try {
co_await socket.async_handshake(boost::asio::ssl::stream<Socket>::handshake_type::client, boost::asio::use_awaitable);
} catch(const std::exception& e) {
throw StartTlsNegotiationError{e.what()};
}
}
template <typename Socket>
auto ReadStartStream(XmlStream<Socket>& stream) -> boost::asio::awaitable<ServerToUserStream> {
co_return (co_await stream.ReadOne(), co_await stream.template ReadOne<ServerToUserStream>());
}
template <typename Socket>
inline auto operator()(XmlStream<Socket> stream)
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
co_await this->Connect(stream.next_layer(), co_await this->Resolve());
co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"});
SPDLOG_DEBUG("UserStream sended");
ServerToUserStream sToUStream = co_await ReadStartStream(stream);
StreamFeatures features = co_await stream.template Read<StreamFeatures>();
SPDLOG_DEBUG("features parsed");
if(features.startTls && features.startTls->required == Required::kRequired) {
throw ServerRequiresStartTls{};
}
co_await this->Auth(stream, std::move(sToUStream), std::move(features));
co_await stream.Send(UserStream{.from = this->account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"});
ServerToUserStream secondSToUStream = co_await ReadStartStream(stream);
StreamFeatures secondFeatures = co_await stream.template Read<StreamFeatures>();
co_return Client{std::move(this->account.Jid()), std::move(stream)};
}
template <typename Socket>
inline auto operator()(XmlStream<boost::asio::ssl::stream<Socket>> stream)
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
auto& socket = stream.next_layer();
co_await this->Connect(socket.next_layer(), co_await this->Resolve());
co_await stream.Send(
UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"},
socket.next_layer());
SPDLOG_DEBUG("UserStream sended");
auto streamHeader = co_await this->ReadStartStream(stream);
StreamFeatures features = co_await stream.template Read<StreamFeatures>();
SPDLOG_DEBUG("features parsed(SSL)");
if(!features.startTls) {
if(this->options.useTls == Options::kRequire) {
throw std::runtime_error("XMPP server not support STARTTLS");
}
socket.next_layer().close();
co_return co_await (*this)(XmlStream<Socket>{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)});
}
co_await this->ProcessTls(stream);
co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer());
auto newStreamHeader = co_await this->ReadStartStream(stream);
auto newFeatures = co_await stream.template Read<StreamFeatures>();
co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures));
ServerToUserStream secondSToUStream = co_await ReadStartStream(stream);
StreamFeatures secondFeatures = co_await stream.template Read<StreamFeatures>();
co_return Client{std::move(this->account.Jid()), std::move(stream)};
}
};
} // namespace impl
template <typename Socket = boost::asio::ip::tcp::socket>
inline auto CreateClient(UserAccount account, Options options = {})
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
auto executor = co_await boost::asio::this_coro::executor;
boost::asio::ssl::context ctx(boost::asio::ssl::context::sslv23);
co_return co_await std::visit(
impl::ClientCreateVisitor{.account = std::move(account), .options = options},
options.useTls == Options::kNever
? std::variant<XmlStream<Socket>, XmlStream<boost::asio::ssl::stream<Socket>>>{XmlStream{Socket{executor}}}
: XmlStream{boost::asio::ssl::stream<Socket>(executor, ctx)});
}
} // namespace larra::xmpp::client