/******************************************************************************* Copyright(C) Jonas 'Sortie' Termansen 2013. This file is part of Sortix. Sortix is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. Sortix is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with Sortix. If not, see . net/fs.cpp Filesystem based socket interface. *******************************************************************************/ // TODO: Should this be moved into user-space? #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "fs.h" // TODO: This is declared in the header that isn't ready for // kernel usage, so we declare it here for now. #ifndef AF_UNIX #define AF_UNIX 3 #endif namespace Sortix { namespace NetFS { class Manager; class StreamSocket; class Manager : public AbstractInode { public: Manager(uid_t owner, gid_t group, mode_t mode); virtual ~Manager() { } virtual Ref open(ioctx_t* ctx, const char* filename, int flags, mode_t mode); public: bool Listen(StreamSocket* socket); void Unlisten(StreamSocket* socket); Ref Accept(StreamSocket* socket, ioctx_t* ctx, uint8_t* addr, size_t* addrsize, int flags); int AcceptPoll(StreamSocket* socket, ioctx_t* ctx, PollNode* node); bool Connect(StreamSocket* socket); private: StreamSocket* LookupServer(struct sockaddr_un* address); private: StreamSocket* first_server; StreamSocket* last_server; kthread_mutex_t manager_lock; }; class StreamSocket : public AbstractInode { public: StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref manager); virtual ~StreamSocket(); virtual Ref accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, int flags); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int listen(ioctx_t* ctx, int backlog); virtual ssize_t recv(ioctx_t* ctx, uint8_t* buf, size_t count, int flags); virtual ssize_t send(ioctx_t* ctx, const uint8_t* buf, size_t count, int flags); virtual ssize_t read(ioctx_t* ctx, uint8_t* buf, size_t count); virtual ssize_t write(ioctx_t* ctx, const uint8_t* buf, size_t count); virtual int poll(ioctx_t* ctx, PollNode* node); private: int do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); public: /* For use by Manager. */ PollChannel accept_poll_channel; Ref manager; PipeEndpoint incoming; PipeEndpoint outgoing; StreamSocket* prev_socket; StreamSocket* next_socket; StreamSocket* first_pending; StreamSocket* last_pending; struct sockaddr_un* bound_address; bool is_listening; bool is_connected; bool is_refused; kthread_mutex_t socket_lock; kthread_cond_t pending_cond; kthread_cond_t accepted_cond; }; static void QueueAppend(StreamSocket** first, StreamSocket** last, StreamSocket* socket) { assert(!socket->prev_socket); assert(!socket->next_socket); socket->prev_socket = *last; socket->next_socket = NULL; if ( *last ) (*last)->next_socket = socket; if ( !*first ) *first = socket; *last = socket; } static void QueueRemove(StreamSocket** first, StreamSocket** last, StreamSocket* socket) { if ( socket->prev_socket ) socket->prev_socket->next_socket = socket->next_socket; else *first = socket->next_socket; if ( socket->next_socket ) socket->next_socket->prev_socket = socket->prev_socket; else *last = socket->prev_socket; } StreamSocket::StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref manager) { inode_type = INODE_TYPE_STREAM; dev = (dev_t) manager.Get(); ino = (ino_t) this; this->type = S_IFSOCK; this->stat_uid = owner; this->stat_gid = group; this->stat_mode = (mode & S_SETABLE) | this->type; this->prev_socket = NULL; this->next_socket = NULL; this->first_pending = NULL; this->last_pending = NULL; this->bound_address = NULL; this->is_listening = false; this->is_connected = false; this->is_refused = false; this->manager = manager; this->socket_lock = KTHREAD_MUTEX_INITIALIZER; this->pending_cond = KTHREAD_COND_INITIALIZER; this->accepted_cond = KTHREAD_COND_INITIALIZER; } StreamSocket::~StreamSocket() { if ( is_listening ) manager->Unlisten(this); delete[] bound_address; } Ref StreamSocket::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, int flags) { ScopedLock lock(&socket_lock); if ( !is_listening ) return errno = EINVAL, Ref(NULL); return manager->Accept(this, ctx, addr, addrsize, flags); } int StreamSocket::do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) { if ( is_connected || is_listening || bound_address ) return errno = EINVAL, -1; size_t path_offset = offsetof(struct sockaddr_un, sun_path); size_t path_len = (path_offset - addrsize) / sizeof(char); if ( addrsize < path_offset ) return errno = EINVAL, -1; uint8_t* buffer = new uint8_t[addrsize]; if ( !buffer ) return -1; if ( ctx->copy_from_src(buffer, addr, addrsize) ) { struct sockaddr_un* address = (struct sockaddr_un*) buffer; if ( address->sun_family == AF_UNIX ) { bool found_nul = false; for ( size_t i = 0; !found_nul && i < path_len; i++ ) if ( address->sun_path[i] == '\0' ) found_nul = true; if ( found_nul ) { bound_address = address; return 0; } errno = EINVAL; } else errno = EAFNOSUPPORT; } delete[] buffer; return -1; } int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) { ScopedLock lock(&socket_lock); return do_bind(ctx, addr, addrsize); } int StreamSocket::connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) { ScopedLock lock(&socket_lock); if ( is_listening ) return errno = EINVAL, -1; if ( is_connected ) return errno = EISCONN, -1; if ( addr && do_bind(ctx, addr, addrsize) != 0 ) return -1; if ( !bound_address ) return errno = EINVAL, -1; return manager->Connect(this) ? 0 : -1; } int StreamSocket::listen(ioctx_t* /*ctx*/, int /*backlog*/) { ScopedLock lock(&socket_lock); if ( is_connected || is_listening || !bound_address ) return errno = EINVAL, -1; if ( !manager->Listen(this) ) return -1; return 0; } ssize_t StreamSocket::recv(ioctx_t* ctx, uint8_t* buf, size_t count, int /*flags*/) { ScopedLock lock(&socket_lock); if ( !is_connected ) return errno = ENOTCONN, -1; return incoming.read(ctx, buf, count); } ssize_t StreamSocket::send(ioctx_t* ctx, const uint8_t* buf, size_t count, int /*flags*/) { ScopedLock lock(&socket_lock); if ( !is_connected ) return errno = ENOTCONN, -1; return outgoing.write(ctx, buf, count); } ssize_t StreamSocket::read(ioctx_t* ctx, uint8_t* buf, size_t count) { return recv(ctx, buf, count, 0); } ssize_t StreamSocket::write(ioctx_t* ctx, const uint8_t* buf, size_t count) { return send(ctx, buf, count, 0); } int StreamSocket::poll(ioctx_t* ctx, PollNode* node) { if ( is_connected ) // TODO: The poll API is broken, can't provide multiple sources on a poll // node. For now, polling the read channel should be most useful. return incoming.poll(ctx, node); if ( is_listening ) return manager->AcceptPoll(this, ctx, node); return errno = ENOTCONN, -1; } Manager::Manager(uid_t owner, gid_t group, mode_t mode) { inode_type = INODE_TYPE_UNKNOWN; dev = (dev_t) this; ino = 0; this->type = S_IFDIR; this->stat_uid = owner; this->stat_gid = group; this->stat_mode = (mode & S_SETABLE) | this->type; this->manager_lock = KTHREAD_MUTEX_INITIALIZER; this->first_server = NULL; this->last_server = NULL; } static int CompareAddress(const struct sockaddr_un* a, const struct sockaddr_un* b) { return strcmp(a->sun_path, b->sun_path); } StreamSocket* Manager::LookupServer(struct sockaddr_un* address) { for ( StreamSocket* iter = first_server; iter; iter = iter->next_socket ) if ( CompareAddress(iter->bound_address, address) == 0 ) return iter; return NULL; } static StreamSocket* QueuePop(StreamSocket** first, StreamSocket** last) { StreamSocket* ret = *first; assert(ret); QueueRemove(first, last, ret); return ret; } bool Manager::Listen(StreamSocket* socket) { ScopedLock lock(&manager_lock); if ( LookupServer(socket->bound_address) ) return errno = EADDRINUSE, false; QueueAppend(&first_server, &last_server, socket); socket->is_listening = true; return true; } void Manager::Unlisten(StreamSocket* socket) { ScopedLock lock(&manager_lock); while ( socket->first_pending ) { socket->first_pending->is_refused = true; kthread_cond_signal(&socket->first_pending->accepted_cond); socket->first_pending = socket->first_pending->next_socket; } socket->last_pending = NULL; QueueRemove(&first_server, &last_server, socket); socket->is_listening = false; } int Manager::AcceptPoll(StreamSocket* socket, ioctx_t* /*ctx*/, PollNode* node) { ScopedLock lock(&manager_lock); if ( socket->first_pending ) return (node->revents |= POLLIN | POLLRDNORM), 0; socket->accept_poll_channel.Register(node); return errno = EAGAIN, -1; } Ref Manager::Accept(StreamSocket* socket, ioctx_t* ctx, uint8_t* addr, size_t* addrsize, int /*flags*/) { ScopedLock lock(&manager_lock); // TODO: Support non-blocking accept! while ( !socket->first_pending ) if ( !kthread_cond_wait_signal(&socket->pending_cond, &manager_lock) ) return errno = EINTR, Ref(NULL); StreamSocket* client = socket->first_pending; struct sockaddr_un* client_addr = client->bound_address; size_t client_addr_size = offsetof(struct sockaddr_un, sun_path) + (strlen(client_addr->sun_path)+1) * sizeof(char); if ( addr ) { size_t caller_addrsize; if ( !ctx->copy_from_src(&caller_addrsize, addrsize, sizeof(caller_addrsize)) ) return Ref(NULL); if ( caller_addrsize < client_addr_size ) return errno = ERANGE, Ref(NULL); if ( !ctx->copy_from_src(addrsize, &client_addr_size, sizeof(client_addr_size)) ) return Ref(NULL); if ( !ctx->copy_to_dest(addr, client_addr, client_addr_size) ) return Ref(NULL); } // TODO: Give the caller the address of the remote! Ref server(new StreamSocket(0, 0, 0666, Ref(this))); if ( !server ) return Ref(NULL); QueuePop(&socket->first_pending, &socket->last_pending); if ( !client->outgoing.Connect(&server->incoming) ) return Ref(NULL); if ( !server->outgoing.Connect(&client->incoming) ) { client->outgoing.Disconnect(); server->incoming.Disconnect(); return Ref(NULL); } client->is_connected = true; server->is_connected = true; // TODO: Should the server socket inherit the address of the listening // socket or perhaps the one of the client's source/destination, or // nothing at all? kthread_cond_signal(&client->accepted_cond); return server; } bool Manager::Connect(StreamSocket* socket) { ScopedLock lock(&manager_lock); StreamSocket* server = LookupServer(socket->bound_address); if ( !server ) return errno = ECONNREFUSED, false; socket->is_refused = false; QueueAppend(&server->first_pending, &server->last_pending, socket); kthread_cond_signal(&server->pending_cond); server->accept_poll_channel.Signal(POLLIN | POLLRDNORM); while ( !(socket->is_connected || socket->is_refused) ) if ( !kthread_cond_wait_signal(&socket->accepted_cond, &manager_lock) && !(socket->is_connected || socket->is_refused) ) { QueueRemove(&server->first_pending, &server->last_pending, socket); return errno = EINTR, false; } return !socket->is_refused; } // TODO: Support a poll method in Manager. Ref Manager::open(ioctx_t* /*ctx*/, const char* filename, int /*flags*/, mode_t /*mode*/) { if ( !strcmp(filename, "stream") ) { StreamSocket* socket = new StreamSocket(0, 0, 0666, Ref(this)); return Ref(socket); } return errno = ENOENT, Ref(NULL); } void Init(const char* devpath, Ref slashdev) { ioctx_t ctx; SetupKernelIOCtx(&ctx); Ref node(new Manager(0, 0, 0666)); if ( !node ) PanicF("Unable to allocate %s/net/fs inode.", devpath); // TODO: Race condition! Create a mkdir function that returns what it // created, possibly with a O_MKDIR flag to open. if ( slashdev->mkdir(&ctx, "net", 0755) < 0 && errno != EEXIST ) PanicF("Could not create a %s/net directory", devpath); if ( slashdev->mkdir(&ctx, "net/fs", 0755) < 0 && errno != EEXIST ) PanicF("Could not create a %s/net/fs directory", devpath); Ref mpoint = slashdev->open(&ctx, "net/fs", O_READ | O_WRITE, 0); if ( !mpoint ) PanicF("Could not open the %s/net/fs directory", devpath); Ref mtable = CurrentProcess()->GetMTable(); // TODO: Make sure that the mount point is *empty*! Add a proper function // for this on the file descriptor class! if ( !mtable->AddMount(mpoint->ino, mpoint->dev, node) ) PanicF("Unable to mount filesystem on %s/net/fs", devpath); } } // namespace NetFS } // namespace Sortix