Skip to content
Commits on Source (4)
  • Andrey Meshkov's avatar
    Pull request: increase MaxConnsPerHost, fix #278 · b14cd4bf
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-278 to master
    
    Squashed commit of the following:
    
    commit 19e0286c
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Sun Dec 4 20:17:01 2022 +0300
    
        increase MaxConnsPerHost, fix #278
    b14cd4bf
  • Ainar Garipov's avatar
    Pull request: 5248-imp-cache · f096b3d0
    Ainar Garipov authored
    Updates AdguardTeam/AdGuardHome#5248.
    
    Squashed commit of the following:
    
    commit 46255f5220ea71e67dd87c5f2cfec6f9d5c6add3
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Fri Dec 9 18:12:48 2022 +0300
    
        proxy: imp logs
    
    commit 6bc8867acf12847b1e6f316c6a6e85e54aedadb9
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Fri Dec 9 17:35:20 2022 +0300
    
        all: upd go; imp logs; typos
    
    commit b3c86258688658aa39fe36d970bcb76264d62567
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Fri Dec 9 16:10:36 2022 +0300
    
        proxy: imp cache locking
    f096b3d0
  • Ainar Garipov's avatar
    Pull request: mv-fastip-to-netip-addr · 0ce51f57
    Ainar Garipov authored
    Merge in DNS/dnsproxy from mv-fastip-to-netip-addr to master
    
    Squashed commit of the following:
    
    commit e74715a22505e3e192e7d021c34ede802926bf23
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Tue Dec 13 18:02:35 2022 +0300
    
        fastip: imp docs, tests
    
    commit 31fe75affae136d9a3d45a2fc27dc3cda5cdba83
    Author: Ainar Garipov <A.Garipov@AdGuard.COM>
    Date:   Tue Dec 13 17:46:09 2022 +0300
    
        fastip: mv to netip.Addr
    0ce51f57
  • Eugene Burkov's avatar
    Pull request: 5251-close-ups · dc6b8960
    Eugene Burkov authored
    Merge in DNS/dnsproxy from 5251-close-ups to master
    
    Updates AdguardTeam/AdGuardHome#5251.
    
    Squashed commit of the following:
    
    commit c4f01d3ff8f4dd3733124a6545ac3e6aa2e2f457
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Wed Dec 14 17:57:19 2022 +0300
    
        upstream: enhance filtered errs
    
    commit ce33519d844e2979d554d1001be1680f6299df08
    Merge: 175680b8 0ce51f57
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Wed Dec 14 13:29:27 2022 +0300
    
        Merge branch 'master' into 5251-close-ups
    
    commit 175680b8
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Dec 13 15:06:01 2022 +0300
    
        upstream: fix doc
    
    commit 9ecfc0b2
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Dec 13 14:59:04 2022 +0300
    
        all: imp docs
    
    commit b275d566
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Dec 13 02:11:13 2022 +0300
    
        upstream: imp dot
    
    commit 0ed84f4c
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 20:40:35 2022 +0300
    
        fastip: fix golangci-lint issues
    
    commit c5ed13b2
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 20:15:45 2022 +0300
    
        all: fix staticcheck issues
    
    commit 6c8b28cc
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 20:05:24 2022 +0300
    
        upstream: imp more
    
    commit e4a2374d
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 19:01:57 2022 +0300
    
        upstream: upd golibs, imp code
    
    commit cae16109
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 16:27:21 2022 +0300
    
        upstream: filter dot errs
    
    commit 51005abc
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Dec 12 14:10:23 2022 +0300
    
        upstream: use sync pool
    
    commit 81e907c5
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Sat Dec 10 03:37:21 2022 +0300
    
        upstream: imp dot
    dc6b8960
