From 0636e1c2347aae2e35f1b1d9f7b3666a8fd7f84b Mon Sep 17 00:00:00 2001 From: sha512sum Date: Tue, 8 Oct 2024 08:36:08 +0000 Subject: [PATCH] Add stream errors handling --- CMakeLists.txt | 9 +- library/include/larra/client/client.hpp | 128 ++++++------------ .../client/xmpp_client_stream_features.hpp | 3 +- library/include/larra/features.hpp | 1 + library/include/larra/jid.hpp | 5 + library/include/larra/printer_stream.hpp | 2 +- library/include/larra/serialization.hpp | 48 ++++++- library/include/larra/stream_error.hpp | 63 +++++---- .../{raw_xml_stream.hpp => xml_stream.hpp} | 75 +++++++--- .../{raw_xml_stream.cpp => xml_stream.cpp} | 2 +- tests/{raw_xml_stream.cpp => xml_stream.cpp} | 50 ++++--- 11 files changed, 216 insertions(+), 170 deletions(-) rename library/include/larra/{raw_xml_stream.hpp => xml_stream.hpp} (78%) rename library/src/{raw_xml_stream.cpp => xml_stream.cpp} (99%) rename tests/{raw_xml_stream.cpp => xml_stream.cpp} (78%) diff --git a/CMakeLists.txt b/CMakeLists.txt index d37d3c4..e96e89f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,12 @@ CPMAddPackage( CPMAddPackage("gh:zeux/pugixml@1.14") CPMAddPackage("gh:fmtlib/fmt#10.2.1") -CPMAddPackage("gh:Neargye/nameof@0.10.4") +CPMAddPackage(NAME nameof + VERSION 0.10.4 + GIT_REPOSITORY "https://github.com/Neargye/nameof.git" + EXCLUDE_FROM_ALL ON + OPTIONS "NAMEOF_OPT_INSTALL ON" + ) CPMAddPackage( NAME spdlog @@ -167,7 +172,7 @@ if(TARGET Boost::pfr) OpenSSL::SSL nameof::nameof OpenSSL::Crypto spdlog xmlplusplus ${LIBXML2_LIBRARIES}) else() - find_package(Boost 1.85.0 REQUIRED) + find_package(Boost 1.85.0 COMPONENTS serialization REQUIRED) target_link_libraries(larra_xmpp PUBLIC utempl::utempl ${Boost_LIBRARIES} pugixml::pugixml OpenSSL::SSL nameof::nameof diff --git a/library/include/larra/client/client.hpp b/library/include/larra/client/client.hpp index 62b2467..666537f 100644 --- a/library/include/larra/client/client.hpp +++ b/library/include/larra/client/client.hpp @@ -11,15 +11,16 @@ #include #include #include +#include #include +#include +#include #include #include -#include #include #include +#include #include - -#include "larra/client/xmpp_client_stream_features.hpp" namespace larra::xmpp { constexpr auto kDefaultXmppPort = 5222; @@ -33,15 +34,15 @@ namespace views = std::views; template struct Client { - constexpr Client(BareJid jid, RawXmlStream connection) : jid(std::move(jid)), connection(std::move(connection)) {}; + constexpr Client(BareJid jid, XmlStream connection) : jid(std::move(jid)), connection(std::move(connection)) {}; template Token = boost::asio::use_awaitable_t<>> constexpr auto Close(Token token = {}) { this->active = false; return boost::asio::async_initiate( - [](Handler&& h, RawXmlStream connection) { // NOLINT + [](Handler&& h, XmlStream connection) { // NOLINT boost::asio::co_spawn( connection.next_layer().get_executor(), - [](auto h, RawXmlStream connection) -> boost::asio::awaitable { + [](auto h, XmlStream connection) -> boost::asio::awaitable { co_await boost::asio::async_write( connection.next_layer(), boost::asio::buffer(""), boost::asio::use_awaitable); std::string response; @@ -70,7 +71,7 @@ struct Client { private: bool active = true; - RawXmlStream connection; + XmlStream connection; BareJid jid; }; @@ -158,7 +159,7 @@ struct ClientCreateVisitor { const Options& options; template - auto Auth(PlainUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + auto Auth(PlainUserAccount account, XmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { SPDLOG_DEBUG("Start Plain Auth"); if(!std::ranges::contains(features.saslMechanisms.mechanisms, "PLAIN")) { @@ -166,11 +167,17 @@ struct ClientCreateVisitor { } const features::PlainAuthData data{.username = account.jid.username, .password = account.password}; co_await stream.Send(data); - std::ignore = co_await stream.Read(); + sasl::Response response = co_await stream.template Read(); + std::visit(utempl::Overloaded( + [](auto error) { + throw std::move(error); + }, + [](sasl::Success) {}), + response); } template - auto ScramAuth(std::string methodName, EncryptionUserAccount account, RawXmlStream& stream, Tag tag) + auto ScramAuth(std::string methodName, EncryptionUserAccount account, XmlStream& stream, Tag tag) -> boost::asio::awaitable { SPDLOG_DEBUG("Start Scram Auth using '{}'", methodName); const auto nonce = GenerateNonce(); @@ -191,23 +198,18 @@ struct ClientCreateVisitor { .iterations = challenge.iterations, .tag = tag}; co_await stream.Send(challengeResponse); - std::unique_ptr doc = co_await stream.Read(); - auto root = doc->get_root_node(); - if(!root || root->get_name() == "failure") { - if(auto textNode = root->get_first_child("text")) { - if(auto text = dynamic_cast(textNode)) { - if(auto childText = text->get_first_child_text()) { - throw std::runtime_error(std::format("Auth failed: {}", childText->get_content())); - } - } - } - throw std::runtime_error("Auth failed"); - } + sasl::Response response = co_await stream.template Read(); + std::visit(utempl::Overloaded( + [](auto error) { + throw std::move(error); + }, + [](sasl::Success) {}), + response); SPDLOG_DEBUG("Success auth for JID {}", ToString(account.jid)); } template - auto Auth(EncryptionRequiredUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + auto Auth(EncryptionRequiredUserAccount account, XmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-512")) { co_return co_await ScramAuth("SCRAM-SHA-512", std::move(account), stream, sha512sum::EncryptionTag{}); @@ -222,7 +224,7 @@ struct ClientCreateVisitor { } template - auto Auth(EncryptionUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + auto Auth(EncryptionUserAccount account, XmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512") ? co_await this->Auth(EncryptionRequiredUserAccount{std::move(account)}, stream, std::move(streamHeader), std::move(features)) @@ -230,7 +232,7 @@ struct ClientCreateVisitor { } template - auto Auth(RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { + auto Auth(XmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { co_return co_await std::visit( [&](auto& account) -> boost::asio::awaitable { return this->Auth(std::move(account), stream, std::move(streamHeader), std::move(features)); @@ -251,17 +253,16 @@ struct ClientCreateVisitor { } template - auto ProcessTls(RawXmlStream>& stream) -> boost::asio::awaitable { + auto ProcessTls(XmlStream>& stream) -> boost::asio::awaitable { const StartTlsRequest request; co_await stream.Send(request); - std::unique_ptr doc = co_await stream.Read(); - if(auto node = doc->get_root_node()) { - if(node->get_name() == "proceed") { - goto proceed; // NOLINT - } - throw StartTlsNegotiationError{"Failure XMPP"}; - } - proceed: + starttls::Response response = co_await stream.template Read(); + std::visit(utempl::Overloaded( + [](starttls::Failure error) { + throw error; + }, + [](starttls::Success) {}), + response); auto& socket = stream.next_layer(); SSL_set_tlsext_host_name(socket.native_handle(), this->account.Jid().server.c_str()); try { @@ -270,57 +271,14 @@ struct ClientCreateVisitor { throw StartTlsNegotiationError{e.what()}; } } - static constexpr auto GetEnumerated(boost::asio::streambuf& streambuf) { - return std::views::zip(std::views::iota(std::size_t{}, streambuf.size()), ::larra::xmpp::impl::GetCharsRangeFromBuf(streambuf)); - } - using EnumeratedT = decltype(std::views::zip(std::views::iota(std::size_t{}, std::size_t{}), - ::larra::xmpp::impl::GetCharsRangeFromBuf(std::declval()))); - struct Splitter { - EnumeratedT range; - struct Sentinel { - std::ranges::sentinel_t end; - }; - struct Iterator { - std::ranges::iterator_t it; - std::ranges::sentinel_t end; - friend constexpr auto operator==(const Iterator& self, const Sentinel& it) -> bool { - return self.it == it.end; - } - auto operator++() -> Iterator& { - if(this->it == this->end) { - return *this; - } - this->it = std::ranges::find(this->it, this->end, '>', [](auto v) { - auto [_, c] = v; - return c; - }); - if(this->it != this->end) { - ++it; - } - - return *this; - }; - auto operator*() const { - return *it; - } - }; - auto begin() -> Iterator { - return Iterator{.it = std::ranges::begin(this->range), .end = std::ranges::end(this->range)}; - } - - auto end() -> Sentinel { - return {.end = std::ranges::end(this->range)}; - } - }; template - auto ReadStartStream(RawXmlStream& stream) -> boost::asio::awaitable { - auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne()); - co_return ServerToUserStream::Parse(doc->get_root_node()); + auto ReadStartStream(XmlStream& stream) -> boost::asio::awaitable { + co_return (co_await stream.ReadOne(), co_await stream.template ReadOne()); } template - inline auto operator()(RawXmlStream stream) + inline auto operator()(XmlStream stream) -> boost::asio::awaitable, Client>>> { co_await this->Connect(stream.next_layer(), co_await this->Resolve()); @@ -339,7 +297,7 @@ struct ClientCreateVisitor { } template - inline auto operator()(RawXmlStream> stream) + inline auto operator()(XmlStream> stream) -> boost::asio::awaitable, Client>>> { auto& socket = stream.next_layer(); co_await this->Connect(socket.next_layer(), co_await this->Resolve()); @@ -353,14 +311,14 @@ struct ClientCreateVisitor { throw std::runtime_error("XMPP server not support STARTTLS"); } socket.next_layer().close(); - co_return co_await (*this)(RawXmlStream{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)}); + co_return co_await (*this)(XmlStream{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)}); } co_await this->ProcessTls(stream); co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer()); auto newStreamHeader = co_await this->ReadStartStream(stream); auto newFeatures = co_await stream.template Read(); co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures)); - co_return Client{std::move(this->account).Jid(), RawXmlStream{std::move(socket)}}; + co_return Client{std::move(this->account).Jid(), XmlStream{std::move(socket)}}; } }; @@ -374,8 +332,8 @@ inline auto CreateClient(UserAccount account, Options options = {}) co_return co_await std::visit( impl::ClientCreateVisitor{.account = std::move(account), .options = options}, options.useTls == Options::kNever - ? std::variant, RawXmlStream>>{RawXmlStream{Socket{executor}}} - : RawXmlStream{boost::asio::ssl::stream(executor, ctx)}); + ? std::variant, XmlStream>>{XmlStream{Socket{executor}}} + : XmlStream{boost::asio::ssl::stream(executor, ctx)}); } } // namespace larra::xmpp::client diff --git a/library/include/larra/client/xmpp_client_stream_features.hpp b/library/include/larra/client/xmpp_client_stream_features.hpp index 6dc8919..b2621c7 100644 --- a/library/include/larra/client/xmpp_client_stream_features.hpp +++ b/library/include/larra/client/xmpp_client_stream_features.hpp @@ -4,8 +4,7 @@ #include #include -#include -#include +#include namespace larra::xmpp::client::features { /* diff --git a/library/include/larra/features.hpp b/library/include/larra/features.hpp index 0e7cad0..f4ffc2e 100644 --- a/library/include/larra/features.hpp +++ b/library/include/larra/features.hpp @@ -17,6 +17,7 @@ struct SaslMechanisms { }; struct StreamFeatures { + static constexpr auto kDefaultName = "stream:features"; struct StartTlsType { Required required; [[nodiscard]] constexpr auto Required(Required required) const -> StartTlsType { diff --git a/library/include/larra/jid.hpp b/library/include/larra/jid.hpp index a61d10f..567f196 100644 --- a/library/include/larra/jid.hpp +++ b/library/include/larra/jid.hpp @@ -13,6 +13,7 @@ struct BareJid { [[nodiscard]] static auto Parse(std::string_view jid) -> BareJid; friend auto ToString(const BareJid& jid) -> std::string; + constexpr auto operator==(const BareJid&) const -> bool = default; template [[nodiscard]] constexpr auto Username(this Self&& self, std::string username) -> BareJid { return utils::FieldSetHelper::With<"username", BareJid>(std::forward(self), std::move(username)); @@ -30,6 +31,8 @@ struct BareResourceJid { [[nodiscard]] static auto Parse(std::string_view jid) -> BareResourceJid; friend auto ToString(const BareResourceJid& jid) -> std::string; + constexpr auto operator==(const BareResourceJid&) const -> bool = default; + template [[nodiscard]] constexpr auto Server(this Self&& self, std::string server) -> BareResourceJid { return utils::FieldSetHelper::With<"server", BareResourceJid>(std::forward(self), std::move(server)); @@ -48,6 +51,8 @@ struct FullJid { [[nodiscard]] static auto Parse(std::string_view jid) -> FullJid; friend auto ToString(const FullJid& jid) -> std::string; + constexpr auto operator==(const FullJid&) const -> bool = default; + template [[nodiscard]] constexpr auto Username(this Self&& self, std::string username) -> FullJid { return utils::FieldSetHelper::With<"username", FullJid>(std::forward(self), std::move(username)); diff --git a/library/include/larra/printer_stream.hpp b/library/include/larra/printer_stream.hpp index f272ffb..5e3f8b6 100644 --- a/library/include/larra/printer_stream.hpp +++ b/library/include/larra/printer_stream.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include diff --git a/library/include/larra/serialization.hpp b/library/include/larra/serialization.hpp index ae2fc4e..707cc05 100644 --- a/library/include/larra/serialization.hpp +++ b/library/include/larra/serialization.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -80,23 +81,35 @@ struct SerializationBase { return false; } }(); + static constexpr auto StartCheck(xmlpp::Element* element) -> bool { + if constexpr(requires { + { T::StartCheck(element) } -> std::same_as; + }) { + return T::StartCheck(element); + } else { + return element && element->get_name() == kDefaultName; + } + }; }; template struct Serialization : SerializationBase { [[nodiscard]] static constexpr auto Parse(xmlpp::Element* element) -> T { + if(!Serialization::StartCheck(element)) { + throw std::runtime_error("StartCheck failed"); + } return T::Parse(element); } [[nodiscard]] static constexpr auto TryParse(xmlpp::Element* element) -> std::optional { if constexpr(HasTryParse) { - return T::TryParse(element); + return Serialization::StartCheck(element) ? T::TryParse(element) : std::nullopt; } else { try { - return T::Parse(element); + return Serialization::StartCheck(element) ? std::optional{T::Parse(element)} : std::nullopt; } catch(const std::exception& e) { - SPDLOG_WARN("Failed Parse but no TryParse found: {}", e.what()); + SPDLOG_WARN("Type {}: Failed Parse but no TryParse found: {}", e.what(), nameof::nameof_type()); } catch(...) { - SPDLOG_WARN("Failed Parse but no TryParse found"); + SPDLOG_WARN("Type {}: Failed Parse but no TryParse found", nameof::nameof_type()); } return std::nullopt; } @@ -122,9 +135,17 @@ struct Serialization> : SerializationBase { template struct Serialization> : SerializationBase<> { + static constexpr auto StartCheck(xmlpp::Element* element) { + return true; + } [[nodiscard]] static constexpr auto TryParse(xmlpp::Element* element) -> std::optional> { - return utempl::FirstOf(utempl::Tuple{[&] { - return Serialization::TryParse(element); + return utempl::FirstOf(utempl::Tuple{[&] -> std::optional { + if(Serialization::StartCheck(element)) { + return Serialization::TryParse(element); + } else { + SPDLOG_DEBUG("StartCheck failed for type {}", nameof::nameof_type()); + return std::nullopt; + } }...}, std::optional>{}); } @@ -140,4 +161,19 @@ struct Serialization> : SerializationBase<> { } }; +template <> +struct Serialization : SerializationBase<> { + static constexpr auto StartCheck(xmlpp::Element*) -> bool { + return true; + }; + [[nodiscard]] static constexpr auto TryParse(xmlpp::Element*) -> std::optional { + return std::monostate{}; + } + [[nodiscard]] static constexpr auto Parse(xmlpp::Element*) -> std::monostate { + return {}; + } + static constexpr auto Serialize(xmlpp::Element*, const std::monostate&) -> void { + } +}; + } // namespace larra::xmpp diff --git a/library/include/larra/stream_error.hpp b/library/include/larra/stream_error.hpp index e707608..82822db 100644 --- a/library/include/larra/stream_error.hpp +++ b/library/include/larra/stream_error.hpp @@ -48,12 +48,15 @@ constexpr auto ToKebabCaseName() -> std::string_view { namespace error::stream { +struct BaseError : std::exception {}; + +// DO NOT MOVE TO ANOTHER NAMESPACE(where no heirs). VIA friend A FUNCTION IS ADDED THAT VIA ADL WILL BE SEARCHED FOR HEIRS +// C++20 modules very unstable in clangd :( template -struct BaseError : std::exception { - static constexpr auto kDefaultName = "error"; - static constexpr auto kDefaultNamespace = "stream"; +struct ErrorImpl : BaseError { + static constexpr auto kDefaultName = "stream:error"; static inline const auto kKebabCaseName = static_cast(impl::ToKebabCaseName()); - static inline const std::string kErrorContentNamespace = "urn:ietf:params:xml:ns:xmpp-streams"; + static constexpr auto kErrorMessage = [] -> std::string_view { static constexpr auto str = [] { return std::array{std::string_view{"Stream Error: "}, nameof::nameof_short_type(), std::string_view{"\0", 1}} | std::views::join; @@ -76,38 +79,38 @@ struct BaseError : std::exception { } friend constexpr auto operator<<(xmlpp::Element* element, const T& obj) -> void { auto node = element->add_child_element(kKebabCaseName); - node->set_namespace_declaration(kErrorContentNamespace); + node->set_namespace_declaration("urn:ietf:params:xml:ns:xmpp-streams"); } [[nodiscard]] constexpr auto what() const noexcept -> const char* override { return kErrorMessage.data(); } }; -struct BadFormat : BaseError {}; -struct BadNamespacePrefix : BaseError {}; -struct Conflict : BaseError {}; -struct ConnectionTimeout : BaseError {}; -struct HostGone : BaseError {}; -struct HostUnknown : BaseError {}; -struct ImproperAdressing : BaseError {}; -struct InternalServerError : BaseError {}; -struct InvalidForm : BaseError {}; -struct InvalidNamespace : BaseError {}; -struct InvalidXml : BaseError {}; -struct NotAuthorized : BaseError {}; -struct NotWellFormed : BaseError {}; -struct PolicyViolation : BaseError {}; -struct RemoteConnectionFailed : BaseError {}; -struct Reset : BaseError {}; -struct ResourceConstraint : BaseError {}; -struct RestrictedXml : BaseError {}; -struct SeeOtherHost : BaseError {}; -struct SystemShutdown : BaseError {}; -struct UndefinedCondition : BaseError {}; -struct UnsupportedEncoding : BaseError {}; -struct UnsupportedFeature : BaseError {}; -struct UnsupportedStanzaType : BaseError {}; -struct UnsupportedVersion : BaseError {}; +struct BadFormat : ErrorImpl {}; +struct BadNamespacePrefix : ErrorImpl {}; +struct Conflict : ErrorImpl {}; +struct ConnectionTimeout : ErrorImpl {}; +struct HostGone : ErrorImpl {}; +struct HostUnknown : ErrorImpl {}; +struct ImproperAdressing : ErrorImpl {}; +struct InternalServerError : ErrorImpl {}; +struct InvalidForm : ErrorImpl {}; +struct InvalidNamespace : ErrorImpl {}; +struct InvalidXml : ErrorImpl {}; +struct NotAuthorized : ErrorImpl {}; +struct NotWellFormed : ErrorImpl {}; +struct PolicyViolation : ErrorImpl {}; +struct RemoteConnectionFailed : ErrorImpl {}; +struct Reset : ErrorImpl {}; +struct ResourceConstraint : ErrorImpl {}; +struct RestrictedXml : ErrorImpl {}; +struct SeeOtherHost : ErrorImpl {}; +struct SystemShutdown : ErrorImpl {}; +struct UndefinedCondition : ErrorImpl {}; +struct UnsupportedEncoding : ErrorImpl {}; +struct UnsupportedFeature : ErrorImpl {}; +struct UnsupportedStanzaType : ErrorImpl {}; +struct UnsupportedVersion : ErrorImpl {}; } // namespace error::stream diff --git a/library/include/larra/raw_xml_stream.hpp b/library/include/larra/xml_stream.hpp similarity index 78% rename from library/include/larra/raw_xml_stream.hpp rename to library/include/larra/xml_stream.hpp index 3ff16a5..168bfff 100644 --- a/library/include/larra/raw_xml_stream.hpp +++ b/library/include/larra/xml_stream.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -89,8 +90,8 @@ auto IsExtraContentAtTheDocument(const _xmlError* error) -> bool; } // namespace impl template -struct RawXmlStream : Stream { - constexpr RawXmlStream(Stream stream, std::unique_ptr buff = std::make_unique()) : +struct XmlStream : Stream { + constexpr XmlStream(Stream stream, std::unique_ptr buff = std::make_unique()) : Stream(std::forward(stream)), streambuf(std::move(buff)) {}; using Stream::Stream; auto next_layer() -> Stream& { @@ -101,7 +102,7 @@ struct RawXmlStream : Stream { return *this; } - auto ReadOne(auto& socket) -> boost::asio::awaitable> { + auto ReadOneRaw(auto& socket) -> boost::asio::awaitable> { auto doc = std::make_unique(); impl::Parser parser(*doc); for(;;) { @@ -144,10 +145,21 @@ struct RawXmlStream : Stream { co_return doc; } } - auto ReadOne() -> boost::asio::awaitable> { - co_return co_await this->ReadOne(this->next_layer()); + template + auto ReadOneRaw(auto& stream) -> boost::asio::awaitable { + auto doc = co_await this->ReadOneRaw(stream); + co_return Serialization::Parse(doc->get_root_node()); } - inline auto Read(auto& socket) -> boost::asio::awaitable> { + + template + auto ReadOneRaw(auto& stream) -> boost::asio::awaitable + requires requires(std::unique_ptr ptr) { + { Serialization::Parse(std::move(ptr)) } -> std::same_as; + } + { + co_return Serialization::Parse(co_await this->ReadOneRaw(stream)); + } + inline auto ReadRaw(auto& socket) -> boost::asio::awaitable> { auto doc = std::make_unique(); // Not movable :( impl::Parser parser(*doc); std::size_t lines = 1; @@ -186,7 +198,7 @@ struct RawXmlStream : Stream { if(!error) { auto linesAdd = impl::CountLines(impl::BufferToStringView(buff, n)); - SPDLOG_DEBUG("Readed {} bytes for RawXmlStream with {} lines", n, linesAdd); + SPDLOG_DEBUG("Readed {} bytes for XmlStream with {} lines", n, linesAdd); lines += linesAdd; if(linesAdd == 0) { @@ -209,27 +221,56 @@ struct RawXmlStream : Stream { co_return doc; } } - auto Read() -> boost::asio::awaitable> { - co_return co_await this->Read(this->next_layer()); - } template - auto Read(auto& stream) -> boost::asio::awaitable { - auto doc = co_await this->Read(stream); + auto ReadRaw(auto& stream) -> boost::asio::awaitable { + auto doc = co_await this->ReadRaw(stream); co_return Serialization::Parse(doc->get_root_node()); } template - auto Read(auto& stream) -> boost::asio::awaitable + auto ReadRaw(auto& stream) -> boost::asio::awaitable requires requires(std::unique_ptr ptr) { { Serialization::Parse(std::move(ptr)) } -> std::same_as; } { - co_return Serialization::Parse(co_await this->Read(stream)); + co_return Serialization::Parse(co_await this->ReadRaw(stream)); } + private: template - auto Read() -> boost::asio::awaitable { - co_return co_await this->template Read(this->next_layer()); + auto ReadImpl(boost::asio::awaitable> awaitable) -> boost::asio::awaitable { + co_return std::visit(utempl::Overloaded( + [](T value) -> T { + return std::move(value); + }, + [](StreamError error) -> T { + std::visit( + [](auto error) { + throw error; + }, + error); + std::unreachable(); + }), + co_await std::move(awaitable)); + } + + public: + template + auto Read(auto& stream) { + return this->ReadImpl(this->ReadRaw>(stream)); + } + + template + auto Read() { + return this->Read(this->next_layer()); + } + template + auto ReadOne(auto& stream) { + return this->ReadImpl(this->ReadOneRaw>(stream)); + } + template + auto ReadOne() { + return this->ReadOne(this->next_layer()); } auto Send(xmlpp::Document& doc, auto& stream, bool bAddXmlDecl, bool removeEnd) const -> boost::asio::awaitable { @@ -270,7 +311,7 @@ struct RawXmlStream : Stream { co_await this->Send(xso, this->next_layer()); } - RawXmlStream(RawXmlStream&& other) = default; + XmlStream(XmlStream&& other) = default; std::unique_ptr streambuf; // Not movable :( }; diff --git a/library/src/raw_xml_stream.cpp b/library/src/xml_stream.cpp similarity index 99% rename from library/src/raw_xml_stream.cpp rename to library/src/xml_stream.cpp index cc74ae2..69afa67 100644 --- a/library/src/raw_xml_stream.cpp +++ b/library/src/xml_stream.cpp @@ -1,8 +1,8 @@ #include #include -#include #include +#include #include #include diff --git a/tests/raw_xml_stream.cpp b/tests/xml_stream.cpp similarity index 78% rename from tests/raw_xml_stream.cpp rename to tests/xml_stream.cpp index a2c7638..ed3906a 100644 --- a/tests/raw_xml_stream.cpp +++ b/tests/xml_stream.cpp @@ -5,8 +5,8 @@ #include #include #include -#include #include +#include #include namespace larra::xmpp { @@ -23,7 +23,7 @@ constexpr std::string_view kDoc3 = constexpr std::string_view kDoc4 = R"()"; -TEST(RawXmlStream, ReadByOne) { +TEST(XmlStream, ReadByOne) { boost::asio::io_context context; bool error{}; @@ -31,14 +31,14 @@ TEST(RawXmlStream, ReadByOne) { context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}}; + XmlStream stream{impl::MockSocket{context.get_executor(), 1}}; stream.AddReceivedData(kDoc); try { - auto doc = co_await stream.Read(); + auto doc = co_await stream.ReadRaw(stream.next_layer()); auto node = doc->get_root_node(); EXPECT_EQ(node->get_name(), std::string_view{"doc"}); EXPECT_FALSE(node->has_child_text()); - auto doc2 = co_await stream.Read(); + auto doc2 = co_await stream.ReadRaw(stream.next_layer()); auto node2 = doc2->get_root_node(); EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); EXPECT_FALSE(node2->has_child_text()); @@ -54,20 +54,20 @@ TEST(RawXmlStream, ReadByOne) { EXPECT_FALSE(error); } -TEST(RawXmlStream, ReadAll) { +TEST(XmlStream, ReadAll) { boost::asio::io_context context; bool error{}; boost::asio::co_spawn( context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc.size()}}; + XmlStream stream{impl::MockSocket{context.get_executor(), kDoc.size()}}; stream.AddReceivedData(kDoc); try { - auto doc = co_await stream.Read(); + auto doc = co_await stream.ReadRaw(stream.next_layer()); auto node = doc->get_root_node(); EXPECT_EQ(node->get_name(), std::string_view{"doc"}); EXPECT_FALSE(node->has_child_text()); - auto doc2 = co_await stream.Read(); + auto doc2 = co_await stream.ReadRaw(stream.next_layer()); auto node2 = doc2->get_root_node(); EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); EXPECT_FALSE(node2->has_child_text()); @@ -82,20 +82,20 @@ TEST(RawXmlStream, ReadAll) { EXPECT_FALSE(error); } -TEST(RawXmlStream, ReadAllWithEnd) { +TEST(XmlStream, ReadAllWithEnd) { boost::asio::io_context context; bool error{}; boost::asio::co_spawn( context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc2.size()}}; + XmlStream stream{impl::MockSocket{context.get_executor(), kDoc2.size()}}; stream.AddReceivedData(kDoc2); try { - auto doc = co_await stream.Read(); + auto doc = co_await stream.ReadRaw(stream.next_layer()); auto node = doc->get_root_node(); EXPECT_EQ(node->get_name(), std::string_view{"doc"}); EXPECT_FALSE(node->has_child_text()); - auto doc2 = co_await stream.Read(); + auto doc2 = co_await stream.ReadRaw(stream.next_layer()); auto node2 = doc2->get_root_node(); EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); EXPECT_FALSE(node2->has_child_text()); @@ -110,13 +110,13 @@ TEST(RawXmlStream, ReadAllWithEnd) { EXPECT_FALSE(error); } -TEST(RawXmlStream, ReadFeatures) { +TEST(XmlStream, ReadFeatures) { boost::asio::io_context context; bool error{}; boost::asio::co_spawn( context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc3.size()}}; + XmlStream stream{impl::MockSocket{context.get_executor(), kDoc3.size()}}; stream.AddReceivedData(kDoc3); try { auto features = co_await stream.template Read(); @@ -139,13 +139,13 @@ struct SomeStruct { } }; -TEST(RawXmlStream, Write) { +TEST(XmlStream, Write) { boost::asio::io_context context; bool error{}; boost::asio::co_spawn( context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream1{impl::MockSocket{context.get_executor()}}; + XmlStream stream1{impl::MockSocket{context.get_executor()}}; auto stream = std::move(stream1); try { co_await stream.Send(SomeStruct{}); @@ -160,23 +160,21 @@ TEST(RawXmlStream, Write) { EXPECT_FALSE(error); } -TEST(RawXmlStream, ReadOneByOne) { +TEST(XmlStream, ReadOneByOne) { boost::asio::io_context context; bool error{}; boost::asio::co_spawn( context, // NOLINTNEXTLINE: Safe [&] -> boost::asio::awaitable { - RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}}; + XmlStream stream{impl::MockSocket{context.get_executor(), 1}}; stream.AddReceivedData(kDoc4); try { - auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne()); - auto node = doc->get_root_node(); - EXPECT_TRUE(node); - if(!node) { - co_return; - } - auto stream = ServerToUserStream::Parse(node); + ServerToUserStream value = (co_await stream.ReadOne(), co_await stream.ReadOne()); + EXPECT_EQ(value.id, "68321991947053239"); + EXPECT_EQ(value.version, "1.0"); + EXPECT_EQ(value.to, BareJid::Parse("test1@localhost")); + EXPECT_EQ(value.from, "localhost"); } catch(const std::exception& err) { SPDLOG_ERROR("{}", err.what()); error = true;