Add stream errors handling

This commit is contained in:
sha512sum 2024-10-08 08:36:08 +00:00
parent bb336da4d9
commit 0636e1c234
11 changed files with 216 additions and 170 deletions

View file

@ -55,7 +55,12 @@ CPMAddPackage(
CPMAddPackage("gh:zeux/pugixml@1.14")
CPMAddPackage("gh:fmtlib/fmt#10.2.1")
CPMAddPackage("gh:Neargye/nameof@0.10.4")
CPMAddPackage(NAME nameof
VERSION 0.10.4
GIT_REPOSITORY "https://github.com/Neargye/nameof.git"
EXCLUDE_FROM_ALL ON
OPTIONS "NAMEOF_OPT_INSTALL ON"
)
CPMAddPackage(
NAME spdlog
@ -167,7 +172,7 @@ if(TARGET Boost::pfr)
OpenSSL::SSL nameof::nameof
OpenSSL::Crypto spdlog xmlplusplus ${LIBXML2_LIBRARIES})
else()
find_package(Boost 1.85.0 REQUIRED)
find_package(Boost 1.85.0 COMPONENTS serialization REQUIRED)
target_link_libraries(larra_xmpp PUBLIC
utempl::utempl ${Boost_LIBRARIES} pugixml::pugixml OpenSSL::SSL
nameof::nameof

View file

@ -11,15 +11,16 @@
#include <boost/asio/ssl.hpp>
#include <boost/asio/use_awaitable.hpp>
#include <charconv>
#include <larra/client/challenge_response.hpp>
#include <larra/client/options.hpp>
#include <larra/client/starttls_response.hpp>
#include <larra/client/xmpp_client_stream_features.hpp>
#include <larra/encryption.hpp>
#include <larra/features.hpp>
#include <larra/raw_xml_stream.hpp>
#include <larra/stream.hpp>
#include <larra/user_account.hpp>
#include <larra/xml_stream.hpp>
#include <ranges>
#include "larra/client/xmpp_client_stream_features.hpp"
namespace larra::xmpp {
constexpr auto kDefaultXmppPort = 5222;
@ -33,15 +34,15 @@ namespace views = std::views;
template <typename Connection>
struct Client {
constexpr Client(BareJid jid, RawXmlStream<Connection> connection) : jid(std::move(jid)), connection(std::move(connection)) {};
constexpr Client(BareJid jid, XmlStream<Connection> connection) : jid(std::move(jid)), connection(std::move(connection)) {};
template <boost::asio::completion_token_for<void()> Token = boost::asio::use_awaitable_t<>>
constexpr auto Close(Token token = {}) {
this->active = false;
return boost::asio::async_initiate<Token, void()>(
[]<typename Handler>(Handler&& h, RawXmlStream<Connection> connection) { // NOLINT
[]<typename Handler>(Handler&& h, XmlStream<Connection> connection) { // NOLINT
boost::asio::co_spawn(
connection.next_layer().get_executor(),
[](auto h, RawXmlStream<Connection> connection) -> boost::asio::awaitable<void> {
[](auto h, XmlStream<Connection> connection) -> boost::asio::awaitable<void> {
co_await boost::asio::async_write(
connection.next_layer(), boost::asio::buffer("</stream:stream>"), boost::asio::use_awaitable);
std::string response;
@ -70,7 +71,7 @@ struct Client {
private:
bool active = true;
RawXmlStream<Connection> connection;
XmlStream<Connection> connection;
BareJid jid;
};
@ -158,7 +159,7 @@ struct ClientCreateVisitor {
const Options& options;
template <typename Socket>
auto Auth(PlainUserAccount account, RawXmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
auto Auth(PlainUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
SPDLOG_DEBUG("Start Plain Auth");
if(!std::ranges::contains(features.saslMechanisms.mechanisms, "PLAIN")) {
@ -166,11 +167,17 @@ struct ClientCreateVisitor {
}
const features::PlainAuthData data{.username = account.jid.username, .password = account.password};
co_await stream.Send(data);
std::ignore = co_await stream.Read();
sasl::Response response = co_await stream.template Read<sasl::Response>();
std::visit(utempl::Overloaded(
[](auto error) {
throw std::move(error);
},
[](sasl::Success) {}),
response);
}
template <typename Socket, typename Tag>
auto ScramAuth(std::string methodName, EncryptionUserAccount account, RawXmlStream<Socket>& stream, Tag tag)
auto ScramAuth(std::string methodName, EncryptionUserAccount account, XmlStream<Socket>& stream, Tag tag)
-> boost::asio::awaitable<void> {
SPDLOG_DEBUG("Start Scram Auth using '{}'", methodName);
const auto nonce = GenerateNonce();
@ -191,23 +198,18 @@ struct ClientCreateVisitor {
.iterations = challenge.iterations,
.tag = tag};
co_await stream.Send(challengeResponse);
std::unique_ptr<xmlpp::Document> doc = co_await stream.Read();
auto root = doc->get_root_node();
if(!root || root->get_name() == "failure") {
if(auto textNode = root->get_first_child("text")) {
if(auto text = dynamic_cast<xmlpp::Element*>(textNode)) {
if(auto childText = text->get_first_child_text()) {
throw std::runtime_error(std::format("Auth failed: {}", childText->get_content()));
}
}
}
throw std::runtime_error("Auth failed");
}
sasl::Response response = co_await stream.template Read<sasl::Response>();
std::visit(utempl::Overloaded(
[](auto error) {
throw std::move(error);
},
[](sasl::Success) {}),
response);
SPDLOG_DEBUG("Success auth for JID {}", ToString(account.jid));
}
template <typename Socket>
auto Auth(EncryptionRequiredUserAccount account, RawXmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
auto Auth(EncryptionRequiredUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-512")) {
co_return co_await ScramAuth("SCRAM-SHA-512", std::move(account), stream, sha512sum::EncryptionTag{});
@ -222,7 +224,7 @@ struct ClientCreateVisitor {
}
template <typename Socket>
auto Auth(EncryptionUserAccount account, RawXmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
auto Auth(EncryptionUserAccount account, XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features)
-> boost::asio::awaitable<void> {
Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512")
? co_await this->Auth(EncryptionRequiredUserAccount{std::move(account)}, stream, std::move(streamHeader), std::move(features))
@ -230,7 +232,7 @@ struct ClientCreateVisitor {
}
template <typename Socket>
auto Auth(RawXmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable<void> {
auto Auth(XmlStream<Socket>& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable<void> {
co_return co_await std::visit(
[&](auto& account) -> boost::asio::awaitable<void> {
return this->Auth(std::move(account), stream, std::move(streamHeader), std::move(features));
@ -251,17 +253,16 @@ struct ClientCreateVisitor {
}
template <typename Socket>
auto ProcessTls(RawXmlStream<boost::asio::ssl::stream<Socket>>& stream) -> boost::asio::awaitable<void> {
auto ProcessTls(XmlStream<boost::asio::ssl::stream<Socket>>& stream) -> boost::asio::awaitable<void> {
const StartTlsRequest request;
co_await stream.Send(request);
std::unique_ptr<xmlpp::Document> doc = co_await stream.Read();
if(auto node = doc->get_root_node()) {
if(node->get_name() == "proceed") {
goto proceed; // NOLINT
}
throw StartTlsNegotiationError{"Failure XMPP"};
}
proceed:
starttls::Response response = co_await stream.template Read<starttls::Response>();
std::visit(utempl::Overloaded(
[](starttls::Failure error) {
throw error;
},
[](starttls::Success) {}),
response);
auto& socket = stream.next_layer();
SSL_set_tlsext_host_name(socket.native_handle(), this->account.Jid().server.c_str());
try {
@ -270,57 +271,14 @@ struct ClientCreateVisitor {
throw StartTlsNegotiationError{e.what()};
}
}
static constexpr auto GetEnumerated(boost::asio::streambuf& streambuf) {
return std::views::zip(std::views::iota(std::size_t{}, streambuf.size()), ::larra::xmpp::impl::GetCharsRangeFromBuf(streambuf));
}
using EnumeratedT = decltype(std::views::zip(std::views::iota(std::size_t{}, std::size_t{}),
::larra::xmpp::impl::GetCharsRangeFromBuf(std::declval<boost::asio::streambuf&>())));
struct Splitter {
EnumeratedT range;
struct Sentinel {
std::ranges::sentinel_t<EnumeratedT> end;
};
struct Iterator {
std::ranges::iterator_t<EnumeratedT> it;
std::ranges::sentinel_t<EnumeratedT> end;
friend constexpr auto operator==(const Iterator& self, const Sentinel& it) -> bool {
return self.it == it.end;
}
auto operator++() -> Iterator& {
if(this->it == this->end) {
return *this;
}
this->it = std::ranges::find(this->it, this->end, '>', [](auto v) {
auto [_, c] = v;
return c;
});
if(this->it != this->end) {
++it;
}
return *this;
};
auto operator*() const {
return *it;
}
};
auto begin() -> Iterator {
return Iterator{.it = std::ranges::begin(this->range), .end = std::ranges::end(this->range)};
}
auto end() -> Sentinel {
return {.end = std::ranges::end(this->range)};
}
};
template <typename Socket>
auto ReadStartStream(RawXmlStream<Socket>& stream) -> boost::asio::awaitable<ServerToUserStream> {
auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne());
co_return ServerToUserStream::Parse(doc->get_root_node());
auto ReadStartStream(XmlStream<Socket>& stream) -> boost::asio::awaitable<ServerToUserStream> {
co_return (co_await stream.ReadOne(), co_await stream.template ReadOne<ServerToUserStream>());
}
template <typename Socket>
inline auto operator()(RawXmlStream<Socket> stream)
inline auto operator()(XmlStream<Socket> stream)
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
co_await this->Connect(stream.next_layer(), co_await this->Resolve());
@ -339,7 +297,7 @@ struct ClientCreateVisitor {
}
template <typename Socket>
inline auto operator()(RawXmlStream<boost::asio::ssl::stream<Socket>> stream)
inline auto operator()(XmlStream<boost::asio::ssl::stream<Socket>> stream)
-> boost::asio::awaitable<std::variant<Client<Socket>, Client<boost::asio::ssl::stream<Socket>>>> {
auto& socket = stream.next_layer();
co_await this->Connect(socket.next_layer(), co_await this->Resolve());
@ -353,14 +311,14 @@ struct ClientCreateVisitor {
throw std::runtime_error("XMPP server not support STARTTLS");
}
socket.next_layer().close();
co_return co_await (*this)(RawXmlStream<Socket>{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)});
co_return co_await (*this)(XmlStream<Socket>{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)});
}
co_await this->ProcessTls(stream);
co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer());
auto newStreamHeader = co_await this->ReadStartStream(stream);
auto newFeatures = co_await stream.template Read<StreamFeatures>();
co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures));
co_return Client{std::move(this->account).Jid(), RawXmlStream{std::move(socket)}};
co_return Client{std::move(this->account).Jid(), XmlStream{std::move(socket)}};
}
};
@ -374,8 +332,8 @@ inline auto CreateClient(UserAccount account, Options options = {})
co_return co_await std::visit(
impl::ClientCreateVisitor{.account = std::move(account), .options = options},
options.useTls == Options::kNever
? std::variant<RawXmlStream<Socket>, RawXmlStream<boost::asio::ssl::stream<Socket>>>{RawXmlStream{Socket{executor}}}
: RawXmlStream{boost::asio::ssl::stream<Socket>(executor, ctx)});
? std::variant<XmlStream<Socket>, XmlStream<boost::asio::ssl::stream<Socket>>>{XmlStream{Socket{executor}}}
: XmlStream{boost::asio::ssl::stream<Socket>(executor, ctx)});
}
} // namespace larra::xmpp::client

View file

@ -4,8 +4,7 @@
#include <larra/encryption.hpp>
#include <larra/jid.hpp>
#include <larra/raw_xml_stream.hpp>
#include <utility>
#include <larra/xml_stream.hpp>
namespace larra::xmpp::client::features {
/*

View file

@ -17,6 +17,7 @@ struct SaslMechanisms {
};
struct StreamFeatures {
static constexpr auto kDefaultName = "stream:features";
struct StartTlsType {
Required required;
[[nodiscard]] constexpr auto Required(Required required) const -> StartTlsType {

View file

@ -13,6 +13,7 @@ struct BareJid {
[[nodiscard]] static auto Parse(std::string_view jid) -> BareJid;
friend auto ToString(const BareJid& jid) -> std::string;
constexpr auto operator==(const BareJid&) const -> bool = default;
template <typename Self>
[[nodiscard]] constexpr auto Username(this Self&& self, std::string username) -> BareJid {
return utils::FieldSetHelper::With<"username", BareJid>(std::forward<Self>(self), std::move(username));
@ -30,6 +31,8 @@ struct BareResourceJid {
[[nodiscard]] static auto Parse(std::string_view jid) -> BareResourceJid;
friend auto ToString(const BareResourceJid& jid) -> std::string;
constexpr auto operator==(const BareResourceJid&) const -> bool = default;
template <typename Self>
[[nodiscard]] constexpr auto Server(this Self&& self, std::string server) -> BareResourceJid {
return utils::FieldSetHelper::With<"server", BareResourceJid>(std::forward<Self>(self), std::move(server));
@ -48,6 +51,8 @@ struct FullJid {
[[nodiscard]] static auto Parse(std::string_view jid) -> FullJid;
friend auto ToString(const FullJid& jid) -> std::string;
constexpr auto operator==(const FullJid&) const -> bool = default;
template <typename Self>
[[nodiscard]] constexpr auto Username(this Self&& self, std::string username) -> FullJid {
return utils::FieldSetHelper::With<"username", FullJid>(std::forward<Self>(self), std::move(username));

View file

@ -3,7 +3,7 @@
#include <boost/asio/ssl.hpp>
#include <boost/asio/write.hpp>
#include <larra/raw_xml_stream.hpp>
#include <larra/xml_stream.hpp>
#include <print>
#include <ranges>

View file

@ -2,6 +2,7 @@
#include <libxml++/libxml++.h>
#include <spdlog/spdlog.h>
#include <nameof.hpp>
#include <string>
#include <utempl/utils.hpp>
@ -80,23 +81,35 @@ struct SerializationBase {
return false;
}
}();
static constexpr auto StartCheck(xmlpp::Element* element) -> bool {
if constexpr(requires {
{ T::StartCheck(element) } -> std::same_as<bool>;
}) {
return T::StartCheck(element);
} else {
return element && element->get_name() == kDefaultName;
}
};
};
template <typename T>
struct Serialization : SerializationBase<T> {
[[nodiscard]] static constexpr auto Parse(xmlpp::Element* element) -> T {
if(!Serialization::StartCheck(element)) {
throw std::runtime_error("StartCheck failed");
}
return T::Parse(element);
}
[[nodiscard]] static constexpr auto TryParse(xmlpp::Element* element) -> std::optional<T> {
if constexpr(HasTryParse<T>) {
return T::TryParse(element);
return Serialization::StartCheck(element) ? T::TryParse(element) : std::nullopt;
} else {
try {
return T::Parse(element);
return Serialization::StartCheck(element) ? std::optional{T::Parse(element)} : std::nullopt;
} catch(const std::exception& e) {
SPDLOG_WARN("Failed Parse but no TryParse found: {}", e.what());
SPDLOG_WARN("Type {}: Failed Parse but no TryParse found: {}", e.what(), nameof::nameof_type<T>());
} catch(...) {
SPDLOG_WARN("Failed Parse but no TryParse found");
SPDLOG_WARN("Type {}: Failed Parse but no TryParse found", nameof::nameof_type<T>());
}
return std::nullopt;
}
@ -122,9 +135,17 @@ struct Serialization<std::optional<T>> : SerializationBase<T> {
template <typename... Ts>
struct Serialization<std::variant<Ts...>> : SerializationBase<> {
static constexpr auto StartCheck(xmlpp::Element* element) {
return true;
}
[[nodiscard]] static constexpr auto TryParse(xmlpp::Element* element) -> std::optional<std::variant<Ts...>> {
return utempl::FirstOf(utempl::Tuple{[&] {
return Serialization<Ts>::TryParse(element);
return utempl::FirstOf(utempl::Tuple{[&] -> std::optional<Ts> {
if(Serialization<Ts>::StartCheck(element)) {
return Serialization<Ts>::TryParse(element);
} else {
SPDLOG_DEBUG("StartCheck failed for type {}", nameof::nameof_type<Ts>());
return std::nullopt;
}
}...},
std::optional<std::variant<Ts...>>{});
}
@ -140,4 +161,19 @@ struct Serialization<std::variant<Ts...>> : SerializationBase<> {
}
};
template <>
struct Serialization<std::monostate> : SerializationBase<> {
static constexpr auto StartCheck(xmlpp::Element*) -> bool {
return true;
};
[[nodiscard]] static constexpr auto TryParse(xmlpp::Element*) -> std::optional<std::monostate> {
return std::monostate{};
}
[[nodiscard]] static constexpr auto Parse(xmlpp::Element*) -> std::monostate {
return {};
}
static constexpr auto Serialize(xmlpp::Element*, const std::monostate&) -> void {
}
};
} // namespace larra::xmpp

View file

@ -48,12 +48,15 @@ constexpr auto ToKebabCaseName() -> std::string_view {
namespace error::stream {
struct BaseError : std::exception {};
// DO NOT MOVE TO ANOTHER NAMESPACE(where no heirs). VIA friend A FUNCTION IS ADDED THAT VIA ADL WILL BE SEARCHED FOR HEIRS
// C++20 modules very unstable in clangd :(
template <typename T>
struct BaseError : std::exception {
static constexpr auto kDefaultName = "error";
static constexpr auto kDefaultNamespace = "stream";
struct ErrorImpl : BaseError {
static constexpr auto kDefaultName = "stream:error";
static inline const auto kKebabCaseName = static_cast<std::string>(impl::ToKebabCaseName<T>());
static inline const std::string kErrorContentNamespace = "urn:ietf:params:xml:ns:xmpp-streams";
static constexpr auto kErrorMessage = [] -> std::string_view {
static constexpr auto str = [] {
return std::array{std::string_view{"Stream Error: "}, nameof::nameof_short_type<T>(), std::string_view{"\0", 1}} | std::views::join;
@ -76,38 +79,38 @@ struct BaseError : std::exception {
}
friend constexpr auto operator<<(xmlpp::Element* element, const T& obj) -> void {
auto node = element->add_child_element(kKebabCaseName);
node->set_namespace_declaration(kErrorContentNamespace);
node->set_namespace_declaration("urn:ietf:params:xml:ns:xmpp-streams");
}
[[nodiscard]] constexpr auto what() const noexcept -> const char* override {
return kErrorMessage.data();
}
};
struct BadFormat : BaseError<BadFormat> {};
struct BadNamespacePrefix : BaseError<BadNamespacePrefix> {};
struct Conflict : BaseError<Conflict> {};
struct ConnectionTimeout : BaseError<ConnectionTimeout> {};
struct HostGone : BaseError<HostGone> {};
struct HostUnknown : BaseError<HostUnknown> {};
struct ImproperAdressing : BaseError<ImproperAdressing> {};
struct InternalServerError : BaseError<InternalServerError> {};
struct InvalidForm : BaseError<InvalidForm> {};
struct InvalidNamespace : BaseError<InvalidNamespace> {};
struct InvalidXml : BaseError<InvalidXml> {};
struct NotAuthorized : BaseError<NotAuthorized> {};
struct NotWellFormed : BaseError<NotWellFormed> {};
struct PolicyViolation : BaseError<PolicyViolation> {};
struct RemoteConnectionFailed : BaseError<RemoteConnectionFailed> {};
struct Reset : BaseError<Reset> {};
struct ResourceConstraint : BaseError<ResourceConstraint> {};
struct RestrictedXml : BaseError<RestrictedXml> {};
struct SeeOtherHost : BaseError<SeeOtherHost> {};
struct SystemShutdown : BaseError<SystemShutdown> {};
struct UndefinedCondition : BaseError<UndefinedCondition> {};
struct UnsupportedEncoding : BaseError<UnsupportedEncoding> {};
struct UnsupportedFeature : BaseError<UnsupportedFeature> {};
struct UnsupportedStanzaType : BaseError<UnsupportedStanzaType> {};
struct UnsupportedVersion : BaseError<UnsupportedVersion> {};
struct BadFormat : ErrorImpl<BadFormat> {};
struct BadNamespacePrefix : ErrorImpl<BadNamespacePrefix> {};
struct Conflict : ErrorImpl<Conflict> {};
struct ConnectionTimeout : ErrorImpl<ConnectionTimeout> {};
struct HostGone : ErrorImpl<HostGone> {};
struct HostUnknown : ErrorImpl<HostUnknown> {};
struct ImproperAdressing : ErrorImpl<ImproperAdressing> {};
struct InternalServerError : ErrorImpl<InternalServerError> {};
struct InvalidForm : ErrorImpl<InvalidForm> {};
struct InvalidNamespace : ErrorImpl<InvalidNamespace> {};
struct InvalidXml : ErrorImpl<InvalidXml> {};
struct NotAuthorized : ErrorImpl<NotAuthorized> {};
struct NotWellFormed : ErrorImpl<NotWellFormed> {};
struct PolicyViolation : ErrorImpl<PolicyViolation> {};
struct RemoteConnectionFailed : ErrorImpl<RemoteConnectionFailed> {};
struct Reset : ErrorImpl<Reset> {};
struct ResourceConstraint : ErrorImpl<ResourceConstraint> {};
struct RestrictedXml : ErrorImpl<RestrictedXml> {};
struct SeeOtherHost : ErrorImpl<SeeOtherHost> {};
struct SystemShutdown : ErrorImpl<SystemShutdown> {};
struct UndefinedCondition : ErrorImpl<UndefinedCondition> {};
struct UnsupportedEncoding : ErrorImpl<UnsupportedEncoding> {};
struct UnsupportedFeature : ErrorImpl<UnsupportedFeature> {};
struct UnsupportedStanzaType : ErrorImpl<UnsupportedStanzaType> {};
struct UnsupportedVersion : ErrorImpl<UnsupportedVersion> {};
} // namespace error::stream

View file

@ -10,6 +10,7 @@
#include <boost/asio/write.hpp>
#include <boost/system/result.hpp>
#include <larra/serialization.hpp>
#include <larra/stream_error.hpp>
#include <larra/utils.hpp>
#include <stack>
#include <utempl/utils.hpp>
@ -89,8 +90,8 @@ auto IsExtraContentAtTheDocument(const _xmlError* error) -> bool;
} // namespace impl
template <typename Stream, typename BufferType = boost::asio::streambuf>
struct RawXmlStream : Stream {
constexpr RawXmlStream(Stream stream, std::unique_ptr<BufferType> buff = std::make_unique<BufferType>()) :
struct XmlStream : Stream {
constexpr XmlStream(Stream stream, std::unique_ptr<BufferType> buff = std::make_unique<BufferType>()) :
Stream(std::forward<Stream>(stream)), streambuf(std::move(buff)) {};
using Stream::Stream;
auto next_layer() -> Stream& {
@ -101,7 +102,7 @@ struct RawXmlStream : Stream {
return *this;
}
auto ReadOne(auto& socket) -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
auto ReadOneRaw(auto& socket) -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
auto doc = std::make_unique<xmlpp::Document>();
impl::Parser parser(*doc);
for(;;) {
@ -144,10 +145,21 @@ struct RawXmlStream : Stream {
co_return doc;
}
}
auto ReadOne() -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
co_return co_await this->ReadOne(this->next_layer());
template <typename T>
auto ReadOneRaw(auto& stream) -> boost::asio::awaitable<T> {
auto doc = co_await this->ReadOneRaw(stream);
co_return Serialization<T>::Parse(doc->get_root_node());
}
inline auto Read(auto& socket) -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
template <typename T>
auto ReadOneRaw(auto& stream) -> boost::asio::awaitable<T>
requires requires(std::unique_ptr<xmlpp::Document> ptr) {
{ Serialization<T>::Parse(std::move(ptr)) } -> std::same_as<T>;
}
{
co_return Serialization<T>::Parse(co_await this->ReadOneRaw(stream));
}
inline auto ReadRaw(auto& socket) -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
auto doc = std::make_unique<xmlpp::Document>(); // Not movable :(
impl::Parser parser(*doc);
std::size_t lines = 1;
@ -186,7 +198,7 @@ struct RawXmlStream : Stream {
if(!error) {
auto linesAdd = impl::CountLines(impl::BufferToStringView(buff, n));
SPDLOG_DEBUG("Readed {} bytes for RawXmlStream with {} lines", n, linesAdd);
SPDLOG_DEBUG("Readed {} bytes for XmlStream with {} lines", n, linesAdd);
lines += linesAdd;
if(linesAdd == 0) {
@ -209,27 +221,56 @@ struct RawXmlStream : Stream {
co_return doc;
}
}
auto Read() -> boost::asio::awaitable<std::unique_ptr<xmlpp::Document>> {
co_return co_await this->Read(this->next_layer());
}
template <typename T>
auto Read(auto& stream) -> boost::asio::awaitable<T> {
auto doc = co_await this->Read(stream);
auto ReadRaw(auto& stream) -> boost::asio::awaitable<T> {
auto doc = co_await this->ReadRaw(stream);
co_return Serialization<T>::Parse(doc->get_root_node());
}
template <typename T>
auto Read(auto& stream) -> boost::asio::awaitable<T>
auto ReadRaw(auto& stream) -> boost::asio::awaitable<T>
requires requires(std::unique_ptr<xmlpp::Document> ptr) {
{ Serialization<T>::Parse(std::move(ptr)) } -> std::same_as<T>;
}
{
co_return Serialization<T>::Parse(co_await this->Read(stream));
co_return Serialization<T>::Parse(co_await this->ReadRaw(stream));
}
private:
template <typename T>
auto Read() -> boost::asio::awaitable<T> {
co_return co_await this->template Read<T>(this->next_layer());
auto ReadImpl(boost::asio::awaitable<std::variant<StreamError, T>> awaitable) -> boost::asio::awaitable<T> {
co_return std::visit(utempl::Overloaded(
[](T value) -> T {
return std::move(value);
},
[](StreamError error) -> T {
std::visit(
[](auto error) {
throw error;
},
error);
std::unreachable();
}),
co_await std::move(awaitable));
}
public:
template <typename T = std::monostate>
auto Read(auto& stream) {
return this->ReadImpl(this->ReadRaw<std::variant<StreamError, T>>(stream));
}
template <typename T = std::monostate>
auto Read() {
return this->Read<T>(this->next_layer());
}
template <typename T = std::monostate>
auto ReadOne(auto& stream) {
return this->ReadImpl(this->ReadOneRaw<std::variant<StreamError, T>>(stream));
}
template <typename T = std::monostate>
auto ReadOne() {
return this->ReadOne<T>(this->next_layer());
}
auto Send(xmlpp::Document& doc, auto& stream, bool bAddXmlDecl, bool removeEnd) const -> boost::asio::awaitable<void> {
@ -270,7 +311,7 @@ struct RawXmlStream : Stream {
co_await this->Send(xso, this->next_layer());
}
RawXmlStream(RawXmlStream&& other) = default;
XmlStream(XmlStream&& other) = default;
std::unique_ptr<BufferType> streambuf; // Not movable :(
};

View file

@ -1,8 +1,8 @@
#include <libxml/parser.h>
#include <larra/impl/public_cast.hpp>
#include <larra/raw_xml_stream.hpp>
#include <larra/utils.hpp>
#include <larra/xml_stream.hpp>
#include <ranges>
#include <span>

View file

@ -5,8 +5,8 @@
#include <boost/asio/ip/tcp.hpp>
#include <larra/features.hpp>
#include <larra/impl/mock_socket.hpp>
#include <larra/raw_xml_stream.hpp>
#include <larra/stream.hpp>
#include <larra/xml_stream.hpp>
#include <utempl/utils.hpp>
namespace larra::xmpp {
@ -23,7 +23,7 @@ constexpr std::string_view kDoc3 =
constexpr std::string_view kDoc4 =
R"(<?xml version='1.0'?><stream:stream id='68321991947053239' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' to='test1@localhost' from='localhost' xmlns='jabber:client'>)";
TEST(RawXmlStream, ReadByOne) {
TEST(XmlStream, ReadByOne) {
boost::asio::io_context context;
bool error{};
@ -31,14 +31,14 @@ TEST(RawXmlStream, ReadByOne) {
context,
// NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}};
XmlStream stream{impl::MockSocket{context.get_executor(), 1}};
stream.AddReceivedData(kDoc);
try {
auto doc = co_await stream.Read();
auto doc = co_await stream.ReadRaw(stream.next_layer());
auto node = doc->get_root_node();
EXPECT_EQ(node->get_name(), std::string_view{"doc"});
EXPECT_FALSE(node->has_child_text());
auto doc2 = co_await stream.Read();
auto doc2 = co_await stream.ReadRaw(stream.next_layer());
auto node2 = doc2->get_root_node();
EXPECT_EQ(node2->get_name(), std::string_view{"doc2"});
EXPECT_FALSE(node2->has_child_text());
@ -54,20 +54,20 @@ TEST(RawXmlStream, ReadByOne) {
EXPECT_FALSE(error);
}
TEST(RawXmlStream, ReadAll) {
TEST(XmlStream, ReadAll) {
boost::asio::io_context context;
bool error{};
boost::asio::co_spawn(
context, // NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc.size()}};
XmlStream stream{impl::MockSocket{context.get_executor(), kDoc.size()}};
stream.AddReceivedData(kDoc);
try {
auto doc = co_await stream.Read();
auto doc = co_await stream.ReadRaw(stream.next_layer());
auto node = doc->get_root_node();
EXPECT_EQ(node->get_name(), std::string_view{"doc"});
EXPECT_FALSE(node->has_child_text());
auto doc2 = co_await stream.Read();
auto doc2 = co_await stream.ReadRaw(stream.next_layer());
auto node2 = doc2->get_root_node();
EXPECT_EQ(node2->get_name(), std::string_view{"doc2"});
EXPECT_FALSE(node2->has_child_text());
@ -82,20 +82,20 @@ TEST(RawXmlStream, ReadAll) {
EXPECT_FALSE(error);
}
TEST(RawXmlStream, ReadAllWithEnd) {
TEST(XmlStream, ReadAllWithEnd) {
boost::asio::io_context context;
bool error{};
boost::asio::co_spawn(
context, // NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc2.size()}};
XmlStream stream{impl::MockSocket{context.get_executor(), kDoc2.size()}};
stream.AddReceivedData(kDoc2);
try {
auto doc = co_await stream.Read();
auto doc = co_await stream.ReadRaw(stream.next_layer());
auto node = doc->get_root_node();
EXPECT_EQ(node->get_name(), std::string_view{"doc"});
EXPECT_FALSE(node->has_child_text());
auto doc2 = co_await stream.Read();
auto doc2 = co_await stream.ReadRaw(stream.next_layer());
auto node2 = doc2->get_root_node();
EXPECT_EQ(node2->get_name(), std::string_view{"doc2"});
EXPECT_FALSE(node2->has_child_text());
@ -110,13 +110,13 @@ TEST(RawXmlStream, ReadAllWithEnd) {
EXPECT_FALSE(error);
}
TEST(RawXmlStream, ReadFeatures) {
TEST(XmlStream, ReadFeatures) {
boost::asio::io_context context;
bool error{};
boost::asio::co_spawn(
context, // NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc3.size()}};
XmlStream stream{impl::MockSocket{context.get_executor(), kDoc3.size()}};
stream.AddReceivedData(kDoc3);
try {
auto features = co_await stream.template Read<StreamFeatures>();
@ -139,13 +139,13 @@ struct SomeStruct {
}
};
TEST(RawXmlStream, Write) {
TEST(XmlStream, Write) {
boost::asio::io_context context;
bool error{};
boost::asio::co_spawn(
context, // NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream1{impl::MockSocket{context.get_executor()}};
XmlStream stream1{impl::MockSocket{context.get_executor()}};
auto stream = std::move(stream1);
try {
co_await stream.Send(SomeStruct{});
@ -160,23 +160,21 @@ TEST(RawXmlStream, Write) {
EXPECT_FALSE(error);
}
TEST(RawXmlStream, ReadOneByOne) {
TEST(XmlStream, ReadOneByOne) {
boost::asio::io_context context;
bool error{};
boost::asio::co_spawn(
context,
// NOLINTNEXTLINE: Safe
[&] -> boost::asio::awaitable<void> {
RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}};
XmlStream stream{impl::MockSocket{context.get_executor(), 1}};
stream.AddReceivedData(kDoc4);
try {
auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne());
auto node = doc->get_root_node();
EXPECT_TRUE(node);
if(!node) {
co_return;
}
auto stream = ServerToUserStream::Parse(node);
ServerToUserStream value = (co_await stream.ReadOne(), co_await stream.ReadOne<ServerToUserStream>());
EXPECT_EQ(value.id, "68321991947053239");
EXPECT_EQ(value.version, "1.0");
EXPECT_EQ(value.to, BareJid::Parse("test1@localhost"));
EXPECT_EQ(value.from, "localhost");
} catch(const std::exception& err) {
SPDLOG_ERROR("{}", err.what());
error = true;