diff --git a/library/include/larra/client/client.hpp b/library/include/larra/client/client.hpp index ff29035..62b2467 100644 --- a/library/include/larra/client/client.hpp +++ b/library/include/larra/client/client.hpp @@ -313,40 +313,9 @@ struct ClientCreateVisitor { } }; - auto GetStartStreamIndex(auto& socket, boost::asio::streambuf& streambuf) -> boost::asio::awaitable { - auto splited = Splitter{GetEnumerated(streambuf)}; - using It = decltype(splited.begin()); - // clang-format off - co_return co_await - [&](this auto&& self, std::size_t n, It it) -> std::optional { - return n == 0 ? std::move(it) : it == splited.end() ? std::nullopt : self(n - 1, (++it, std::move(it))); - }(2, splited.begin()) - .transform([](auto value) -> boost::asio::awaitable { - auto [n, _] = *value; - co_return n; - }) // NOLINTNEXTLINE - .or_else([&] { // NOLINTNEXTLINE - return std::optional{[](auto self, auto& socket, auto& streambuf) -> boost::asio::awaitable { - auto buf = streambuf.prepare(4096); // NOLINT - std::size_t n = co_await socket.async_read_some(buf, boost::asio::use_awaitable); - streambuf.commit(n); - co_return co_await self->GetStartStreamIndex(socket, streambuf); - }(this, socket, streambuf)}; - }) - .value(); - // clang-format on - } - - auto ReadStartStream(auto& socket, boost::asio::streambuf& streambuf) -> boost::asio::awaitable { - auto n = co_await this->GetStartStreamIndex(socket, streambuf); - xmlpp::DomParser parser; - std::string dataToReed = - (::larra::xmpp::impl::GetCharsRangeFromBuf(streambuf) | std::views::take(n - 1) | std::ranges::to()) + "/>"; - - parser.parse_memory(dataToReed); - auto doc = parser.get_document(); - SPDLOG_DEBUG("Stream readed. Consuming {} bytes with stream data {}. Total buffer size: {}", n, dataToReed, streambuf.size()); - streambuf.consume(n); + template + auto ReadStartStream(RawXmlStream& stream) -> boost::asio::awaitable { + auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne()); co_return ServerToUserStream::Parse(doc->get_root_node()); } @@ -357,7 +326,7 @@ struct ClientCreateVisitor { co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server, .version = "1.0", .xmlLang = "en"}); SPDLOG_DEBUG("UserStream sended"); - ServerToUserStream sToUStream = co_await ReadStartStream(stream.next_layer(), *stream.streambuf); + ServerToUserStream sToUStream = co_await ReadStartStream(stream); StreamFeatures features = co_await stream.template Read(); SPDLOG_DEBUG("features parsed"); @@ -376,7 +345,7 @@ struct ClientCreateVisitor { co_await this->Connect(socket.next_layer(), co_await this->Resolve()); co_await stream.Send(UserStream{.from = account.Jid().Username("anonymous"), .to = account.Jid().server}, socket.next_layer()); SPDLOG_DEBUG("UserStream sended"); - auto streamHeader = co_await this->ReadStartStream(socket, *stream.streambuf); + auto streamHeader = co_await this->ReadStartStream(stream); StreamFeatures features = co_await stream.template Read(); SPDLOG_DEBUG("features parsed(SSL)"); if(!features.startTls) { @@ -388,7 +357,7 @@ struct ClientCreateVisitor { } co_await this->ProcessTls(stream); co_await stream.Send(UserStream{.from = account.Jid(), .to = account.Jid().server}, socket.next_layer()); - auto newStreamHeader = co_await this->ReadStartStream(socket, *stream.streambuf); + auto newStreamHeader = co_await this->ReadStartStream(stream); auto newFeatures = co_await stream.template Read(); co_await this->Auth(stream, std::move(newStreamHeader), std::move(newFeatures)); co_return Client{std::move(this->account).Jid(), RawXmlStream{std::move(socket)}}; diff --git a/library/include/larra/raw_xml_stream.hpp b/library/include/larra/raw_xml_stream.hpp index 87879ad..42b9918 100644 --- a/library/include/larra/raw_xml_stream.hpp +++ b/library/include/larra/raw_xml_stream.hpp @@ -58,6 +58,8 @@ struct XmlPath : public xmlpp::Element { namespace impl { +constexpr std::size_t kXmlStreamReadChunkSize = 4096; + constexpr auto BufferToStringView(const boost::asio::const_buffer& buffer, size_t size) -> std::string_view { assert(size <= buffer.size()); return {boost::asio::buffer_cast(buffer), size}; @@ -124,6 +126,52 @@ struct RawXmlStream : Stream { return *this; } + auto ReadOne(auto& socket) -> boost::asio::awaitable> { + auto doc = std::make_unique(); + impl::Parser parser(*doc); + for(;;) { + auto enumerated = std::views::zip(std::views::iota(std::size_t{}, this->streambuf->size()), + ::larra::xmpp::impl::GetCharsRangeFromBuf(*this->streambuf)); + + auto it = std::ranges::find(enumerated, '>', [](auto v) { + auto [_, c] = v; + return c; + }); + if(it == std::ranges::end(enumerated)) { + for(const auto& buf : this->streambuf->data()) { + auto error = parser.ParseChunk(impl::BufferToStringView(buf)); + if(error) { + throw std::runtime_error(std::format("Bad xml object: {}", xmlpp::format_xml_error(error))); + } + } + this->streambuf->consume(this->streambuf->size()); + auto buff = this->streambuf->prepare(impl::kXmlStreamReadChunkSize); + auto n = co_await socket.async_read_some(buff, boost::asio::use_awaitable); + this->streambuf->commit(n); + continue; + } + auto [i, _] = *it; + auto toRead = i + 1; + for(const auto& buf : this->streambuf->data()) { + if(toRead == 0) { + break; + } + auto toReadCurrent = std::min(buf.size(), toRead); + + auto error = parser.ParseChunk(impl::BufferToStringView(buf, toReadCurrent)); + if(error) { + throw std::runtime_error(std::format("Bad xml object: {}", xmlpp::format_xml_error(error))); + } + toRead -= toReadCurrent; + } + + this->streambuf->consume(i + 1); + co_return doc; + } + } + auto ReadOne() -> boost::asio::awaitable> { + co_return co_await this->ReadOne(this->next_layer()); + } inline auto Read(auto& socket) -> boost::asio::awaitable> { auto doc = std::make_unique(); // Not movable :( impl::Parser parser(*doc); @@ -153,7 +201,7 @@ struct RawXmlStream : Stream { } this->streambuf->consume(this->streambuf->size()); for(;;) { - auto buff = this->streambuf->prepare(4096); // NOLINT + auto buff = this->streambuf->prepare(impl::kXmlStreamReadChunkSize); auto [e, n] = co_await socket.async_read_some(buff, boost::asio::as_tuple(boost::asio::use_awaitable)); if(e) { boost::system::throw_exception_from_error(e, boost::source_location()); diff --git a/tests/raw_xml_stream.cpp b/tests/raw_xml_stream.cpp index 451d3ed..a2c7638 100644 --- a/tests/raw_xml_stream.cpp +++ b/tests/raw_xml_stream.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace larra::xmpp { @@ -19,6 +20,9 @@ constexpr std::string_view kDoc3 = "xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>PLAINSCRAM-SHA-256X-OAUTH2"; +constexpr std::string_view kDoc4 = + R"()"; + TEST(RawXmlStream, ReadByOne) { boost::asio::io_context context; bool error{}; @@ -156,4 +160,32 @@ TEST(RawXmlStream, Write) { EXPECT_FALSE(error); } +TEST(RawXmlStream, ReadOneByOne) { + boost::asio::io_context context; + bool error{}; + boost::asio::co_spawn( + context, + // NOLINTNEXTLINE: Safe + [&] -> boost::asio::awaitable { + RawXmlStream stream{impl::MockSocket{context.get_executor(), 1}}; + stream.AddReceivedData(kDoc4); + try { + auto doc = (co_await stream.ReadOne(), co_await stream.ReadOne()); + auto node = doc->get_root_node(); + EXPECT_TRUE(node); + if(!node) { + co_return; + } + auto stream = ServerToUserStream::Parse(node); + } catch(const std::exception& err) { + SPDLOG_ERROR("{}", err.what()); + error = true; + } + }, + boost::asio::detached); + + context.run(); + EXPECT_FALSE(error); +} + } // namespace larra::xmpp