Fix non-blocking accept4(2) and getting the Unix socket peer address.

Rename the internal kernel method from accept to accept4.

fixup! Fix non-blocking accept4(2) and getting the unix socket peer address.
This commit is contained in:
Jonas 'Sortie' Termansen 2017-02-25 17:00:24 +01:00
parent 8f3e11b162
commit 4eb9caaa39
9 changed files with 91 additions and 84 deletions

View File

@ -850,11 +850,15 @@ int Descriptor::poll(ioctx_t* ctx, PollNode* node)
return vnode->poll(ctx, node); return vnode->poll(ctx, node);
} }
Ref<Descriptor> Descriptor::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags) Ref<Descriptor> Descriptor::accept4(ioctx_t* ctx, uint8_t* addr,
size_t* addrlen, int flags)
{ {
Ref<Vnode> retvnode = vnode->accept(ctx, addr, addrlen, flags); int old_ctx_dflags = ctx->dflags;
ctx->dflags = ContextFlags(old_ctx_dflags, dflags);
Ref<Vnode> retvnode = vnode->accept4(ctx, addr, addrlen, flags);
if ( !retvnode ) if ( !retvnode )
return Ref<Descriptor>(); return Ref<Descriptor>();
ctx->dflags = old_ctx_dflags;
return Ref<Descriptor>(new Descriptor(retvnode, O_READ | O_WRITE)); return Ref<Descriptor>(new Descriptor(retvnode, O_READ | O_WRITE));
} }

View File

@ -160,7 +160,7 @@ public:
void Disconnect(); void Disconnect();
void Unmount(); void Unmount();
Channel* Connect(ioctx_t* ctx); Channel* Connect(ioctx_t* ctx);
Channel* Accept(); Channel* Accept(ioctx_t* ctx);
Ref<Inode> BootstrapNode(ino_t ino, mode_t type); Ref<Inode> BootstrapNode(ino_t ino, mode_t type);
Ref<Inode> OpenNode(ino_t ino, mode_t type); Ref<Inode> OpenNode(ino_t ino, mode_t type);
@ -181,8 +181,8 @@ class ServerNode : public AbstractInode
public: public:
ServerNode(Ref<Server> server); ServerNode(Ref<Server> server);
virtual ~ServerNode(); virtual ~ServerNode();
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags); int flags);
private: private:
Ref<Server> server; Ref<Server> server;
@ -242,8 +242,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node); virtual int poll(ioctx_t* ctx, PollNode* node);
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname, virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname); const char* newname);
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags); int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int listen(ioctx_t* ctx, int backlog); virtual int listen(ioctx_t* ctx, int backlog);
@ -594,13 +594,17 @@ Channel* Server::Connect(ioctx_t* ctx)
return channel; return channel;
} }
Channel* Server::Accept() Channel* Server::Accept(ioctx_t* ctx)
{ {
ScopedLock lock(&connect_lock); ScopedLock lock(&connect_lock);
listener_system_tid = CurrentThread()->system_tid; listener_system_tid = CurrentThread()->system_tid;
while ( !connecting && !unmounted ) while ( !connecting && !unmounted )
{
if ( ctx->dflags & O_NONBLOCK )
return errno = EWOULDBLOCK, (Channel*) NULL;
if ( !kthread_cond_wait_signal(&connecting_cond, &connect_lock) ) if ( !kthread_cond_wait_signal(&connecting_cond, &connect_lock) )
return errno = EINTR, (Channel*) NULL; return errno = EINTR, (Channel*) NULL;
}
if ( unmounted ) if ( unmounted )
return errno = ECONNRESET, (Channel*) NULL; return errno = ECONNRESET, (Channel*) NULL;
Channel* result = connecting; Channel* result = connecting;
@ -638,18 +642,19 @@ ServerNode::~ServerNode()
server->Disconnect(); server->Disconnect();
} }
Ref<Inode> ServerNode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, Ref<Inode> ServerNode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags) int flags)
{ {
(void) addr; (void) addr;
(void) flags; if ( flags & ~(0) )
return errno = EINVAL, Ref<Inode>(NULL);
size_t out_addrlen = 0; size_t out_addrlen = 0;
if ( addrlen && !ctx->copy_to_dest(addrlen, &out_addrlen, sizeof(out_addrlen)) ) if ( addrlen && !ctx->copy_to_dest(addrlen, &out_addrlen, sizeof(out_addrlen)) )
return Ref<Inode>(NULL); return Ref<Inode>(NULL);
Ref<ChannelNode> node(new ChannelNode); Ref<ChannelNode> node(new ChannelNode);
if ( !node ) if ( !node )
return Ref<Inode>(NULL); return Ref<Inode>(NULL);
Channel* channel = server->Accept(); Channel* channel = server->Accept(ctx);
if ( !channel ) if ( !channel )
return Ref<Inode>(NULL); return Ref<Inode>(NULL);
node->Construct(channel); node->Construct(channel);
@ -1462,8 +1467,8 @@ int Unode::rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
return ret; return ret;
} }
Ref<Inode> Unode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/, Ref<Inode> Unode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/) size_t* /*addrlen*/, int /*flags*/)
{ {
return errno = ENOTSOCK, Ref<Inode>(); return errno = ENOTSOCK, Ref<Inode>();
} }

