spit. not so much polish.
This commit is contained in:
parent
58c556ad15
commit
4847a89ecf
|
@ -48,7 +48,7 @@ struct req {
|
||||||
|
|
||||||
typedef struct req goreq;
|
typedef struct req goreq;
|
||||||
|
|
||||||
static struct req *initializeReq(size_t sz, int ipVersion) {
|
static struct req *initializeReq(size_t sz, int ipLen) {
|
||||||
struct req *r = malloc(sizeof(struct req));
|
struct req *r = malloc(sizeof(struct req));
|
||||||
memset(r, 0, sizeof(*r));
|
memset(r, 0, sizeof(*r));
|
||||||
r->buf = malloc(sz);
|
r->buf = malloc(sz);
|
||||||
|
@ -57,12 +57,12 @@ static struct req *initializeReq(size_t sz, int ipVersion) {
|
||||||
r->iov.iov_len = sz;
|
r->iov.iov_len = sz;
|
||||||
r->hdr.msg_iov = &r->iov;
|
r->hdr.msg_iov = &r->iov;
|
||||||
r->hdr.msg_iovlen = 1;
|
r->hdr.msg_iovlen = 1;
|
||||||
switch(ipVersion) {
|
switch(ipLen) {
|
||||||
case 4:
|
case 4:
|
||||||
r->hdr.msg_name = &r->sa;
|
r->hdr.msg_name = &r->sa;
|
||||||
r->hdr.msg_namelen = sizeof(r->sa);
|
r->hdr.msg_namelen = sizeof(r->sa);
|
||||||
break;
|
break;
|
||||||
case 6:
|
case 16:
|
||||||
r->hdr.msg_name = &r->sa6;
|
r->hdr.msg_name = &r->sa6;
|
||||||
r->hdr.msg_namelen = sizeof(r->sa6);
|
r->hdr.msg_namelen = sizeof(r->sa6);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -45,6 +45,11 @@ type UDPConn struct {
|
||||||
// closed is an atomic variable that indicates whether the connection has been closed.
|
// closed is an atomic variable that indicates whether the connection has been closed.
|
||||||
// TODO: Make an atomic bool type that we can use here.
|
// TODO: Make an atomic bool type that we can use here.
|
||||||
closed uint32
|
closed uint32
|
||||||
|
// shutdown is a sequence of funcs to be called when the UDPConn closes.
|
||||||
|
shutdown []func()
|
||||||
|
|
||||||
|
// file is the os file underlying this connection.
|
||||||
|
file *os.File
|
||||||
|
|
||||||
// local is the local address of this UDPConn.
|
// local is the local address of this UDPConn.
|
||||||
local net.Addr
|
local net.Addr
|
||||||
|
@ -61,7 +66,8 @@ type UDPConn struct {
|
||||||
// sendReqC is a channel containing indices into sendReqs
|
// sendReqC is a channel containing indices into sendReqs
|
||||||
// that are free to use (that is, not in the kernel).
|
// that are free to use (that is, not in the kernel).
|
||||||
sendReqC chan int
|
sendReqC chan int
|
||||||
is4 bool
|
// is4 indicates whether the conn is an IPv4 connection.
|
||||||
|
is4 bool
|
||||||
// reads counts the number of outstanding read requests.
|
// reads counts the number of outstanding read requests.
|
||||||
// It is accessed atomically.
|
// It is accessed atomically.
|
||||||
reads int32
|
reads int32
|
||||||
|
@ -72,52 +78,64 @@ func NewUDPConn(pconn net.PacketConn) (*UDPConn, error) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("cannot use io_uring with conn of type %T", pconn)
|
return nil, fmt.Errorf("cannot use io_uring with conn of type %T", pconn)
|
||||||
}
|
}
|
||||||
// this is dumb
|
local := conn.LocalAddr()
|
||||||
local := conn.LocalAddr().String()
|
udpAddr, ok := local.(*net.UDPAddr)
|
||||||
ip, err := netaddr.ParseIPPort(local)
|
if !ok {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("cannot use io_uring with conn.LocalAddr of type %T", local)
|
||||||
return nil, fmt.Errorf("failed to parse UDPConn local addr %s as IP: %w", local, err)
|
|
||||||
}
|
|
||||||
ipVersion := 6
|
|
||||||
if ip.IP().Is4() {
|
|
||||||
ipVersion = 4
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: probe for system capabilities: https://unixism.net/loti/tutorial/probe_liburing.html
|
// TODO: probe for system capabilities: https://unixism.net/loti/tutorial/probe_liburing.html
|
||||||
|
|
||||||
file, err := conn.File()
|
file, err := conn.File()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// conn.File dup'd the conn's fd. We no longer need the original conn.
|
// conn.File dup'd the conn's fd. We no longer need the original conn.
|
||||||
conn.Close()
|
conn.Close()
|
||||||
recvRing := new(C.go_uring)
|
|
||||||
sendRing := new(C.go_uring)
|
u := &UDPConn{
|
||||||
|
recvRing: new(C.go_uring),
|
||||||
|
sendRing: new(C.go_uring),
|
||||||
|
file: file,
|
||||||
|
local: local,
|
||||||
|
is4: len(udpAddr.IP) == 4,
|
||||||
|
}
|
||||||
|
|
||||||
fd := file.Fd()
|
fd := file.Fd()
|
||||||
for _, r := range []*C.go_uring{recvRing, sendRing} {
|
u.shutdown = append(u.shutdown, func() { file.Close() })
|
||||||
ret := C.initialize(r, C.int(fd))
|
|
||||||
if ret < 0 {
|
if ret := C.initialize(u.recvRing, C.int(fd)); ret < 0 {
|
||||||
// TODO: free recvRing if sendRing initialize failed
|
u.doShutdown()
|
||||||
return nil, fmt.Errorf("uring initialization failed: %d", ret)
|
return nil, fmt.Errorf("recvRing initialization failed: %w", syscall.Errno(-ret))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
u := &UDPConn{
|
u.shutdown = append(u.shutdown, func() { C.io_uring_queue_exit(u.recvRing) })
|
||||||
recvRing: recvRing,
|
|
||||||
sendRing: sendRing,
|
if ret := C.initialize(u.sendRing, C.int(fd)); ret < 0 {
|
||||||
local: conn.LocalAddr(),
|
u.doShutdown()
|
||||||
is4: ipVersion == 4,
|
return nil, fmt.Errorf("sendRing initialization failed: %w", syscall.Errno(-ret))
|
||||||
}
|
}
|
||||||
|
u.shutdown = append(u.shutdown, func() { C.io_uring_queue_exit(u.sendRing) })
|
||||||
|
|
||||||
// Initialize buffers
|
// Initialize buffers
|
||||||
for _, reqs := range []*[8]*C.goreq{&u.recvReqs, &u.sendReqs} {
|
for i := range u.recvReqs {
|
||||||
for i := range reqs {
|
u.recvReqs[i] = C.initializeReq(bufferSize, C.int(len(udpAddr.IP)))
|
||||||
reqs[i] = C.initializeReq(bufferSize, C.int(ipVersion))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
for i := range u.sendReqs {
|
||||||
|
u.sendReqs[i] = C.initializeReq(bufferSize, C.int(len(udpAddr.IP)))
|
||||||
|
}
|
||||||
|
u.shutdown = append(u.shutdown, func() {
|
||||||
|
for _, r := range u.recvReqs {
|
||||||
|
C.freeReq(r)
|
||||||
|
}
|
||||||
|
for _, r := range u.sendReqs {
|
||||||
|
C.freeReq(r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Initialize recv half.
|
// Initialize recv half.
|
||||||
for i := range u.recvReqs {
|
for i := range u.recvReqs {
|
||||||
if err := u.submitRecvRequest(i); err != nil {
|
if err := u.submitRecvRequest(i); err != nil {
|
||||||
u.Close() // TODO: will this crash?
|
u.doShutdown()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,20 +251,17 @@ func (u *UDPConn) Close() error {
|
||||||
}
|
}
|
||||||
// TODO: block until no one else uses our rings.
|
// TODO: block until no one else uses our rings.
|
||||||
// (Or is that unnecessary now?)
|
// (Or is that unnecessary now?)
|
||||||
C.io_uring_queue_exit(u.recvRing)
|
u.doShutdown()
|
||||||
C.io_uring_queue_exit(u.sendRing)
|
|
||||||
|
|
||||||
// Free buffers
|
|
||||||
for _, r := range u.recvReqs {
|
|
||||||
C.freeReq(r)
|
|
||||||
}
|
|
||||||
for _, r := range u.sendReqs {
|
|
||||||
C.freeReq(r)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *UDPConn) doShutdown() {
|
||||||
|
for _, fn := range u.shutdown {
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implement net.PacketConn, for convenience integrating with magicsock.
|
// Implement net.PacketConn, for convenience integrating with magicsock.
|
||||||
|
|
||||||
var _ net.PacketConn = (*UDPConn)(nil)
|
var _ net.PacketConn = (*UDPConn)(nil)
|
||||||
|
|
Loading…
Reference in New Issue