xmpp_proxy/main.cpp
2024-11-15 13:16:40 +00:00

214 lines
8.7 KiB
C++

#include <boost/asio.hpp>
#include <boost/asio/co_spawn.hpp>
#include <boost/cobalt.hpp>
#include <boost/cobalt/main.hpp>
#include <filesystem>
#include <fstream>
#include <larra/printer_stream.hpp>
#include <larra/stream.hpp>
#include <larra/xml_stream.hpp>
#include <nlohmann/json.hpp>
#include <print>
static constexpr std::uint16_t kXmppS2SPort = 5269;
enum class IpVersion { kV6, kV4 };
struct Socks5Proxy {
std::string host;
std::uint16_t port;
static constexpr auto Parse(nlohmann::json json) -> Socks5Proxy {
return {.host = json["host"].get<std::string>(), .port = json["port"].get<std::uint16_t>()};
};
};
struct TldToProxy {
std::string tld;
Socks5Proxy proxy;
static constexpr auto Parse(nlohmann::json json) -> TldToProxy {
return {.tld = json["tld"].get<std::string>(), .proxy = Socks5Proxy::Parse(json["proxy"])};
}
};
struct Options {
IpVersion ipVersion;
std::uint16_t listenPort;
bool debug;
std::vector<TldToProxy> data;
static constexpr auto Parse(std::filesystem::path file) -> Options {
std::ifstream input(file);
auto json = nlohmann::json::parse(input);
auto ipVersionStr = json["ipVersion"].get<std::string_view>();
return {.ipVersion = ipVersionStr == "v6" ? IpVersion::kV6
: ipVersionStr == "v4" ? IpVersion::kV4
: (throw std::runtime_error{"Invalid ip version in config"}),
.listenPort = json["listenPort"].get<std::uint16_t>(),
.debug = json["debug"].get<bool>(),
.data = json["data"] | std::views::transform(&TldToProxy::Parse) | std::ranges::to<std::vector<TldToProxy>>()};
};
};
// NOLINTNEXTLINE
auto ConnectVia(auto& socket, const Socks5Proxy& socksProxy, std::string_view address, std::uint16_t port) -> boost::cobalt::task<void> {
constexpr std::array kSocks5RequestStart = {std::byte{0x05}, std::byte{0x01}, std::byte{0x00}, std::byte{0x03}};
constexpr std::size_t kSocks5RequestMaxSize = 257;
constexpr std::size_t kSocks5ReplyTypeSize = 10;
constexpr std::array kHandshakeRequest{std::byte{0x05}, std::byte{0x01}, std::byte{0x00}};
std::array<std::byte, 2> handshakeResponse; // NOLINT
co_await boost::asio::async_write(socket, boost::asio::buffer(kHandshakeRequest), boost::asio::transfer_all(), boost::cobalt::use_op);
co_await boost::asio::async_read(socket, boost::asio::buffer(handshakeResponse), boost::asio::transfer_all(), boost::cobalt::use_op);
if(handshakeResponse[0] != std::byte{0x05} || handshakeResponse[1] != std::byte{0x00}) { // NOLINT
throw std::exception{};
};
const auto [connectRequest, size] = [&] {
const auto size = static_cast<std::byte>(address.size());
const auto htonsPort = std::bit_cast<std::array<std::byte, 2>>(htons(port));
auto range =
std::array{std::span{kSocks5RequestStart.begin(), kSocks5RequestStart.size()},
std::span{&size, 1},
std::span{larra::xmpp::utils::StartLifetimeAsArray<const std::byte>(address.data(), address.size()), address.size()},
std::span{htonsPort.data(), 2}} |
std::views::join;
std::array<std::byte, kSocks5RequestMaxSize> response; // NOLINT
auto sizee = std::ranges::copy(range, response.begin()).out - response.begin();
return std::pair{response, sizee};
}();
co_await boost::asio::async_write(
socket, boost::asio::buffer(connectRequest.begin(), size), boost::asio::transfer_all(), boost::cobalt::use_op);
std::array<std::byte, kSocks5ReplyTypeSize> connectReplyType; // NOLINT
co_await boost::asio::async_read(socket, boost::asio::buffer(connectReplyType), boost::asio::transfer_all(), boost::cobalt::use_op);
if(connectReplyType[1] != std::byte{0x00}) {
throw std::exception{};
};
co_return;
}
struct Domain {
std::string_view name;
std::string_view tld;
static constexpr auto Parse(std::string_view data) -> Domain {
auto dot = data.rfind('.');
return {.name = data.substr(0, dot), .tld = data.substr(dot + 1)};
};
};
// NOLINTNEXTLINE
auto Proxify(auto& in, auto& out) -> boost::cobalt::task<void> {
std::array<std::byte, 1024> buffer; // NOLINT
for(;;) {
auto n = co_await in.async_read_some(boost::asio::buffer(buffer), boost::cobalt::use_op);
co_await out.async_write_some(boost::asio::buffer(buffer, n), boost::cobalt::use_op);
};
};
template <typename T>
auto Cobaltify(boost::asio::awaitable<T> task) -> boost::cobalt::task<T> {
auto executor = co_await boost::cobalt::this_coro::executor;
co_return co_await boost::asio::co_spawn(executor, std::move(task), boost::cobalt::use_op);
}
auto ConnectTo(std::string_view address, std::uint16_t port) -> boost::cobalt::task<boost::asio::ip::tcp::socket> {
auto executor = co_await boost::cobalt::this_coro::executor;
boost::asio::ip::tcp::resolver resolver{executor};
auto response = co_await resolver.async_resolve({static_cast<std::string>(address), std::to_string(port)}, boost::cobalt::use_op);
boost::asio::ip::tcp::socket sock{executor};
co_await boost::asio::async_connect(sock, response, boost::cobalt::use_op);
co_return sock;
};
// NOLINTNEXTLINE
auto SendHeader(auto& sock, larra::xmpp::ServerStream stream) -> boost::cobalt::task<void> {
constexpr auto beginSize = sizeof("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n") - 1;
xmlpp::Document doc;
using S = larra::xmpp::Serialization<larra::xmpp::ServerStream>;
auto node = doc.create_root_node(S::kDefaultName, S::kDefaultNamespace, S::kPrefix);
S::Serialize(node, stream);
node->set_namespace_declaration("jabber:server:dialback", "db");
auto str = doc.write_to_string();
auto view = std::string_view{str}.substr(beginSize, str.size() - beginSize - 1);
std::string data = "<?xml version=\"1.0\"?>" + static_cast<std::string>(view.substr(0, view.size() - 2)) + ">";
std::ranges::replace(data, '"', '\'');
co_await boost::asio::async_write(sock, boost::asio::buffer(data), boost::cobalt::use_op);
}
constexpr auto GetAddress(std::string_view addr) {
auto pos = addr.find(".xmpp");
return pos != std::string_view::npos ? addr.substr(pos + 1) : addr;
};
// NOLINTNEXTLINE
auto Process(auto socket, auto transform, const Options& options) -> boost::cobalt::detached {
try {
larra::xmpp::XmlStream stream{transform(std::move(socket))};
larra::xmpp::ServerStream header =
(co_await Cobaltify(stream.ReadOne()), co_await Cobaltify(stream.template ReadOne<larra::xmpp::ServerStream>()));
std::string_view address = GetAddress(header.to.value());
auto domain = Domain::Parse(address);
spdlog::debug("Got domain address {}. Tld: {}", address, domain.tld);
for(const auto& [tld, proxy] : options.data) {
if(domain.tld != tld) {
continue;
};
spdlog::debug("Connect via proxy {}:{}", proxy.host, proxy.port);
larra::xmpp::XmlStream sock2 = transform(co_await ConnectTo(proxy.host, proxy.port));
co_await ConnectVia(sock2, proxy, address, kXmppS2SPort);
co_await SendHeader(sock2, std::move(header));
co_await boost::cobalt::join(Proxify(sock2, stream.next_layer()), Proxify(stream.next_layer(), sock2));
co_return;
};
spdlog::debug("Tld not found");
larra::xmpp::XmlStream sock2 = transform(co_await ConnectTo(address, kXmppS2SPort));
co_await SendHeader(sock2, std::move(header));
co_await boost::cobalt::join(Proxify(sock2, stream.next_layer()), Proxify(stream.next_layer(), sock2));
} catch(const std::exception& err) {
SPDLOG_DEBUG("Exception {}", err.what());
}
};
auto Listen(IpVersion version, std::uint16_t port) -> boost::cobalt::generator<boost::asio::ip::tcp::socket> {
auto executor = co_await boost::cobalt::this_coro::executor;
boost::asio::ip::tcp::acceptor acceptor(executor,
{version == IpVersion::kV4 ? boost::asio::ip::tcp::v4() : boost::asio::ip::tcp::v6(), port});
for(;;) {
co_yield co_await acceptor.async_accept(boost::cobalt::use_op);
};
};
auto Main(std::span<const char*> args) -> boost::cobalt::task<void> {
spdlog::set_pattern("[%H:%M:%S %z] [DEBUG] %v");
if(args.size() != 2) {
std::println("Invalid args.\nUsage: xmpp_proxy <config>");
co_return;
}
auto options = Options::Parse(args[1]);
if(options.debug) {
spdlog::set_level(spdlog::level::debug);
spdlog::debug("Started");
}
for(auto l = Listen(options.ipVersion, options.listenPort);;) {
if(options.debug) {
Process(
co_await l,
[](auto sock) {
return larra::xmpp::PrintStream{std::move(sock)};
},
options);
continue;
}
Process(
co_await l,
[](auto arg) {
return std::move(arg);
},
options);
};
};
auto main(int argc, const char* argv[]) -> int {
run(Main(std::span{argv, static_cast<std::size_t>(argc)}));
};