Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions disco/disco.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/netip"

"go4.org/mem"
"golang.org/x/crypto/nacl/box"
"tailscale.com/types/key"
)

Expand All @@ -48,6 +49,19 @@ const (

const v0 = byte(0)

// v1 Ping and Pong are padded as follows. CallMeMaybe is still on v0 and unpadded.
const v1 = byte(1)

// paddedPayloadLen is the desired length we want to pad Ping and Pong payloads
// to so that they are the maximum size of a Wireguard packet we would
// subsequently send. This ensures that any UDP paths we discover will actually
// support the packet sizes the net stack will send over those paths. Any peers
// behind a small-MTU link will have to depend on DERP.
// c.f. https://github.com/coder/coder/issues/15523
// Our inner IP packets can be up to 1280 bytes, with the Wireguard header of
// 30 bytes, that is 1310. The final 2 is the inner payload header's type and version.
const paddedPayloadLen = 1310 - len(Magic) - keyLen - NonceLen - box.Overhead - 2

var errShort = errors.New("short message")

// LooksLikeDiscoWrapper reports whether p looks like it's a packet
Expand Down Expand Up @@ -120,12 +134,8 @@ type Ping struct {
}

func (m *Ping) AppendMarshal(b []byte) []byte {
dataLen := 12
hasKey := !m.NodeKey.IsZero()
if hasKey {
dataLen += key.NodePublicRawLen
}
ret, d := appendMsgHeader(b, TypePing, v0, dataLen)
ret, d := appendMsgHeader(b, TypePing, v1, paddedPayloadLen)
n := copy(d, m.TxID[:])
if hasKey {
m.NodeKey.AppendTo(d[:n])
Expand Down Expand Up @@ -217,7 +227,7 @@ type Pong struct {
const pongLen = 12 + 16 + 2

func (m *Pong) AppendMarshal(b []byte) []byte {
ret, d := appendMsgHeader(b, TypePong, v0, pongLen)
ret, d := appendMsgHeader(b, TypePong, v1, paddedPayloadLen)
d = d[copy(d, m.TxID[:]):]
ip16 := m.Src.Addr().As16()
d = d[copy(d, ip16[:]):]
Expand Down
75 changes: 69 additions & 6 deletions disco/disco_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,31 @@ func TestMarshalAndParse(t *testing.T) {
m: &Ping{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
},
want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c",
want: "01 01 01 02 03 04 05 06 07 08 09 0a 0b 0c",
},
{
name: "ping_with_nodekey_src",
m: &Ping{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})),
},
want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f",
want: "01 01 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f",
},
{
name: "pong",
m: &Pong{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Src: mustIPPort("2.3.4.5:1234"),
},
want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2",
want: "02 01 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2",
},
{
name: "pongv6",
m: &Pong{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Src: mustIPPort("[fed0::12]:6666"),
},
want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a",
want: "02 01 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a",
},
{
name: "call_me_maybe",
Expand Down Expand Up @@ -77,8 +77,8 @@ func TestMarshalAndParse(t *testing.T) {
}

gotHex := fmt.Sprintf("% x", got)
if gotHex != tt.want {
t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want)
if !strings.HasPrefix(gotHex, tt.want) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expect the end of the ping payload to be a bunch of 00s we should probably test that, either by changing the want for each of the ping/pong cases to be the full value or by some other method

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't care that it's zeros per se, but I'll add a check that it's the padded length.

t.Fatalf("wrong marshal\n got: %s\nwant prefix: %s\n", gotHex, tt.want)
}

back, err := Parse([]byte(got))
Expand All @@ -92,6 +92,69 @@ func TestMarshalAndParse(t *testing.T) {
}
}

func TestParsePingPongV0(t *testing.T) {
tests := []struct {
name string
payload []byte
m Message
}{
{
name: "ping",
m: &Ping{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
},
payload: []byte{0x01, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c},
},
{
name: "ping_with_nodekey_src",
m: &Ping{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})),
},
payload: []byte{
0x01, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
0x00, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1e, 0x1f},
},
{
name: "pong",
m: &Pong{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Src: mustIPPort("2.3.4.5:1234"),
},
payload: []byte{
0x02, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x02, 0x03, 0x04, 0x05,
0x04, 0xd2},
},
{
name: "pongv6",
m: &Pong{
TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Src: mustIPPort("[fed0::12]:6666"),
},
payload: []byte{
0x02, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
0xfe, 0xd0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12,
0x1a, 0x0a},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
back, err := Parse(tt.payload)
if err != nil {
t.Fatalf("parse back: %v", err)
}
if !reflect.DeepEqual(back, tt.m) {
t.Errorf("message in %+v doesn't match Parse result %+v", tt.m, back)
}
})
}
}