View File

@ -94,8 +94,8 @@ public:
int poll(ioctx_t* ctx, PollNode* node); int poll(ioctx_t* ctx, PollNode* node);
int rename_here(ioctx_t* ctx, Ref<Descriptor> from, const char* oldpath, int rename_here(ioctx_t* ctx, Ref<Descriptor> from, const char* oldpath,
const char* newpath); const char* newpath);
Ref<Descriptor> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, Ref<Descriptor> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags); int flags);
int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int listen(ioctx_t* ctx, int backlog); int listen(ioctx_t* ctx, int backlog);

View File

@ -104,8 +104,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node) = 0; virtual int poll(ioctx_t* ctx, PollNode* node) = 0;
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname, virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname) = 0; const char* newname) = 0;
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags) = 0; int flags) = 0;
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0; virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0;
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0; virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) = 0;
virtual int listen(ioctx_t* ctx, int backlog) = 0; virtual int listen(ioctx_t* ctx, int backlog) = 0;
@ -210,8 +210,8 @@ public:
virtual int poll(ioctx_t* ctx, PollNode* node); virtual int poll(ioctx_t* ctx, PollNode* node);
virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname, virtual int rename_here(ioctx_t* ctx, Ref<Inode> from, const char* oldname,
const char* newname); const char* newname);
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags); int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
virtual int listen(ioctx_t* ctx, int backlog); virtual int listen(ioctx_t* ctx, int backlog);

View File

@ -93,7 +93,7 @@ public:
int poll(ioctx_t* ctx, PollNode* node); int poll(ioctx_t* ctx, PollNode* node);
int rename_here(ioctx_t* ctx, Ref<Vnode> from, const char* oldname, int rename_here(ioctx_t* ctx, Ref<Vnode> from, const char* oldname,
const char* newname); const char* newname);
Ref<Vnode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags); Ref<Vnode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags);
int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen); int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrlen);
int listen(ioctx_t* ctx, int backlog); int listen(ioctx_t* ctx, int backlog);

View File

@ -512,8 +512,8 @@ int AbstractInode::rename_here(ioctx_t* /*ctx*/, Ref<Inode> /*from*/,
return errno = ENOTDIR, -1; return errno = ENOTDIR, -1;
} }
Ref<Inode> AbstractInode::accept(ioctx_t* /*ctx*/, uint8_t* /*addr*/, Ref<Inode> AbstractInode::accept4(ioctx_t* /*ctx*/, uint8_t* /*addr*/,
size_t* /*addrlen*/, int /*flags*/) size_t* /*addrlen*/, int /*flags*/)
{ {
return errno = ENOTSOCK, Ref<Inode>(); return errno = ENOTSOCK, Ref<Inode>();
} }

View File

@ -731,13 +731,15 @@ int sys_accept4(int fd, void* addr, size_t* addrlen, int flags)
int fdflags = 0; int fdflags = 0;
if ( flags & SOCK_CLOEXEC ) fdflags |= FD_CLOEXEC; if ( flags & SOCK_CLOEXEC ) fdflags |= FD_CLOEXEC;
if ( flags & SOCK_CLOFORK ) fdflags |= FD_CLOFORK; if ( flags & SOCK_CLOFORK ) fdflags |= FD_CLOFORK;
flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK); int descflags = 0;
if ( flags & SOCK_NONBLOCK ) descflags |= O_NONBLOCK;
flags &= ~(SOCK_CLOEXEC | SOCK_CLOFORK | SOCK_NONBLOCK);
ioctx_t ctx; SetupUserIOCtx(&ctx); ioctx_t ctx; SetupUserIOCtx(&ctx);
Ref<Descriptor> conn = desc->accept(&ctx, (uint8_t*) addr, addrlen, flags); Ref<Descriptor> conn = desc->accept4(&ctx, (uint8_t*) addr, addrlen, flags);
if ( !conn ) if ( !conn )
return -1; return -1;
if ( flags & SOCK_NONBLOCK ) if ( descflags )
conn->SetFlags(conn->GetFlags() | O_NONBLOCK); conn->SetFlags(conn->GetFlags() | descflags);
return CurrentProcess()->GetDTable()->Allocate(conn, fdflags); return CurrentProcess()->GetDTable()->Allocate(conn, fdflags);
} }

