diff --git a/kernel/net/fs.cpp b/kernel/net/fs.cpp index 73f4c22a..d37689c1 100644 --- a/kernel/net/fs.cpp +++ b/kernel/net/fs.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2014, 2016, 2017, 2021 Jonas 'Sortie' Termansen. + * Copyright (c) 2013, 2014, 2016, 2017, 2021, 2022 Jonas 'Sortie' Termansen. * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -61,12 +61,15 @@ public: mode_t mode); public: + bool Bind(StreamSocket* socket, struct sockaddr_un* addr, size_t addrsize); bool Listen(StreamSocket* socket); + void Unbind(StreamSocket* socket); void Unlisten(StreamSocket* socket); Ref Accept(StreamSocket* socket, ioctx_t* ctx, uint8_t* addr, size_t* addrsize, int flags); int AcceptPoll(StreamSocket* socket, ioctx_t* ctx, PollNode* node); - bool Connect(StreamSocket* socket); + bool Connect(StreamSocket* socket, struct sockaddr_un* addr, + size_t addrsize); private: StreamSocket* LookupServer(struct sockaddr_un* address); @@ -108,9 +111,6 @@ public: virtual int getpeername(ioctx_t* ctx, uint8_t* addr, size_t* addrsize); virtual int getsockname(ioctx_t* ctx, uint8_t* addr, size_t* addrsize); -private: - int do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); - public: /* For use by Manager. */ PollChannel accept_poll_channel; Ref manager; @@ -120,9 +120,12 @@ public: /* For use by Manager. */ StreamSocket* next_socket; StreamSocket* first_pending; StreamSocket* last_pending; - struct sockaddr_un* bound_address; - size_t bound_address_size; + struct sockaddr_un* name; + struct sockaddr_un* peer; + size_t name_size; + size_t peer_size; int shutdown_flags; + bool is_registered; bool is_listening; bool is_connected; bool is_refused; @@ -176,9 +179,12 @@ StreamSocket::StreamSocket(uid_t owner, gid_t group, mode_t mode, this->next_socket = NULL; this->first_pending = NULL; this->last_pending = NULL; - this->bound_address = NULL; - this->bound_address_size = 0; + this->name = NULL; + this->peer = NULL; + this->name_size = 0; + this->peer_size = 0; this->shutdown_flags = 0; + this->is_registered = false; this->is_listening = false; this->is_connected = false; this->is_refused = false; @@ -193,7 +199,9 @@ StreamSocket::~StreamSocket() { if ( is_listening ) manager->Unlisten(this); - free(bound_address); + free(peer); + if ( name ) + manager->Unbind(this); } bool StreamSocket::pass() @@ -222,56 +230,73 @@ Ref StreamSocket::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, return manager->Accept(this, ctx, addr, addrsize, flags); } -int StreamSocket::do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) +static struct sockaddr_un* import_addr(ioctx_t* ctx, + const uint8_t* user_addr, + size_t addrsize) { - if ( is_connected || is_listening || bound_address ) - return errno = EINVAL, -1; size_t path_offset = offsetof(struct sockaddr_un, sun_path); if ( addrsize < path_offset ) - return errno = EINVAL, -1; + return errno = EINVAL, (struct sockaddr_un*) NULL; size_t path_len = path_offset - addrsize; - struct sockaddr_un* address = (struct sockaddr_un*) malloc(addrsize); - if ( !address ) - return -1; - if ( !ctx->copy_from_src(address, addr, addrsize) ) - return free(address), -1; - if ( address->sun_family != AF_UNIX ) - return free(address), errno = EAFNOSUPPORT, -1; + struct sockaddr_un* addr = (struct sockaddr_un*) malloc(addrsize); + if ( !addr ) + return NULL; + if ( !ctx->copy_from_src(addr, user_addr, addrsize) ) + return free(addr), (struct sockaddr_un*) NULL; + if ( addr->sun_family != AF_UNIX ) + return free(addr), errno = EAFNOSUPPORT, (struct sockaddr_un*) NULL; bool found_nul = false; for ( size_t i = 0; !found_nul && i < path_len; i++ ) - if ( address->sun_path[i] == '\0' ) + if ( addr->sun_path[i] == '\0' ) found_nul = true; if ( !found_nul ) - return free(address), errno = EINVAL, -1; - bound_address = address; - bound_address_size = addrsize; + return free(addr), errno = EINVAL, (struct sockaddr_un*) NULL; + return addr; +} + +int StreamSocket::bind(ioctx_t* ctx, const uint8_t* user_addr, size_t addrsize) +{ + ScopedLock lock(&socket_lock); + if ( is_connected || is_listening || name ) + return errno = EINVAL, -1; + struct sockaddr_un* addr = import_addr(ctx, user_addr, addrsize); + if ( !addr ) + return -1; + if ( !manager->Bind(this, addr, addrsize) ) + return free(addr), -1; return 0; } -int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) -{ - ScopedLock lock(&socket_lock); - return do_bind(ctx, addr, addrsize); -} - -int StreamSocket::connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) +int StreamSocket::connect(ioctx_t* ctx, + const uint8_t* user_addr, + size_t addrsize) { ScopedLock lock(&socket_lock); if ( is_listening ) return errno = EINVAL, -1; if ( is_connected ) return errno = EISCONN, -1; - if ( addr && do_bind(ctx, addr, addrsize) != 0 ) + if ( !name ) + { + // TODO: Actually bind the socket to a unique random name. + name = import_addr(ctx, user_addr, addrsize); + if ( !name ) + return -1; + name_size = addrsize; + is_registered = false; + } + struct sockaddr_un* addr = import_addr(ctx, user_addr, addrsize); + if ( !addr ) return -1; - if ( !bound_address ) - return errno = EINVAL, -1; - return manager->Connect(this) ? 0 : -1; + if ( !manager->Connect(this, addr, addrsize) ) + return free(addr), -1; + return 0; } int StreamSocket::listen(ioctx_t* /*ctx*/, int /*backlog*/) { ScopedLock lock(&socket_lock); - if ( is_connected || is_listening || !bound_address ) + if ( is_connected || is_listening || !name ) return errno = EINVAL, -1; if ( !manager->Listen(this) ) return -1; @@ -426,9 +451,9 @@ int StreamSocket::getpeername(ioctx_t* ctx, uint8_t* addr, size_t* addrsize) size_t used_addrsize; if ( !ctx->copy_from_src(&used_addrsize, addrsize, sizeof(used_addrsize)) ) return -1; - if ( bound_address_size < used_addrsize ) - used_addrsize = bound_address_size; - if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) ) + if ( peer_size < used_addrsize ) + used_addrsize = peer_size; + if ( !ctx->copy_to_dest(addr, peer, used_addrsize) ) return -1; if ( !ctx->copy_to_dest(addrsize, &used_addrsize, sizeof(used_addrsize)) ) return -1; @@ -441,9 +466,9 @@ int StreamSocket::getsockname(ioctx_t* ctx, uint8_t* addr, size_t* addrsize) size_t used_addrsize; if ( !ctx->copy_from_src(&used_addrsize, addrsize, sizeof(used_addrsize)) ) return -1; - if ( bound_address_size < used_addrsize ) - used_addrsize = bound_address_size; - if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) ) + if ( name_size < used_addrsize ) + used_addrsize = name_size; + if ( !ctx->copy_to_dest(addr, name, used_addrsize) ) return -1; if ( !ctx->copy_to_dest(addrsize, &used_addrsize, sizeof(used_addrsize)) ) return -1; @@ -473,7 +498,7 @@ static int CompareAddress(const struct sockaddr_un* a, StreamSocket* Manager::LookupServer(struct sockaddr_un* address) { for ( StreamSocket* iter = first_server; iter; iter = iter->next_socket ) - if ( CompareAddress(iter->bound_address, address) == 0 ) + if ( CompareAddress(iter->name, address) == 0 ) return iter; return NULL; } @@ -486,17 +511,28 @@ static StreamSocket* QueuePop(StreamSocket** first, StreamSocket** last) return ret; } +bool Manager::Bind(StreamSocket* socket, + struct sockaddr_un* addr, + size_t addrsize) +{ + ScopedLock lock(&manager_lock); + if ( LookupServer(addr) ) + return errno = EADDRINUSE, false; + socket->name = addr; + socket->name_size = addrsize; + socket->is_registered = true; + QueueAppend(&first_server, &last_server, socket); + return true; +} + bool Manager::Listen(StreamSocket* socket) { ScopedLock lock(&manager_lock); - if ( LookupServer(socket->bound_address) ) - return errno = EADDRINUSE, false; - QueueAppend(&first_server, &last_server, socket); socket->is_listening = true; return true; } -void Manager::Unlisten(StreamSocket* socket) +void Manager::Unbind(StreamSocket* socket) { ScopedLock lock(&manager_lock); while ( socket->first_pending ) @@ -506,7 +542,16 @@ void Manager::Unlisten(StreamSocket* socket) socket->first_pending = socket->first_pending->next_socket; } socket->last_pending = NULL; - QueueRemove(&first_server, &last_server, socket); + if ( socket->is_registered ) + QueueRemove(&first_server, &last_server, socket); + free(socket->name); + socket->name = NULL; + socket->name_size = 0; +} + +void Manager::Unlisten(StreamSocket* socket) +{ + ScopedLock lock(&manager_lock); socket->is_listening = false; } @@ -536,17 +581,17 @@ Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, return errno = EINTR, Ref(NULL); } - struct sockaddr_un* bound_address = socket->bound_address; - size_t bound_address_size = socket->bound_address_size; + StreamSocket* client = socket->first_pending; + if ( addr ) { size_t used_addrsize; if ( !ctx->copy_from_src(&used_addrsize, addrsize, sizeof(used_addrsize)) ) return Ref(NULL); - if ( bound_address_size < used_addrsize ) - used_addrsize = bound_address_size; - if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) ) + if ( client->name_size < used_addrsize ) + used_addrsize = client->name_size; + if ( !ctx->copy_to_dest(addr, client->name, used_addrsize) ) return Ref(NULL); if ( !ctx->copy_to_dest(addrsize, &used_addrsize, sizeof(used_addrsize)) ) @@ -557,14 +602,18 @@ Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, if ( !server ) return Ref(NULL); - server->bound_address = (struct sockaddr_un*) malloc(bound_address_size); - if ( !server->bound_address ) + server->name = (struct sockaddr_un*) malloc(socket->name_size); + if ( !server->name ) + return Ref(NULL); + server->peer = (struct sockaddr_un*) malloc(client->name_size); + if ( !server->peer ) return Ref(NULL); - server->bound_address_size = bound_address_size; - memcpy(server->bound_address, bound_address, bound_address_size); + server->name_size = socket->name_size; + memcpy(server->name, socket->name, socket->name_size); + server->peer_size = client->name_size; + memcpy(server->peer, client->name, client->name_size); - StreamSocket* client = socket->first_pending; QueuePop(&socket->first_pending, &socket->last_pending); if ( !client->outgoing.Connect(&server->incoming) ) @@ -584,11 +633,13 @@ Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, return server; } -bool Manager::Connect(StreamSocket* socket) +bool Manager::Connect(StreamSocket* socket, + struct sockaddr_un* addr, + size_t addrsize) { ScopedLock lock(&manager_lock); - StreamSocket* server = LookupServer(socket->bound_address); - if ( !server ) + StreamSocket* server = LookupServer(addr); + if ( !server || !server->is_listening ) return errno = ECONNREFUSED, false; socket->is_refused = false; @@ -605,7 +656,11 @@ bool Manager::Connect(StreamSocket* socket) return errno = EINTR, false; } - return !socket->is_refused; + if ( socket->is_refused ) + return false; + socket->peer = addr; + socket->peer_size = addrsize; + return true; } // TODO: Support a poll method in Manager.