Skip to content
Commits on Source (2)
  • Eugene Burkov's avatar
    Pull request 251: 324-refactor-bootstrap · 466aa934
    Eugene Burkov authored
    Merge in GO/dnsproxy from 324-refactor-bootstrap to master
    
    Updates #324.
    
    Squashed commit of the following:
    
    commit 9452c0b45a8602fd9371234c01ffda5009cf890f
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Wed Apr 12 14:50:04 2023 +0300
    
        fastest: imp docs
    
    commit 3d7e941c
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 20:11:59 2023 +0300
    
        bootstrap: fix test races
    
    commit 2aee12a9
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 19:58:42 2023 +0300
    
        bootstrap: fix logging
    
    commit 716c08a8
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 19:53:16 2023 +0300
    
        all: imp code, docs
    
    commit 76dd0bdd
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 16:28:47 2023 +0300
    
        upstream: clone req
    
    commit b5ed3874
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 16:20:30 2023 +0300
    
        fastip: fix helper upstream
    
    commit 15e8329b
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 16:09:19 2023 +0300
    
        upstream: imp code
    
    commit 9897feae
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 10 17:24:27 2023 +0300
    
        bootstrap: fix doc
    
    commit 3da08768
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 10 17:23:11 2023 +0300
    
        bootstrap: rm v6 tests
    
    commit 5c819b4a
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 10 17:06:35 2023 +0300
    
        all: imp code, add tests
    
    commit 442155aa
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 7 19:43:54 2023 +0300
    
        netutil: add doc
    
    commit 6475e227
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 7 19:42:20 2023 +0300
    
        all: imp docs
    
    commit 29732236
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 7 19:28:54 2023 +0300
    
        all: introduce bootstrap pkg
    466aa934
  • Ainar Garipov's avatar
    Pull request 252: AGDNS-1433-reuseport · 5864a879
    Ainar Garipov authored
    Merge in GO/dnsproxy from AGDNS-1433-reuseport to master
    
    Squashed commit of the following:
    
    commit f1415e5ec785c1a717c531aa265828d6524b37f4
    Merge: 4320c058 466aa934
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Thu Apr 13 14:05:24 2023 +0300
    
        Merge branch 'master' into AGDNS-1433-reuseport
    
    commit 4320c058
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Wed Apr 12 20:06:40 2023 +0300
    
        netutil: fix windows
    
    commit 71d1dbd1
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Wed Apr 12 19:59:39 2023 +0300
    
        proxy: add reuseport for tcp
    
    commit b796c92d
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Wed Apr 12 19:47:44 2023 +0300
    
        all: depr utils; imp logs; add reuseport
    5864a879
