diff --git a/main.cpp b/main.cpp index 81f1986..ce4169f 100644 --- a/main.cpp +++ b/main.cpp @@ -8,19 +8,19 @@ int main() { std::cout << LdH::Sockets::Address::parse("google.com", "443").to_string() << std::endl; - auto server = LdH::Sockets::listen_tcp(LdH::Sockets::Address::parse("127.0.0.1", "8081"), 1); + auto server = LdH::Sockets::listen_udp(LdH::Sockets::Address::parse("127.0.0.2", "8081")); auto server_thread = LdH::fork("server", [&] { - auto server_stream = server.wait_for_connection(); char buffer[10]; - server_stream.sock.read(10, buffer); + server.recvOneTruncating(10, buffer); std::cout << buffer << std::endl; - server_stream.sock.close(); + server.close(); }); + auto client_thread = LdH::fork("client", [&] { - auto client = LdH::Sockets::connect_tcp(LdH::Sockets::Address::parse("127.0.0.1", "8081")); + auto client = LdH::Sockets::connect_udp(LdH::Sockets::Address::parse("127.0.0.2", "8081")); char buffer[10] = "hello\n"; - client.write(10, buffer); + client.sendOne(10, buffer); client.close(); }); @@ -28,7 +28,6 @@ int main() { client_thread.join(); server_thread.destroy(); client_thread.destroy(); - server.close(); LdH::Sockets::deinit_sockets(); } diff --git a/modules/sockets/src/berkeley_sockets.cppm b/modules/sockets/src/berkeley_sockets.cppm index afab237..bbb984f 100644 --- a/modules/sockets/src/berkeley_sockets.cppm +++ b/modules/sockets/src/berkeley_sockets.cppm @@ -5,6 +5,7 @@ import ru.landgrafhomyak.BGTU.networks_1.exceptions; import ru.landgrafhomyak.BGTU.networks_1.streams; namespace LdH::Sockets::Berkeley { +#pragma region context concept export template concept BerkeleySocketsContext = requires { typename ctx_t::address_t; } && @@ -29,8 +30,14 @@ namespace LdH::Sockets::Berkeley { requires(typename ctx_t::socket_t s) { { s.close() } -> std::same_as; } && requires(typename ctx_t::socket_t s, std::size_t c, char const *d) { { s.send_stream(c, d) } -> std::same_as; } && requires(typename ctx_t::socket_t s, std::size_t c, char *d) { { s.recv_stream(c, d) } -> std::same_as; } && + requires(typename ctx_t::socket_t s, std::size_t c, char *d) { { s.recv_datagram(c, d) } -> std::same_as; } && + requires(typename ctx_t::socket_t s, typename ctx_t::address_t e, std::size_t c, char const *d) { { s.send_datagram(e, c, d) } -> std::same_as; } && requires(typename ctx_t::socket_t s, typename ctx_t::address_t *a) { { s.accept(a) } -> std::same_as; }; +#pragma endregion + +#pragma region address + template requires BerkeleySocketsContext class _addr_internals; @@ -87,6 +94,9 @@ namespace LdH::Sockets::Berkeley { static Address wrap(ctx_t::address_t const &raw) { return Address{true, raw}; } }; +#pragma endregion + +#pragma region abstract socket enum _socket_state_t { _socket_state_UNINITIALIZED, _socket_state_MOVED, @@ -226,13 +236,15 @@ namespace LdH::Sockets::Berkeley { } }; - export - template requires BerkeleySocketsContext - class StreamSocketsServer; - template class _socket_internals { }; +#pragma endregion + +#pragma region stream sockets + export + template requires BerkeleySocketsContext + class StreamSocketsServer; export @@ -364,4 +376,243 @@ namespace LdH::Sockets::Berkeley { sock.listen(queue_size); return _socket_internals >::wrap(std::move(sock)); } +#pragma endregion + +#pragma region datagram sockets + + template requires BerkeleySocketsContext + class _abstract_datagram_socket { + private: + using _refcnt_t = signed long long; + + std::atomic<_refcnt_t> _refcnt; + + protected: + ctx_t::socket_t _value; + + private: + static constexpr _refcnt_t _refcnt_UNINITIALIZED = -0x0800'0000'0000'0000ll; + static constexpr _refcnt_t _refcnt_MOVED = -0x0800'0000'0000'0001ll; + static constexpr _refcnt_t _refcnt_CLOSED = -0x07FFF'FFFF'FFFF'FFFFll; + + protected: + explicit _abstract_datagram_socket(ctx_t::socket_t &&value) : _refcnt{0}, _value{std::move(value)} { + } + + public: + _abstract_datagram_socket() : _refcnt{_refcnt_UNINITIALIZED}, _value{} { + }; + + _abstract_datagram_socket(_abstract_datagram_socket &&other) noexcept : _refcnt{other._refcnt.load()}, _value{std::move(other._value)} { + }; + + _abstract_datagram_socket &operator=(_abstract_datagram_socket &&other) noexcept { + _refcnt_t current = other->_refcnt.load(); + while (true) { + if (current > 0) + LdH::abort("Can't move socket while it is in use"); + if (this->_refcnt.compare_exchange_weak(current, _refcnt_MOVED)) + break; + } + + current = this->_refcnt.load(); + while (true) { + if (current >= 0) { + other._refcnt.store(0); // rollback + LdH::abort("Variable already initialized"); + } + if (this->_refcnt.compare_exchange_weak(current, 0)) + break; + } + + this->_value = std::move(other._value); + + return *this; + } + + protected: + void _start_usage() { + _refcnt_t current = this->_refcnt.load(); + while (true) { + if (current < 0) { + switch (current) { + case _refcnt_UNINITIALIZED: + LdH::abort("Socket not initialized"); + case _refcnt_MOVED: + LdH::abort("Socket was moved to another location"); + case _refcnt_CLOSED: + LdH::abort("Socket was closed"); + default: + LdH::abort("Socket wrapper corrupted"); + } + } + if (this->_refcnt.compare_exchange_weak(current, current + 1)) + break; + } + } + + void _finish_usage() { + _refcnt_t current = this->_refcnt.load(); + while (true) { + if (current < 0) { + switch (current) { + case _refcnt_UNINITIALIZED: + LdH::abort("Socket not initialized"); + case _refcnt_MOVED: + LdH::abort("Socket was moved to another location"); + case _refcnt_CLOSED: + LdH::abort("Socket was closed"); + default: + LdH::abort("Socket wrapper corrupted"); + } + } + if (this->_refcnt.compare_exchange_weak(current, current - 1)) + break; + } + } + + protected: + void _close() { + _refcnt_t current = this->_refcnt.load(); + while (true) { + if (current != 0) { + switch (current) { + case _refcnt_UNINITIALIZED: + LdH::abort("Socket not initialized"); + case _refcnt_MOVED: + LdH::abort("Socket was moved to another location"); + case _refcnt_CLOSED: + LdH::abort("Socket was already closed"); + default: + if (current > 0) + LdH::abort("Can't close socket while it is in use"); + else + LdH::abort("Socket wrapper corrupted"); + } + } + if (this->_refcnt.compare_exchange_weak(current, _refcnt_CLOSED)) + break; + } + } + + ~_abstract_datagram_socket() noexcept = default; + }; + + export + template requires BerkeleySocketsContext + class ServerDatagramSocket : public _abstract_datagram_socket { + private: + explicit ServerDatagramSocket(ctx_t::socket_t &&value) : _abstract_datagram_socket{std::move(value)} { + } + + template + friend + class _socket_internals; + + public: + ServerDatagramSocket() = default; + + ServerDatagramSocket(ServerDatagramSocket &&other) noexcept = default; + + ServerDatagramSocket &operator=(ServerDatagramSocket &&other) noexcept = default; + + public: + Address recvOneTruncating(std::size_t size, char *data) { + this->_start_usage(); + auto addr = this->_value.recv_datagram(size, data); + this->_finish_usage(); + return _addr_internals::wrap(std::move(addr)); + } + + void sendOne(Address destination, std::size_t size, char const *data) { + this->_start_usage(); + this->_value.send_datagram(destination, size, data); + this->_finish_usage(); + } + + void close() { + this->_close(); + } + + ~ServerDatagramSocket() noexcept = default; + }; + + + template + class _socket_internals > { + public: + static ServerDatagramSocket wrap(ctx_t::socket_t &&raw) { + return ServerDatagramSocket{std::move(raw)}; + } + }; + + export + template requires BerkeleySocketsContext + class ClientDatagramSocket : public _abstract_datagram_socket, public LdH::Streams::InputMessanger, public LdH::Streams::OutputMessanger { + private: + explicit ClientDatagramSocket(ctx_t::socket_t &&value) : _abstract_datagram_socket{std::move(value)} { + } + + template + friend + class _socket_internals; + + public: + ClientDatagramSocket() = default; + + ClientDatagramSocket(ClientDatagramSocket &&other) noexcept = default; + + ClientDatagramSocket &operator=(ClientDatagramSocket &&other) noexcept = default; + + public: + void recvOneTruncating(std::size_t size, char *data) override { + this->_start_usage(); + this->_value.recv_stream(size, data); + this->_finish_usage(); + } + + void sendOne(std::size_t size, char const *data) override { + this->_start_usage(); + this->_value.send_stream(size, data); + this->_finish_usage(); + } + + void close() override { + this->_close(); + } + + ~ClientDatagramSocket() noexcept override = default; + }; + + + template + class _socket_internals > { + public: + static ClientDatagramSocket wrap(ctx_t::socket_t &&raw) { + return ClientDatagramSocket{std::move(raw)}; + } + }; + + export + template requires BerkeleySocketsContext + [[nodiscard]] + ClientDatagramSocket connect_udp(Address addr) { + if (!_addr_internals::has_value(addr)) + LdH::abort("Address not initialized"); + typename ctx_t::socket_t sock = ctx_t::socket_t::create(_addr_internals::unwrap(addr), ctx_t::sock_type::dgram(), ctx_t::proto::udp()); + sock.connect(_addr_internals::unwrap(addr)); + return _socket_internals >::wrap(std::move(sock)); + } + + export + template requires BerkeleySocketsContext + [[nodiscard]] + ServerDatagramSocket listen_udp(Address addr) { + if (!_addr_internals::has_value(addr)) + LdH::abort("Address not initialized"); + typename ctx_t::socket_t sock = ctx_t::socket_t::create(_addr_internals::unwrap(addr), ctx_t::sock_type::dgram(), ctx_t::proto::udp()); + sock.bind(_addr_internals::unwrap(addr)); + return _socket_internals >::wrap(std::move(sock)); + } +#pragma endregion } diff --git a/modules/sockets/src/platform/windows.cpp.inc b/modules/sockets/src/platform/windows.cpp.inc index 2b44bb1..8f71697 100644 --- a/modules/sockets/src/platform/windows.cpp.inc +++ b/modules/sockets/src/platform/windows.cpp.inc @@ -172,6 +172,23 @@ namespace LdH::Sockets { if (size <= 0) return; } } + + void send_datagram(address_t dest, std::size_t size, char const *data) { + int sent_count = ::sendto(this->_value, data, size, 0, reinterpret_cast(&dest._value), sizeof(dest._value)); + if (sent_count == SOCKET_ERROR) { + LdH::throwFromWindowsErrCode(WSAGetLastError()); + } + } + + address_t recv_datagram(std::size_t size, char *data) { + address_t out; + int out_size = sizeof(out); + int sent_count = ::recvfrom(this->_value, data, size, 0, reinterpret_cast(&out), &out_size); + if (sent_count == SOCKET_ERROR) { + LdH::throwFromWindowsErrCode(WSAGetLastError()); + } + return out; + } }; }; @@ -197,4 +214,18 @@ namespace LdH::Sockets { StreamSocketsServer listen_tcp(Address addr, std::size_t queue_size) { return Berkeley::listen_tcp<_WinsockContext>(addr, queue_size); } + + export using ClientDatagramSocket = Berkeley::ClientDatagramSocket<_WinsockContext>; + + export using ServerDatagramSocket = Berkeley::ServerDatagramSocket<_WinsockContext>; + + export + ClientDatagramSocket connect_udp(Address addr) { + return Berkeley::connect_udp<_WinsockContext>(addr); + } + + export + ServerDatagramSocket listen_udp(Address addr) { + return Berkeley::listen_udp<_WinsockContext>(addr); + } }