From d2dd3f9460d918ef7a1977eee55f00b8621a79fe Mon Sep 17 00:00:00 2001
From: Doralitze <doralitze@chaotikum.org>
Date: Mon, 13 Mar 2023 01:38:40 +0100
Subject: [PATCH] fix: unix socketaddr generation

---
 src/net/sock_address_factory.cpp | 21 +++++++++++++++++----
 src/net/socketaddress.hpp        |  6 +++---
 test/address_tests.cpp           |  2 +-
 3 files changed, 21 insertions(+), 8 deletions(-)

diff --git a/src/net/sock_address_factory.cpp b/src/net/sock_address_factory.cpp
index 88352fd..054d287 100644
--- a/src/net/sock_address_factory.cpp
+++ b/src/net/sock_address_factory.cpp
@@ -162,6 +162,17 @@ namespace rmrf::net {
     }
 
     std::list<socketaddr> get_socketaddr_list(const std::string &interface_description, const std::string &service_or_port, const socket_t socket_type) {
+        if (socket_type == socket_t::UNIX) {
+            sockaddr_un storage;
+            strncpy(storage.sun_path, interface_description.c_str(), sizeof(storage.sun_path));
+            // Required as the automatic initialization of sockaddr_un is broken on linux.
+            // This will be optimized out on platforms where it is not.
+            ((sockaddr*) &storage)->sa_family = AF_UNIX;
+            const socketaddr sa{storage};
+            std::list<socketaddr> l = {sa};
+            return l;
+        }
+
         int port = -1;
 
         try {
@@ -177,9 +188,9 @@ namespace rmrf::net {
             return l;
         }
 
+        // Attempt DNS lookup
         struct addrinfo hints = {};
-
-        struct addrinfo* addrs;
+        struct addrinfo* addrs = nullptr;
 
         hints.ai_family = AF_INET6;
         hints.ai_socktype = get_socket_type_hint(socket_type);
@@ -190,8 +201,10 @@ namespace rmrf::net {
             dns_error = getaddrinfo(interface_description.c_str(), NULL, &hints, &addrs);
 
             if (dns_error != 0) {
-                freeaddrinfo(addrs);
-                throw std::invalid_argument("Something went wrong with the DNS lookup. Error code: " + format_network_error(dns_error));
+                if (addrs != nullptr) {
+                    freeaddrinfo(addrs);
+                }
+                throw std::invalid_argument("Something went wrong with the DNS lookup. Error: " + format_network_error(dns_error));
             }
         }
 
diff --git a/src/net/socketaddress.hpp b/src/net/socketaddress.hpp
index 106c5a1..18f3516 100644
--- a/src/net/socketaddress.hpp
+++ b/src/net/socketaddress.hpp
@@ -70,7 +70,7 @@ public:
     template <typename T, typename std::enable_if<has_field<T>::value, T>::type * = nullptr>
     explicit socketaddr(T *other) : addr{}, len{} {
         if (other->*(family_map<T>::sa_family_field) != family_map<T>::sa_family) {
-            throw netio_exception("Address family mismatch in sockaddr structure.");
+            throw netio_exception("Unable to construct socketaddr object. Address family mismatch in sockaddr structure.");
         }
 
         memcpy(&addr, other, sizeof(T));
@@ -88,7 +88,7 @@ public:
     template <typename T, typename std::enable_if<has_field<T>::value, T>::type * = nullptr>
     explicit socketaddr(const T& other) : addr{}, len{} {
         if (other.*(family_map<T>::sa_family_field) != family_map<T>::sa_family) {
-            throw netio_exception("Address family mismatch in sockaddr structure.");
+            throw netio_exception("Unable to construct socketaddr object from reference. Address family mismatch in sockaddr structure.");
         }
 
         memcpy(&addr, &other, sizeof(T));
@@ -98,7 +98,7 @@ public:
     template <typename T>
     socketaddr& operator=(const T *rhs) {
         if (rhs->*(family_map<T>::sa_family_field) != family_map<T>::sa_family) {
-            throw netio_exception("Address family mismatch in sockaddr structure.");
+            throw netio_exception("Unable to construct socketaddr object from rhs. Address family mismatch in sockaddr structure.");
         }
 
         memcpy(&addr, rhs, sizeof(T));
diff --git a/test/address_tests.cpp b/test/address_tests.cpp
index 867be3e..0c9512b 100644
--- a/test/address_tests.cpp
+++ b/test/address_tests.cpp
@@ -65,5 +65,5 @@ BOOST_AUTO_TEST_CASE(Socketaddr_comparison) {
 
 BOOST_AUTO_TEST_CASE(Unix_socket_construction_test) {
     const auto sa = get_first_general_socketaddr("/tmp/9Lq7BNBnBycd6nxy.socket", "", socket_t::UNIX);
-    BOOST_CHECK_EQUAL(sa.str(), "FileSocket /tmp/9Lq7BNBnBycd6nxy.socket");
+    BOOST_CHECK_EQUAL(sa.str(), "SocketAddress: FileSocket /tmp/9Lq7BNBnBycd6nxy.socket");
 }
-- 
GitLab