......@@ -78,14 +78,25 @@ func TestFastestAddr_ExchangeFastest(t *testing.T) {
}
type errUpstream struct {
upstream.Upstream
err error
err error
closeErr error
}
func (u errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
// Address implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Address() string {
return "bad_upstream"
}
// Exchange implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) {
return nil, u.err
}
// Close implements the [upstream.Upstream] interface for errUpstream.
func (u *errUpstream) Close() error {
return u.closeErr
}
type testAUpstream struct {
recs []*dns.A
}
......
// Package bootstrap provides types and functions to resolve upstream hostnames
// and to dial retrieved addresses.
package bootstrap
import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"time"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server. It establishes the connection to the server
// specified at initialization and ignores the addr.
type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// ResolveDialContext returns a DialHandler that uses addresses resolved from
// u using resolvers. u must not be nil.
//
// TODO(e.burkov): Use in the [upstream] package.
func ResolveDialContext(
u *url.URL,
timeout time.Duration,
resolvers []Resolver,
preferIPv6 bool,
) (h DialHandler, err error) {
defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }()
host, port, err := netutil.SplitHostPort(u.Host)
if err != nil {
// Don't wrap the error since it's informative enough as is and there is
// already deferred annotation here.
return nil, err
}
ctx := context.Background()
if timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
ips, err := LookupParallel(ctx, resolvers, host)
if err != nil {
return nil, fmt.Errorf("resolving hostname: %w", err)
}
proxynetutil.SortNetIPAddrs(ips, preferIPv6)
addrs := make([]string, 0, len(ips))
for _, ip := range ips {
if !ip.IsValid() {
// All invalid addresses should be in the tail after sorting.
break
}
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)).String())
}
return NewDialContext(timeout, addrs...), nil
}
// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection.
//
// TODO(e.burkov): Use in the [upstream] package.
func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
dialer := &net.Dialer{
Timeout: timeout,
}
l := len(addrs)
if l == 0 {
log.Debug("bootstrap: no addresses to dial")
return func(ctx context.Context, _, _ string) (conn net.Conn, err error) {
return nil, errors.Error("no addresses")
}
}
return func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
var errs []error
// Return first succeeded connection. Note that we're using addrs
// instead of what's passed to the function.
for i, addr := range addrs {
log.Debug("bootstrap: dialing %s (%d/%d)", addr, i+1, l)
start := time.Now()
conn, err := dialer.DialContext(ctx, network, addr)
elapsed := time.Since(start)
if err == nil {
log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed)
return conn, nil
}
log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err)
errs = append(errs, err)
}
// TODO(e.burkov): Use errors.Join in Go 1.20.
return nil, errors.List("all dialers failed", errs...)
}
}
package bootstrap_test
import (
"context"
"errors"
"net"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout used in tests of this package.
const testTimeout = 1 * time.Second
// newListener creates a new listener of zero address of the specified network
// type and returns it, adding it's closing to the test cleanup. sig is used to
// send the address of each accepted connection and must be read properly.
func newListener(t testing.TB, network string, sig chan net.Addr) (ipp netip.AddrPort) {
t.Helper()
// TODO(e.burkov): Listen IPv6 as well, when the CI adds IPv6 interfaces.
l, err := net.Listen(network, "127.0.0.1:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
go func() {
pt := testutil.PanicT{}
for c, lerr := l.Accept(); !errors.Is(lerr, net.ErrClosed); c, lerr = l.Accept() {
require.NoError(pt, lerr)
testutil.RequireSend(pt, sig, c.LocalAddr(), testTimeout)
require.NoError(pt, c.Close())
}
}()
ipp, err = netip.ParseAddrPort(l.Addr().String())
require.NoError(t, err)
return ipp
}
// See the details here: https://github.com/AdguardTeam/dnsproxy/issues/18
func TestResolveDialContext(t *testing.T) {
sig := make(chan net.Addr, 1)
ipp := newListener(t, "tcp", sig)
port := ipp.Port()
testCases := []struct {
name string
addresses []netip.Addr
preferIPv6 bool
}{{
name: "v4",
addresses: []netip.Addr{netutil.IPv4Localhost()},
preferIPv6: false,
}, {
name: "both_prefer_v6",
addresses: []netip.Addr{netutil.IPv4Localhost(), netutil.IPv6Localhost()},
preferIPv6: true,
}, {
name: "both_prefer_v4",
addresses: []netip.Addr{netutil.IPv6Localhost(), netutil.IPv4Localhost()},
preferIPv6: false,
}, {
name: "strip_invalid",
addresses: []netip.Addr{{}, netutil.IPv4Localhost(), {}, netutil.IPv6Localhost(), {}},
preferIPv6: true,
}}
const hostname = "host.name"
pt := testutil.PanicT{}
for _, tc := range testCases {
r := &testResolver{
onLookupNetIP: func(
_ context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, "ip", network)
require.Equal(pt, hostname, host)
return tc.addresses, nil
},
}
t.Run(tc.name, func(t *testing.T) {
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, int(port))},
testTimeout,
[]bootstrap.Resolver{r},
tc.preferIPv6,
)
require.NoError(t, err)
conn, err := dialContext(context.Background(), "tcp", "")
require.NoError(t, err)
expected, ok := testutil.RequireReceive(t, sig, testTimeout)
require.True(t, ok)
assert.Equal(t, expected.String(), conn.RemoteAddr().String())
})
}
t.Run("no_addresses", func(t *testing.T) {
r := &testResolver{
onLookupNetIP: func(
_ context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
require.Equal(pt, "ip", network)
require.Equal(pt, hostname, host)
return nil, nil
},
}
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, int(port))},
testTimeout,
[]bootstrap.Resolver{r},
false,
)
require.NoError(t, err)
_, err = dialContext(context.Background(), "tcp", "")
testutil.AssertErrorMsg(t, "no addresses", err)
})
t.Run("bad_hostname", func(t *testing.T) {
const errMsg = `dialing "bad hostname": address bad hostname: ` +
`missing port in address`
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: "bad hostname"},
testTimeout,
nil,
false,
)
testutil.AssertErrorMsg(t, errMsg, err)
assert.Nil(t, dialContext)
})
t.Run("no_resolvers", func(t *testing.T) {
dialContext, err := bootstrap.ResolveDialContext(
&url.URL{Host: netutil.JoinHostPort(hostname, int(port))},
testTimeout,
nil,
false,
)
assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
assert.Nil(t, dialContext)
})
}
package bootstrap
import (
"context"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// Resolver resolves the hostnames to IP addresses.
type Resolver interface {
// LookupIPAddr looks up the IP addresses for the given host. network must
// be one of "ip", "ip4" or "ip6".
LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error)
}
// type check
var _ Resolver = &net.Resolver{}
// ErrNoResolvers is returned when zero resolvers specified.
const ErrNoResolvers errors.Error = "no resolvers specified"
// LookupParallel performs lookup for IP address of host with all resolvers
// concurrently.
func LookupParallel(
ctx context.Context,
resolvers []Resolver,
host string,
) (addrs []netip.Addr, err error) {
resolversNum := len(resolvers)
switch resolversNum {
case 0:
return nil, ErrNoResolvers
case 1:
return lookup(ctx, resolvers[0], host)
default:
// Go on.
}
// Size of channel must accommodate results of lookups from all resolvers,
// sending into channel will be block otherwise.
ch := make(chan *lookupResult, resolversNum)
for _, res := range resolvers {
go lookupAsync(ctx, res, host, ch)
}
var errs []error
for n := 0; n < resolversNum; n++ {
result := <-ch
if result.err == nil {
return result.addrs, nil
}
errs = append(errs, result.err)
}
// TODO(e.burkov): Use [errors.Join] in Go 1.20.
return nil, errors.List("all resolvers failed", errs...)
}
// lookupResult is a structure that represents the result of a lookup.
type lookupResult struct {
err error
addrs []netip.Addr
}
// lookupAsync tries to lookup for ip of host with r and sends the result into
// resCh. It's inteneded to be used as a goroutine.
func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan *lookupResult) {
defer log.OnPanic("parallel lookup")
addrs, err := lookup(ctx, r, host)
resCh <- &lookupResult{
err: err,
addrs: addrs,
}
}
// lookup tries to lookup ip of host with r.
func lookup(ctx context.Context, r Resolver, host string) (addrs []netip.Addr, err error) {
start := time.Now()
addrs, err = r.LookupNetIP(ctx, "ip", host)
elapsed := time.Since(start)
if err != nil {
log.Debug("parallel lookup: lookup for %s failed in %s: %s", host, elapsed, err)
} else {
log.Debug("parallel lookup: lookup for %s succeeded in %s: %s", host, elapsed, addrs)
}
return addrs, err
}
package bootstrap_test
import (
"context"
"fmt"
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testResolver is the [Resolver] interface implementation for testing purposes.
type testResolver struct {
onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error)
}
// LookupNetIP implements the [Resolver] interface for *testResolver.
func (r *testResolver) LookupNetIP(
ctx context.Context,
network string,
host string,
) (addrs []netip.Addr, err error) {
return r.onLookupNetIP(ctx, network, host)
}
func TestLookupParallel(t *testing.T) {
const hostname = "host.name"
t.Run("no_resolvers", func(t *testing.T) {
addrs, err := bootstrap.LookupParallel(context.Background(), nil, "")
assert.ErrorIs(t, err, bootstrap.ErrNoResolvers)
assert.Nil(t, addrs)
})
pt := testutil.PanicT{}
hostAddrs := []netip.Addr{netutil.IPv4Localhost()}
immediate := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
require.Equal(pt, hostname, host)
require.Equal(pt, "ip", network)
return hostAddrs, nil
},
}
t.Run("one_resolver", func(t *testing.T) {
addrs, err := bootstrap.LookupParallel(
context.Background(),
[]bootstrap.Resolver{immediate},
hostname,
)
require.NoError(t, err)
assert.Equal(t, hostAddrs, addrs)
})
t.Run("two_resolvers", func(t *testing.T) {
delayCh := make(chan struct{}, 1)
delayed := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
require.Equal(pt, hostname, host)
require.Equal(pt, "ip", network)
testutil.RequireReceive(pt, delayCh, testTimeout)
return []netip.Addr{netutil.IPv6Localhost()}, nil
},
}
addrs, err := bootstrap.LookupParallel(
context.Background(),
[]bootstrap.Resolver{immediate, delayed},
hostname,
)
require.NoError(t, err)
testutil.RequireSend(t, delayCh, struct{}{}, testTimeout)
assert.Equal(t, hostAddrs, addrs)
})
t.Run("all_errors", func(t *testing.T) {
err := assert.AnError
wantErrMsg := fmt.Sprintf("all resolvers failed: 3 errors: %[1]q, %[1]q, %[1]q", err)
r := &testResolver{
onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, assert.AnError
},
}
addrs, err := bootstrap.LookupParallel(
context.Background(),
[]bootstrap.Resolver{r, r, r},
hostname,
)
testutil.AssertErrorMsg(t, wantErrMsg, err)
assert.Nil(t, addrs)
})
}
package netutil
import "net"
// ListenConfig returns the default [net.ListenConfig] used by the plain-DNS
// servers in this module.
//
// TODO(a.garipov): Add tests.
//
// TODO(a.garipov): DRY with AdGuard DNS when we can.
func ListenConfig() (lc *net.ListenConfig) {
return &net.ListenConfig{
Control: defaultListenControl,
}
}
//go:build unix
package netutil
import (
"fmt"
"syscall"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/sys/unix"
)
// defaultListenControl is used as a [net.ListenConfig.Control] function to set
// the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the
// DNS servers in this module.
func defaultListenControl(_, _ string, c syscall.RawConn) (err error) {
var opErr error
err = c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
if opErr != nil {
opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr)
return
}
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
if opErr != nil {
opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr)
}
})
return errors.WithDeferred(opErr, err)
}
//go:build windows
package netutil
import "syscall"
// defaultListenControl is nil on Windows, because it doesn't support
// SO_REUSEPORT.
var defaultListenControl func(_, _ string, _ syscall.RawConn) (_ error)
......@@ -7,6 +7,7 @@ package netutil
import (
"net"
"net/netip"
glnetutil "github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
......@@ -49,3 +50,30 @@ func SortIPAddrs(addrs []net.IPAddr, preferIPv6 bool) {
return a.Less(b)
})
}
// SortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end. Zones are ignored.
func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) {
l := len(addrs)
if l <= 1 {
return
}
slices.SortStableFunc(addrs, func(addrA, addrB netip.Addr) (sortsBefore bool) {
if !addrA.IsValid() {
return false
} else if !addrB.IsValid() {
return true
}
if aIs4, bIs4 := addrA.Is4(), addrB.Is4(); aIs4 != bIs4 {
if aIs4 {
return !preferIPv6
}
return preferIPv6
}
return addrA.Less(addrB)
})
}
package netutil
import "net"
// UDPGetOOBSize returns maximum size of the received OOB data.
func UDPGetOOBSize() (oobSize int) {
return udpGetOOBSize()
}
// UDPSetOptions sets flag options on a UDP socket to be able to receive the
// necessary OOB data.
func UDPSetOptions(c *net.UDPConn) (err error) {
return udpSetOptions(c)
}
// UDPRead reads the message from conn using buf and receives a control-message
// payload of size udpOOBSize from it. It returns the number of bytes copied
// into buf and the source address of the message.
func UDPRead(
conn *net.UDPConn,
buf []byte,
udpOOBSize int,
) (n int, localIP net.IP, remoteAddr *net.UDPAddr, err error) {
return udpRead(conn, buf, udpOOBSize)
}
// UDPWrite writes the data to the remoteAddr using conn.
func UDPWrite(
data []byte,
conn *net.UDPConn,
remoteAddr *net.UDPAddr,
localIP net.IP,
) (n int, err error) {
return udpWrite(data, conn, remoteAddr, localIP)
}
//go:build aix || dragonfly || linux || netbsd || openbsd || freebsd || solaris || darwin
// +build aix dragonfly linux netbsd openbsd freebsd solaris darwin
//go:build unix
package proxyutil
package netutil
import (
"fmt"
"net"
"github.com/AdguardTeam/golibs/mathutil"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// ipv*Flags is the set of socket option flags for configuring IPv* UDP
// These are the set of socket option flags for configuring an IPv[46] UDP
// connection to receive an appropriate OOB data. For both versions the flags
// are:
//
......@@ -26,11 +26,7 @@ const (
func udpGetOOBSize() (oobSize int) {
l4, l6 := len(ipv4.NewControlMessage(ipv4Flags)), len(ipv6.NewControlMessage(ipv6Flags))
if l4 >= l6 {
return l4
}
return l6
return mathutil.Max(l4, l6)
}
func udpSetOptions(c *net.UDPConn) (err error) {
......
//go:build windows
// +build windows
package proxyutil
package netutil
import "net"
import (
"net"
)
func udpGetOOBSize() int {
return 0
......
//go:build darwin
// +build darwin
package proxyutil
package netutil
import (
"net"
......@@ -9,16 +8,16 @@ import (
"golang.org/x/net/ipv6"
)
// udpMakeOOBWithSrc makes the OOB data with a specified source IP.
func udpMakeOOBWithSrc(ip net.IP) []byte {
// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip net.IP) (b []byte) {
if ip4 := ip.To4(); ip4 != nil {
// Do not set the IPv4 source address via OOB, because it can
// cause the address to become unspecified on darwin.
// Do not set the IPv4 source address via OOB, because it can cause the
// address to become unspecified on darwin.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2807.
//
// TODO(e.burkov): Develop a workaround to make it write OOB
// only when listening on an unspecified address.
// TODO(e.burkov): Develop a workaround to make it write OOB only when
// listening on an unspecified address.
return []byte{}
}
......
//go:build aix || dragonfly || linux || netbsd || openbsd || freebsd || solaris
// +build aix dragonfly linux netbsd openbsd freebsd solaris
//go:build !darwin
package proxyutil
package netutil
import (
"net"
......@@ -10,8 +9,8 @@ import (
"golang.org/x/net/ipv6"
)
// udpMakeOOBWithSrc makes the OOB data with a specified source IP.
func udpMakeOOBWithSrc(ip net.IP) []byte {
// udpMakeOOBWithSrc makes the OOB data with the specified source IP.
func udpMakeOOBWithSrc(ip net.IP) (b []byte) {
if ip4 := ip.To4(); ip4 != nil {
return (&ipv4.ControlMessage{
Src: ip,
......
......@@ -3,6 +3,7 @@
package proxy
import (
"context"
"fmt"
"io"
"net"
......@@ -13,7 +14,7 @@ import (
"time"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/dnsproxy/proxyutil"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
......@@ -142,7 +143,7 @@ func (p *Proxy) Init() (err error) {
p.initCache()
if p.MaxGoroutines > 0 {
log.Info("MaxGoroutines is set to %d", p.MaxGoroutines)
log.Info("dnsproxy: max goroutines is set to %d", p.MaxGoroutines)
p.requestGoroutinesSema, err = newChanSemaphore(p.MaxGoroutines)
if err != nil {
......@@ -152,7 +153,7 @@ func (p *Proxy) Init() (err error) {
p.requestGoroutinesSema = newNoopSemaphore()
}
p.udpOOBSize = proxyutil.UDPGetOOBSize()
p.udpOOBSize = proxynetutil.UDPGetOOBSize()
p.bytesPool = &sync.Pool{
New: func() interface{} {
// 2 bytes may be used to store packet length (see TCP/TLS)
......@@ -163,7 +164,8 @@ func (p *Proxy) Init() (err error) {
}
if p.UpstreamMode == UModeFastestAddr {
log.Printf("Fastest IP is enabled")
log.Info("dnsproxy: fastest ip is enabled")
p.fastestAddr = fastip.NewFastestAddr()
if timeout := p.FastestPingTimeout; timeout > 0 {
p.fastestAddr.PingWaitTimeout = timeout
......@@ -188,10 +190,11 @@ func (p *Proxy) Init() (err error) {
// Start initializes the proxy server and starts listening
func (p *Proxy) Start() (err error) {
log.Info("dnsproxy: starting dns proxy server")
p.Lock()
defer p.Unlock()
log.Info("Starting the DNS proxy server")
err = p.validateConfig()
if err != nil {
return err
......@@ -202,12 +205,15 @@ func (p *Proxy) Start() (err error) {
return err
}
err = p.startListeners()
// TODO(a.garipov): Accept a context into this method.
ctx := context.Background()
err = p.startListeners(ctx)
if err != nil {
return err
return fmt.Errorf("starting listeners: %w", err)
}
p.started = true
return nil
}
......@@ -225,12 +231,13 @@ func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
// Stop stops the proxy server including all its listeners
func (p *Proxy) Stop() error {
log.Info("Stopping the DNS proxy server")
log.Info("dnsproxy: stopping dns proxy server")
p.Lock()
defer p.Unlock()
if !p.started {
log.Info("The DNS proxy server is not started")
log.Info("dnsproxy: dns proxy server is not started")
return nil
}
......@@ -273,7 +280,9 @@ func (p *Proxy) Stop() error {
}
p.started = false
log.Println("Stopped the DNS proxy server")
log.Println("dnsproxy: stopped dns proxy server")
if len(errs) > 0 {
return errors.List("stopping dns proxy server", errs...)
}
......@@ -387,7 +396,7 @@ func (p *Proxy) needsLocalUpstream(req *dns.Msg) (ok bool) {
host := req.Question[0].Name
ip, err := netutil.IPFromReversedAddr(host)
if err != nil {
log.Debug("proxy: failed to parse ip from ptr request: %s", err)
log.Debug("dnsproxy: failed to parse ip from ptr request: %s", err)
return false
}
......@@ -580,7 +589,7 @@ func (dctx *DNSContext) processECS(cliIP net.IP) {
if ones, _ := ecs.Mask.Size(); ones != 0 {
dctx.ReqECS = ecs
log.Debug("passing through ecs: %s", dctx.ReqECS)
log.Debug("dnsproxy: passing through ecs: %s", dctx.ReqECS)
return
}
......@@ -599,7 +608,7 @@ func (dctx *DNSContext) processECS(cliIP net.IP) {
// Section 6.
dctx.ReqECS = setECS(dctx.Req, cliIP, 0)
log.Debug("setting ecs: %s", dctx.ReqECS)
log.Debug("dnsproxy: setting ecs: %s", dctx.ReqECS)
}
}
......
package proxy
import (
"context"
"fmt"
"net"
"time"
......@@ -11,13 +12,13 @@ import (
)
// startListeners configures and starts listener loops
func (p *Proxy) startListeners() error {
err := p.createUDPListeners()
func (p *Proxy) startListeners(ctx context.Context) error {
err := p.createUDPListeners(ctx)
if err != nil {
return err
}
err = p.createTCPListeners()
err = p.createTCPListeners(ctx)
if err != nil {
return err
}
......
package proxy
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"time"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
func (p *Proxy) createTCPListeners() (err error) {
func (p *Proxy) createTCPListeners(ctx context.Context) (err error) {
for _, a := range p.TCPListenAddr {
log.Printf("Creating a TCP server socket")
log.Info("dnsproxy: creating tcp server socket %s", a)
var tcpListen *net.TCPListener
tcpListen, err = net.ListenTCP("tcp", a)
lsnr, err := proxynetutil.ListenConfig().Listen(ctx, "tcp", a.String())
if err != nil {
return fmt.Errorf("listening to tcp socket: %w", err)
}
tcpListen := lsnr.(*net.TCPListener)
if err != nil {
return fmt.Errorf("starting listening on tcp socket: %w", err)
return fmt.Errorf("listening on tcp addr %s: %w", a, err)
}
p.tcpListen = append(p.tcpListen, tcpListen)
log.Printf("Listening to tcp://%s", tcpListen.Addr())
log.Info("dnsproxy: listening to tcp://%s", tcpListen.Addr())
}
return nil
......@@ -32,17 +39,18 @@ func (p *Proxy) createTCPListeners() (err error) {
func (p *Proxy) createTLSListeners() (err error) {
for _, a := range p.TLSListenAddr {
log.Printf("Creating a TLS server socket")
log.Info("dnsproxy: creating tls server socket %s", a)
var tcpListen *net.TCPListener
tcpListen, err = net.ListenTCP("tcp", a)
if err != nil {
return fmt.Errorf("starting tls listener: %w", err)
return fmt.Errorf("listening on tls addr %s: %w", a, err)
}
l := tls.NewListener(tcpListen, p.TLSConfig)
p.tlsListen = append(p.tlsListen, l)
log.Printf("Listening to tls://%s", l.Addr())
log.Info("dnsproxy: listening to tls://%s", l.Addr())
}
return nil
......@@ -53,15 +61,16 @@ func (p *Proxy) createTLSListeners() (err error) {
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, requestGoroutinesSema semaphore) {
log.Printf("Entering the %s listener loop on %s", proto, l.Addr())
log.Info("dnsproxy: entering %s listener loop on %s", proto, l.Addr())
for {
clientConn, err := l.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Debug("tcpPacketLoop: connection closed: %s", err)
log.Debug("dnsproxy: tcp connection %s closed", l.Addr())
} else {
log.Info("got error when reading from TCP listen: %s", err)
log.Error("dnsproxy: reading from tcp: %s", err)
}
break
......@@ -80,11 +89,12 @@ func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, requestGoroutinesSema
func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto) {
defer log.OnPanic("proxy.handleTCPConnection")
log.Tracef("handling tcp: started handling %s request from %s", proto, conn.RemoteAddr())
log.Debug("dnsproxy: handling new %s request from %s", proto, conn.RemoteAddr())
defer func() {
err := conn.Close()
if err != nil {
logWithNonCrit(err, "handling tcp: closing conn")
logWithNonCrit(err, "dnsproxy: handling tcp: closing conn")
}
}()
......@@ -111,7 +121,7 @@ func (p *Proxy) handleTCPConnection(conn net.Conn, proto Proto) {
req := &dns.Msg{}
err = req.Unpack(packet)
if err != nil {
log.Error("handling tcp: unpacking msg: %s", err)
log.Error("dnsproxy: handling tcp: unpacking msg: %s", err)
return
}
......
package proxy
import (
"context"
"fmt"
"net"
"github.com/AdguardTeam/dnsproxy/proxyutil"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
func (p *Proxy) createUDPListeners() error {
func (p *Proxy) createUDPListeners(ctx context.Context) (err error) {
for _, a := range p.UDPListenAddr {
udpListen, err := p.udpCreate(a)
var pc *net.UDPConn
pc, err := p.udpCreate(ctx, a)
if err != nil {
return err
return fmt.Errorf("listening on udp addr %s: %w", a, err)
}
p.udpListen = append(p.udpListen, udpListen)
p.udpListen = append(p.udpListen, pc)
}
return nil
}
// udpCreate - create a UDP listening socket
func (p *Proxy) udpCreate(udpAddr *net.UDPAddr) (*net.UDPConn, error) {
log.Info("Creating the UDP server socket")
udpListen, err := net.ListenUDP("udp", udpAddr)
func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPConn, error) {
log.Info("dnsproxy: creating udp server socket %s", udpAddr)
packetConn, err := proxynetutil.ListenConfig().ListenPacket(ctx, "udp", udpAddr.String())
if err != nil {
return nil, fmt.Errorf("listening to udp socket: %w", err)
}
udpListen := packetConn.(*net.UDPConn)
if p.Config.UDPBufferSize > 0 {
err = udpListen.SetReadBuffer(p.Config.UDPBufferSize)
if err != nil {
......@@ -39,14 +44,14 @@ func (p *Proxy) udpCreate(udpAddr *net.UDPAddr) (*net.UDPConn, error) {
}
}
err = proxyutil.UDPSetOptions(udpListen)
err = proxynetutil.UDPSetOptions(udpListen)
if err != nil {
_ = udpListen.Close()
return nil, fmt.Errorf("setting udp opts: %w", err)
}
log.Info("Listening to udp://%s", udpListen.LocalAddr())
log.Info("dnsproxy: listening to udp://%s", udpListen.LocalAddr())
return udpListen, nil
}
......@@ -55,7 +60,8 @@ func (p *Proxy) udpCreate(udpAddr *net.UDPAddr) (*net.UDPConn, error) {
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore) {
log.Info("Entering the UDP listener loop on %s", conn.LocalAddr())
log.Info("dnsproxy: entering udp listener loop on %s", conn.LocalAddr())
b := make([]byte, dns.MaxMsgSize)
for {
p.RLock()
......@@ -64,7 +70,7 @@ func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore
}
p.RUnlock()
n, localIP, remoteAddr, err := proxyutil.UDPRead(conn, b, p.udpOOBSize)
n, localIP, remoteAddr, err := proxynetutil.UDPRead(conn, b, p.udpOOBSize)
// documentation says to handle the packet even if err occurs, so do that first
if n > 0 {
// make a copy of all bytes because ReadFrom() will overwrite contents of b on next call
......@@ -79,9 +85,9 @@ func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore
}
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Debug("udpPacketLoop: connection closed")
log.Debug("dnsproxy: udp connection %s closed", conn.LocalAddr())
} else {
log.Error("got error when reading from UDP listen: %s", err)
log.Error("dnsproxy: reading from udp: %s", err)
}
break
......@@ -91,12 +97,12 @@ func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema semaphore
// udpHandlePacket processes the incoming UDP packet and sends a DNS response
func (p *Proxy) udpHandlePacket(packet []byte, localIP net.IP, remoteAddr *net.UDPAddr, conn *net.UDPConn) {
log.Tracef("Start handling new UDP packet from %s", remoteAddr)
log.Debug("dnsproxy: handling new udp packet from %s", remoteAddr)
req := &dns.Msg{}
err := req.Unpack(packet)
if err != nil {
log.Error("unpacking udp packet: %s", err)
log.Error("dnsproxy: unpacking udp packet: %s", err)
return
}
......@@ -108,7 +114,7 @@ func (p *Proxy) udpHandlePacket(packet []byte, localIP net.IP, remoteAddr *net.U
err = p.handleDNSRequest(d)
if err != nil {
log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
log.Debug("dnsproxy: handling dns (proto %s) request: %s", d.Proto, err)
}
}
......@@ -128,7 +134,7 @@ func (p *Proxy) respondUDP(d *DNSContext) error {
conn := d.Conn.(*net.UDPConn)
rAddr := d.Addr.(*net.UDPAddr)
n, err := proxyutil.UDPWrite(bytes, conn, rAddr, d.localIP)
n, err := proxynetutil.UDPWrite(bytes, conn, rAddr, d.localIP)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
......
package proxyutil
import "net"
import (
"net"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
)
// UDPGetOOBSize returns maximum size of the received OOB data.
//
// Deprecated: This function is deprecated. Packages in module dnsproxy should
// use internal/netutil.UDPGetOOBSize instead.
func UDPGetOOBSize() (oobSize int) {
return udpGetOOBSize()
return proxynetutil.UDPGetOOBSize()
}
// UDPSetOptions sets flag options on a UDP socket to be able to receive the
// necessary OOB data.
//
// Deprecated: This function is deprecated. Packages in module dnsproxy should
// use internal/netutil.UDPSetOptions instead.
func UDPSetOptions(c *net.UDPConn) (err error) {
return udpSetOptions(c)
return proxynetutil.UDPSetOptions(c)
}
// UDPRead udpRead reads the message from c using buf receives payload of size
// udpOOBSize from the UDP socket. It returns the number of bytes copied into
// buf, the number of bytes copied with OOB and the source address of the
// message.
//
// Deprecated: This function is deprecated. Packages in module dnsproxy should
// use internal/netutil.UDPRead instead.
func UDPRead(
c *net.UDPConn,
buf []byte,
udpOOBSize int,
) (n int, localIP net.IP, remoteAddr *net.UDPAddr, err error) {
return udpRead(c, buf, udpOOBSize)
return proxynetutil.UDPRead(c, buf, udpOOBSize)
}
// UDPWrite writes the data to the remoteAddr using conn.
//
// Deprecated: This function is deprecated. Packages in module dnsproxy should
// use internal/netutil.UDPWrite instead.
func UDPWrite(
data []byte,
conn *net.UDPConn,
remoteAddr *net.UDPAddr,
localIP net.IP,
) (n int, err error) {
return udpWrite(data, conn, remoteAddr, localIP)
return proxynetutil.UDPWrite(data, conn, remoteAddr, localIP)
}
......@@ -43,7 +43,7 @@ type bootstrapper struct {
// resolvers is a list of *net.Resolver to use to resolve the upstream
// hostname, if necessary.
resolvers []*Resolver
resolvers []Resolver
// dialContext is the dial function for creating unencrypted TCP
// connections.
......@@ -100,11 +100,11 @@ func newBootstrapperResolved(upsURL *url.URL, options *Options) (*bootstrapper,
// resolver address string (i.e. tls://one.one.one.one:853), options is the
// upstream configuration options.
func newBootstrapper(u *url.URL, options *Options) (b *bootstrapper, err error) {
resolvers := []*Resolver{}
resolvers := []Resolver{}
if len(options.Bootstrap) != 0 {
// Create a list of resolvers for parallel lookup
// Create a list of resolvers for parallel lookup.
for _, boot := range options.Bootstrap {
var r *Resolver
var r Resolver
r, err = NewResolver(boot, options)
if err != nil {
return nil, err
......@@ -202,15 +202,13 @@ func (n *bootstrapper) get() (*tls.Config, dialHandler, error) {
return nil, nil, fmt.Errorf("lookup %s: %w", host, err)
}
proxynetutil.SortIPAddrs(addrs, n.options.PreferIPv6)
proxynetutil.SortNetIPAddrs(addrs, n.options.PreferIPv6)
resolved := []string{}
resolved := make([]string, 0, len(addrs))
for _, addr := range addrs {
if addr.IP.To4() == nil && addr.IP.To16() == nil {
continue
if addr.IsValid() {
resolved = append(resolved, net.JoinHostPort(addr.String(), port))
}
resolved = append(resolved, net.JoinHostPort(addr.String(), port))
}
if len(resolved) == 0 {
......