From ba03f3fe6fa2a24cafdee95ee6ff156d2511a48a Mon Sep 17 00:00:00 2001 From: sha512sum Date: Tue, 8 Oct 2024 18:14:04 +0000 Subject: [PATCH] Add initiating new stream after SASL --- library/include/larra/client/client.hpp | 15 +++++++++++---- library/include/larra/serialization.hpp | 2 +- library/src/features.cpp | 7 ++----- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/library/include/larra/client/client.hpp b/library/include/larra/client/client.hpp index 666537f..eb14506 100644 --- a/library/include/larra/client/client.hpp +++ b/library/include/larra/client/client.hpp @@ -235,7 +235,7 @@ struct ClientCreateVisitor { auto Auth(XmlStream& stream, ServerToUserStream streamHeader, StreamFeatures features) -> boost::asio::awaitable { co_return co_await std::visit( [&](auto& account) -> boost::asio::awaitable { - return this->Auth(std::move(account), stream, std::move(streamHeader), std::move(features)); + return this->Auth(account, stream, std::move(streamHeader), std::move(features)); }, this->account); } @@ -293,7 +293,10 @@ struct ClientCreateVisitor { } co_await this->Auth(stream, std::move(sToUStream), std::move(features)); - co_return Client{std::move(this->account).Jid(), std::move(stream)}; + co_await stream.Send(UserStream{.from = this->account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"}); + ServerToUserStream secondSToUStream = co_await ReadStartStream(stream); + StreamFeatures secondFeatures = co_await stream.template Read(); + co_return Client{std::move(this->account.Jid()), std::move(stream)}; } template @@ -301,7 +304,9 @@ struct ClientCreateVisitor { -> boost::asio::awaitable, Client>>> { auto& socket = stream.next_layer(); co_await this->Connect(socket.next_layer(), co_await this->Resolve()); - co_await stream.Send(UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server}, socket.next_layer()); + co_await stream.Send( + UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"}, + socket.next_layer()); SPDLOG_DEBUG("UserStream sended"); auto streamHeader = co_await this->ReadStartStream(stream); StreamFeatures features = co_await stream.template Read(); @@ -318,7 +323,9 @@ struct ClientCreateVisitor { auto newStreamHeader = co_await this->ReadStartStream(stream); auto newFeatures = co_await stream.template Read(); co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures)); - co_return Client{std::move(this->account).Jid(), XmlStream{std::move(socket)}}; + ServerToUserStream secondSToUStream = co_await ReadStartStream(stream); + StreamFeatures secondFeatures = co_await stream.template Read(); + co_return Client{std::move(this->account.Jid()), std::move(stream)}; } }; diff --git a/library/include/larra/serialization.hpp b/library/include/larra/serialization.hpp index 707cc05..60b2f58 100644 --- a/library/include/larra/serialization.hpp +++ b/library/include/larra/serialization.hpp @@ -107,7 +107,7 @@ struct Serialization : SerializationBase { try { return Serialization::StartCheck(element) ? std::optional{T::Parse(element)} : std::nullopt; } catch(const std::exception& e) { - SPDLOG_WARN("Type {}: Failed Parse but no TryParse found: {}", e.what(), nameof::nameof_type()); + SPDLOG_WARN("Type {}: Failed Parse but no TryParse found: {}", nameof::nameof_type(), e.what()); } catch(...) { SPDLOG_WARN("Type {}: Failed Parse but no TryParse found", nameof::nameof_type()); } diff --git a/library/src/features.cpp b/library/src/features.cpp index fc34fcb..3849b37 100644 --- a/library/src/features.cpp +++ b/library/src/features.cpp @@ -36,13 +36,10 @@ auto StreamFeatures::BindType::Parse(const xmlpp::Element* node) -> StreamFeatur auto StreamFeatures::Parse(const xmlpp::Element* node) -> StreamFeatures { auto ptr = dynamic_cast(node->get_first_child("mechanisms")); - if(!ptr) { - throw std::runtime_error("Not found or invalid node mechanisms for StreamFeatures"); - } - + using SaslMechanisms = struct SaslMechanisms; return {.startTls = ToOptional(node->get_first_child("starttls")), .bind = ToOptional(node->get_first_child("bind")), - .saslMechanisms = SaslMechanisms::Parse(ptr), + .saslMechanisms = ptr ? SaslMechanisms::Parse(ptr) : SaslMechanisms{.mechanisms = {}}, .others = node->get_children() | std::views::filter([](const xmlpp::Node* node) -> bool { auto name = node->get_name(); return name != "starttls" && name != "mechanisms" && name != "bind";