diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3a3a16a..88118e7 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -4,9 +4,13 @@ FROM archlinux@sha256:a10e51dd0694d6c4142754e9d06cbce7baf91ace8031a30df37064d1091ab414 # Update the package database and install clang +# 1. system tools +# 2. build tools +# 3. libraries RUN pacman -Syyu --noconfirm \ - && pacman -S --noconfirm git less vim sudo base-devel \ - && pacman -S --noconfirm clang cmake make ninja gtk4 gtkmm-4.0 boost spdlog fmt pugixml + && pacman -S --noconfirm git less vim sudo base-devel python-pip \ + && pacman -S --noconfirm clang cmake make ninja meson \ + && pacman -S --noconfirm gtk4 gtkmm-4.0 boost spdlog fmt libxml++-5.0 # Create a non-root user 'dev' RUN useradd -ms /bin/bash dev \ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b2ffe09..4e6c412 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -2,6 +2,7 @@ "name": "Arch Linux with GCC & Clang++", "dockerComposeFile": "docker-compose.yml", "service": "devcontainer", + "initializeCommand": "docker stop ejabberd > /dev/null 2>&1 ; docker rm ejabberd > /dev/null 2>&1 ; exit 0", "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", "forwardPorts": [ 5222, diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index c7b7f55..d33aca1 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -21,15 +21,16 @@ services: ejabberd_server: image: ghcr.io/processone/ejabberd container_name: ejabberd - # - # For some reasons below environment variables doesn't work inside vs code dev container - # Please, use devcontainer.json 'postStartCommand' for configuring ejabberd_server - # + pull_policy: always # Do not use cache for ejabberd + environment: - - CTL_ON_CREATE=register admin localhost "admin" ; - register test1 localhost "test1" + - CTL_ON_CREATE=register admin localhost admin ; + register test1 localhost test1 - CTL_ON_START=registered_users localhost ; - status + status ; + check_password test1 localhost test1 ; + help accounts + ports: - "5222:5222" - "5269:5269" diff --git a/.devcontainer/ejabberd.yml b/.devcontainer/ejabberd.yml index 9f52450..21c0a24 100644 --- a/.devcontainer/ejabberd.yml +++ b/.devcontainer/ejabberd.yml @@ -17,7 +17,13 @@ hosts: - localhost -loglevel: info +auth_method: internal +#auth_password_format: scram +#auth_scram_hash: sha256 +auth_use_cache: false + +loglevel: debug +hide_sensitive_log_data: false ca_file: /opt/ejabberd/conf/cacert.pem @@ -89,6 +95,11 @@ listen: s2s_use_starttls: optional +c2s_protocol_options: + - no_sslv3 + - cipher_server_preference + - no_compression + acl: admin: user: admin@localhost diff --git a/.devcontainer/post_create_config.sh b/.devcontainer/post_create_config.sh index 3927e55..65b424d 100755 --- a/.devcontainer/post_create_config.sh +++ b/.devcontainer/post_create_config.sh @@ -1,8 +1,12 @@ sleep 1 +# Check that ejabberd server started successfully +red='\e[1;31m' +off='\e[0m' +if [ "$( sudo docker container inspect -f '{{.State.Status}}' ejabberd )" != "running" ]; then printf "\n\n\t$red ERROR: ejabberd container is not running! $off Stop vscode dev environment \n\n\n\n"; exit 1; fi + printf "\n\n\tConfigure ejabber server\n\n" -sudo docker exec -it ejabberd ejabberdctl register admin localhost admin -sudo docker exec -it ejabberd ejabberdctl register sha512sum localhost 12345 +# sudo docker exec -it ejabberd ejabberdctl register user localhost password printf "\n\n\tList of registered users:\n" sudo docker exec -it ejabberd ejabberdctl registered_users localhost diff --git a/.gitignore b/.gitignore index 6ceee8b..4028266 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,6 @@ compile_commands.json cpm-package-lock.cmake larraXMPPConfig.cmake larraXMPPVersionConfig.cmake -larra larra_xmpp_tests larra_xmpp_tests\[1\]_include.cmake larra_xmpp_tests\[1\]_tests.cmake diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..1ee66a5 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug: connect", + "program": "${workspaceFolder}/build/examples/output/connect", + "args": [], + "cwd": "${workspaceFolder}", + "preLaunchTask": "Build Debug GCC" + } + ] +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e358d7..de8f26c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,8 +75,15 @@ CPMAddPackage( ) set(CPM_USE_LOCAL_PACKAGES ${TMP}) +find_package(LibXml2 REQUIRED) + pkg_check_modules(xmlplusplus libxml++-5.0) + + + + + if(xmlplusplus_FOUND) add_library(xmlplusplus INTERFACE EXCLUDE_FROM_ALL) @@ -156,12 +163,12 @@ target_include_directories(larra_xmpp PUBLIC if(TARGET Boost::pfr) target_link_libraries(larra_xmpp PUBLIC Boost::asio Boost::serialization utempl::utempl pugixml::pugixml OpenSSL::SSL - OpenSSL::Crypto spdlog xmlplusplus) + OpenSSL::Crypto spdlog xmlplusplus ${LIBXML2_LIBRARIES}) else() find_package(Boost 1.85.0 REQUIRED) target_link_libraries(larra_xmpp PUBLIC utempl::utempl ${Boost_LIBRARIES} pugixml::pugixml OpenSSL::SSL - OpenSSL::Crypto spdlog xmlplusplus) + OpenSSL::Crypto spdlog xmlplusplus ${LIBXML2_LIBRARIES}) endif() diff --git a/examples/src/connect.cpp b/examples/src/connect.cpp index ab24ba2..5c4e064 100644 --- a/examples/src/connect.cpp +++ b/examples/src/connect.cpp @@ -11,7 +11,8 @@ auto Coroutine() -> boost::asio::awaitable { try { auto client = co_await larra::xmpp::client::CreateClient>( - larra::xmpp::EncryptionUserAccount{{"test1", "localhost"}, "test1"}, {.useTls = larra::xmpp::client::Options::kNever}); + larra::xmpp::PlainUserAccount{.jid = {.username = "test1", .server = "localhost"}, .password = "test1"}, + {.useTls = larra::xmpp::client::Options::kNever}); } catch(const std::exception& err) { SPDLOG_ERROR("{}", err.what()); co_return; diff --git a/library/include/larra/client/client.hpp b/library/include/larra/client/client.hpp index 6d524f6..b1dd348 100644 --- a/library/include/larra/client/client.hpp +++ b/library/include/larra/client/client.hpp @@ -14,9 +14,12 @@ #include #include #include +#include #include #include #include + +#include "larra/client/xmpp_client_stream_features.hpp" namespace larra::xmpp { constexpr auto kDefaultXmppPort = 5222; @@ -25,22 +28,26 @@ constexpr auto kDefaultXmppPort = 5222; namespace larra::xmpp::client { +namespace rng = std::ranges; +namespace views = std::views; + template struct Client { - constexpr Client(BareJid jid, Connection connection) : jid(std::move(jid)), connection(std::move(connection)) {}; + constexpr Client(BareJid jid, RawXmlStream connection) : jid(std::move(jid)), connection(std::move(connection)) {}; template Token = boost::asio::use_awaitable_t<>> constexpr auto Close(Token token = {}) { this->active = false; return boost::asio::async_initiate( - [](Handler&& h, Connection connection) { + [](Handler&& h, RawXmlStream connection) { // NOLINT boost::asio::co_spawn( - connection.get_executor(), - [](auto h, Connection connection) -> boost::asio::awaitable { - co_await boost::asio::async_write(connection, boost::asio::buffer(""), boost::asio::use_awaitable); + connection.next_layer().get_executor(), + [](auto h, RawXmlStream connection) -> boost::asio::awaitable { + co_await boost::asio::async_write( + connection.next_layer(), boost::asio::buffer(""), boost::asio::use_awaitable); std::string response; co_await boost::asio::async_read_until( - connection, boost::asio::dynamic_buffer(response), "", boost::asio::use_awaitable); - std::move(h)(); + connection.next_layer(), boost::asio::dynamic_buffer(response), "", boost::asio::use_awaitable); + h(); }(std::move(h), std::move(connection)), boost::asio::detached); }, @@ -48,7 +55,7 @@ struct Client { std::move(this->connection)); } constexpr Client(const Client&) = delete; - constexpr Client(Client&& client) : connection(std::move(client.connection)), jid(std::move(client.jid)) { + constexpr Client(Client&& client) noexcept : connection(std::move(client.connection)), jid(std::move(client.jid)) { client.active = false; } constexpr ~Client() { @@ -63,7 +70,7 @@ struct Client { private: bool active = true; - Connection connection; + RawXmlStream connection; BareJid jid; }; @@ -79,159 +86,154 @@ struct ServerRequiresStartTls : std::exception { namespace impl { -inline auto StartStream(const BareJid& from, auto& connection) -> boost::asio::awaitable { - auto stream = UserStream{}.To(from.server).From(std::move(from)).Version("1.0").XmlLang("en"); - auto buffer = "" + ToString(stream); - co_await boost::asio::async_write(connection, boost::asio::buffer(buffer), boost::asio::transfer_all(), boost::asio::use_awaitable); - co_return; -} - template -auto Contains(Range&& range, Args&&... values) { +auto Contains(Range&& range, Args&&... values) { // NOLINT for(auto& value : range) { - if(((value == values) || ...)) { + if(((value == std::forward(values)) || ...)) { // NOLINT return true; } } return false; } - -inline auto ToInt(std::string_view input) -> std::optional { - int out{}; +template +inline auto ToInt(std::string_view input) -> std::optional { + T out{}; const std::from_chars_result result = std::from_chars(input.data(), input.data() + input.size(), out); return result.ec == std::errc::invalid_argument || result.ec == std::errc::result_out_of_range ? std::nullopt : std::optional{out}; } -inline auto ParseChallenge(std::string_view str) { - return std::views::split(str, ',') | std::views::transform([](auto param) { - return std::string_view{param}; - }) | - std::views::transform([](std::string_view param) -> std::pair { - auto v = param.find("="); - return {param.substr(0, v), param.substr(v + 1)}; - }) | - std::ranges::to>(); +struct Challenge { + std::string body; + std::string_view serverNonce; + std::string salt; + int iterations; + [[nodiscard]] inline static auto Parse(const xmlpp::Element* node) -> Challenge { + if(node->get_name() != "challenge") { + throw std::runtime_error(std::format("Invalid name {} for challenge", node->get_name())); + } + std::string decoded = DecodeBase64(node->get_first_child_text()->get_content()); + auto params = std::views::split(decoded, ',') // + | std::views::transform([](auto param) { // + return std::string_view{param}; // + }) // + | std::views::transform([](std::string_view param) -> std::pair { // + auto v = param.find("="); // + return {param.substr(0, v), param.substr(v + 1)}; // + }) // + | std::ranges::to>(); + return {.body = std::move(decoded), + .serverNonce = params.at("r"), + .salt = DecodeBase64(params.at("s")), + .iterations = ToInt(params.at("i")).value()}; + } }; -inline auto GetAuthData(const PlainUserAccount& account) -> std::string { - return EncodeBase64('\0' + account.jid.username + '\0' + account.password); -} +template +struct ChallengeResponse { + static constexpr auto kDefaultName = "response"; + static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-sasl"; + std::string_view password; + std::string& salt; + std::string_view serverNonce; + std::string_view firstServerMessage; + std::string_view initialMessage; + int iterations{}; + Tag tag; + friend constexpr auto operator<<(xmlpp::Element* element, const ChallengeResponse& self) { + auto text = EncodeBase64(GenerateScramAuthMessage( + self.password, std::move(self.salt), self.serverNonce, self.firstServerMessage, self.initialMessage, self.iterations, self.tag)); + element->add_child_text(text); + } +}; + +struct StartTlsRequest { + static constexpr auto kDefaultName = "starttls"; + static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-tls"; + friend constexpr auto operator<<(xmlpp::Element*, const StartTlsRequest&) { + } +}; struct ClientCreateVisitor { UserAccount account; const Options& options; - auto Auth(PlainUserAccount account, auto& socket, StreamFeatures features, ServerToUserStream stream) -> boost::asio::awaitable { + template + auto Auth(PlainUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + -> boost::asio::awaitable { + SPDLOG_DEBUG("Start Plain Auth"); if(!std::ranges::contains(features.saslMechanisms.mechanisms, "PLAIN")) { throw std::runtime_error("Server not support PLAIN auth"); } - pugi::xml_document doc; - auto data = GetAuthData(account); - auto auth = doc.append_child("auth"); - auth.text().set(data.c_str(), data.size()); - auth.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; - auth.append_attribute("mechanism") = "PLAIN"; - - std::ostringstream strstream; - doc.print( - strstream, "", pugi::format_default | pugi::format_no_empty_element_tags | pugi::format_attribute_single_quote | pugi::format_raw); - std::string str = std::move(strstream.str()); - co_await boost::asio::async_write(socket, boost::asio::buffer(str), boost::asio::transfer_all(), boost::asio::use_awaitable); - std::string response; - co_await boost::asio::async_read_until(socket, boost::asio::dynamic_buffer(response), '>', boost::asio::use_awaitable); + const features::PlainAuthData data{.username = account.jid.username, .password = account.password}; + co_await stream.Send(data); + std::ignore = co_await stream.Read(); } - auto ScramAuth(std::string_view methodName, - const EncryptionUserAccount& account, - auto& socket, - auto tag) -> boost::asio::awaitable { - pugi::xml_document doc; - auto auth = doc.append_child("auth"); - auth.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; - auth.append_attribute("mechanism") = methodName.data(); - auto nonce = GenerateNonce(); + template + auto ScramAuth(std::string methodName, EncryptionUserAccount account, RawXmlStream& stream, Tag tag) + -> boost::asio::awaitable { + SPDLOG_DEBUG("Start Scram Auth using '{}'", methodName); + const auto nonce = GenerateNonce(); SPDLOG_DEBUG("nonce: {}", nonce); - auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce); - auto data = EncodeBase64(initialMessage); - auth.text().set(data.c_str()); - std::ostringstream strstream; - doc.print(strstream, - "", - pugi::format_default | pugi::format_no_empty_element_tags | pugi::format_attribute_single_quote | pugi::format_raw | - pugi::format_no_escapes); - std::string str = std::move(strstream.str()); - co_await boost::asio::async_write(socket, boost::asio::buffer(str), boost::asio::transfer_all(), boost::asio::use_awaitable); - std::string response; - co_await boost::asio::async_read_until(socket, boost::asio::dynamic_buffer(response), "", boost::asio::use_awaitable); - doc.load_string(response.c_str()); - - auto decoded = DecodeBase64(doc.child("challenge").text().get()); - auto params = ParseChallenge(decoded); - auto serverNonce = params["r"]; + const auto initialMessage = std::format("n,,n={},r={}", account.jid.username, nonce); + const features::ScramAuthData authData{.mechanism = methodName, .initialMessage = initialMessage, .tag = tag}; + co_await stream.Send(authData); + Challenge challenge = co_await stream.template Read(); + const std::string_view serverNonce = challenge.serverNonce; if(serverNonce.substr(0, nonce.size()) != nonce) { throw std::runtime_error("XMPP Server SCRAM nonce not started with client nonce"); } - doc = pugi::xml_document{}; - auto success = doc.append_child("response"); - success.append_attribute("xmlns") = "urn:ietf:params:xml:ns:xmpp-sasl"; - success.text().set(EncodeBase64(GenerateScramAuthMessage(account.password, - DecodeBase64(params["s"]), - serverNonce, - decoded, - std::string_view{initialMessage}.substr(3), - ToInt(params["i"]).value(), - tag)) - .c_str()); - std::ostringstream strstream2; - doc.print(strstream2, - "", - pugi::format_default | pugi::format_no_empty_element_tags | pugi::format_attribute_single_quote | pugi::format_raw | - pugi::format_no_escapes); - str.clear(); - str = std::move(strstream2.str()); - co_await boost::asio::async_write(socket, boost::asio::buffer(str), boost::asio::transfer_all(), boost::asio::use_awaitable); - response.clear(); - co_await boost::asio::async_read_until(socket, boost::asio::dynamic_buffer(response), '>', boost::asio::use_awaitable); - doc.load_string(response.c_str()); - if(auto failure = doc.child("failure")) { - throw std::runtime_error(std::format("Auth failed: {}", failure.child("text").text().get())); + const ChallengeResponse challengeResponse{.password = account.password, + .salt = challenge.salt, // Mutable reference + .serverNonce = serverNonce, + .firstServerMessage = challenge.body, + .initialMessage = std::string_view{initialMessage}.substr(3), + .iterations = challenge.iterations, + .tag = tag}; + co_await stream.Send(challengeResponse); + std::unique_ptr doc = co_await stream.Read(); + auto root = doc->get_root_node(); + if(!root || root->get_name() == "failure") { + if(auto textNode = root->get_first_child("text")) { + if(auto text = dynamic_cast(textNode)) { + if(auto childText = text->get_first_child_text()) { + throw std::runtime_error(std::format("Auth failed: {}", childText->get_content())); + } + } + } + throw std::runtime_error("Auth failed"); } + SPDLOG_DEBUG("Success auth for JID {}", ToString(account.jid)); } - auto Auth(EncryptionRequiredUserAccount account, - auto& socket, - StreamFeatures features, - ServerToUserStream stream) -> boost::asio::awaitable { - // NOLINTBEGIN + template + auto Auth(EncryptionRequiredUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + -> boost::asio::awaitable { if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-512")) { - co_return co_await ScramAuth("SCRAM-SHA-512", account, socket, sha512sum::EncryptionTag{}); + co_return co_await ScramAuth("SCRAM-SHA-512", std::move(account), stream, sha512sum::EncryptionTag{}); } if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-256")) { - co_return co_await ScramAuth("SCRAM-SHA-256", account, socket, sha256sum::EncryptionTag{}); + co_return co_await ScramAuth("SCRAM-SHA-256", std::move(account), stream, sha256sum::EncryptionTag{}); } if(std::ranges::contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1")) { - co_return co_await ScramAuth("SCRAM-SHA-1", account, socket, sha1sum::EncryptionTag{}); + co_return co_await ScramAuth("SCRAM-SHA-1", std::move(account), stream, sha1sum::EncryptionTag{}); } - // NOLINTEND throw std::runtime_error("Server not support SCRAM SHA 1 or SCRAM SHA 256 or SCRAM SHA 512 auth"); } - auto Auth(EncryptionUserAccount account, - auto& socket, - StreamFeatures features, - ServerToUserStream stream) -> boost::asio::awaitable { - return Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512") - ? this->Auth(EncryptionRequiredUserAccount{std::move(account)}, socket, std::move(features), std::move(stream)) - : this->Auth(static_cast(std::move(account)), socket, std::move(features), std::move(stream)); + template + auto Auth(EncryptionUserAccount account, RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) + -> boost::asio::awaitable { + Contains(features.saslMechanisms.mechanisms, "SCRAM-SHA-1", "SCRAM-SHA-256", "SCRAM-SHA-512") + ? co_await this->Auth(EncryptionRequiredUserAccount{std::move(account)}, stream, std::move(streamHeader), std::move(features)) + : co_await this->Auth(static_cast(std::move(account)), stream, std::move(streamHeader), std::move(features)); } - auto Auth(auto& socket, pugi::xml_document doc) -> boost::asio::awaitable { - co_return co_await std::visit>( + template + auto Auth(RawXmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { + co_return co_await std::visit( [&](auto& account) -> boost::asio::awaitable { - return this->Auth(std::move(account), - socket, - StreamFeatures::Parse(doc.child("stream:stream").child("stream:features")), - ServerToUserStream::Parse(doc.child("stream:stream"))); + return this->Auth(std::move(account), stream, std::move(streamHeader), std::move(features)); }, this->account); } @@ -247,92 +249,162 @@ struct ClientCreateVisitor { auto Connect(auto& socket, boost::asio::ip::tcp::resolver::results_type resolveResults) -> boost::asio::awaitable { co_await boost::asio::async_connect(socket, resolveResults, boost::asio::use_awaitable); } - auto ReadStream(auto& socket, std::string& buffer) -> boost::asio::awaitable { - co_await boost::asio::async_read_until(socket, boost::asio::dynamic_buffer(buffer), "", boost::asio::use_awaitable); - } - auto ReadStream(auto& socket) -> boost::asio::awaitable { - std::string buffer; - co_await ReadStream(socket, buffer); - co_return buffer; - } + template - auto ProcessTls(boost::asio::ssl::stream& socket, std::string& buffer) -> boost::asio::awaitable { - co_await boost::asio::async_write(socket.next_layer(), - boost::asio::buffer(""), - boost::asio::transfer_all(), - boost::asio::use_awaitable); - buffer.clear(); - pugi::xml_document doc; - co_await boost::asio::async_read_until(socket.next_layer(), boost::asio::dynamic_buffer(buffer), ">", boost::asio::use_awaitable); - doc.load_string(buffer.c_str()); - if(doc.child("proceed").attribute("xmlns").as_string() != std::string_view{"urn:ietf:params:xml:ns:xmpp-tls"}) { + auto ProcessTls(RawXmlStream>& stream) -> boost::asio::awaitable { + const StartTlsRequest request; + co_await stream.Send(request); + std::unique_ptr doc = co_await stream.Read(); + if(auto node = doc->get_root_node()) { + if(node->get_name() == "proceed") { + goto proceed; // NOLINT + } throw StartTlsNegotiationError{"Failure XMPP"}; - }; - SSL_set_tlsext_host_name(socket.native_handle(), account.Jid().server.c_str()); + } + proceed: + auto& socket = stream.next_layer(); + SSL_set_tlsext_host_name(socket.native_handle(), this->account.Jid().server.c_str()); try { co_await socket.async_handshake(boost::asio::ssl::stream::handshake_type::client, boost::asio::use_awaitable); } catch(const std::exception& e) { throw StartTlsNegotiationError{e.what()}; } + } + static constexpr auto GetEnumerated(boost::asio::streambuf& streambuf) { + return std::views::zip(std::views::iota(std::size_t{}, streambuf.size()), ::larra::xmpp::impl::GetCharsRangeFromBuf(streambuf)); + } + using EnumeratedT = decltype(std::views::zip(std::views::iota(std::size_t{}, std::size_t{}), + ::larra::xmpp::impl::GetCharsRangeFromBuf(std::declval()))); + struct Splitter { + EnumeratedT range; + struct Sentinel { + std::ranges::sentinel_t end; + }; + struct Iterator { + std::ranges::iterator_t it; + std::ranges::sentinel_t end; + friend constexpr auto operator==(const Iterator& self, const Sentinel& it) -> bool { + return self.it == it.end; + } + auto operator++() -> Iterator& { + if(this->it == this->end) { + return *this; + } + this->it = std::ranges::find(this->it, this->end, '>', [](auto v) { + auto [_, c] = v; + return c; + }); + if(this->it != this->end) { + ++it; + } + + return *this; + }; + auto operator*() const { + return *it; + } + }; + auto begin() -> Iterator { + return Iterator{.it = std::ranges::begin(this->range), .end = std::ranges::end(this->range)}; + } + + auto end() -> Sentinel { + return {.end = std::ranges::end(this->range)}; + } }; + + auto GetStartStreamIndex(auto& socket, boost::asio::streambuf& streambuf) -> boost::asio::awaitable { + auto buf = streambuf.prepare(4096); // NOLINT + std::size_t n = co_await socket.async_read_some(buf, boost::asio::use_awaitable); + streambuf.commit(n); + auto splited = Splitter{GetEnumerated(streambuf)}; + using It = decltype(splited.begin()); + // clang-format off + co_return co_await + [&](this auto&& self, std::size_t n, It it) -> std::optional { + return n == 0 ? std::move(it) : it == splited.end() ? std::nullopt : self(n - 1, (++it, std::move(it))); + }(2, splited.begin()) + .transform([](auto value) -> boost::asio::awaitable { + auto [n, _] = *value; + co_return n; + }) + .or_else([&] -> std::optional> { + return this->GetStartStreamIndex(socket, streambuf); + }) + .value(); + // clang-format on + } + + auto ReadStartStream(auto& socket, boost::asio::streambuf& streambuf) -> boost::asio::awaitable { + auto n = co_await this->GetStartStreamIndex(socket, streambuf); + xmlpp::DomParser parser; + std::string dataToReed = + (::larra::xmpp::impl::GetCharsRangeFromBuf(streambuf) | std::views::take(n - 1) | std::ranges::to()) + "/>"; + + parser.parse_memory(dataToReed); + auto doc = parser.get_document(); + SPDLOG_DEBUG("Stream readed. Consuming {} bytes with stream data {}. Total buffer size: {}", n, dataToReed, streambuf.size()); + streambuf.consume(n); + co_return ServerToUserStream::Parse(doc->get_root_node()); + } + template - inline auto operator()(Socket&& socket) - -> boost::asio::awaitable>, Client>>>> { - co_await this->Connect(socket, co_await this->Resolve()); - co_await impl::StartStream(account.Jid(), socket); - auto response = co_await ReadStream(socket); - pugi::xml_document doc; - doc.load_string(response.c_str()); - auto streamNode = doc.child("stream:stream"); - auto features = streamNode.child("stream:features"); - if(features.child("starttls").child("required")) { + inline auto operator()(RawXmlStream stream) + -> boost::asio::awaitable, Client>>> { + co_await this->Connect(stream.next_layer(), co_await this->Resolve()); + + co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"}); + SPDLOG_DEBUG("UserStream sended"); + ServerToUserStream sToUStream = co_await ReadStartStream(stream.next_layer(), *stream.streambuf); + StreamFeatures features = co_await stream.template Read(); + SPDLOG_DEBUG("features parsed"); + + if(features.startTls && features.startTls->required == Required::kRequired) { throw ServerRequiresStartTls{}; } - co_await this->Auth(socket, std::move(doc)); - co_return Client{std::move(this->account).Jid(), std::move(socket)}; + co_await this->Auth(stream, std::move(sToUStream), std::move(features)); + co_return Client{std::move(this->account).Jid(), std::move(stream)}; } + template - inline auto operator()(boost::asio::ssl::stream&& socket) + inline auto operator()(RawXmlStream> stream) -> boost::asio::awaitable, Client>>> { + auto& socket = stream.next_layer(); co_await this->Connect(socket.next_layer(), co_await this->Resolve()); - co_await impl::StartStream(account.Jid().Username("anonymous"), socket.next_layer()); - auto response = co_await this->ReadStream(socket.next_layer()); - pugi::xml_document doc; - doc.load_string(response.c_str()); - auto streamNode = doc.child("stream:stream"); - auto stream = ServerToUserStream::Parse(streamNode); - auto features = streamNode.child("stream:features"); - if(!features.child("starttls")) { + co_await stream.Send(UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server}, socket.next_layer()); + SPDLOG_DEBUG("UserStream sended"); + auto streamHeader = co_await this->ReadStartStream(socket, *stream.streambuf); + StreamFeatures features = co_await stream.template Read(); + SPDLOG_DEBUG("features parsed(SSL)"); + if(!features.startTls) { if(this->options.useTls == Options::kRequire) { throw std::runtime_error("XMPP server not support STARTTLS"); } socket.next_layer().close(); - co_return co_await (*this)(socket.next_layer()); + co_return co_await (*this)(RawXmlStream{Socket{std::move(socket.next_layer())}, std::move(stream.streambuf)}); } - response.clear(); - co_await this->ProcessTls(socket, response); - co_await impl::StartStream(account.Jid(), socket); - response.clear(); - co_await this->ReadStream(socket, response); - doc.load_string(response.c_str()); - co_await this->Auth(socket, std::move(doc)); - co_return Client{std::move(this->account).Jid(), std::move(socket)}; + co_await this->ProcessTls(stream); + co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer()); + auto newStreamHeader = co_await this->ReadStartStream(socket, *stream.streambuf); + auto newFeatures = co_await stream.template Read(); + co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures)); + co_return Client{std::move(this->account).Jid(), RawXmlStream{std::move(socket)}}; } }; } // namespace impl template -inline auto CreateClient(UserAccount account, const Options& options = {}) +inline auto CreateClient(UserAccount account, Options options = {}) -> boost::asio::awaitable, Client>>> { auto executor = co_await boost::asio::this_coro::executor; - boost::asio::ssl::context ctx(boost::asio::ssl::context::sslv23); - co_return co_await std::visit, Client>>>>( - impl::ClientCreateVisitor{std::move(account), options}, - options.useTls == Options::kNever ? std::variant>{Socket{executor}} - : boost::asio::ssl::stream(executor, ctx)); + co_return co_await std::visit( + impl::ClientCreateVisitor{.account = std::move(account), .options = options}, + options.useTls == Options::kNever + ? std::variant, RawXmlStream>>{RawXmlStream{Socket{executor}}} + : RawXmlStream{boost::asio::ssl::stream(executor, ctx)}); } } // namespace larra::xmpp::client diff --git a/library/include/larra/client/xmpp_client_stream_features.hpp b/library/include/larra/client/xmpp_client_stream_features.hpp new file mode 100644 index 0000000..6dc8919 --- /dev/null +++ b/library/include/larra/client/xmpp_client_stream_features.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace larra::xmpp::client::features { +/* + * Auth features + */ + +template +struct AuthData { + static constexpr auto kDefaultName = "auth"; + static constexpr auto kDefaultNamespace = "urn:ietf:params:xml:ns:xmpp-sasl"; + + friend constexpr auto operator<<(xmlpp::Element* element, const AuthData& self) { + SPDLOG_DEBUG("AuthData operator<<"); + element->set_attribute("mechanism", static_cast(static_cast(self).mechanism)); + static_cast(self).Write(element); + } +}; + +struct PlainAuthData : AuthData { + // Not by code style + static constexpr auto mechanism = "PLAIN"; + const std::string& username; // Very ugly, but shit happends + const std::string& password; // std::format can't work with '\0' in fmt string and operator+ requires std::string's + + [[nodiscard]] inline auto GetAuthData() const -> std::string { + return EncodeBase64('\0' + this->username + '\0' + this->password); + } + + constexpr auto Write(xmlpp::Element* element) const { + element->add_child_text(this->GetAuthData()); + } +}; + +template +struct ScramAuthData : AuthData> { + std::string_view mechanism; + std::string_view initialMessage; + Tag tag; + + constexpr auto Write(xmlpp::Element* element) const { + element->add_child_text(EncodeBase64(this->initialMessage)); + } +}; + +} // namespace larra::xmpp::client::features diff --git a/library/include/larra/features.hpp b/library/include/larra/features.hpp index 3f55695..0e7cad0 100644 --- a/library/include/larra/features.hpp +++ b/library/include/larra/features.hpp @@ -1,4 +1,6 @@ #pragma once +#include + #include #include #include @@ -10,7 +12,8 @@ enum class Required : bool { kNotRequired = false, kRequired = true }; struct SaslMechanisms { std::vector mechanisms; - static auto Parse(pugi::xml_node) -> SaslMechanisms; + + [[nodiscard]] static auto Parse(const xmlpp::Element*) -> SaslMechanisms; }; struct StreamFeatures { @@ -19,19 +22,21 @@ struct StreamFeatures { [[nodiscard]] constexpr auto Required(Required required) const -> StartTlsType { return {required}; }; - static auto Parse(pugi::xml_node) -> StartTlsType; + + [[nodiscard]] static auto Parse(const xmlpp::Element*) -> StartTlsType; }; struct BindType { Required required; [[nodiscard]] constexpr auto Required(Required required) const -> BindType { return {required}; }; - static auto Parse(pugi::xml_node) -> BindType; + + [[nodiscard]] static auto Parse(const xmlpp::Element*) -> BindType; }; std::optional startTls; std::optional bind; SaslMechanisms saslMechanisms; - std::vector others; + std::vector others; template [[nodiscard]] constexpr auto StartTls(this Self&& self, std::optional value) { return utils::FieldSetHelper::With<"startTls">(std::forward(self), std::move(value)); @@ -48,7 +53,8 @@ struct StreamFeatures { [[nodiscard]] constexpr auto Others(this Self&& self, std::vector value) { return utils::FieldSetHelper::With<"others">(std::forward(self), std::move(value)); } - static auto Parse(pugi::xml_node) -> StreamFeatures; + [[nodiscard]] static auto Parse(pugi::xml_node) -> StreamFeatures; + [[nodiscard]] static auto Parse(const xmlpp::Element*) -> StreamFeatures; }; } // namespace larra::xmpp diff --git a/library/include/larra/impl/mock_socket.hpp b/library/include/larra/impl/mock_socket.hpp new file mode 100644 index 0000000..0b27f0a --- /dev/null +++ b/library/include/larra/impl/mock_socket.hpp @@ -0,0 +1,73 @@ +#include +#include +#include + +namespace larra::xmpp::impl { + +class MockSocket { + public: + using executor_type = boost::asio::any_io_executor; + + MockSocket(boost::asio::any_io_executor executor, std::size_t writeWithMaxBlocksBy = 5) : // NOLINT + executor(std::move(executor)), writeWithMaxBlocksBy(writeWithMaxBlocksBy) { + } + auto get_executor() -> executor_type { + return executor; + } + + auto lowest_layer() -> MockSocket& { + return *this; + } + + template + auto async_connect(const EndpointType&, CompletionToken&& token) { + return boost::asio::async_initiate( + [](auto&& handler) { + handler(boost::system::error_code{}); + }, + token); + } + + template + auto async_write_some(const ConstBufferSequence& buffers, CompletionToken&& token) { + sentData.append(boost::asio::buffer_cast(*buffers.begin()), boost::asio::buffer_size(*buffers.begin())); + return boost::asio::async_initiate( + [](auto&& handler, std::size_t bytes_transferred) { + handler(boost::system::error_code{}, bytes_transferred); + }, + token, + boost::asio::buffer_size(*buffers.begin())); + } + + template + auto async_read_some(const MutableBufferSequence& buffers, CompletionToken&& token) { + std::size_t bytesToRead = + std::min({boost::asio::buffer_size(buffers), this->writeWithMaxBlocksBy, this->receivedData.size() - this->readPosition}); + std::memcpy(boost::asio::buffer_cast(buffers), receivedData.data() + readPosition, bytesToRead); // NOLINT + readPosition += bytesToRead; + return boost::asio::async_initiate( + [](auto&& handler, std::size_t bytes_transferred) { + handler(bytes_transferred == 0 ? boost::asio::error::eof : boost::system::error_code{}, bytes_transferred); + }, + token, + bytesToRead); + } + + auto AddReceivedData(std::string_view data) -> void { + receivedData += data; + } + + auto GetSentData() -> std::string { + auto sentData = std::move(this->sentData); + this->sentData = std::string{}; + return sentData; + } + + boost::asio::any_io_executor executor; + std::string sentData; + std::string receivedData; + std::size_t writeWithMaxBlocksBy; + std::size_t readPosition = 0; +}; + +} // namespace larra::xmpp::impl diff --git a/library/include/larra/impl/public_cast.hpp b/library/include/larra/impl/public_cast.hpp new file mode 100644 index 0000000..1f39ae5 --- /dev/null +++ b/library/include/larra/impl/public_cast.hpp @@ -0,0 +1,34 @@ +#pragma once +#include +#include + +namespace larra::xmpp::impl { + +template +struct PublicCastTag { + friend constexpr auto MagicGetPrivateMember(PublicCastTag); +}; + +template +struct PublicCast {}; + +template +struct PublicCast { + friend constexpr auto MagicGetPrivateMember(PublicCastTag) { + return ptr; + } +}; + +template +struct PublicCast { + friend constexpr auto MagicGetPrivateMember(PublicCastTag) { + return ptr; + } +}; + +template +constexpr auto GetPrivateMember(const T&) { + return MagicGetPrivateMember(PublicCastTag, I>{}); +}; + +} // namespace larra::xmpp::impl diff --git a/library/include/larra/printer_stream.hpp b/library/include/larra/printer_stream.hpp index e1eea57..f272ffb 100644 --- a/library/include/larra/printer_stream.hpp +++ b/library/include/larra/printer_stream.hpp @@ -3,31 +3,51 @@ #include #include +#include #include -#include +#include namespace larra::xmpp { +namespace impl { + +constexpr auto GetStringFromBuf(const auto& buffers, std::size_t n) -> std::string { + auto f = [&] { + if constexpr(requires { + { buffers.data() } -> std::convertible_to; + }) { + return impl::BufferToStringView(buffers); + } else { + return GetCharsRangeFromBuf(buffers); + } + }; + + return f() | std::views::take(n) | std::ranges::to(); +} + +constexpr auto GetStringFromBuf(const auto& buffers) -> std::string { + return GetStringFromBuf(buffers, buffers.size()); +} + +} // namespace impl + template struct PrintStream : Socket { using Socket::Socket; + PrintStream(PrintStream&&) = default; using Executor = Socket::executor_type; template > - auto async_write_some(const ConstBufferSequence& buffers, WriteToken&& token) { - std::ostringstream stream; // Write to buffer for concurrent logging - // std::osyncstream not realized in libc++ - stream << "Writing data to stream: "; - for(boost::asio::const_buffer buf : buffers) { - stream << std::string_view{static_cast(buf.data()), buf.size()}; - } - SPDLOG_INFO("{}", stream.str()); + auto async_write_some(const ConstBufferSequence& buffers, WriteToken&& token) { // NOLINT + + SPDLOG_INFO("Writing data to stream: {}", impl::GetStringFromBuf(buffers)); + return boost::asio::async_initiate( - [this](Handler&& token, const ConstBufferSequence& buffers) { - Socket::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - SPDLOG_INFO("Data writing completed"); - token(err, s); + [this](Handler&& handler, const ConstBufferSequence& buffers) { // NOLINT + Socket::async_write_some(buffers, [h = std::move(handler)](boost::system::error_code err, std::size_t s) mutable { + SPDLOG_INFO("Data writing completed: {}", s); + h(err, s); }); }, token, @@ -36,19 +56,14 @@ struct PrintStream : Socket { template > - auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { - SPDLOG_INFO("Reading data from stream"); + auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { // NOLINT + SPDLOG_INFO("Reading data from stream:"); return boost::asio::async_initiate( - [this](ReadToken&& token, const MutableBufferSequence& buffers) { - Socket::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - std::ostringstream stream; // Write to buffer for concurrent logging - // std::osyncstream not realized in libc++ - stream << "Data after read: "; - for(boost::asio::mutable_buffer buf : buffers) { - stream << std::string_view{static_cast(buf.data()), buf.size()}; - } - SPDLOG_INFO("{}", stream.str()); - token(err, s); + [this](Handler&& handler, const MutableBufferSequence& buffers) { // NOLINT + Socket::async_read_some(buffers, [buffers, h = std::move(handler)](boost::system::error_code err, std::size_t s) mutable { + SPDLOG_INFO("Readed data: {}", impl::GetStringFromBuf(buffers, s)); + + h(err, s); }); }, token, @@ -77,19 +92,15 @@ struct boost::asio::ssl::stream> : public boost template > - auto async_write_some(const ConstBufferSequence& buffers, WriteToken&& token) { - std::ostringstream stream; // Write to buffer for concurrent logging - // std::osyncstream not realized in libc++ - stream << "Writing data to stream(SSL): "; - for(boost::asio::const_buffer buf : buffers) { - stream << std::string_view{static_cast(buf.data()), buf.size()}; - } - SPDLOG_INFO("{}", stream.str()); + auto async_write_some(const ConstBufferSequence& buffers, WriteToken&& token) { // NOLINT + + SPDLOG_INFO("Writing data to stream(SSL): {}", ::larra::xmpp::impl::GetStringFromBuf(buffers)); + return boost::asio::async_initiate( - [this](Handler&& token, const ConstBufferSequence& buffers) { - Base::async_write_some(buffers, [token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - SPDLOG_INFO("Data writing completed(SSL)"); - std::move(token)(err, s); + [this](Handler&& handler, const ConstBufferSequence& buffers) { // NOLINT + Base::async_write_some(buffers, [h = std::move(handler)](boost::system::error_code err, std::size_t s) mutable { + SPDLOG_INFO("Data writing completed(SSL): {}", s); + h(err, s); }); }, token, @@ -98,19 +109,13 @@ struct boost::asio::ssl::stream> : public boost template > - auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { + auto async_read_some(const MutableBufferSequence& buffers, ReadToken&& token) { // NOLINT SPDLOG_INFO("Reading data from stream(SSL)"); return boost::asio::async_initiate( - [this](Handler&& token, const MutableBufferSequence& buffers) { - Base::async_read_some(buffers, [buffers, token = std::move(token)](boost::system::error_code err, std::size_t s) mutable { - std::ostringstream stream; // Write to buffer for concurrent logging - // std::osyncstream not realized in libc++ - stream << "Data after read(SSL): "; - for(boost::asio::mutable_buffer buf : buffers) { - stream << std::string_view{static_cast(buf.data()), buf.size()}; - } - SPDLOG_INFO("{}", stream.str()); - std::move(token)(err, s); + [this](Handler&& handler, const MutableBufferSequence& buffers) { // NOLINT + Base::async_read_some(buffers, [buffers, h = std::move(handler)](boost::system::error_code err, std::size_t s) mutable { + SPDLOG_INFO("Readed data(SSL): {}", ::larra::xmpp::impl::GetStringFromBuf(buffers, s)); + h(err, s); }); }, token, @@ -118,13 +123,13 @@ struct boost::asio::ssl::stream> : public boost } template > - auto async_handshake(Base::handshake_type type, HandshakeToken&& token = default_completion_token_t{}) { + auto async_handshake(Base::handshake_type type, HandshakeToken&& token = default_completion_token_t{}) { // NOLINT std::println("SSL Handshake start"); return boost::asio::async_initiate( - [this](Handler&& token, Base::handshake_type type) { - Base::async_handshake(type, [token = std::move(token)](boost::system::error_code error) mutable { + [this](Handler&& handler, Base::handshake_type type) { // NOLINT + Base::async_handshake(type, [h = std::move(handler)](boost::system::error_code error) mutable { SPDLOG_INFO("SSL Handshake completed"); - std::move(token)(error); + h(error); }); }, token, diff --git a/library/include/larra/raw_xml_stream.hpp b/library/include/larra/raw_xml_stream.hpp new file mode 100644 index 0000000..87879ad --- /dev/null +++ b/library/include/larra/raw_xml_stream.hpp @@ -0,0 +1,282 @@ +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct _xmlError; + +namespace larra::xmpp { + +template +concept AsXml = requires(xmlpp::Element* element, const T& obj) { + element << obj; + { T::kDefaultName } -> std::convertible_to; +}; + +template +concept HasDefaultNamespace = requires { + { T::kDefaultNamespace } -> std::convertible_to; +}; + +template +concept HasDefaultPrefix = requires { + { T::kPrefix } -> std::convertible_to; +}; + +template +concept HasAddXmlDecl = requires { + { T::kAddXmlDecl } -> std::convertible_to; +}; + +template +concept HasRemoveEnd = requires { + { T::kRemoveEnd } -> std::convertible_to; +}; + +struct XmlGroup : xmlpp::Element { + using Element::Element; +}; + +struct XmlPath : public xmlpp::Element { + public: + using Element::Element; + + [[nodiscard]] auto GetData() const -> std::string { + return get_attribute("d")->get_value(); + } +}; + +namespace impl { + +constexpr auto BufferToStringView(const boost::asio::const_buffer& buffer, size_t size) -> std::string_view { + assert(size <= buffer.size()); + return {boost::asio::buffer_cast(buffer), size}; +}; + +constexpr auto BufferToStringView(const boost::asio::const_buffer& buffer) -> std::string_view { + return BufferToStringView(buffer, buffer.size()); +}; + +class Parser : private xmlpp::SaxParser { + public: + inline explicit Parser(xmlpp::Document& document) : doc(document) {}; + ~Parser() override = default; + + auto ParseChunk(std::string_view str) -> const _xmlError*; + + std::stack context; + xmlpp::Document& doc; + + private: + inline auto on_start_document() -> void override { + } + inline auto on_end_document() -> void override { + } + auto on_start_element(const std::string& name, const AttributeList& properties) -> void override; + auto on_end_element(const std::string& name) -> void override; + auto on_characters(const std::string& characters) -> void override; + auto on_cdata_block(const std::string& characters) -> void override; +}; + +constexpr auto GetCharsRangeFromBuf(auto&& buf) { + return buf.data() // + | std::views::transform([](const auto& buf) -> std::string_view { // + return ::larra::xmpp::impl::BufferToStringView(buf); // + }) // + | std::views::join; +}; + +constexpr auto SplitStreamBuf(auto&& buf, char delim) { + return GetCharsRangeFromBuf(buf) // + | std::views::lazy_split(delim); // +}; + +auto GetIndex(const boost::asio::streambuf&, const _xmlError* error, std::size_t alreadyCountedLines = 1) -> std::size_t; + +auto CountLines(const boost::asio::streambuf&) -> std::size_t; + +auto CountLines(std::string_view) -> std::size_t; + +auto IsExtraContentAtTheDocument(const _xmlError* error) -> bool; + +} // namespace impl + +template +struct RawXmlStream : Stream { + constexpr RawXmlStream(Stream stream, std::unique_ptr buff = std::make_unique()) : + Stream(std::forward(stream)), streambuf(std::move(buff)) {}; + using Stream::Stream; + auto next_layer() -> Stream& { + return *this; + } + + auto next_layer() const -> const Stream& { + return *this; + } + + inline auto Read(auto& socket) -> boost::asio::awaitable> { + auto doc = std::make_unique(); // Not movable :( + impl::Parser parser(*doc); + std::size_t lines = 1; + std::size_t size{}; + for(auto elem : this->streambuf->data()) { + auto error = parser.ParseChunk(impl::BufferToStringView(elem)); + if(!error) { + auto linesAdd = impl::CountLines(impl::BufferToStringView(elem)); + lines += linesAdd; + if(linesAdd == 0) { + size += elem.size(); + } + if(parser.context.empty() && parser.doc.get_root_node() != nullptr) { + SPDLOG_DEBUG("Object already transferred"); + co_return doc; + } + continue; + } + if(!impl::IsExtraContentAtTheDocument(error)) { + throw std::runtime_error(std::format("Bad xml object: {}", xmlpp::format_xml_error(error))); + } + std::size_t size = impl::GetIndex(*this->streambuf, error, lines) - size; + this->streambuf->consume(size); + SPDLOG_DEBUG("Object already transferred"); + co_return doc; + } + this->streambuf->consume(this->streambuf->size()); + for(;;) { + auto buff = this->streambuf->prepare(4096); // NOLINT + auto [e, n] = co_await socket.async_read_some(buff, boost::asio::as_tuple(boost::asio::use_awaitable)); + if(e) { + boost::system::throw_exception_from_error(e, boost::source_location()); + } + this->streambuf->commit(n); + auto error = parser.ParseChunk(impl::BufferToStringView(buff, n)); + + if(!error) { + auto linesAdd = impl::CountLines(impl::BufferToStringView(buff, n)); + SPDLOG_DEBUG("Readed {} bytes for RawXmlStream with {} lines", n, linesAdd); + + lines += linesAdd; + if(linesAdd == 0) { + size += n; + } + this->streambuf->consume(this->streambuf->size()); + if(parser.context.empty() && parser.doc.get_root_node() != nullptr) { + co_return doc; + } + SPDLOG_DEBUG( + "Object not transferred. context size: {}, isValidRootNode: {}", parser.context.size(), parser.doc.get_root_node() != nullptr); + continue; + } + + if(!impl::IsExtraContentAtTheDocument(error)) { + throw std::runtime_error(std::format("Bad xml object: {}", xmlpp::format_xml_error(error))); + } + auto toConsume = impl::GetIndex(*this->streambuf, error, lines) - size; + this->streambuf->consume(toConsume); + co_return doc; + } + } + auto Read() -> boost::asio::awaitable> { + co_return co_await this->Read(this->next_layer()); + } + template + auto Read(auto& stream) -> boost::asio::awaitable { + auto doc = co_await this->Read(stream); + co_return T::Parse(doc->get_root_node()); + } + + template + auto Read(auto& stream) -> boost::asio::awaitable + requires requires(std::unique_ptr ptr) { + { T::Parse(std::move(ptr)) } -> std::same_as; + } + { + co_return T::Parse(co_await this->Read(stream)); + } + + template + auto Read() -> boost::asio::awaitable { + co_return co_await this->template Read(this->next_layer()); + } + + auto Send(xmlpp::Document& doc, auto& stream, bool bAddXmlDecl, bool removeEnd) const -> boost::asio::awaitable { + constexpr auto beginSize = sizeof("\n") - 1; + + auto str = doc.write_to_string(); + auto view = std::string_view{str}.substr(beginSize, str.size() - beginSize - 1); + if(bAddXmlDecl) { + if(removeEnd) { + std::string data = "" + static_cast(view.substr(0, view.size() - 2)) + ">"; + co_await boost::asio::async_write(stream, boost::asio::buffer(data), boost::asio::use_awaitable); + co_return; + } + std::string data = "" + static_cast(view); + co_await boost::asio::async_write(stream, boost::asio::buffer(data), boost::asio::use_awaitable); + co_return; + } + if(removeEnd) { + std::string data = static_cast(view.substr(0, view.size() - 2)) + ">"; + co_await boost::asio::async_write(stream, boost::asio::buffer(data), boost::asio::use_awaitable); + } else { + co_await boost::asio::async_write(stream, boost::asio::buffer(view), boost::asio::use_awaitable); + } + } + + auto Send(xmlpp::Document& doc, bool bAddXmlDecl = false) -> boost::asio::awaitable { + co_await this->Send(doc, this->next_layer(), bAddXmlDecl); + } + template + auto Send(const T& xso, auto& stream) const -> boost::asio::awaitable { + xmlpp::Document doc; + const std::string empty; + const std::string namespaceStr = [&] -> std::string { + if constexpr(HasDefaultNamespace) { + return T::kDefaultNamespace; + } else { + return empty; + } + }(); + const std::string prefixStr = [&] -> decltype(auto) { + if constexpr(HasDefaultPrefix) { + return T::kPrefix; + } else { + return empty; + } + }(); + const bool bAddXmlDecl = [&] -> bool { + if constexpr(HasAddXmlDecl) { + return T::kAddXmlDecl; + } + return false; + }(); + const bool removeEnd = [&] -> bool { + if constexpr(HasRemoveEnd) { + return T::kRemoveEnd; + } + return false; + }(); + + doc.create_root_node(T::kDefaultName, namespaceStr, prefixStr) << xso; + co_await this->Send(doc, stream, bAddXmlDecl, removeEnd); + } + + auto Send(const AsXml auto& xso) -> boost::asio::awaitable { + co_await this->Send(xso, this->next_layer()); + } + + RawXmlStream(RawXmlStream&& other) = default; + + std::unique_ptr streambuf; // Not movable :( +}; + +} // namespace larra::xmpp diff --git a/library/include/larra/stream.hpp b/library/include/larra/stream.hpp index ba09026..c66f4a1 100644 --- a/library/include/larra/stream.hpp +++ b/library/include/larra/stream.hpp @@ -1,4 +1,6 @@ #pragma once +#include + #include #include @@ -12,6 +14,11 @@ struct BasicStream { static constexpr bool kJidTo = JidTo; using FromType = std::optional>; using ToType = std::optional>; + static inline const std::string kDefaultNamespace = JidFrom || JidTo ? "jabber:client" : "jabber:server"; + static constexpr auto kRemoveEnd = true; + static constexpr auto kAddXmlDecl = true; + static inline const std::string kDefaultPrefix = ""; + static inline const std::string kDefaultName = "stream:stream"; FromType from; ToType to; std::optional id; @@ -19,37 +26,39 @@ struct BasicStream { std::optional xmlLang; template - constexpr auto From(this Self&& self, FromType value) -> BasicStream { + [[nodiscard]] constexpr auto From(this Self&& self, FromType value) -> BasicStream { return utils::FieldSetHelper::With<"from", BasicStream>(std::forward(self), std::move(value)); } template - constexpr auto To(this Self&& self, ToType value) -> BasicStream { + [[nodiscard]] constexpr auto To(this Self&& self, ToType value) -> BasicStream { return utils::FieldSetHelper::With<"to", BasicStream>(std::forward(self), std::move(value)); } template - constexpr auto Id(this Self&& self, std::optional value) -> BasicStream { + [[nodiscard]] constexpr auto Id(this Self&& self, std::optional value) -> BasicStream { return utils::FieldSetHelper::With<"id", BasicStream>(std::forward(self), std::move(value)); } template - constexpr auto Version(this Self&& self, std::optional value) -> BasicStream { + [[nodiscard]] constexpr auto Version(this Self&& self, std::optional value) -> BasicStream { return utils::FieldSetHelper::With<"version", BasicStream>(std::forward(self), std::move(value)); } template - constexpr auto XmlLang(this Self&& self, std::optional value) -> BasicStream { + [[nodiscard]] constexpr auto XmlLang(this Self&& self, std::optional value) -> BasicStream { return utils::FieldSetHelper::With<"xmlLang", BasicStream>(std::forward_like(self), std::move(value)); } - auto SerializeStream(pugi::xml_node& node) const -> void; - friend auto operator<<(pugi::xml_node node, const BasicStream& stream) -> pugi::xml_node { + auto SerializeStream(xmlpp::Element* node) const -> void; + + friend auto operator<<(xmlpp::Element* node, const BasicStream& stream) -> xmlpp::Element* { stream.SerializeStream(node); - return (std::move(node)); + return node; } friend auto ToString(const BasicStream&) -> std::string; static auto Parse(const pugi::xml_node& node) -> BasicStream; + [[nodiscard]] static auto Parse(const xmlpp::Element*) -> BasicStream; }; } // namespace impl diff --git a/library/include/larra/utils.hpp b/library/include/larra/utils.hpp index 0303d37..65bcefb 100644 --- a/library/include/larra/utils.hpp +++ b/library/include/larra/utils.hpp @@ -1,5 +1,6 @@ #pragma once #include +#include #include namespace larra::xmpp::utils { @@ -215,4 +216,53 @@ struct FieldSetHelper { // clang-format on }; +/* +template +SplitView(Range&& range, Delim&& delim) -> SplitView, Delim>; +*/ +#if __has_cpp_attribute(__cpp_lib_start_lifetime_as) + +template +inline auto StartLifetimeAsArray(void* ptr, std::size_t n) -> T* { + return std::start_lifetime_as_array(ptr, n); +} + +template +inline auto StartLifetimeAs(void* ptr) -> T* { + return std::start_lifetime_as(ptr); +} + +template +inline auto StartLifetimeAsArray(const void* ptr, std::size_t n) -> const T* { + return std::start_lifetime_as_array(ptr, n); +} + +template +inline auto StartLifetimeAs(const void* ptr) -> const T* { + return std::start_lifetime_as(ptr); +} + +#else +template +inline auto StartLifetimeAsArray(void* ptr, std::size_t n) -> T* { + return std::launder(reinterpret_cast(new(ptr) std::byte[n * sizeof(T)])); +} + +template +inline auto StartLifetimeAs(void* ptr) -> T* { + return StartLifetimeAsArray(ptr, 1); +} + +template +inline auto StartLifetimeAsArray(const void* ptr, std::size_t n) -> const T* { + return std::launder(reinterpret_cast(new(const_cast(ptr)) std::byte[n * sizeof(T)])); // NOLINT +} + +template +inline auto StartLifetimeAs(const void* ptr) -> const T* { + return StartLifetimeAsArray(ptr, 1); +} + +#endif + } // namespace larra::xmpp::utils diff --git a/library/src/features.cpp b/library/src/features.cpp index 313d134..fc34fcb 100644 --- a/library/src/features.cpp +++ b/library/src/features.cpp @@ -3,43 +3,51 @@ namespace { template -inline auto ToOptional(const pugi::xml_node& node) -> std::optional { - return node ? std::optional{T::Parse(node)} : std::nullopt; +inline auto ToOptional(const xmlpp::Node* node) -> std::optional { + auto ptr = dynamic_cast(node); + return ptr ? std::optional{T::Parse(ptr)} : std::nullopt; } } // namespace namespace larra::xmpp { -auto SaslMechanisms::Parse(pugi::xml_node node) -> SaslMechanisms { - std::vector response; - for(pugi::xml_node mechanism = node.child("mechanism"); mechanism; mechanism = mechanism.next_sibling("mechanism")) { - response.emplace_back(mechanism.child_value()); +auto SaslMechanisms::Parse(const xmlpp::Element* node) -> SaslMechanisms { + return {node->get_children("mechanism") | std::views::transform([](const xmlpp::Node* node) -> std::string { + auto ptr = dynamic_cast(node); + if(!ptr) { + throw std::runtime_error("Invalid node for mechanisms"); + } + if(!ptr->has_child_text()) { + throw std::runtime_error("Invalid node for mechanisms"); + } + return ptr->get_first_child_text()->get_content(); + }) | + std::ranges::to>()}; +} + +auto StreamFeatures::StartTlsType::Parse(const xmlpp::Element* node) -> StreamFeatures::StartTlsType { + return {node->get_first_child("required") ? Required::kRequired : Required::kNotRequired}; +} + +auto StreamFeatures::BindType::Parse(const xmlpp::Element* node) -> StreamFeatures::BindType { + return {node->get_first_child("required") ? Required::kRequired : Required::kNotRequired}; +} + +auto StreamFeatures::Parse(const xmlpp::Element* node) -> StreamFeatures { + auto ptr = dynamic_cast(node->get_first_child("mechanisms")); + if(!ptr) { + throw std::runtime_error("Not found or invalid node mechanisms for StreamFeatures"); } - return {response}; -} -auto StreamFeatures::StartTlsType::Parse(pugi::xml_node node) -> StreamFeatures::StartTlsType { - return {node.child("required") ? Required::kRequired : Required::kNotRequired}; -} - -auto StreamFeatures::BindType::Parse(pugi::xml_node node) -> StreamFeatures::BindType { - return {node.child("required") ? Required::kRequired : Required::kNotRequired}; -} - -auto StreamFeatures::Parse(pugi::xml_node node) -> StreamFeatures { - std::vector others; - for(pugi::xml_node current = node.first_child(); current; current = current.next_sibling()) { - // Проверяем, не является ли узел starttls, bind или mechanisms - if(std::string_view(current.name()) != "starttls" && std::string_view(current.name()) != "bind" && - std::string_view(current.name()) != "mechanisms") { - others.push_back(node); - } - } - return {ToOptional(node.child("starttls")), - ToOptional(node.child("bind")), - SaslMechanisms::Parse(node.child("mechanisms")), - std::move(others)}; + return {.startTls = ToOptional(node->get_first_child("starttls")), + .bind = ToOptional(node->get_first_child("bind")), + .saslMechanisms = SaslMechanisms::Parse(ptr), + .others = node->get_children() | std::views::filter([](const xmlpp::Node* node) -> bool { + auto name = node->get_name(); + return name != "starttls" && name != "mechanisms" && name != "bind"; + }) | + std::ranges::to>()}; } } // namespace larra::xmpp diff --git a/library/src/raw_xml_stream.cpp b/library/src/raw_xml_stream.cpp new file mode 100644 index 0000000..cc74ae2 --- /dev/null +++ b/library/src/raw_xml_stream.cpp @@ -0,0 +1,133 @@ +#include + +#include +#include +#include +#include +#include + +namespace larra::xmpp::impl { + +template struct PublicCast<&xmlpp::SaxParser::sax_handler_>; + +auto Parser::ParseChunk(std::string_view str) -> const xmlError* { + xmlResetLastError(); + + if(!context_) { + this->context_ = xmlCreatePushParserCtxt((this->*GetPrivateMember(static_cast(*this))).get(), + nullptr, // user_data + nullptr, // chunk + 0, // size + nullptr); // no filename for fetching external entities + + if(!this->context_) { + throw xmlpp::internal_error("Could not create parser context\n" + xmlpp::format_xml_error()); + } + initialize_context(); + } else { + xmlCtxtResetLastError(this->context_); + } + + xmlParseChunk(this->context_, str.data(), static_cast(str.size()), 0); + return xmlCtxtGetLastError(this->context_); +} + +auto Parser::on_start_element(const std::string& name, const AttributeList& attributes) -> void { + SPDLOG_DEBUG("Start element with name {}", name); + + std::string::size_type idx = name.find(':'); + std::string elementPrefix = idx == std::string::npos ? std::string{} : name.substr(0, idx); + + xmlpp::Element* elementNormal = nullptr; + if(this->doc.get_root_node() == nullptr) { + elementNormal = this->doc.create_root_node(name); + } else { + elementNormal = this->context.top()->add_child_element(name); + } + + auto node = elementNormal->cobj(); + delete elementNormal; // NOLINT: Api + elementNormal = nullptr; + + xmlpp::Element* elementDerived = nullptr; + if(name == "g") { + elementDerived = new XmlGroup(node); // NOLINT: Owned + } else if(name == "path") { + elementDerived = new XmlPath(node); // NOLINT: Owned + } else { + elementDerived = new xmlpp::Element(node); // NOLINT: Owned + } + if(elementDerived) { + this->context.push(elementDerived); + + for(const auto& attr_pair : attributes) { + const auto attr_name = attr_pair.name; + const auto attr_value = attr_pair.value; + const auto idx_colon = attr_name.find(':'); + if(idx_colon == std::string::npos) { + if(attr_name == "xmlns") { + elementDerived->set_namespace_declaration(attr_value); + } else { + elementDerived->set_attribute(attr_name, attr_value); + } + } else { + auto prefix = attr_name.substr(0, idx_colon); + auto suffix = attr_name.substr(idx_colon + 1); + if(prefix == "xmlns") { + elementDerived->set_namespace_declaration(attr_value, suffix); + } else { + auto attr = elementDerived->set_attribute(suffix, attr_value); + attr->set_namespace(prefix); + } + } + } + } +} + +auto Parser::on_end_element(const std::string& name) -> void { + SPDLOG_DEBUG("End element with name {}", name); + this->context.pop(); +} + +auto Parser::on_characters(const std::string& text) -> void { + SPDLOG_DEBUG("Add characters to element: {}", text); + if(!this->context.empty()) { + this->context.top()->add_child_text(text); + } +} + +auto Parser::on_cdata_block(const std::string& text) -> void { + this->on_characters(text); +} + +inline auto GetLines(const boost::asio::streambuf& buf) { + return SplitStreamBuf(buf, '\n'); +} + +auto CountLines(const boost::asio::streambuf& buf) -> std::size_t { + return std::ranges::fold_left(GetLines(buf), 0, [](auto accum, auto&&) { + return accum + 1; + }); +}; + +auto CountLines(std::string_view buf) -> std::size_t { + return std::ranges::fold_left(buf | std::views::split('\n'), 0, [](auto accum, auto&&) { + return accum + 1; + }); +} +auto GetIndex(const boost::asio::streambuf& buf, const xmlError* error, std::size_t alreadyCountedLines) -> std::size_t { + return std::ranges::fold_left( + GetLines(buf) | std::views::take(error->line - alreadyCountedLines) | std::views::transform([](auto&& line) -> std::size_t { + return std::ranges::fold_left(line, std::size_t{1}, [](auto accum, auto&&) { + return accum + 1; + }); + }), + error->int2 - 1, // columns + std::plus<>{}); +} + +auto IsExtraContentAtTheDocument(const xmlError* error) -> bool { + return error->code == XML_ERR_DOCUMENT_END; +} + +} // namespace larra::xmpp::impl diff --git a/library/src/stream.cpp b/library/src/stream.cpp index ae94d64..a4e5057 100644 --- a/library/src/stream.cpp +++ b/library/src/stream.cpp @@ -7,6 +7,10 @@ inline auto ToOptionalString(const pugi::xml_attribute& attribute) -> std::optio return attribute ? std::optional{std::string{attribute.as_string()}} : std::nullopt; } +inline auto ToOptionalString(const xmlpp::Attribute* attribute) -> std::optional { + return attribute ? std::optional{std::string{attribute->get_value()}} : std::nullopt; +} + template inline auto ToOptionalUser(const pugi::xml_attribute& attribute) { if constexpr(IsJid) { @@ -16,6 +20,15 @@ inline auto ToOptionalUser(const pugi::xml_attribute& attribute) { } } +template +inline auto ToOptionalUser(const xmlpp::Attribute* attribute) { + if constexpr(IsJid) { + return attribute ? std::optional{larra::xmpp::BareJid::Parse(attribute->get_value())} : std::nullopt; + } else { + return ToOptionalString(attribute); + } +} + auto ToString(std::string data) -> std::string { return std::move(data); }; @@ -25,60 +38,28 @@ auto ToString(std::string data) -> std::string { namespace larra::xmpp { template -auto impl::BasicStream::SerializeStream(pugi::xml_node& node) const -> void { +auto impl::BasicStream::SerializeStream(xmlpp::Element* node) const -> void { if(this->from) { - node.append_attribute("from") = ToString(*this->from).c_str(); + node->set_attribute("from", ToString(*this->from)); } if(this->to) { - node.append_attribute("to") = ToString(*this->to).c_str(); + node->set_attribute("to", ToString(*this->to)); } if(this->id) { - node.append_attribute("id") = this->id->c_str(); + node->set_attribute("id", *this->id); } if(this->version) { - node.append_attribute("version") = this->version->c_str(); + node->set_attribute("version", *this->version); } if(this->xmlLang) { - node.append_attribute("xml:lang") = this->xmlLang->c_str(); + node->set_attribute("lang", *this->xmlLang, "xml"); } - if constexpr(JidFrom || JidTo) { - node.append_attribute("xmlns") = "jabber:client"; - } else { - node.append_attribute("xmlns") = "jabber:server"; - } - node.append_attribute("xmlns:stream") = "http://etherx.jabber.org/streams"; + node->set_namespace_declaration("http://etherx.jabber.org/streams", "stream"); } -template auto ServerStream::SerializeStream(pugi::xml_node& node) const -> void; -template auto ServerToUserStream::SerializeStream(pugi::xml_node& node) const -> void; -template auto UserStream::SerializeStream(pugi::xml_node& node) const -> void; - -namespace impl { - -template -inline auto ToStringHelper(const BasicStream& stream) { - return std::format("", - stream.id ? std::format(" id='{}'", *stream.id) : "", - stream.from ? std::format(" from='{}'", *stream.from) : "", - stream.to ? std::format(" to='{}'", *stream.to) : "", - stream.version ? std::format(" version='{}'", *stream.version) : "", - stream.xmlLang ? std::format(" xml:lang='{}'", *stream.xmlLang) : "", - JidFrom || JidTo ? " xmlns='jabber:client'" : "xmlns='jabber:server'"); -}; - -auto ToString(const ServerStream& ref) -> std::string { - return ToStringHelper(ref); -} - -auto ToString(const UserStream& ref) -> std::string { - return ToStringHelper(ref); -} - -auto ToString(const ServerToUserStream& ref) -> std::string { - return ToStringHelper(ref); -} - -} // namespace impl +template auto ServerStream::SerializeStream(xmlpp::Element* node) const -> void; +template auto ServerToUserStream::SerializeStream(xmlpp::Element* node) const -> void; +template auto UserStream::SerializeStream(xmlpp::Element* node) const -> void; template auto impl::BasicStream::Parse(const pugi::xml_node& node) -> impl::BasicStream { @@ -88,8 +69,24 @@ auto impl::BasicStream::Parse(const pugi::xml_node& node) -> imp ToOptionalString(node.attribute("version")), ToOptionalString(node.attribute("xml:lang"))}; } + +template +auto impl::BasicStream::Parse(const xmlpp::Element* node) -> impl::BasicStream { + return {ToOptionalUser(node->get_attribute("from")), + ToOptionalUser(node->get_attribute("to")), + ToOptionalString(node->get_attribute("id")), + ToOptionalString(node->get_attribute("version")), + ToOptionalString(node->get_attribute("lang", "xml"))}; +} + template auto UserStream::Parse(const pugi::xml_node& node) -> UserStream; template auto ServerStream::Parse(const pugi::xml_node& node) -> ServerStream; template auto ServerToUserStream::Parse(const pugi::xml_node& node) -> ServerToUserStream; + +template auto UserStream::Parse(const xmlpp::Element* node) -> UserStream; + +template auto ServerStream::Parse(const xmlpp::Element* node) -> ServerStream; +template auto ServerToUserStream::Parse(const xmlpp::Element* node) -> ServerToUserStream; + } // namespace larra::xmpp diff --git a/tests/raw_xml_stream.cpp b/tests/raw_xml_stream.cpp new file mode 100644 index 0000000..451d3ed --- /dev/null +++ b/tests/raw_xml_stream.cpp @@ -0,0 +1,159 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace larra::xmpp { + +constexpr std::string_view kDoc = "\n\n"; + +constexpr std::string_view kDoc2 = "\n"; + +constexpr std::string_view kDoc3 = + "PLAINSCRAM-SHA-256X-OAUTH2"; + +TEST(RawXmlStream, ReadByOne) { + boost::asio::io_context context; + bool error{}; + + boost::asio::co_spawn( + context, + // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}}; + stream.AddReceivedData(kDoc); + try { + auto doc = co_await stream.Read(); + auto node = doc->get_root_node(); + EXPECT_EQ(node->get_name(), std::string_view{"doc"}); + EXPECT_FALSE(node->has_child_text()); + auto doc2 = co_await stream.Read(); + auto node2 = doc2->get_root_node(); + EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); + EXPECT_FALSE(node2->has_child_text()); + + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + + context.run(); + EXPECT_FALSE(error); +} + +TEST(RawXmlStream, ReadAll) { + boost::asio::io_context context; + bool error{}; + boost::asio::co_spawn( + context, // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc.size()}}; + stream.AddReceivedData(kDoc); + try { + auto doc = co_await stream.Read(); + auto node = doc->get_root_node(); + EXPECT_EQ(node->get_name(), std::string_view{"doc"}); + EXPECT_FALSE(node->has_child_text()); + auto doc2 = co_await stream.Read(); + auto node2 = doc2->get_root_node(); + EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); + EXPECT_FALSE(node2->has_child_text()); + + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + context.run(); + EXPECT_FALSE(error); +} + +TEST(RawXmlStream, ReadAllWithEnd) { + boost::asio::io_context context; + bool error{}; + boost::asio::co_spawn( + context, // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc2.size()}}; + stream.AddReceivedData(kDoc2); + try { + auto doc = co_await stream.Read(); + auto node = doc->get_root_node(); + EXPECT_EQ(node->get_name(), std::string_view{"doc"}); + EXPECT_FALSE(node->has_child_text()); + auto doc2 = co_await stream.Read(); + auto node2 = doc2->get_root_node(); + EXPECT_EQ(node2->get_name(), std::string_view{"doc2"}); + EXPECT_FALSE(node2->has_child_text()); + + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + context.run(); + EXPECT_FALSE(error); +} + +TEST(RawXmlStream, ReadFeatures) { + boost::asio::io_context context; + bool error{}; + boost::asio::co_spawn( + context, // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream{impl::MockSocket{context.get_executor(), kDoc3.size()}}; + stream.AddReceivedData(kDoc3); + try { + auto features = co_await stream.template Read(); + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + context.run(); + EXPECT_FALSE(error); +} + +struct SomeStruct { + static constexpr auto kDefaultName = "some"; + static constexpr auto kDefaultNamespace = "namespace"; + static constexpr auto kPrefix = "prefix"; + friend auto operator<<(xmlpp::Element* node, const SomeStruct&) { + node->add_child_text("text"); + } +}; + +TEST(RawXmlStream, Write) { + boost::asio::io_context context; + bool error{}; + boost::asio::co_spawn( + context, // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream1{impl::MockSocket{context.get_executor()}}; + auto stream = std::move(stream1); + try { + co_await stream.Send(SomeStruct{}); + EXPECT_EQ(stream.GetSentData(), std::string_view{"text"}); + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + context.run(); + EXPECT_FALSE(error); +} + +} // namespace larra::xmpp diff --git a/tests/stream.cpp b/tests/stream.cpp index bb65ad7..f2572f0 100644 --- a/tests/stream.cpp +++ b/tests/stream.cpp @@ -9,33 +9,31 @@ constexpr std::string_view kSerializedData = )"; constexpr std::string_view kCheckSerializeData = - R"( -)"; + "\n\n"; TEST(Stream, Serialize) { UserStream originalStream; - originalStream.from = BareJid{"user", "example.com"}; + originalStream.from = BareJid{.username = "user", .server = "example.com"}; originalStream.to = "example.com"; originalStream.id = "abc"; originalStream.version = "1.0"; originalStream.xmlLang = "en"; - - pugi::xml_document doc; - pugi::xml_node streamNode = doc.append_child("stream:stream"); + xmlpp::Document doc; + auto streamNode = doc.create_root_node(UserStream::kDefaultName, UserStream::kDefaultNamespace, UserStream::kDefaultPrefix); streamNode << originalStream; - std::ostringstream oss; - doc.child("stream:stream").print(oss, "\t"); - const std::string serializedData = oss.str(); + const std::string serializedData = doc.write_to_string(); ASSERT_EQ(serializedData, kCheckSerializeData); } TEST(Stream, Deserialize) { - pugi::xml_document parsedDoc; - parsedDoc.load_string(kSerializedData.data()); + xmlpp::DomParser parser; + parser.parse_memory(static_cast(kSerializedData)); + auto parsedDoc = parser.get_document(); - const UserStream deserializedStream = UserStream::Parse(parsedDoc.child("stream:stream")); + const UserStream deserializedStream = UserStream::Parse(parsedDoc->get_root_node()); ASSERT_TRUE(deserializedStream.from.has_value()); ASSERT_EQ(ToString(*deserializedStream.from), "user@example.com");