func mustIPPort(s string) netip.AddrPort {
ipp, err := netip.ParseAddrPort(s)
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions wgengine/magicsock/magicsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,12 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
continue
}
trySetSocketBuffer(pconn, c.logf)
// CODER: https://github.com/coder/coder/issues/15523
// Attempt to tell the OS not to fragment packets over this interface. We pad disco Ping and Pong packets to the
// size of the direct UDP packets that get sent for direct connections. Thus, any interfaces or paths that
// cannot fully support direct connections due to MTU limitations will not be selected. If no direct paths meet
// the MTU requirements for a peer, we will fall back to DERP for that peer.
tryPreventFragmentation(pconn, c.logf, network)
// Success.
if debugBindSocket() {
c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port)
Expand Down
36 changes: 36 additions & 0 deletions wgengine/magicsock/magicsock_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package magicsock

import (
"net"

"golang.org/x/sys/unix"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
)

func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) {
if c, ok := pconn.(*net.UDPConn); ok {
s, err := c.SyscallConn()
if err != nil {
logf("magicsock: dontfrag: failed to get syscall conn: %v", err)
}
level := unix.IPPROTO_IP
option := unix.IP_DONTFRAG
if network == "udp6" {
level = unix.IPPROTO_IPV6
option = unix.IPV6_DONTFRAG
}
err = s.Control(func(fd uintptr) {
err := unix.SetsockoptInt(int(fd), level, option, 1)
if err != nil {
logf("magicsock: dontfrag: SetsockoptInt failed: %v", err)
}
})
if err != nil {
logf("magicsock: dontfrag: control connection failed: %v", err)
}
logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String())
return
}
logf("magicsock: dontfrag: failed because it was not a UDPConn")
}
27 changes: 27 additions & 0 deletions wgengine/magicsock/magicsock_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,30 @@ func init() {
// message. These contain a single uint16 of data.
controlMessageSize = unix.CmsgSpace(2)
}

func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) {
if c, ok := pconn.(*net.UDPConn); ok {
s, err := c.SyscallConn()
if err != nil {
logf("magicsock: dontfrag: failed to get syscall conn: %v", err)
}
level := unix.IPPROTO_IP
option := unix.IP_MTU_DISCOVER
if network == "udp6" {
level = unix.IPPROTO_IPV6
option = unix.IPV6_MTU_DISCOVER
}
err = s.Control(func(fd uintptr) {
err := unix.SetsockoptInt(int(fd), level, option, unix.IP_PMTUDISC_DO)
if err != nil {
logf("magicsock: dontfrag: SetsockoptInt failed: %v", err)
}
})
if err != nil {
logf("magicsock: dontfrag: control connection failed: %v", err)
}
logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String())
return
}
logf("magicsock: dontfrag: failed because it was not a UDPConn")
}
47 changes: 47 additions & 0 deletions wgengine/magicsock/magicsock_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package magicsock

import (
"net"

"golang.org/x/sys/windows"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
)

// https://github.com/tpn/winsdk-10/blob/9b69fd26ac0c7d0b83d378dba01080e93349c2ed/Include/10.0.16299.0/shared/ws2ipdef.h
const (
IP_MTU_DISCOVER = 71 // IPV6_MTU_DISCOVER has the same value, which is nice.
)

const (
IP_PMTUDISC_NOT_SET = iota
IP_PMTUDISC_DO
IP_PMTUDISC_DONT
IP_PMTUDISC_PROBE
IP_PMTUDISC_MAX
)

func tryPreventFragmentation(pconn nettype.PacketConn, logf logger.Logf, network string) {
if c, ok := pconn.(*net.UDPConn); ok {
s, err := c.SyscallConn()
if err != nil {
logf("magicsock: dontfrag: failed to get syscall conn: %v", err)
}
level := windows.IPPROTO_IP
if network == "udp6" {
level = windows.IPPROTO_IPV6
}
err = s.Control(func(fd uintptr) {
err := windows.SetsockoptInt(windows.Handle(fd), level, IP_MTU_DISCOVER, IP_PMTUDISC_DO)
if err != nil {
logf("magicsock: dontfrag: SetsockoptInt failed: %v", err)
}
})
if err != nil {
logf("magicsock: dontfrag: control connection failed: %v", err)
}
logf("magicsock: dontfrag: success on %s", pconn.LocalAddr().String())
return
}
logf("magicsock: dontfrag: failed because it was not a UDPConn")
}
Loading