......@@ -5,7 +5,7 @@ plan:
key: DNSPROXYSPECS
name: dnsproxy - Build and run tests
variables:
dockerGo: adguard/golang-ubuntu:5.0
dockerGo: adguard/golang-ubuntu:5.4
dockerLint: golangci/golangci-lint:v1.50.0
stages:
......
......@@ -2,7 +2,7 @@ package fastip
import (
"encoding/binary"
"net"
"net/netip"
"time"
)
......@@ -18,11 +18,12 @@ type cacheEntry struct {
latencyMsec uint
}
// packCacheEntry - packss cache entry + ttl to bytes
// packCacheEntry packs the cache entry and the TTL to bytes in the following
// order:
//
// expire [4]byte
// status byte
// latency_msec [2]byte
// expire [4]byte (Unix time, seconds)
// status byte (0 for ok, 1 for timed out)
// latency [2]byte (milliseconds)
func packCacheEntry(ent *cacheEntry, ttl uint32) []byte {
expire := uint32(time.Now().Unix()) + ttl
......@@ -61,9 +62,8 @@ func unpackCacheEntry(data []byte) *cacheEntry {
// cacheFind - find entry in the cache for this IP
// returns null if nothing found or if the record for this ip is expired
func (f *FastestAddr) cacheFind(ip net.IP) *cacheEntry {
k := getCacheKey(ip)
val := f.ipCache.Get(k)
func (f *FastestAddr) cacheFind(ip netip.Addr) *cacheEntry {
val := f.ipCache.Get(ip.AsSlice())
if val == nil {
return nil
}
......@@ -75,7 +75,7 @@ func (f *FastestAddr) cacheFind(ip net.IP) *cacheEntry {
}
// cacheAddFailure - store unsuccessful attempt in cache
func (f *FastestAddr) cacheAddFailure(addr net.IP) {
func (f *FastestAddr) cacheAddFailure(ip netip.Addr) {
ent := cacheEntry{
status: 1,
}
......@@ -83,14 +83,14 @@ func (f *FastestAddr) cacheAddFailure(addr net.IP) {
f.ipCacheLock.Lock()
defer f.ipCacheLock.Unlock()
if f.cacheFind(addr) == nil {
f.cacheAdd(&ent, addr, fastestAddrCacheTTLSec)
if f.cacheFind(ip) == nil {
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
}
}
// store a successful ping result in cache
// replace previous result if our latency is lower
func (f *FastestAddr) cacheAddSuccessful(addr net.IP, latency uint) {
func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) {
ent := cacheEntry{
latencyMsec: latency,
}
......@@ -98,24 +98,14 @@ func (f *FastestAddr) cacheAddSuccessful(addr net.IP, latency uint) {
f.ipCacheLock.Lock()
defer f.ipCacheLock.Unlock()
entCached := f.cacheFind(addr)
entCached := f.cacheFind(ip)
if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency {
f.cacheAdd(&ent, addr, fastestAddrCacheTTLSec)
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
}
}
// cacheAdd -- adds a new entry to the cache
func (f *FastestAddr) cacheAdd(ent *cacheEntry, addr net.IP, ttl uint32) {
ip := getCacheKey(addr)
func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) {
val := packCacheEntry(ent, ttl)
f.ipCache.Set(ip, val)
}
// getCacheKey - gets cache key (compresses ipv4 to 4 bytes)
func getCacheKey(addr net.IP) net.IP {
ip := addr.To4()
if ip == nil {
ip = addr
}
return ip
f.ipCache.Set(ip.AsSlice(), val)
}
......@@ -2,6 +2,7 @@ package fastip
import (
"net"
"net/netip"
"testing"
"time"
......@@ -14,7 +15,8 @@ func TestCacheAdd(t *testing.T) {
status: 0,
latencyMsec: 111,
}
ip := net.ParseIP("1.1.1.1")
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec)
// check that it's there
......@@ -27,7 +29,8 @@ func TestCacheTtl(t *testing.T) {
status: 0,
latencyMsec: 111,
}
ip := net.ParseIP("1.1.1.1")
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAdd(&ent, ip, 1)
// check that it's there
......@@ -42,8 +45,8 @@ func TestCacheTtl(t *testing.T) {
func TestCacheAddSuccessfulOverwrite(t *testing.T) {
f := NewFastestAddr()
ip := net.ParseIP("1.1.1.1")
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAddFailure(ip)
// check that it's there
......@@ -63,8 +66,8 @@ func TestCacheAddSuccessfulOverwrite(t *testing.T) {
func TestCacheAddFailureNoOverwrite(t *testing.T) {
f := NewFastestAddr()
ip := net.ParseIP("1.1.1.1")
ip := netip.MustParseAddr("1.1.1.1")
f.cacheAddSuccessful(ip, 11)
// check that it's there
......@@ -88,12 +91,13 @@ func TestCache(t *testing.T) {
status: 0,
latencyMsec: 111,
}
// f.cacheAdd(&ent, net.ParseIP("1.1.1.1"), fastestAddrCacheMinTTLSec)
val := packCacheEntry(&ent, 1) // ttl=1
val := packCacheEntry(&ent, 1)
f.ipCache.Set(net.ParseIP("1.1.1.1").To4(), val)
ent = cacheEntry{
status: 0,
latencyMsec: 222,
}
f.cacheAdd(&ent, net.ParseIP("2.2.2.2"), fastestAddrCacheTTLSec)
f.cacheAdd(&ent, netip.MustParseAddr("2.2.2.2"), fastestAddrCacheTTLSec)
}
// Package fastip implements the algorithm that allows to
// query multiple resolvers, ping all IP addresses that were returned,
// and return the fastest one among them.
package fastip
// Package fastip implements the algorithm that allows to query multiple
// resolvers, ping all IP addresses that were returned, and return the fastest
// one among them.
package fastip
import (
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
)
// DefaultPingWaitTimeout is the default period of time for waiting ping
......@@ -53,7 +58,7 @@ func NewFastestAddr() (f *FastestAddr) {
}
// ExchangeFastest queries each specified upstream and returns a response with
// the fastest IP address. The fastest IP address is cosidered to be the first
// the fastest IP address. The fastest IP address is considered to be the first
// one successfully dialed and other addresses are removed from the answer.
func (f *FastestAddr) ExchangeFastest(req *dns.Msg, ups []upstream.Upstream) (
resp *dns.Msg,
......@@ -67,16 +72,17 @@ func (f *FastestAddr) ExchangeFastest(req *dns.Msg, ups []upstream.Upstream) (
host := strings.ToLower(req.Question[0].Name)
ips := make([]net.IP, 0, len(replies))
ipSet := map[netip.Addr]struct{}{}
for _, r := range replies {
for _, rr := range r.Resp.Answer {
ip := proxyutil.IPFromRR(rr)
if ip != nil && !containsIP(ips, ip) {
ips = append(ips, ip)
ip := ipFromRR(rr)
if _, ok := ipSet[ip]; !ok && ip != (netip.Addr{}) {
ipSet[ip] = struct{}{}
}
}
}
ips := maps.Keys(ipSet)
if pingRes := f.pingAll(host, ips); pingRes != nil {
return f.prepareReply(pingRes, replies)
}
......@@ -86,25 +92,23 @@ func (f *FastestAddr) ExchangeFastest(req *dns.Msg, ups []upstream.Upstream) (
return replies[0].Resp, replies[0].Upstream, nil
}
// prepareReply converts replies into the DNS answer message accoding to
// pingRes. The returned upstreams is the one which replied with the fastest
// address.
func (f *FastestAddr) prepareReply(pingRes *pingResult, replies []upstream.ExchangeAllResult) (
m *dns.Msg,
u upstream.Upstream,
err error,
) {
ip := pingRes.ipp.IP
// prepareReply converts replies into the DNS answer message according to res.
// The returned upstreams is the one which replied with the fastest address.
func (f *FastestAddr) prepareReply(
res *pingResult,
replies []upstream.ExchangeAllResult,
) (resp *dns.Msg, u upstream.Upstream, err error) {
ip := res.addrPort.Addr()
for _, r := range replies {
if hasInAns(r.Resp, ip) {
m = r.Resp
resp = r.Resp
u = r.Upstream
break
}
}
if m == nil {
if resp == nil {
log.Error("found no replies with IP %s, most likely this is a bug", ip)
return replies[0].Resp, replies[0].Upstream, nil
......@@ -112,35 +116,34 @@ func (f *FastestAddr) prepareReply(pingRes *pingResult, replies []upstream.Excha
// Modify the message and keep only A and AAAA records containing the
// fastest IP address.
ans := make([]dns.RR, 0, len(m.Answer))
for _, rr := range m.Answer {
ans := make([]dns.RR, 0, len(resp.Answer))
ipBytes := ip.AsSlice()
for _, rr := range resp.Answer {
switch addr := rr.(type) {
case *dns.A:
if ip.Equal(addr.A.To4()) {
if addr.A.Equal(ipBytes) {
ans = append(ans, rr)
}
case *dns.AAAA:
if ip.Equal(addr.AAAA) {
if addr.AAAA.Equal(ipBytes) {
ans = append(ans, rr)
}
default:
ans = append(ans, rr)
}
}
// Set new answer.
m.Answer = ans
resp.Answer = ans
return m, u, nil
return resp, u, nil
}
// hasInAns returns true if m contains ip in its Answer section.
func hasInAns(m *dns.Msg, ip net.IP) (ok bool) {
func hasInAns(m *dns.Msg, ip netip.Addr) (ok bool) {
for _, rr := range m.Answer {
respIP := proxyutil.IPFromRR(rr)
if respIP != nil && respIP.Equal(ip) {
respIP := ipFromRR(rr)
if respIP == ip {
return true
}
}
......@@ -148,17 +151,14 @@ func hasInAns(m *dns.Msg, ip net.IP) (ok bool) {
return false
}
// containsIP returns true if ips contains the ip.
func containsIP(ips []net.IP, ip net.IP) (ok bool) {
if len(ips) == 0 {
return false
// ipFromRR returns the IP address from rr if any.
func ipFromRR(rr dns.RR) (ip netip.Addr) {
switch rr := rr.(type) {
case *dns.A:
ip, _ = netutil.IPToAddr(rr.A, netutil.AddrFamilyIPv4)
case *dns.AAAA:
ip, _ = netutil.IPToAddr(rr.AAAA, netutil.AddrFamilyIPv6)
}
for _, i := range ips {
if i.Equal(ip) {
return true
}
}
return false
return ip
}
......@@ -2,6 +2,7 @@ package fastip
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/dnsproxy/upstream"
......@@ -29,7 +30,7 @@ func TestFastestAddr_ExchangeFastest(t *testing.T) {
})
t.Run("one_dead", func(t *testing.T) {
port := listen(t, nil)
port := listen(t, netip.IPv4Unspecified())
f := NewFastestAddr()
f.pingPorts = []uint{port}
......
package fastip
import (
"net"
"net/netip"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)
// pingTCPTimeout is a TCP connection timeout. It's higher than pingWaitTimeout
......@@ -14,8 +13,8 @@ const pingTCPTimeout = 4 * time.Second
// pingResult is the result of dialing the address.
type pingResult struct {
// ipp is the address-port pair the result is related to.
ipp netutil.IPPort
// addrPort is the address-port pair the result is related to.
addrPort netip.AddrPort
// latency is the duration of dialing process in milliseconds.
latency uint
// success is true when the dialing succeeded.
......@@ -24,15 +23,15 @@ type pingResult struct {
// pingAll pings all ips concurrently and returns as soon as the fastest one is
// found or the timeout is exceeded.
func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) {
ipN := len(ips)
switch ipN {
case 0:
return nil
case 1:
return &pingResult{
ipp: netutil.IPPort{IP: ips[0]},
success: true,
addrPort: netip.AddrPortFrom(ips[0], 0),
success: true,
}
}
......@@ -45,7 +44,7 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
cached := f.cacheFind(ip)
if cached == nil {
for _, port := range f.pingPorts {
go f.pingDoTCP(host, netutil.IPPort{IP: ip, Port: int(port)}, resCh)
go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh)
}
scheduled += portN
......@@ -56,9 +55,9 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
if pr == nil || cached.latencyMsec < pr.latency {
pr = &pingResult{
ipp: netutil.IPPort{IP: ip},
latency: cached.latencyMsec,
success: true,
addrPort: netip.AddrPortFrom(ip, 0),
latency: cached.latencyMsec,
success: true,
}
}
}
......@@ -66,7 +65,7 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
cached := pr != nil
if scheduled == 0 {
if cached {
log.Debug("pingAll: %s: return cached response: %s", host, pr.ipp)
log.Debug("pingAll: %s: return cached response: %s", host, pr.addrPort)
} else {
log.Debug("pingAll: %s: returning nothing", host)
}
......@@ -81,7 +80,7 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
log.Debug(
"pingAll: %s: got result for %s status %v",
host,
res.ipp,
res.addrPort,
res.success,
)
if !res.success {
......@@ -98,7 +97,7 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
log.Debug(
"pingAll: %s: pinging timed out, returning cached: %s",
host,
pr.ipp,
pr.addrPort,
)
} else {
log.Debug(
......@@ -115,12 +114,11 @@ func (f *FastestAddr) pingAll(host string, ips []net.IP) (pr *pingResult) {
}
// pingDoTCP sends the result of dialing the specified address into resCh.
func (f *FastestAddr) pingDoTCP(host string, ipp netutil.IPPort, resCh chan *pingResult) {
log.Debug("pingDoTCP: %s: connecting to %s", host, ipp)
addr := ipp.String()
func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) {
log.Debug("pingDoTCP: %s: connecting to %s", host, addrPort)
start := time.Now()
conn, err := f.pinger.Dial("tcp", addr)
conn, err := f.pinger.Dial("tcp", addrPort.String())
elapsed := time.Since(start)
success := err == nil
......@@ -133,22 +131,23 @@ func (f *FastestAddr) pingDoTCP(host string, ipp netutil.IPPort, resCh chan *pin
latency := uint(elapsed.Milliseconds())
resCh <- &pingResult{
ipp: ipp,
latency: latency,
success: success,
addrPort: addrPort,
latency: latency,
success: success,
}
addr := addrPort.Addr().Unmap()
if success {
log.Debug("pingDoTCP: %s: elapsed %s ms on %s", host, elapsed, ipp)
f.cacheAddSuccessful(ipp.IP, latency)
log.Debug("pingDoTCP: %s: elapsed %s ms on %s", host, elapsed, addrPort)
f.cacheAddSuccessful(addr, latency)
} else {
log.Debug(
"pingDoTCP: %s: failed to connect to %s, elapsed %s ms: %v",
host,
ipp,
addrPort,
elapsed,
err,
)
f.cacheAddFailure(ipp.IP)
f.cacheAddFailure(addr)
}
}
......@@ -2,6 +2,7 @@ package fastip
import (
"net"
"net/netip"
"runtime"
"sync"
"syscall"
......@@ -28,8 +29,8 @@ func TestFastestAddr_PingAll_timeout(t *testing.T) {
return nil
}
ip := net.IP{127, 0, 0, 1}
res := f.pingAll("", []net.IP{ip, ip})
ip := netutil.IPv4Localhost()
res := f.pingAll("", []netip.Addr{ip, ip})
require.Nil(t, res)
waitCh <- unit{}
......@@ -40,7 +41,8 @@ func TestFastestAddr_PingAll_timeout(t *testing.T) {
const lat uint = 42
ip1, ip2 := net.IP{127, 0, 0, 1}, net.IP{127, 0, 0, 2}
ip1 := netutil.IPv4Localhost()
ip2 := netip.MustParseAddr("127.0.0.2")
f.cacheAddSuccessful(ip1, lat)
waitCh := make(chan unit)
......@@ -50,7 +52,7 @@ func TestFastestAddr_PingAll_timeout(t *testing.T) {
return nil
}
res := f.pingAll("", []net.IP{ip1, ip2})
res := f.pingAll("", []netip.Addr{ip1, ip2})
require.NotNil(t, res)
assert.True(t, res.success)
......@@ -62,7 +64,7 @@ func TestFastestAddr_PingAll_timeout(t *testing.T) {
// assertCaching checks the cache of f for containing a connection to ip with
// the specified status.
func assertCaching(t *testing.T, f *FastestAddr, ip net.IP, status int) {
func assertCaching(t *testing.T, f *FastestAddr, ip netip.Addr, status int) {
t.Helper()
const tickDur = pingTCPTimeout / 16
......@@ -75,23 +77,23 @@ func assertCaching(t *testing.T, f *FastestAddr, ip net.IP, status int) {
}
func TestFastestAddr_PingAll_cache(t *testing.T) {
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
t.Run("cached_failed", func(t *testing.T) {
f := NewFastestAddr()
f.cacheAddFailure(ip)
res := f.pingAll("", []net.IP{ip, ip})
res := f.pingAll("", []netip.Addr{ip, ip})
require.Nil(t, res)
})
t.Run("cached_succesfull", func(t *testing.T) {
t.Run("cached_successful", func(t *testing.T) {
const lat uint = 1
f := NewFastestAddr()
f.cacheAddSuccessful(ip, lat)
res := f.pingAll("", []net.IP{ip, ip})
res := f.pingAll("", []netip.Addr{ip, ip})
require.NotNil(t, res)
assert.True(t, res.success)
assert.Equal(t, lat, res.latency)
......@@ -102,11 +104,11 @@ func TestFastestAddr_PingAll_cache(t *testing.T) {
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, listener.Close)
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
f := NewFastestAddr()
f.pingPorts = []uint{uint(listener.Addr().(*net.TCPAddr).Port)}
ips := []net.IP{ip, ip}
ips := []netip.Addr{ip, ip}
wg := &sync.WaitGroup{}
wg.Add(len(ips) * len(f.pingPorts))
......@@ -134,10 +136,10 @@ func TestFastestAddr_PingAll_cache(t *testing.T) {
}
// listen is a helper function that creates a new listener on ip for t.
func listen(t *testing.T, ip net.IP) (port uint) {
func listen(t *testing.T, ip netip.Addr) (port uint) {
t.Helper()
l, err := net.Listen("tcp", netutil.IPPort{IP: ip, Port: 0}.String())
l, err := net.Listen("tcp", netip.AddrPortFrom(ip, 0).String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
......@@ -145,20 +147,20 @@ func listen(t *testing.T, ip net.IP) (port uint) {
}
func TestFastestAddr_PingAll(t *testing.T) {
ip := net.IP{127, 0, 0, 1}
ip := netutil.IPv4Localhost()
t.Run("single", func(t *testing.T) {
f := NewFastestAddr()
res := f.pingAll("", []net.IP{ip})
res := f.pingAll("", []netip.Addr{ip})
require.NotNil(t, res)
assert.True(t, res.success)
assert.True(t, ip.Equal(res.ipp.IP))
assert.Equal(t, ip, res.addrPort.Addr())
// There was no ping so the port is zero.
assert.Zero(t, res.ipp.Port)
assert.Zero(t, res.addrPort.Port())
// Nothing in the cache since there was no ping.
ce := f.cacheFind(res.ipp.IP)
ce := f.cacheFind(res.addrPort.Addr())
require.Nil(t, ce)
})
......@@ -174,11 +176,9 @@ func TestFastestAddr_PingAll(t *testing.T) {
slowPort,
}
f.pinger.Control = func(_, address string, _ syscall.RawConn) error {
ipp, err := netutil.ParseIPPort(address)
require.NoError(t, err)
require.Contains(t, []uint{fastPort, slowPort}, uint(ipp.Port))
if ipp.Port == int(fastPort) {
addrPort := netip.MustParseAddrPort(address)
require.Contains(t, []uint{fastPort, slowPort}, uint(addrPort.Port()))
if addrPort.Port() == uint16(fastPort) {
return nil
}
......@@ -187,16 +187,15 @@ func TestFastestAddr_PingAll(t *testing.T) {
return nil
}
ips := []net.IP{ip, ip}
ips := []netip.Addr{ip, ip}
res := f.pingAll("", ips)
ctrlCh <- unit{}
require.NotNil(t, res)
assert.True(t, res.success)
assert.True(t, ip.Equal(res.ipp.IP))
assert.EqualValues(t, fastPort, res.ipp.Port)
assert.Equal(t, ip, res.addrPort.Addr())
assert.EqualValues(t, fastPort, res.addrPort.Port())
assertCaching(t, f, ip, 0)
})
......@@ -212,7 +211,7 @@ func TestFastestAddr_PingAll(t *testing.T) {
f := NewFastestAddr()
f.pingPorts = []uint{port}
res := f.pingAll("test", []net.IP{ip, ip})
res := f.pingAll("test", []netip.Addr{ip, ip})
require.Nil(t, res)
assertCaching(t, f, ip, 1)
......
......@@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy
go 1.18
require (
github.com/AdguardTeam/golibs v0.10.9
github.com/AdguardTeam/golibs v0.11.2
github.com/ameshkov/dnscrypt/v2 v2.2.5
github.com/ameshkov/dnsstamps v1.0.3
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0
......@@ -13,7 +13,9 @@ require (
github.com/miekg/dns v1.1.50
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/stretchr/testify v1.8.0
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b
golang.org/x/exp v0.0.0-20221019170559-20944726eadf
golang.org/x/net v0.1.0
golang.org/x/sys v0.1.1-0.20221102194838-fc697a31fa06
gopkg.in/yaml.v3 v3.0.1
)
......@@ -32,10 +34,8 @@ require (
github.com/onsi/ginkgo/v2 v2.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/sys v0.1.1-0.20221102194838-fc697a31fa06 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/text v0.4.0 // indirect
golang.org/x/tools v0.1.12 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
)
github.com/AdguardTeam/golibs v0.10.9 h1:F9oP2da0dQ9RQDM1lGR7LxUTfUWu8hEFOs4icwAkKM0=
github.com/AdguardTeam/golibs v0.10.9/go.mod h1:W+5rznZa1cSNSFt+gPS7f4Wytnr9fOrd5ZYqwadPw14=
github.com/AdguardTeam/golibs v0.11.2 h1:JbQB1Dg2JWStXgHh1QqBbOLWnP4t9oDjppoBH6TVXSE=
github.com/AdguardTeam/golibs v0.11.2/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw=
......@@ -64,8 +64,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20221019170559-20944726eadf h1:nFVjjKDgNY37+ZSYCJmtYf7tOlfQswHqplG2eosjOMg=
golang.org/x/exp v0.0.0-20221019170559-20944726eadf/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
......@@ -73,8 +73,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
......@@ -93,8 +93,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
......
......@@ -30,8 +30,6 @@ type cache struct {
// itemsWithSubnetLock protects requests cache.
itemsWithSubnetLock sync.RWMutex
// cacheSize is the size of a key-value pair of cache.
cacheSize int
// optimistic defines if the cache should return expired items and resolve
// those again.
optimistic bool
......@@ -162,22 +160,30 @@ func (c *cache) unpackItem(data []byte, req *dns.Msg) (ci *cacheItem, expired bo
// initCache initializes cache if it's enabled.
func (p *Proxy) initCache() {
if !p.CacheEnabled {
log.Info("dnsproxy: cache: disabled")
return
}
log.Printf("DNS cache is enabled")
size := p.CacheSizeBytes
log.Info("dnsproxy: cache: enabled, size %d b", size)
p.cache = newCache(size, p.EnableEDNSClientSubnet, p.CacheOptimistic)
p.shortFlighter = newOptimisticResolver(p)
}
c := &cache{
optimistic: p.CacheOptimistic,
cacheSize: p.CacheSizeBytes,
// newCache returns a properly initialized cache.
func newCache(size int, withECS, optimistic bool) (c *cache) {
c = &cache{
items: createCache(size),
optimistic: optimistic,
}
p.cache = c
c.initLazy()
if p.EnableEDNSClientSubnet {
c.initLazyWithSubnet()
if withECS {
c.itemsWithSubnet = createCache(size)
}
p.shortFlighter = newOptimisticResolver(p)
return c
}
// get returns cached item for the req if it's found. expired is true if the
......@@ -241,35 +247,15 @@ func canLookUpInCache(cache glcache.Cache, req *dns.Msg) (ok bool) {
return cache != nil && req != nil && len(req.Question) == 1
}
// initLazy initializes the cache for general requests.
func (c *cache) initLazy() {
c.itemsLock.Lock()
defer c.itemsLock.Unlock()
if c.items == nil {
c.items = c.createCache()
}
}
// initLazyWithSubnet initializes the cache for requests with subnets.
func (c *cache) initLazyWithSubnet() {
c.itemsWithSubnetLock.Lock()
defer c.itemsWithSubnetLock.Unlock()
if c.itemsWithSubnet == nil {
c.itemsWithSubnet = c.createCache()
}
}
// createCache returns new Cache with predefined settings.
func (c *cache) createCache() (glc glcache.Cache) {
// createCache returns new Cache with the given cacheSize.
func createCache(cacheSize int) (glc glcache.Cache) {
conf := glcache.Config{
MaxSize: defaultCacheSize,
EnableLRU: true,
}
if c.cacheSize > 0 {
conf.MaxSize = uint(c.cacheSize)
if cacheSize > 0 {
conf.MaxSize = uint(cacheSize)
}
return glcache.New(conf)
......@@ -282,8 +268,6 @@ func (c *cache) set(m *dns.Msg, u upstream.Upstream) {
return
}
c.initLazy()
key := msgToKey(m)
packed := item.pack()
......@@ -301,8 +285,6 @@ func (c *cache) setWithSubnet(m *dns.Msg, u upstream.Upstream, subnet *net.IPNet
return
}
c.initLazyWithSubnet()
pref, _ := subnet.Mask.Size()
key := msgToKeyWithSubnet(m, subnet.IP, pref)
packed := item.pack()
......@@ -318,19 +300,20 @@ func (c *cache) clearItems() {
c.itemsLock.Lock()
defer c.itemsLock.Unlock()
if c.items != nil {
c.items.Clear()
}
c.items.Clear()
}
// clearItemsWithSubnet empties the subnet cache.
// clearItemsWithSubnet empties the subnet cache, if any.
func (c *cache) clearItemsWithSubnet() {
if c.itemsWithSubnet == nil {
// ECS disabled, return immediately.
return
}
c.itemsWithSubnetLock.Lock()
defer c.itemsWithSubnetLock.Unlock()
if c.itemsWithSubnet != nil {
c.itemsWithSubnet.Clear()
}
c.itemsWithSubnet.Clear()
}
// cacheTTL returns the number of seconds for which m is valid to be cached.
......@@ -344,16 +327,18 @@ func cacheTTL(m *dns.Msg) (ttl uint32) {
case m == nil:
return 0
case m.Truncated:
log.Tracef("refusing to cache truncated message")
log.Debug("dnsproxy: cache: truncated message; not caching")
return 0
case len(m.Question) != 1:
log.Tracef("refusing to cache message with wrong number of questions")
log.Debug("dnsproxy: cache: message with wrong number of questions; not caching")
return 0
default:
ttl = calculateTTL(m)
if ttl == 0 {
log.Debug("dnsproxy: cache: ttl calculated to be 0; not caching")
return 0
}
}
......@@ -363,18 +348,18 @@ func cacheTTL(m *dns.Msg) (ttl uint32) {
if isCacheableSucceded(m) {
return ttl
}
log.Debug("dnsproxy: cache: not a cacheable noerror response; not caching")
case dns.RcodeNameError:
if isCacheableNegative(m) {
return ttl
}
log.Debug("dnsproxy: cache: not a cacheable nxdomain response; not caching")
case dns.RcodeServerFailure:
return ttl
default:
log.Tracef(
"%s: refusing to cache message with response code %s",
m.Question[0].Name,
dns.RcodeToString[rcode],
)
log.Debug("dnsproxy: cache: response code %s; not caching", dns.RcodeToString[rcode])
}
return 0
......@@ -393,7 +378,7 @@ func hasIPAns(m *dns.Msg) (ok bool) {
}
// isCacheableSucceded returns true if m contains useful data to be cached
// treating it as a succeesful response.
// treating it as a successful response.
func isCacheableSucceded(m *dns.Msg) (ok bool) {
qType := m.Question[0].Qtype
......
......@@ -9,7 +9,6 @@ import (
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
glcache "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
......@@ -18,6 +17,9 @@ import (
"github.com/miekg/dns"
)
// testCacheSize is the maximum size of cache for tests.
const testCacheSize = 4096
func newRR(t *testing.T, rr string) (r dns.RR) {
t.Helper()
......@@ -28,18 +30,6 @@ func newRR(t *testing.T, rr string) (r dns.RR) {
return r
}
func TestCacheSanity(t *testing.T) {
testCache := &cache{}
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
ci, expired, key := testCache.get(request)
assert.Nil(t, ci)
assert.False(t, expired)
assert.Nil(t, key)
}
const testUpsAddr = "https://upstream.address"
var upstreamWithAddr = &funcUpstream{
......@@ -127,12 +117,7 @@ func TestCache_expired(t *testing.T) {
optimistic: true,
}}
testCache := &cache{
items: glcache.New(glcache.Config{
MaxSize: defaultCacheSize,
EnableLRU: true,
}),
}
testCache := newCache(testCacheSize, false, false)
for _, tc := range testCases {
ans.Hdr.Ttl = tc.ttl
req := (&dns.Msg{}).SetQuestion(host, dns.TypeA)
......@@ -169,7 +154,7 @@ func TestCache_expired(t *testing.T) {
}
func TestCacheDO(t *testing.T) {
testCache := &cache{}
testCache := newCache(testCacheSize, false, false)
// Fill the cache.
reply := (&dns.Msg{
......@@ -212,7 +197,7 @@ func TestCacheDO(t *testing.T) {
}
func TestCacheCNAME(t *testing.T) {
testCache := &cache{}
testCache := newCache(testCacheSize, false, false)
// Fill the cache
reply := (&dns.Msg{
......@@ -227,10 +212,9 @@ func TestCacheCNAME(t *testing.T) {
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
t.Run("no_cnames", func(t *testing.T) {
r, expired, key := testCache.get(request)
assert.False(t, expired)
assert.Nil(t, key)
r, expired, _ := testCache.get(request)
assert.Nil(t, r)
assert.False(t, expired)
})
// Now fill the cache with a cacheable CNAME response.
......@@ -250,7 +234,7 @@ func TestCacheCNAME(t *testing.T) {
}
func TestCache_uncacheable(t *testing.T) {
testCache := &cache{}
testCache := newCache(testCacheSize, false, false)
// Create a DNS request.
request := (&dns.Msg{}).SetQuestion("google.com.", dns.TypeA)
......@@ -260,14 +244,13 @@ func TestCache_uncacheable(t *testing.T) {
// We are testing that SERVFAIL responses aren't cached
testCache.set(reply, upstreamWithAddr)
r, expired, key := testCache.get(request)
assert.False(t, expired)
assert.Nil(t, key)
r, expired, _ := testCache.get(request)
assert.Nil(t, r)
assert.False(t, expired)
}
func TestCache_concurrent(t *testing.T) {
testCache := &cache{}
testCache := newCache(testCacheSize, false, false)
hosts := map[string]string{
dns.Fqdn("yandex.com"): "213.180.204.62",
......@@ -399,23 +382,41 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) {
})
}
type testEntry struct {
q string
a []dns.RR
t uint16
}
type testCase struct {
ok require.BoolAssertionFunc
q string
a []dns.RR
t uint16
}
type testCases struct {
cache []testEntry
cases []testCase
}
func TestCache(t *testing.T) {
t.Run("simple", func(t *testing.T) {
testCases{
cache: []testEntry{{
q: "google.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.True,
q: "google.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
ok: require.True,
t: dns.TypeA,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
ok: require.False,
}},
}.run(t)
})
......@@ -424,36 +425,36 @@ func TestCache(t *testing.T) {
testCases{
cache: []testEntry{{
q: "gOOgle.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.True,
q: "gOOgle.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
ok: require.True,
t: dns.TypeA,
}, {
ok: require.True,
q: "google.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
ok: require.True,
t: dns.TypeA,
}, {
ok: require.True,
q: "GOOGLE.COM.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 3600 IN A 8.8.8.8")},
ok: require.True,
t: dns.TypeA,
}, {
q: "gOOgle.com.",
t: dns.TypeMX,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
ok: require.False,
}, {
ok: require.False,
q: "GOOGLE.COM.",
t: dns.TypeMX,
ok: require.False,
}},
}.run(t)
})
......@@ -462,40 +463,40 @@ func TestCache(t *testing.T) {
testCases{
cache: []testEntry{{
q: "gOOgle.com.",
t: dns.TypeA,
a: []dns.RR{newRR(t, "google.com. 0 IN A 8.8.8.8")},
t: dns.TypeA,
}},
cases: []testCase{{
ok: require.False,
q: "google.com.",
t: dns.TypeA,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeA,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeA,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
ok: require.False,
}, {
ok: require.False,
q: "google.com.",
t: dns.TypeMX,
ok: require.False,
}},
}.run(t)
})
}
func (tests testCases) run(t *testing.T) {
testCache := &cache{}
testCache := newCache(testCacheSize, false, false)
for _, res := range tests.cache {
reply := (&dns.Msg{
......@@ -533,24 +534,6 @@ func (tests testCases) run(t *testing.T) {
}
}
type testCases struct {
cache []testEntry
cases []testCase
}
type testEntry struct {
q string
t uint16
a []dns.RR
}
type testCase struct {
q string
t uint16
a []dns.RR
ok require.BoolAssertionFunc
}
// requireEqualMsgs asserts the messages are equal except their ID, Rdlength, and
// the case of questions.
func requireEqualMsgs(t *testing.T, expected, actual *dns.Msg) {
......@@ -610,18 +593,17 @@ func setAndGetCache(t *testing.T, c *cache, g *sync.WaitGroup, host, ip string)
}
func TestSubnet(t *testing.T) {
c := &cache{}
c := newCache(testCacheSize, true, false)
ip1234, ip2234, ip3234 := net.IP{1, 2, 3, 4}, net.IP{2, 2, 3, 4}, net.IP{3, 2, 3, 4}
req := (&dns.Msg{}).SetQuestion("example.com.", dns.TypeA)
t.Run("empty", func(t *testing.T) {
ci, expired, key := c.getWithSubnet(req, &net.IPNet{
ci, expired, _ := c.getWithSubnet(req, &net.IPNet{
IP: ip1234,
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
})
assert.False(t, expired)
assert.Nil(t, key)
assert.Nil(t, ci)
assert.False(t, expired)
})
// Add a response with subnet.
......
......@@ -185,7 +185,7 @@ func (p *Proxy) validateConfig() error {
return nil
}
// validateListenAddrs returns an error if the addressses are not configured
// validateListenAddrs returns an error if the addresses are not configured
// properly.
func (p *Proxy) validateListenAddrs() error {
if !p.hasListenAddrs() {
......
......@@ -134,7 +134,8 @@ type Proxy struct {
Config // proxy configuration
}
// Init - initializes the proxy structures but does not start it
// Init populates fields of p but does not start it. Init must be called before
// calling Start.
func (p *Proxy) Init() (err error) {
p.initCache()
......@@ -448,56 +449,77 @@ const defaultUDPBufSize = 2048
// Resolve is the default resolving method used by the DNS proxy to query
// upstream servers.
func (p *Proxy) Resolve(d *DNSContext) (err error) {
func (p *Proxy) Resolve(dctx *DNSContext) (err error) {
if p.EnableEDNSClientSubnet {
d.processECS(p.EDNSAddr)
dctx.processECS(p.EDNSAddr)
}
d.calcFlagsAndSize()
dctx.calcFlagsAndSize()
// Use cache only if it's enabled and the query doesn't use custom upstream.
// Also don't lookup the cache for responses with DNSSEC checking disabled
// since only validated responses are cached and those may be not the
// desired result for user specifying CD flag.
cacheWorks := p.cache != nil && d.CustomUpstreamConfig == nil && !d.Req.CheckingDisabled
cacheWorks := p.cacheWorks(dctx)
if cacheWorks {
if p.replyFromCache(d) {
if p.replyFromCache(dctx) {
// Complete the response from cache.
d.scrub()
dctx.scrub()
return nil
}
// On cache miss request for DNSSEC from the upstream to cache it
// afterwards.
addDO(d.Req)
addDO(dctx.Req)
}
var ok bool
ok, err = p.replyFromUpstream(d)
ok, err = p.replyFromUpstream(dctx)
// Don't cache the responses having CD flag, just like Dnsmasq does. It
// prevents the cache from being poisoned with unvalidated answers which may
// differ from validated ones.
//
// See https://github.com/imp/dnsmasq/blob/770bce967cfc9967273d0acfb3ea018fb7b17522/src/forward.c#L1169-L1172.
if cacheWorks && ok && !d.Res.CheckingDisabled {
if cacheWorks && ok && !dctx.Res.CheckingDisabled {
// Cache the response with DNSSEC RRs.
p.cacheResp(d)
p.cacheResp(dctx)
}
filterMsg(d.Res, d.Res, d.adBit, d.doBit, 0)
filterMsg(dctx.Res, dctx.Res, dctx.adBit, dctx.doBit, 0)
// Complete the response.
d.scrub()
dctx.scrub()
if p.ResponseHandler != nil {
p.ResponseHandler(d, err)
p.ResponseHandler(dctx, err)
}
return err
}
// cacheWorks returns true if the cache works for the given context. If not, it
// returns false and logs the reason why.
func (p *Proxy) cacheWorks(dctx *DNSContext) (ok bool) {
var reason string
switch {
case p.cache == nil:
reason = "disabled"
case dctx.CustomUpstreamConfig != nil:
// See https://github.com/AdguardTeam/dnsproxy/issues/169.
reason = "custom upstreams used"
case dctx.Req.CheckingDisabled:
reason = "dnssec check disabled"
default:
return true
}
log.Debug("dnsproxy: cache: %s; not caching", reason)
return false
}
// processECS adds EDNS Client Subnet data into the request from d.
func (dctx *DNSContext) processECS(cliIP net.IP) {
if ecs, _ := ecsFromMsg(dctx.Req); ecs != nil {
......
......@@ -33,7 +33,7 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) {
d.Res = ci.m
d.CachedUpstreamAddr = ci.u
log.Debug(hitMsg)
log.Debug("dnsproxy: cache: %s", hitMsg)
if p.cache.optimistic && expired {
// Build a reduced clone of the current context to avoid data race.
......@@ -75,7 +75,7 @@ func (p *Proxy) cacheResp(d *DNSContext) {
// TODO(a.meshkov): The whole response MUST be dropped if ECS in it
// doesn't correspond.
if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes {
log.Debug("invalid response: ecs %s mismatches requested %s", ecs, d.ReqECS)
log.Debug("dnsproxy: cache: bad response: ecs %s does not match %s", ecs, d.ReqECS)
return
}
......@@ -89,7 +89,8 @@ func (p *Proxy) cacheResp(d *DNSContext) {
ecs.Mask = net.CIDRMask(scope, bits)
ecs.IP = ecs.IP.Mask(ecs.Mask)
}
log.Debug("ecs option in response: %s", ecs)
log.Debug("dnsproxy: cache: ecs option in response: %s", ecs)
p.cache.setWithSubnet(d.Res, d.Upstream, ecs)
case d.ReqECS != nil:
......@@ -103,6 +104,9 @@ func (p *Proxy) cacheResp(d *DNSContext) {
// ClearCache clears the DNS cache of p.
func (p *Proxy) ClearCache() {
p.cache.clearItems()
p.cache.clearItemsWithSubnet()
if p.cache != nil {
p.cache.clearItems()
p.cache.clearItemsWithSubnet()
log.Debug("dnsproxy: cache: cleared")
}
}
......@@ -86,7 +86,6 @@ func (p *Proxy) quicPacketLoop(l quic.EarlyListener, requestGoroutinesSema semap
log.Info("Entering the DNS-over-QUIC listener loop on %s", l.Addr())
for {
conn, err := l.Accept(context.Background())
if err != nil {
if isQUICNonCrit(err) {
log.Tracef("quic connection closed or timed out: %s", err)
......@@ -345,15 +344,12 @@ func isQUICNonCrit(err error) (ok bool) {
return true
}
// This error is returned when we're trying to accept a new stream from a
// connection that had no activity for over than the keep-alive timeout.
// This is a common scenario, no need for extra logs.
var qIdleErr *quic.IdleTimeoutError
if errors.As(err, &qIdleErr) {
// This error is returned when we're trying to accept a new stream from
// a connection that had no activity for over than the keep-alive
// timeout. This is a common scenario, no need for extra logs.
return true
}
return false
return errors.As(err, &qIdleErr)
}
// closeQUICConn quietly closes the QUIC connection.
......
......@@ -26,8 +26,12 @@ type Upstream interface {
// Exchange sends the DNS query m to this upstream and returns the response
// that has been received or an error if something went wrong.
Exchange(m *dns.Msg) (*dns.Msg, error)
// Address returns the address of the upstream DNS resolver.
Address() string
// Closer used to close the upstreams properly. Exchange shouldn't be
// called after calling Close.
io.Closer
}
......
......@@ -32,12 +32,13 @@ const (
transportDefaultIdleConnTimeout = 5 * time.Minute
// dohMaxConnsPerHost controls the maximum number of connections for
// each host.
dohMaxConnsPerHost = 1
// each host. Note, that setting it to 1 may cause issues with Go's http
// implementation, see https://github.com/AdguardTeam/dnsproxy/issues/278.
dohMaxConnsPerHost = 2
// dohMaxIdleConns controls the maximum number of connections being idle
// at the same time.
dohMaxIdleConns = 1
dohMaxIdleConns = 2
)
// dnsOverHTTPS is a struct that implements the Upstream interface for the
......@@ -592,7 +593,7 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err
ch <- nil
elapsed := time.Now().Sub(startTime)
elapsed := time.Since(startTime)
log.Debug("elapsed on establishing a QUIC connection: %s", elapsed)
}
......@@ -612,7 +613,7 @@ func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config,
ch <- nil
elapsed := time.Now().Sub(startTime)
elapsed := time.Since(startTime)
log.Debug("elapsed on establishing a TLS connection: %s", elapsed)
}
......
package upstream
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// dialTimeout is the global timeout for establishing a TLS connection.
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second
// dnsOverTLS is a struct that implements the Upstream interface for the
// DNS-over-TLS protocol.
type dnsOverTLS struct {
boot *bootstrapper
pool *TLSPool
poolMu sync.Mutex
// boot resolves the hostname upstream addresses.
boot *bootstrapper
// connsMu protects conns.
connsMu sync.Mutex
// conns stores the connections ready for reuse. Don't use [sync.Pool]
// here, since there is no need to deallocate these connections.
//
// TODO(e.burkov, ameshkov): Currently connections just stored in FILO
// order, which eventually makes most of them unusable due to timeouts.
// This leads to weak performance for all exchanges coming across such
// connections.
conns []net.Conn
}
// type check
var _ Upstream = (*dnsOverTLS)(nil)
// newDoT returns the DNS-over-TLS Upstream.
func newDoT(uu *url.URL, opts *Options) (u Upstream, err error) {
addPort(uu, defaultPortDoT)
func newDoT(u *url.URL, opts *Options) (ups Upstream, err error) {
addPort(u, defaultPortDoT)
var b *bootstrapper
b, err = urlToBoot(uu, opts)
boot, err := urlToBoot(u, opts)
if err != nil {
return nil, fmt.Errorf("creating tls bootstrapper: %w", err)
}
u = &dnsOverTLS{boot: b}
ups = &dnsOverTLS{
boot: boot,
}
runtime.SetFinalizer(u, (*dnsOverTLS).Close)
runtime.SetFinalizer(ups, (*dnsOverTLS).Close)
return u, nil
return ups, nil
}
// Address implements the Upstream interface for *dnsOverTLS.
// Address implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Address() string { return p.boot.URL.String() }
// Exchange implements the Upstream interface for *dnsOverTLS.
// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
pool := p.getPool()
poolConn, err := pool.Get()
conn, err := p.conn()
if err != nil {
return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err)
return nil, fmt.Errorf("getting conn to %s: %w", p.Address(), err)
}
logBegin(p.Address(), m)
reply, err = p.exchangeConn(poolConn, m)
logFinish(p.Address(), err)
reply, err = p.exchangeWithConn(conn, m)
if err != nil {
log.Tracef("The TLS connection is expired due to %s", err)
// The pooled connection might have been closed already, see
// https://github.com/AdguardTeam/dnsproxy/issues/3. The following
// connection from pool may also be malformed, so dial a new one.
// The pooled connection might have been closed already (see https://github.com/AdguardTeam/dnsproxy/issues/3)
// So we're trying to re-connect right away here.
// We are forcing creation of a new connection instead of calling Get() again
// as there's no guarantee that other pooled connections are intact
poolConn, err = pool.Create()
err = errors.WithDeferred(err, conn.Close())
log.Debug("dot upstream: bad conn from pool: %s", err)
// Retry.
conn, err = p.dial()
if err != nil {
return nil, fmt.Errorf("creating new connection to %s: %w", p.Address(), err)
return nil, fmt.Errorf("dialing conn to %s: %w", p.Address(), err)
}
// Retry sending the DNS request
logBegin(p.Address(), m)
reply, err = p.exchangeConn(poolConn, m)
logFinish(p.Address(), err)
reply, err = p.exchangeWithConn(conn, m)
if err != nil {
return reply, errors.WithDeferred(err, conn.Close())
}
}
if err == nil {
pool.Put(poolConn)
}
return reply, err
p.putBack(conn)
return reply, nil
}
// Close implements the Upstream interface for *dnsOverTLS.
// Close implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Close() (err error) {
p.poolMu.Lock()
defer p.poolMu.Unlock()
runtime.SetFinalizer(p, nil)
if p.pool == nil {
return nil
p.connsMu.Lock()
defer p.connsMu.Unlock()
var closeErrs []error
for _, conn := range p.conns {
closeErr := conn.Close()
if closeErr != nil && isCriticalTCP(closeErr) {
closeErrs = append(closeErrs, closeErr)
}
}
if len(closeErrs) > 0 {
return errors.List("closing tls conns", closeErrs...)
}
return p.pool.Close()
return nil
}
func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
// conn returns the first available connection from the pool if there is any, or
// dials a new one otherwise.
func (p *dnsOverTLS) conn() (conn net.Conn, err error) {
// Dial a new connection outside the lock, if needed.
defer func() {
if err == nil {
return
}
if cerr := conn.Close(); cerr != nil {
err = &errors.Pair{Returned: err, Deferred: cerr}
if conn == nil {
conn, err = p.dial()
}
}()
p.connsMu.Lock()
defer p.connsMu.Unlock()
l := len(p.conns)
if l == 0 {
return nil, nil
}
p.conns, conn = p.conns[:l-1], p.conns[l-1]
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
log.Debug("dot upstream: setting deadline to conn from pool: %s", err)
// If deadLine can't be updated it means that connection was already
// closed.
return nil, nil
}
log.Debug("dot upstream: using existing conn %s", conn.RemoteAddr())
return conn, nil
}
func (p *dnsOverTLS) putBack(conn net.Conn) {
p.connsMu.Lock()
defer p.connsMu.Unlock()
p.conns = append(p.conns, conn)
}
// exchangeWithConn tries to exchange the query using conn.
func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) {
addr := p.Address()
logBegin(addr, m)
defer func() { logFinish(addr, err) }()
dnsConn := dns.Conn{Conn: conn}
err = dnsConn.WriteMsg(m)
if err != nil {
return nil, fmt.Errorf("sending request to %s: %w", p.Address(), err)
return nil, fmt.Errorf("sending request to %s: %w", addr, err)
}
reply, err = dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err)
return nil, fmt.Errorf("reading response from %s: %w", addr, err)
} else if reply.Id != m.Id {
err = dns.ErrId
return reply, dns.ErrId
}
return reply, err
}
func (p *dnsOverTLS) getPool() (pool *TLSPool) {
p.poolMu.Lock()
defer p.poolMu.Unlock()
// dial dials a new connection that may be stored in pool.
func (p *dnsOverTLS) dial() (conn net.Conn, err error) {
tlsConfig, dialContext, err := p.boot.get()
if err != nil {
return nil, err
}
conn, err = tlsDial(dialContext, "tcp", tlsConfig)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err)
}
return conn, nil
}
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed to the
// function.
rawConn, err := dialContext(context.Background(), network, "")
if err != nil {
return nil, err
}
if p.pool == nil {
p.pool = &TLSPool{boot: p.boot}
// We want the timeout to cover the whole process: TCP connection and
// TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
panic(fmt.Errorf("dnsproxy: tls dial: setting deadline: %w", err))
}
return p.pool
err = conn.Handshake()
if err != nil {
return nil, errors.WithDeferred(err, conn.Close())
}
return conn, nil
}
// isCriticalTCP returns true if err isn't an expected error in terms of closing
// the TCP connection.
func isCriticalTCP(err error) (ok bool) {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return false
}
switch {
case
errors.Is(err, io.EOF),
errors.Is(err, net.ErrClosed),
errors.Is(err, os.ErrDeadlineExceeded),
isConnBroken(err):
return false
default:
return true
}
}
package upstream
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// dialTimeout is the global timeout for establishing a TLS connection.
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second
// TLSPool is a connections pool for the DNS-over-TLS Upstream.
//
// Example:
//
// pool := TLSPool{Address: "tls://1.1.1.1:853"}
// netConn, err := pool.Get()
// if err != nil {panic(err)}
// c := dns.Conn{Conn: netConn}
// q := dns.Msg{}
// q.SetQuestion("google.com.", dns.TypeA)
// log.Println(q)
// err = c.WriteMsg(&q)
// if err != nil {panic(err)}
// r, err := c.ReadMsg()
// if err != nil {panic(err)}
// log.Println(r)
// pool.Put(c.Conn)
type TLSPool struct {
boot *bootstrapper
// conns is the list of connections available in the pool.
conns []net.Conn
connsMu sync.Mutex
}
// type check
var _ io.Closer = (*TLSPool)(nil)
// Get gets a connection from the pool (if there's one available) or creates
// a new TLS connection.
func (n *TLSPool) Get() (conn net.Conn, err error) {
// Get the connection from the slice inside the lock.
n.connsMu.Lock()
num := len(n.conns)
if num > 0 {
last := num - 1
conn = n.conns[last]
n.conns = n.conns[:last]
}
n.connsMu.Unlock()
// If we got connection from the slice, update deadline and return it.
if conn != nil {
err = conn.SetDeadline(time.Now().Add(dialTimeout))
// If deadLine can't be updated it means that connection was already closed
if err == nil {
log.Tracef(
"Returning existing connection to %s with updated deadLine",
conn.RemoteAddr(),
)
return conn, nil
}
}
return n.Create()
}
// Create creates a new connection for the pool (but not puts it there).
func (n *TLSPool) Create() (conn net.Conn, err error) {
tlsConfig, dialContext, err := n.boot.get()
if err != nil {
return nil, err
}
conn, err = tlsDial(dialContext, "tcp", tlsConfig)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err)
}
return conn, nil
}
// Put returns the connection to the pool.
func (n *TLSPool) Put(conn net.Conn) {
if conn == nil {
return
}
n.connsMu.Lock()
defer n.connsMu.Unlock()
n.conns = append(n.conns, conn)
}
// Close implements io.Closer for *TLSPool.
func (n *TLSPool) Close() (err error) {
n.connsMu.Lock()
defer n.connsMu.Unlock()
var closeErrs []error
for _, c := range n.conns {
cErr := c.Close()
if cErr != nil {
closeErrs = append(closeErrs, cErr)
}
}
if len(closeErrs) > 0 {
return errors.List("failed to close some connections", closeErrs...)
}
return nil
}
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
// We're using bootstrapped address instead of what's passed
// to the function.
rawConn, err := dialContext(context.Background(), network, "")
if err != nil {
return nil, err
}
// We want the timeout to cover the whole process: TCP connection and
// TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
panic(fmt.Errorf("cannot set deadline: %w", err))
}
err = conn.Handshake()
if err != nil {
return nil, errors.WithDeferred(err, conn.Close())
}
return conn, nil
}