View File

@ -24,6 +24,7 @@
#include <errno.h> #include <errno.h>
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdlib.h>
#include <string.h> #include <string.h>
#include <sortix/fcntl.h> #include <sortix/fcntl.h>
@ -82,8 +83,8 @@ class StreamSocket : public AbstractInode
public: public:
StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref<Manager> manager); StreamSocket(uid_t owner, gid_t group, mode_t mode, Ref<Manager> manager);
virtual ~StreamSocket(); virtual ~StreamSocket();
virtual Ref<Inode> accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, virtual Ref<Inode> accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags); int flags);
virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize);
virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize); virtual int connect(ioctx_t* ctx, const uint8_t* addr, size_t addrsize);
virtual int listen(ioctx_t* ctx, int backlog); virtual int listen(ioctx_t* ctx, int backlog);
@ -116,6 +117,7 @@ public: /* For use by Manager. */
StreamSocket* first_pending; StreamSocket* first_pending;
StreamSocket* last_pending; StreamSocket* last_pending;
struct sockaddr_un* bound_address; struct sockaddr_un* bound_address;
size_t bound_address_size;
bool is_listening; bool is_listening;
bool is_connected; bool is_connected;
bool is_refused; bool is_refused;
@ -167,6 +169,7 @@ StreamSocket::StreamSocket(uid_t owner, gid_t group, mode_t mode,
this->first_pending = NULL; this->first_pending = NULL;
this->last_pending = NULL; this->last_pending = NULL;
this->bound_address = NULL; this->bound_address = NULL;
this->bound_address_size = 0;
this->is_listening = false; this->is_listening = false;
this->is_connected = false; this->is_connected = false;
this->is_refused = false; this->is_refused = false;
@ -181,11 +184,11 @@ StreamSocket::~StreamSocket()
{ {
if ( is_listening ) if ( is_listening )
manager->Unlisten(this); manager->Unlisten(this);
delete[] bound_address; free(bound_address);
} }
Ref<Inode> StreamSocket::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrsize, Ref<Inode> StreamSocket::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrsize,
int flags) int flags)
{ {
ScopedLock lock(&socket_lock); ScopedLock lock(&socket_lock);
if ( !is_listening ) if ( !is_listening )
@ -198,33 +201,25 @@ int StreamSocket::do_bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize)
if ( is_connected || is_listening || bound_address ) if ( is_connected || is_listening || bound_address )
return errno = EINVAL, -1; return errno = EINVAL, -1;
size_t path_offset = offsetof(struct sockaddr_un, sun_path); size_t path_offset = offsetof(struct sockaddr_un, sun_path);
size_t path_len = (path_offset - addrsize) / sizeof(char);
if ( addrsize < path_offset ) if ( addrsize < path_offset )
return errno = EINVAL, -1; return errno = EINVAL, -1;
uint8_t* buffer = new uint8_t[addrsize]; size_t path_len = path_offset - addrsize;
if ( !buffer ) struct sockaddr_un* address = (struct sockaddr_un*) malloc(addrsize);
if ( !address )
return -1; return -1;
if ( ctx->copy_from_src(buffer, addr, addrsize) ) if ( !ctx->copy_from_src(address, addr, addrsize) )
{ return free(address), -1;
struct sockaddr_un* address = (struct sockaddr_un*) buffer; if ( address->sun_family != AF_UNIX )
if ( address->sun_family == AF_UNIX ) return free(address), errno = EAFNOSUPPORT, -1;
{ bool found_nul = false;
bool found_nul = false; for ( size_t i = 0; !found_nul && i < path_len; i++ )
for ( size_t i = 0; !found_nul && i < path_len; i++ ) if ( address->sun_path[i] == '\0' )
if ( address->sun_path[i] == '\0' ) found_nul = true;
found_nul = true; if ( !found_nul )
if ( found_nul ) return free(address), errno = EINVAL, -1;
{ bound_address = address;
bound_address = address; bound_address_size = addrsize;
return 0; return 0;
}
errno = EINVAL;
}
else
errno = EAFNOSUPPORT;
}
delete[] buffer;
return -1;
} }
int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize) int StreamSocket::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrsize)
@ -465,40 +460,43 @@ int Manager::AcceptPoll(StreamSocket* socket, ioctx_t* /*ctx*/, PollNode* node)
} }
Ref<StreamSocket> Manager::Accept(StreamSocket* socket, ioctx_t* ctx, Ref<StreamSocket> Manager::Accept(StreamSocket* socket, ioctx_t* ctx,
uint8_t* addr, size_t* addrsize, int /*flags*/) uint8_t* addr, size_t* addrsize, int flags)
{ {
if ( flags & ~(0) )
return errno = EINVAL, Ref<StreamSocket>(NULL);
ScopedLock lock(&manager_lock); ScopedLock lock(&manager_lock);
// TODO: Support non-blocking accept!
while ( !socket->first_pending ) while ( !socket->first_pending )
{
if ( (ctx->dflags & O_NONBLOCK) || (flags & SOCK_NONBLOCK) )
return errno = EWOULDBLOCK, Ref<StreamSocket>(NULL);
if ( !kthread_cond_wait_signal(&socket->pending_cond, &manager_lock) ) if ( !kthread_cond_wait_signal(&socket->pending_cond, &manager_lock) )
return errno = EINTR, Ref<StreamSocket>(NULL); return errno = EINTR, Ref<StreamSocket>(NULL);
StreamSocket* client = socket->first_pending;
struct sockaddr_un* client_addr = client->bound_address;
size_t client_addr_size = offsetof(struct sockaddr_un, sun_path) +
(strlen(client_addr->sun_path)+1) * sizeof(char);
if ( addr )
{
size_t caller_addrsize;
if ( !ctx->copy_from_src(&caller_addrsize, addrsize, sizeof(caller_addrsize)) )
return Ref<StreamSocket>(NULL);
if ( caller_addrsize < client_addr_size )
return errno = ERANGE, Ref<StreamSocket>(NULL);
if ( !ctx->copy_from_src(addrsize, &client_addr_size, sizeof(client_addr_size)) )
return Ref<StreamSocket>(NULL);
if ( !ctx->copy_to_dest(addr, client_addr, client_addr_size) )
return Ref<StreamSocket>(NULL);
} }
// TODO: Give the caller the address of the remote! struct sockaddr_un* bound_address = socket->bound_address;
size_t bound_address_size = socket->bound_address_size;
if ( addr )
{
size_t used_addrsize;
if ( !ctx->copy_from_src(&used_addrsize, addrsize,
sizeof(used_addrsize)) )
return Ref<StreamSocket>(NULL);
if ( bound_address_size < used_addrsize )
used_addrsize = bound_address_size;
if ( !ctx->copy_to_dest(addr, bound_address, bound_address_size) )
return Ref<StreamSocket>(NULL);
if ( !ctx->copy_to_dest(addrsize, &used_addrsize,
sizeof(used_addrsize)) )
return Ref<StreamSocket>(NULL);
}
Ref<StreamSocket> server(new StreamSocket(0, 0, 0666, Ref<Manager>(this))); Ref<StreamSocket> server(new StreamSocket(0, 0, 0666, Ref<Manager>(this)));
if ( !server ) if ( !server )
return Ref<StreamSocket>(NULL); return Ref<StreamSocket>(NULL);
StreamSocket* client = socket->first_pending;
QueuePop(&socket->first_pending, &socket->last_pending); QueuePop(&socket->first_pending, &socket->last_pending);
if ( !client->outgoing.Connect(&server->incoming) ) if ( !client->outgoing.Connect(&server->incoming) )
@ -513,10 +511,6 @@ Ref<StreamSocket> Manager::Accept(StreamSocket* socket, ioctx_t* ctx,
client->is_connected = true; client->is_connected = true;
server->is_connected = true; server->is_connected = true;
// TODO: Should the server socket inherit the address of the listening
// socket or perhaps the one of the client's source/destination, or
// nothing at all?
kthread_cond_signal(&client->accepted_cond); kthread_cond_signal(&client->accepted_cond);
return server; return server;

View File

@ -391,12 +391,14 @@ int Vnode::poll(ioctx_t* ctx, PollNode* node)
return inode->poll(ctx, node); return inode->poll(ctx, node);
} }
Ref<Vnode> Vnode::accept(ioctx_t* ctx, uint8_t* addr, size_t* addrlen, int flags) Ref<Vnode> Vnode::accept4(ioctx_t* ctx, uint8_t* addr, size_t* addrlen,
int flags)
{ {
Ref<Inode> retinode = inode->accept(ctx, addr, addrlen, flags); Ref<Inode> retinode = inode->accept4(ctx, addr, addrlen, flags);
if ( !retinode ) if ( !retinode )
return Ref<Vnode>(); return Ref<Vnode>();
return Ref<Vnode>(new Vnode(retinode, Ref<Vnode>(), retinode->ino, retinode->dev)); return Ref<Vnode>(new Vnode(retinode, Ref<Vnode>(), retinode->ino,
retinode->dev));
} }
int Vnode::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen) int Vnode::bind(ioctx_t* ctx, const uint8_t* addr, size_t addrlen)