Skip to content
Commits on Source (1)
  • Eugene Burkov's avatar
    Pull request 253: 324 refactor upstream · 1e3a23f1
    Eugene Burkov authored
    Merge in GO/dnsproxy from 324-refactor-upstream to master
    
    Closes #324.
    
    Squashed commit of the following:
    
    commit cceed0b0d49988db3bb58f3c64ca26e5a7ef03b7
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 18 15:48:15 2023 +0300
    
        proxy: fix check
    
    commit 3b0a231436afc51fa1a5baa2dcb3613517348dc2
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 18 13:55:34 2023 +0300
    
        all: imp tests again
    
    commit eba1be66
    Merge: c4e73af9 5864a879
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 17 19:20:29 2023 +0300
    
        Merge branch 'master' into 324-refactor-upstream
    
    commit c4e73af9
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 17 15:00:58 2023 +0300
    
        upstream: fix tests
    
    commit 1bc8d368
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 17 14:46:16 2023 +0300
    
        all: imp code, docs
    
    commit c9ec31f7
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 14 19:48:33 2023 +0300
    
        upstream: imp logging, field alignment
    
    commit 418df4c9
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 14 19:08:15 2023 +0300
    
        upstream: imp code
    
    commit 92df4534
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 14 18:45:54 2023 +0300
    
        upstream: bootstrap plain dns
    
    commit 44c9e386
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 14 14:07:10 2023 +0300
    
        upstream: fix tests, imp docs
    
    commit 4c0959d3
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 14 12:30:00 2023 +0300
    
        upstream: imp code, docs
    
    commit 33819170
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Wed Apr 12 21:27:55 2023 +0300
    
        upstream: refactor doq
    
    commit 0bd11e10
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Wed Apr 12 21:01:30 2023 +0300
    
        upstream: refactor doh
    
    commit 29ce52f7
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 19:34:37 2023 +0300
    
        all: refactor dot
    
    commit 6c9d87e9
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 01:34:41 2023 +0300
    
        upstream: refactor dnscrypt, imp code
    
    commit f1117298
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Apr 11 19:53:16 2023 +0300
    
        all: imp code, docs
    
    commit 726d9377
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Mon Apr 10 17:06:35 2023 +0300
    
        all: imp code, add tests
    
    commit 550f1ab5
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 7 19:42:20 2023 +0300
    
        all: imp docs
    
    commit 66986633
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Fri Apr 7 19:28:54 2023 +0300
    
        all: introduce bootstrap pkg
    1e3a23f1
......@@ -23,8 +23,6 @@ type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn,
// ResolveDialContext returns a DialHandler that uses addresses resolved from
// u using resolvers. u must not be nil.
//
// TODO(e.burkov): Use in the [upstream] package.
func ResolveDialContext(
u *url.URL,
timeout time.Duration,
......@@ -68,9 +66,7 @@ func ResolveDialContext(
}
// NewDialContext returns a DialHandler that dials addrs and returns the first
// successful connection.
//
// TODO(e.burkov): Use in the [upstream] package.
// successful connection. At least a single addr should be specified.
func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
dialer := &net.Dialer{
Timeout: timeout,
......@@ -80,7 +76,7 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
if l == 0 {
log.Debug("bootstrap: no addresses to dial")
return func(ctx context.Context, _, _ string) (conn net.Conn, err error) {
return func(_ context.Context, _, _ string) (conn net.Conn, err error) {
return nil, errors.Error("no addresses")
}
}
......
......@@ -48,7 +48,7 @@ func LookupParallel(
}
var errs []error
for n := 0; n < resolversNum; n++ {
for range resolvers {
result := <-ch
if result.err == nil {
return result.addrs, nil
......
......@@ -14,11 +14,12 @@ func TestLookupIPAddr(t *testing.T) {
p := Proxy{}
upstreams := make([]upstream.Upstream, 0)
// Use AdGuard DNS here
opts := &upstream.Options{Timeout: defaultTimeout}
dnsUpstream, err := upstream.AddressToUpstream("94.140.14.14", opts)
if err != nil {
t.Fatalf("cannot prepare the upstream: %s", err)
}
dnsUpstream, err := upstream.AddressToUpstream("94.140.14.14", &upstream.Options{
Timeout: defaultTimeout,
})
require.NoError(t, err)
p.UpstreamConfig = &UpstreamConfig{}
p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)
......@@ -28,22 +29,13 @@ func TestLookupIPAddr(t *testing.T) {
// Now let's try doing some lookups
addrs, err := p.LookupIPAddr("dns.google")
assert.Nil(t, err)
assert.True(t, len(addrs) == 2 || len(addrs) == 4)
assertContainsIP(t, addrs, "8.8.8.8")
assertContainsIP(t, addrs, "8.8.4.4")
if len(addrs) == 4 {
assertContainsIP(t, addrs, "2001:4860:4860::8888")
assertContainsIP(t, addrs, "2001:4860:4860::8844")
}
}
require.NoError(t, err)
require.NotEmpty(t, addrs)
func assertContainsIP(t *testing.T, addrs []net.IPAddr, ip string) {
for _, addr := range addrs {
if addr.String() == ip {
return
}
assert.Contains(t, addrs, net.IPAddr{IP: net.IP{8, 8, 8, 8}})
assert.Contains(t, addrs, net.IPAddr{IP: net.IP{8, 8, 4, 4}})
if len(addrs) > 2 {
assert.Contains(t, addrs, net.IPAddr{IP: net.ParseIP("2001:4860:4860::8888")})
assert.Contains(t, addrs, net.IPAddr{IP: net.ParseIP("2001:4860:4860::8844")})
}
t.Fatalf("%s not found in %v", ip, addrs)
}
package upstream
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/url"
"sync"
"time"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// NextProtoDQ is the ALPN token for DoQ. During the connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "doq" in the
// crypto handshake.
// The current draft version is https://datatracker.ietf.org/doc/rfc9250/.
const NextProtoDQ = "doq"
// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}
// RootCAs is the CertPool that must be used by all upstreams. Redefining
// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
// NEPacketTunnelProvider.
// TODO(ameshkov): remove this and replace with an upstream option.
var RootCAs *x509.CertPool
// CipherSuites is a custom list of TLSv1.2 ciphers.
var CipherSuites []uint16
// TODO(ameshkov): refactor bootstrapper, it's overcomplicated and hard to
// understand what it does.
type bootstrapper struct {
// URL is the upstream server address.
URL *url.URL
// resolvers is a list of *net.Resolver to use to resolve the upstream
// hostname, if necessary.
resolvers []Resolver
// dialContext is the dial function for creating unencrypted TCP
// connections.
dialContext dialHandler
// resolvedConfig is a *tls.Config that is used for encrypted DNS protocols.
resolvedConfig *tls.Config
// sessionsCache is necessary to achieve TLS session resumption. We create
// once when the bootstrapper is created and re-use every time when we need
// to create a new tls.Config.
sessionsCache tls.ClientSessionCache
// guard protects dialContext and resolvedConfig.
guard sync.RWMutex
// options is the Options that were passed to the AddressToUpstream
// function. It configures different upstream properties: callbacks for
// checking certificates, timeout, etc.
options *Options
}
// newBootstrapperResolved creates a new bootstrapper that already contains
// resolved config. This can be done only in the case when we already know the
// resolver IP address passed via options.
func newBootstrapperResolved(upsURL *url.URL, options *Options) (*bootstrapper, error) {
// get a host without port
host, port, err := net.SplitHostPort(upsURL.Host)
if err != nil {
return nil, fmt.Errorf("bootstrapper requires port in address %s", upsURL)
}
var resolverAddresses []string
for _, ip := range options.ServerIPAddrs {
addr := net.JoinHostPort(ip.String(), port)
resolverAddresses = append(resolverAddresses, addr)
}
b := &bootstrapper{
URL: upsURL,
options: options,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
sessionsCache: tls.NewLRUClientSessionCache(0),
}
b.dialContext = b.createDialContext(resolverAddresses)
b.resolvedConfig = b.createTLSConfig(host)
return b, nil
}
// newBootstrapper initializes a new bootstrapper instance. u is the original
// resolver address string (i.e. tls://one.one.one.one:853), options is the
// upstream configuration options.
func newBootstrapper(u *url.URL, options *Options) (b *bootstrapper, err error) {
resolvers := []Resolver{}
if len(options.Bootstrap) != 0 {
// Create a list of resolvers for parallel lookup.
for _, boot := range options.Bootstrap {
var r Resolver
r, err = NewResolver(boot, options)
if err != nil {
return nil, err
}
resolvers = append(resolvers, r)
}
} else {
r, _ := NewResolver("", options) // NewResolver("") always succeeds
// nil resolver if the default one
resolvers = append(resolvers, r)
}
return &bootstrapper{
URL: u,
resolvers: resolvers,
options: options,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
sessionsCache: tls.NewLRUClientSessionCache(0),
}, nil
}
// dialHandler describes the dial function for creating unencrypted network
// connections to the upstream server. Internally, this function will use the
// supplied bootstrap DNS servers to resolve the upstream's IP address and only
// then it will actually establish a connection.
type dialHandler func(ctx context.Context, network, addr string) (net.Conn, error)
// get is the main function of bootstrapper that does two crucial things.
// First, it creates an instance of a dialHandler function that should be used
// by the Upstream to establish a connection to the upstream DNS server. This
// dialHandler in a lazy manner resolves the DNS server IP address using the
// bootstrap DNS servers supplied to this bootstrapper instance. It will also
// create an instance of *tls.Config that should be used for establishing an
// encrypted connection for DoH/DoT/DoQ.
func (n *bootstrapper) get() (*tls.Config, dialHandler, error) {
n.guard.RLock()
if n.dialContext != nil && n.resolvedConfig != nil { // fast path
tlsConfig, dialContext := n.resolvedConfig, n.dialContext
n.guard.RUnlock()
return tlsConfig.Clone(), dialContext, nil
}
//
// Slow path: resolve the IP address of the n.address's host
//
// get a host without port
u := n.URL
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
n.guard.RUnlock()
return nil, nil, fmt.Errorf("bootstrapper requires port in address %s", u)
}
// if n.address's host is an IP, just use it right away.
ip := net.ParseIP(host)
if ip != nil {
n.guard.RUnlock()
resolverAddress := net.JoinHostPort(host, port)
// Upgrade lock to protect n.resolvedConfig.
// TODO(ameshkov): rework, that's not how it should be done.
n.guard.Lock()
defer n.guard.Unlock()
n.dialContext = n.createDialContext([]string{resolverAddress})
n.resolvedConfig = n.createTLSConfig(host)
return n.resolvedConfig, n.dialContext, nil
}
// Don't lock anymore (we can launch multiple lookup requests at a time)
// Otherwise, it might mess with the timeout specified for the Upstream
// See here: https://github.com/AdguardTeam/dnsproxy/issues/15
n.guard.RUnlock()
//
// if it's a hostname
//
var ctx context.Context
if n.options.Timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(context.Background(), n.options.Timeout)
defer cancel()
} else {
ctx = context.Background()
}
addrs, err := LookupParallel(ctx, n.resolvers, host)
if err != nil {
return nil, nil, fmt.Errorf("lookup %s: %w", host, err)
}
proxynetutil.SortNetIPAddrs(addrs, n.options.PreferIPv6)
resolved := make([]string, 0, len(addrs))
for _, addr := range addrs {
if addr.IsValid() {
resolved = append(resolved, net.JoinHostPort(addr.String(), port))
}
}
if len(resolved) == 0 {
// couldn't find any suitable IP address
return nil, nil, fmt.Errorf("couldn't find any suitable IP address for host %s", host)
}
n.guard.Lock()
defer n.guard.Unlock()
n.dialContext = n.createDialContext(resolved)
n.resolvedConfig = n.createTLSConfig(host)
return n.resolvedConfig, n.dialContext, nil
}
// createTLSConfig creates a client TLS config that will be used to establish
// an encrypted connection for DoH/DoT/DoQ.
func (n *bootstrapper) createTLSConfig(host string) *tls.Config {
tlsConfig := &tls.Config{
ServerName: host,
RootCAs: RootCAs,
CipherSuites: CipherSuites,
ClientSessionCache: n.sessionsCache,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: n.options.InsecureSkipVerify,
VerifyPeerCertificate: n.options.VerifyServerCertificate,
VerifyConnection: n.options.VerifyConnection,
}
// Depending on the URL scheme, we choose what ALPN will be advertised by
// the client.
switch n.URL.Scheme {
case "tls":
// Don't use the ALPN since some servers currently do not accept it.
//
// See https://github.com/ameshkov/dnslookup/issues/19.
case "https":
httpVersions := n.options.HTTPVersions
if httpVersions == nil {
httpVersions = DefaultHTTPVersions
}
var nextProtos []string
for _, v := range httpVersions {
nextProtos = append(nextProtos, string(v))
}
tlsConfig.NextProtos = nextProtos
case "quic":
tlsConfig.NextProtos = compatProtoDQ
}
return tlsConfig
}
// createDialContext returns a dialHandler function that tries to establish the
// connection to each of the provided addresses one by one.
func (n *bootstrapper) createDialContext(addresses []string) (dialContext dialHandler) {
dialer := &net.Dialer{
Timeout: n.options.Timeout,
}
return func(ctx context.Context, network, _ string) (net.Conn, error) {
if len(addresses) == 0 {
return nil, errors.Error("no addresses")
}
var errs []error
// Return first connection without error
// Note that we're using bootstrapped resolverAddress instead of what's passed to the function
for _, resolverAddress := range addresses {
log.Tracef("Dialing to %s", resolverAddress)
start := time.Now()
conn, err := dialer.DialContext(ctx, network, resolverAddress)
elapsed := time.Since(start)
if err == nil {
log.Tracef(
"dialer has successfully initialized connection to %s in %s",
resolverAddress,
elapsed,
)
return conn, nil
}
errs = append(errs, err)
log.Tracef(
"dialer failed to initialize connection to %s, in %s, cause: %s",
resolverAddress,
elapsed,
err,
)
}
return nil, errors.List("all dialers failed", errs...)
}
}
// newContext creates a new context with deadline if needed. If no timeout is
// set cancel would be a simple noop.
func (n *bootstrapper) newContext() (ctx context.Context, cancel context.CancelFunc) {
ctx = context.Background()
cancel = func() {}
if n.options.Timeout > 0 {
ctx, cancel = context.WithDeadline(ctx, time.Now().Add(n.options.Timeout))
}
return ctx, cancel
}
package upstream
import (
"context"
"testing"
"time"
)
// See the details here: https://github.com/AdguardTeam/dnsproxy/issues/18
func TestDialContext(t *testing.T) {
resolved := []struct {
addresses []string
host string
}{{
addresses: []string{"216.239.32.59:443"},
host: "dns.google.com",
}, {
addresses: []string{"94.140.14.14:855", "94.140.14.14:853"},
host: "dns.adguard.com",
}, {
addresses: []string{"1.1.1.1:5555", "1.1.1.1:853", "8.8.8.8:85"},
host: "1dot1dot1dot1.cloudflare-dns.com",
}}
b := bootstrapper{options: &Options{Timeout: 2 * time.Second}}
for _, test := range resolved {
dialContext := b.createDialContext(test.addresses)
_, err := dialContext(context.TODO(), "tcp", "")
if err != nil {
t.Fatalf("Couldn't dial to %s: %s", test.host, err)
}
}
}
......@@ -128,7 +128,7 @@ func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err err
}
// exchangeResult represents the result of DNS exchange.
type exchangeResult struct {
type exchangeResult = struct {
// upstream is the Upstream that successfully resolved the request.
upstream Upstream
......
......@@ -8,11 +8,14 @@ import (
"fmt"
"io"
"net"
"net/netip"
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
......@@ -130,123 +133,138 @@ const (
defaultPortDoQ = 853
)
// AddressToUpstream converts addr to an Upstream instance:
// RootCAs is the CertPool that must be used by all upstreams. Redefining
// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
// NEPacketTunnelProvider.
//
// - 8.8.8.8:53 or udp://dns.adguard.com for plain DNS;
// - tcp://8.8.8.8:53 for plain DNS-over-TCP;
// - tls://1.1.1.1 for DNS-over-TLS;
// - https://dns.adguard.com/dns-query for DNS-over-HTTPS;
// TODO(ameshkov): remove this and replace with an upstream option.
var RootCAs *x509.CertPool
// CipherSuites is a custom list of TLSv1.2 ciphers.
var CipherSuites []uint16
// AddressToUpstream converts addr to an Upstream using the specified options.
// addr can be either a URL, or a plain address, either a domain name or an IP.
//
// - udp://5.3.5.3:53 or 5.3.5.3:53 for plain DNS using IP address;
// - udp://name.server:53 or name.server:53 for plain DNS using domain name;
// - tcp://5.3.5.3:53 for plain DNS-over-TCP using IP address;
// - tcp://name.server:53 for plain DNS-over-TCP using domain name;
// - tls://5.3.5.3:853 for DNS-over-TLS using IP address;
// - tls://name.server:853 for DNS-over-TLS using domain name;
// - https://5.3.5.3:443/dns-query for DNS-over-HTTPS using IP address;
// - https://name.server:443/dns-query for DNS-over-HTTPS using domain name;
// - quic://5.3.5.3:853 for DNS-over-QUIC using IP address;
// - quic://name.server:853 for DNS-over-QUIC using domain name;
// - h3://dns.google for DNS-over-HTTPS that only works with HTTP/3;
// - sdns://... for DNS stamp, see https://dnscrypt.info/stamps-specifications.
//
// opts are applied to the u. nil is a valid value for opts.
// If addr doesn't have port specified, the default port of the appropriate
// protocol will be used.
//
// opts are applied to the u and shouldn't be modified afterwards, nil value is
// valid.
//
// TODO(e.burkov): Clone opts?
func AddressToUpstream(addr string, opts *Options) (u Upstream, err error) {
if opts == nil {
opts = &Options{}
}
var uu *url.URL
if strings.Contains(addr, "://") {
var uu *url.URL
// Parse as URL.
uu, err = url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", addr, err)
}
} else {
// Probably, plain UDP upstream defined by address or address:port.
_, port, splitErr := net.SplitHostPort(addr)
if splitErr == nil {
// Validate port.
_, err = strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid address %s: %w", addr, err)
}
}
return urlToUpstream(uu, opts)
}
var host, port string
host, port, err = net.SplitHostPort(addr)
if err != nil {
return &plainDNS{address: net.JoinHostPort(addr, "53"), timeout: opts.Timeout}, nil
}
// Validate port.
portN, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid address: %s", addr)
}
return &plainDNS{address: netutil.JoinHostPort(host, int(portN)), timeout: opts.Timeout}, nil
}
// urlToBoot creates a bootstrapper with the specified options.
func urlToBoot(u *url.URL, opts *Options) (b *bootstrapper, err error) {
if len(opts.ServerIPAddrs) == 0 {
return newBootstrapper(u, opts)
uu = &url.URL{
Scheme: "udp",
Host: addr,
}
}
return newBootstrapperResolved(u, opts)
return urlToUpstream(uu, opts)
}
// urlToUpstream converts uu to an Upstream using opts.
func urlToUpstream(uu *url.URL, opts *Options) (u Upstream, err error) {
switch sch := uu.Scheme; sch {
case "sdns":
return stampToUpstream(uu, opts)
return parseStamp(uu, opts)
case "udp", "tcp":
return newPlain(uu, opts.Timeout, sch == "tcp"), nil
return newPlain(uu, opts)
case "quic":
return newDoQ(uu, opts)
case "tls":
return newDoT(uu, opts)
case "h3":
opts.HTTPVersions = []HTTPVersion{HTTPVersion3}
uu.Scheme = "https"
return newDoH(uu, opts)
case "https":
case "h3", "https":
return newDoH(uu, opts)
default:
return nil, fmt.Errorf("unsupported url scheme: %s", sch)
}
}
// stampToUpstream converts a DNS stamp to an Upstream
// options -- Upstream customization options
func stampToUpstream(upsURL *url.URL, opts *Options) (Upstream, error) {
// parseStamp converts a DNS stamp to an Upstream.
func parseStamp(upsURL *url.URL, opts *Options) (u Upstream, err error) {
stamp, err := dnsstamps.NewServerStampFromString(upsURL.String())
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", upsURL, err)
}
// TODO(e.burkov): Port?
if stamp.ServerAddrStr != "" {
host, _, err := net.SplitHostPort(stamp.ServerAddrStr)
host, _, err := netutil.SplitHostPort(stamp.ServerAddrStr)
if err != nil {
host = stamp.ServerAddrStr
}
// Parse and add to options
// Parse and add to options.
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("invalid server address in the stamp: %s", stamp.ServerAddrStr)
return nil, fmt.Errorf("invalid server stamp address %s", stamp.ServerAddrStr)
}
// TODO(e.burkov): Append?
opts.ServerIPAddrs = []net.IP{ip}
}
switch stamp.Proto {
case dnsstamps.StampProtoTypePlain:
return &plainDNS{address: stamp.ServerAddrStr, timeout: opts.Timeout}, nil
return newPlain(&url.URL{Scheme: "udp", Host: stamp.ServerAddrStr}, opts)
case dnsstamps.StampProtoTypeDNSCrypt:
b, err := newBootstrapper(upsURL, opts)
if err != nil {
return nil, fmt.Errorf("bootstrap server parse: %s", err)
}
return &dnsCrypt{boot: b}, nil
return newDNSCrypt(upsURL, opts)
case dnsstamps.StampProtoTypeDoH:
return AddressToUpstream(fmt.Sprintf("https://%s%s", stamp.ProviderName, stamp.Path), opts)
return newDoH(&url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path}, opts)
case dnsstamps.StampProtoTypeDoQ:
return AddressToUpstream(fmt.Sprintf("quic://%s%s", stamp.ProviderName, stamp.Path), opts)
return newDoQ(&url.URL{Scheme: "quic", Host: stamp.ProviderName, Path: stamp.Path}, opts)
case dnsstamps.StampProtoTypeTLS:
return AddressToUpstream(fmt.Sprintf("tls://%s", stamp.ProviderName), opts)
return newDoT(&url.URL{Scheme: "tls", Host: stamp.ProviderName}, opts)
default:
return nil, fmt.Errorf("unsupported stamp protocol %s", &stamp.Proto)
}
return nil, fmt.Errorf("unsupported protocol %v in %s", stamp.Proto, upsURL)
}
// addPort appends port to u if it's absent.
func addPort(u *url.URL, port int) {
if u != nil && u.Port() == "" {
u.Host = netutil.JoinHostPort(strings.Trim(u.Host, "[]"), port)
if u != nil {
_, _, err := net.SplitHostPort(u.Host)
if err != nil {
u.Host = netutil.JoinHostPort(u.Host, port)
return
}
}
}
......@@ -269,3 +287,78 @@ func logFinish(upstreamAddress string, err error) {
}
log.Debug("%s: response: %s", upstreamAddress, status)
}
// DialerInitializer returns the handler that it creates. All the subsequent
// calls to it, except the first one, will return the same handler so that
// resolving will be performed only once.
type DialerInitializer func() (handler bootstrap.DialHandler, err error)
// newDialerInitializer creates an initializer of the dialer that will dial the
// addresses resolved from u using opts.
func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer, err error) {
host, port, err := netutil.SplitHostPort(u.Host)
if err != nil {
return nil, fmt.Errorf("invalid address: %s: %w", u.Host, err)
}
if addrsLen := len(opts.ServerIPAddrs); addrsLen > 0 {
// Don't resolve the addresses of the server since those from the
// options should be used.
addrs := make([]string, 0, addrsLen)
for _, addr := range opts.ServerIPAddrs {
addrs = append(addrs, netutil.JoinHostPort(addr.String(), port))
}
handler := bootstrap.NewDialContext(opts.Timeout, addrs...)
return func() (bootstrap.DialHandler, error) { return handler, nil }, nil
} else if _, err = netip.ParseAddr(host); err == nil {
// Don't resolve the address of the server since it's already an IP.
handler := bootstrap.NewDialContext(opts.Timeout, u.Host)
return func() (bootstrap.DialHandler, error) { return handler, nil }, nil
}
bootstraps := opts.Bootstrap
if len(opts.Bootstrap) == 0 {
// Use the default resolver for bootstrapping.
bootstraps = []string{""}
}
// Prepare resolvers for bootstrapping.
resolvers := make([]Resolver, 0, len(bootstraps))
for _, boot := range bootstraps {
var r Resolver
r, err = NewResolver(boot, opts)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
resolvers = append(resolvers, r)
}
var dialHandler atomic.Value
di = func() (h bootstrap.DialHandler, resErr error) {
// Check if the dial handler has already been created.
h, ok := dialHandler.Load().(bootstrap.DialHandler)
if ok {
return h, nil
}
// TODO(e.burkov): It may appear that several exchanges will try to
// resolve the upstream hostname at the same time. Currently, the last
// successful value will be stored in dialHandler, but ideally we should
// resolve only once.
h, resolveErr := bootstrap.ResolveDialContext(u, opts.Timeout, resolvers, opts.PreferIPv6)
if resolveErr != nil {
return nil, fmt.Errorf("creating dial handler: %w", resolveErr)
}
dialHandler.Store(h)
return h, nil
}
return di, nil
}
......@@ -3,6 +3,7 @@ package upstream
import (
"fmt"
"io"
"net/url"
"os"
"sync"
"time"
......@@ -13,55 +14,83 @@ import (
"github.com/miekg/dns"
)
// dnsCrypt is a struct that implements the Upstream interface for the DNSCrypt
// protocol.
// dnsCrypt implements the [Upstream] interface for the DNSCrypt protocol.
type dnsCrypt struct {
boot *bootstrapper
client *dnscrypt.Client // DNSCrypt client properties
serverInfo *dnscrypt.ResolverInfo // DNSCrypt resolver info
// mu protects client and serverInfo.
mu *sync.RWMutex
sync.RWMutex // protects DNSCrypt client
// client stores the DNSCrypt client properties.
client *dnscrypt.Client
// serverInfo stores the DNSCrypt server properties.
serverInfo *dnscrypt.ResolverInfo
// addr is the DNSCrypt server URL.
addr *url.URL
// verifyCert is a callback that verifies the resolver's certificate.
verifyCert func(cert *dnscrypt.Cert) (err error)
// timeout is the timeout for the DNS requests.
timeout time.Duration
}
// newDNSCrypt returns a new DNSCrypt Upstream.
func newDNSCrypt(addr *url.URL, opts *Options) (u *dnsCrypt, err error) {
return &dnsCrypt{
mu: &sync.RWMutex{},
addr: addr,
verifyCert: opts.VerifyDNSCryptCertificate,
timeout: opts.Timeout,
}, nil
}
// type check
var _ Upstream = (*dnsCrypt)(nil)
// Address implements the Upstream interface for *dnsCrypt.
func (p *dnsCrypt) Address() string { return p.boot.URL.String() }
// Exchange implements the Upstream interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(m *dns.Msg) (*dns.Msg, error) {
reply, err := p.exchangeDNSCrypt(m)
// Address implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Address() string { return p.addr.String() }
// Exchange implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp, err = p.exchangeDNSCrypt(m)
if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
// If request times out, it is possible that the server configuration has been changed.
// It is safe to assume that the key was rotated (for instance, as it is described here: https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated/).
// We should re-fetch the server certificate info so that the new requests were not failing.
p.Lock()
p.client = nil
p.serverInfo = nil
p.Unlock()
// Retry the request one more time
// If request times out, it is possible that the server configuration
// has been changed. It is safe to assume that the key was rotated, see
// https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated.
// Re-fetch the server certificate info for new requests to not fail.
_, _, err = p.resetClient()
if err != nil {
return nil, err
}
return p.exchangeDNSCrypt(m)
}
return reply, err
return resp, err
}
// Close implements the Upstream interface for *dnsCrypt.
// Close implements the [Upstream] interface for *dnsCrypt.
func (p *dnsCrypt) Close() (err error) {
// Nothing to close here.
return nil
}
// exchangeDNSCrypt attempts to send the DNS query and returns the response
func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (reply *dns.Msg, err error) {
p.RLock()
client := p.client
resolverInfo := p.serverInfo
p.RUnlock()
// exchangeDNSCrypt attempts to send the DNS query and returns the response.
func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (resp *dns.Msg, err error) {
var client *dnscrypt.Client
var resolverInfo *dnscrypt.ResolverInfo
func() {
p.mu.RLock()
defer p.mu.RUnlock()
client = p.client
resolverInfo = p.serverInfo
}()
// Check the client and server info are set and the certificate is not
// expired, since any of these cases require a client reset.
//
// TODO(ameshkov): Consider using [time.Time] for [dnscrypt.Cert.NotAfter].
now := uint32(time.Now().Unix())
if client == nil || resolverInfo == nil || resolverInfo.ResolverCert.NotAfter < now {
client, resolverInfo, err = p.resetClient()
......@@ -71,39 +100,47 @@ func (p *dnsCrypt) exchangeDNSCrypt(m *dns.Msg) (reply *dns.Msg, err error) {
}
}
reply, err = client.Exchange(m, resolverInfo)
if reply != nil && reply.Truncated {
log.Tracef("truncated message received, retrying over tcp, question: %v", m.Question[0])
tcpClient := dnscrypt.Client{Timeout: p.boot.options.Timeout, Net: "tcp"}
reply, err = tcpClient.Exchange(m, resolverInfo)
resp, err = client.Exchange(m, resolverInfo)
if resp != nil && resp.Truncated {
q := &m.Question[0]
log.Debug("dnscrypt %s: received truncated, falling back to tcp with %s", p.addr, q)
tcpClient := dnscrypt.Client{Timeout: p.timeout, Net: "tcp"}
resp, err = tcpClient.Exchange(m, resolverInfo)
}
if err == nil && reply != nil && reply.Id != m.Id {
if err == nil && resp != nil && resp.Id != m.Id {
err = dns.ErrId
}
return reply, err
return resp, err
}
// resetClient renews the DNSCrypt client and server properties and also sets
// those to nil on fail.
func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.ResolverInfo, err error) {
p.Lock()
defer p.Unlock()
addr := p.Address()
// Using "udp" for DNSCrypt upstreams by default.
client = &dnscrypt.Client{Timeout: p.boot.options.Timeout, Net: "udp"}
ri, err = client.Dial(p.Address())
// Use UDP for DNSCrypt upstreams by default.
client = &dnscrypt.Client{Timeout: p.timeout, Net: "udp"}
ri, err = client.Dial(addr)
if err != nil {
return nil, nil, fmt.Errorf("fetching certificate info from %s: %w", p.Address(), err)
}
if p.boot.options.VerifyDNSCryptCertificate != nil {
err = p.boot.options.VerifyDNSCryptCertificate(ri.ResolverCert)
// Trigger client and server info renewal on the next request.
client, ri = nil, nil
err = fmt.Errorf("fetching certificate info from %s: %w", addr, err)
} else if p.verifyCert != nil {
err = p.verifyCert(ri.ResolverCert)
if err != nil {
return nil, nil, fmt.Errorf("verifying certificate info from %s: %w", p.Address(), err)
// Trigger client and server info renewal on the next request.
client, ri = nil, nil
err = fmt.Errorf("verifying certificate info from %s: %w", addr, err)
}
}
p.mu.Lock()
defer p.mu.Unlock()
p.client = client
p.serverInfo = ri
return client, ri, nil
return p.client, p.serverInfo, nil
}
......@@ -13,6 +13,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
......@@ -44,58 +45,99 @@ const (
// dnsOverHTTPS is a struct that implements the Upstream interface for the
// DNS-over-HTTPS protocol.
type dnsOverHTTPS struct {
boot *bootstrapper
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// addr is the DNS-over-HTTPS server URL.
addr *url.URL
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// The Client's Transport typically has internal state (cached TCP
// connections), so Clients should be reused instead of created as
// needed. Clients are safe for concurrent use by multiple goroutines.
client *http.Client
clientMu sync.Mutex
// connections), so Clients should be reused instead of created as needed.
// Clients are safe for concurrent use by multiple goroutines.
client *http.Client
// quicConfig is the QUIC configuration that is used if HTTP/3 is enabled
// for this upstream.
quicConfig *quic.Config
quicConfigGuard sync.Mutex
quicConfig *quic.Config
// clientMu protects client.
clientMu sync.Mutex
// quicConfMu protects quicConfig.
quicConfMu sync.Mutex
// timeout is used in HTTP client and for H3 probes.
timeout time.Duration
}
// type check
var _ Upstream = (*dnsOverHTTPS)(nil)
// newDoH returns the DNS-over-HTTPS Upstream.
func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) {
addPort(uu, defaultPortDoH)
func newDoH(addr *url.URL, opts *Options) (u Upstream, err error) {
addPort(addr, defaultPortDoH)
var b *bootstrapper
b, err = urlToBoot(uu, opts)
if err != nil {
return nil, fmt.Errorf("creating https bootstrapper: %w", err)
var httpVersions []HTTPVersion
if addr.Scheme == "h3" {
addr.Scheme = "https"
httpVersions = []HTTPVersion{HTTPVersion3}
} else if httpVersions = opts.HTTPVersions; len(opts.HTTPVersions) == 0 {
httpVersions = DefaultHTTPVersions
}
u = &dnsOverHTTPS{
boot: b,
getDialer, err := newDialerInitializer(addr, opts)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
ups := &dnsOverHTTPS{
getDialer: getDialer,
addr: addr,
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: RootCAs,
CipherSuites: CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
},
timeout: opts.Timeout,
}
for _, v := range httpVersions {
ups.tlsConf.NextProtos = append(ups.tlsConf.NextProtos, string(v))
}
runtime.SetFinalizer(u, (*dnsOverHTTPS).Close)
runtime.SetFinalizer(ups, (*dnsOverHTTPS).Close)
return u, nil
return ups, nil
}
// Address implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Address() string { return p.boot.URL.String() }
// Address implements the [Upstream] interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Address() string { return p.addr.String() }
// Exchange implements the Upstream interface for *dnsOverHTTPS.
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
// Quote from https://www.rfc-editor.org/rfc/rfc8484.html:
// In order to maximize HTTP cache friendliness, DoH clients using media
// formats that include the ID field from the DNS message header, such
// as "application/dns-message", SHOULD use a DNS ID of 0 in every DNS
// request.
//
// See https://www.rfc-editor.org/rfc/rfc8484.html.
id := m.Id
m.Id = 0
defer func() {
......@@ -167,9 +209,11 @@ func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
// exchangeHTTPS logs the request and its result and calls exchangeHTTPSClient.
func (p *dnsOverHTTPS) exchangeHTTPS(client *http.Client, req *dns.Msg) (resp *dns.Msg, err error) {
logBegin(p.Address(), req)
addr := p.Address()
logBegin(addr, req)
resp, err = p.exchangeHTTPSClient(client, req)
logFinish(p.Address(), err)
logFinish(addr, err)
return resp, err
}
......@@ -194,15 +238,15 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
}
u := url.URL{
Scheme: p.boot.URL.Scheme,
Host: p.boot.URL.Host,
Path: p.boot.URL.Path,
Scheme: p.addr.Scheme,
Host: p.addr.Host,
Path: p.addr.Path,
RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)),
}
httpReq, err := http.NewRequest(method, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", p.boot.URL, err)
return nil, fmt.Errorf("creating http request to %s: %w", p.addr, err)
}
httpReq.Header.Set("Accept", "application/dns-message")
......@@ -210,13 +254,13 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("requesting %s: %w", p.boot.URL, err)
return nil, fmt.Errorf("requesting %s: %w", p.addr, err)
}
defer log.OnCloserError(httpResp.Body, log.DEBUG)
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", p.boot.URL, err)
return nil, fmt.Errorf("reading %s: %w", p.addr, err)
}
if httpResp.StatusCode != http.StatusOK {
......@@ -225,7 +269,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
"expected status %d, got %d from %s",
http.StatusOK,
httpResp.StatusCode,
p.boot.URL,
p.addr,
)
}
......@@ -234,7 +278,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
if err != nil {
return nil, fmt.Errorf(
"unpacking response from %s: body is %s: %w",
p.boot.URL,
p.addr,
body,
err,
)
......@@ -300,8 +344,8 @@ func (p *dnsOverHTTPS) resetClient(resetErr error) (client *http.Client, err err
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfMu.Lock()
defer p.quicConfMu.Unlock()
return p.quicConfig
}
......@@ -309,8 +353,8 @@ func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfMu.Lock()
defer p.quicConfMu.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
......@@ -323,6 +367,7 @@ func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
p.clientMu.Lock()
defer p.clientMu.Unlock()
if p.client != nil {
return p.client, true, nil
}
......@@ -330,7 +375,7 @@ func (p *dnsOverHTTPS) getClient() (c *http.Client, isCached bool, err error) {
// Timeout can be exceeded while waiting for the lock. This happens quite
// often on mobile devices.
elapsed := time.Since(startTime)
if p.boot.options.Timeout > 0 && elapsed > p.boot.options.Timeout {
if p.timeout > 0 && elapsed > p.timeout {
return nil, false, fmt.Errorf("timeout exceeded: %s", elapsed)
}
......@@ -352,8 +397,10 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
client := &http.Client{
Transport: transport,
Timeout: p.boot.options.Timeout,
Jar: nil,
// TODO(ameshkov): p.timeout may appear zero that will disable the
// timeout for client, consider using the default.
Timeout: p.timeout,
Jar: nil,
}
p.client = client
......@@ -368,15 +415,16 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
// HTTP3 is enabled in the upstream options). If this attempt is successful,
// it returns an HTTP3 transport, otherwise it returns the H1/H2 transport.
func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
tlsConfig, dialContext, err := p.boot.get()
dialContext, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("bootstrapping %s: %w", p.boot.URL, err)
return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err)
}
// First, we attempt to create an HTTP3 transport. If the probe QUIC
// connection is established successfully, we'll be using HTTP3 for this
// upstream.
transportH3, err := p.createTransportH3(tlsConfig, dialContext)
tlsConf := p.tlsConf.Clone()
transportH3, err := p.createTransportH3(tlsConf, dialContext)
if err == nil {
log.Debug("using HTTP/3 for this upstream: QUIC was faster")
return transportH3, nil
......@@ -389,15 +437,15 @@ func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
TLSClientConfig: tlsConf,
DisableCompression: true,
DialContext: dialContext,
IdleConnTimeout: transportDefaultIdleConnTimeout,
MaxConnsPerHost: dohMaxConnsPerHost,
MaxIdleConns: dohMaxIdleConns,
// Since we have a custom DialContext, we need to use this field to
// make golang http.Client attempt to use HTTP/2. Otherwise, it would
// only be used when negotiated on the TLS level.
// Since we have a custom DialContext, we need to use this field to make
// golang http.Client attempt to use HTTP/2. Otherwise, it would only be
// used when negotiated on the TLS level.
ForceAttemptHTTP2: true,
}
......@@ -470,7 +518,7 @@ func (h *http3Transport) Close() (err error) {
// will create the *http3.RoundTripper instance.
func (p *dnsOverHTTPS) createTransportH3(
tlsConfig *tls.Config,
dialContext dialHandler,
dialContext bootstrap.DialHandler,
) (roundTripper http.RoundTripper, err error) {
if !p.supportsH3() {
return nil, errors.Error("HTTP3 support is not enabled")
......@@ -507,7 +555,7 @@ func (p *dnsOverHTTPS) createTransportH3(
// should use to establish the QUIC connections.
func (p *dnsOverHTTPS) probeH3(
tlsConfig *tls.Config,
dialContext dialHandler,
dialContext bootstrap.DialHandler,
) (addr string, err error) {
// We're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
......@@ -521,7 +569,7 @@ func (p *dnsOverHTTPS) probeH3(
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return "", fmt.Errorf("not a UDP connection to %s", p.Address())
return "", fmt.Errorf("not a UDP connection to %s", p.addr)
}
addr = udpConn.RemoteAddr().String()
......@@ -575,7 +623,7 @@ func (p *dnsOverHTTPS) probeH3(
func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
timeout := p.boot.options.Timeout
timeout := p.timeout
if timeout == 0 {
timeout = dialTimeout
}
......@@ -584,7 +632,7 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err
conn, err := quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.getQUICConfig())
if err != nil {
ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.Address(), err)
ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.addr, err)
return
}
......@@ -599,7 +647,7 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err
// probeTLS attempts to establish a TLS connection to the specified address. We
// run probeQUIC and probeTLS in parallel and see which one is faster.
func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config, ch chan error) {
func (p *dnsOverHTTPS) probeTLS(dialContext bootstrap.DialHandler, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
conn, err := tlsDial(dialContext, "tcp", tlsConfig)
......@@ -619,8 +667,8 @@ func (p *dnsOverHTTPS) probeTLS(dialContext dialHandler, tlsConfig *tls.Config,
// supportsH3 returns true if HTTP/3 is supported by this upstream.
func (p *dnsOverHTTPS) supportsH3() (ok bool) {
for _, v := range p.supportedHTTPVersions() {
if v == HTTPVersion3 {
for _, v := range p.tlsConf.NextProtos {
if v == string(HTTPVersion3) {
return true
}
}
......@@ -630,8 +678,8 @@ func (p *dnsOverHTTPS) supportsH3() (ok bool) {
// supportsHTTP returns true if HTTP/1.1 or HTTP2 is supported by this upstream.
func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
for _, v := range p.supportedHTTPVersions() {
if v == HTTPVersion11 || v == HTTPVersion2 {
for _, v := range p.tlsConf.NextProtos {
if v == string(HTTPVersion11) || v == string(HTTPVersion2) {
return true
}
}
......@@ -639,16 +687,6 @@ func (p *dnsOverHTTPS) supportsHTTP() (ok bool) {
return false
}
// supportedHTTPVersions returns the list of supported HTTP versions.
func (p *dnsOverHTTPS) supportedHTTPVersions() (v []HTTPVersion) {
v = p.boot.options.HTTPVersions
if v == nil {
v = DefaultHTTPVersions
}
return v
}
// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
_, ok = client.Transport.(*http3Transport)
......
......@@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
......@@ -21,14 +22,20 @@ import (
// TODO(ameshkov): use bootstrap timeout instead.
const dialTimeout = 10 * time.Second
// dnsOverTLS is a struct that implements the Upstream interface for the
// DNS-over-TLS protocol.
// dnsOverTLS implements the [Upstream] interface for the DNS-over-TLS protocol.
type dnsOverTLS struct {
// boot resolves the hostname upstream addresses.
boot *bootstrapper
// addr is the DNS-over-TLS server URL.
addr *url.URL
// getDialer either returns an initialized dial handler or creates a
// new one.
getDialer DialerInitializer
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// connsMu protects conns.
connsMu sync.Mutex
connsMu *sync.Mutex
// conns stores the connections ready for reuse. Don't use [sync.Pool]
// here, since there is no need to deallocate these connections.
......@@ -44,31 +51,52 @@ type dnsOverTLS struct {
var _ Upstream = (*dnsOverTLS)(nil)
// newDoT returns the DNS-over-TLS Upstream.
func newDoT(u *url.URL, opts *Options) (ups Upstream, err error) {
addPort(u, defaultPortDoT)
func newDoT(addr *url.URL, opts *Options) (ups Upstream, err error) {
addPort(addr, defaultPortDoT)
boot, err := urlToBoot(u, opts)
getDialer, err := newDialerInitializer(addr, opts)
if err != nil {
return nil, fmt.Errorf("creating tls bootstrapper: %w", err)
// Don't wrap the error since it's informative enough as is.
return nil, err
}
ups = &dnsOverTLS{
boot: boot,
tlsUps := &dnsOverTLS{
addr: addr,
getDialer: getDialer,
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: RootCAs,
CipherSuites: CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
},
connsMu: &sync.Mutex{},
}
runtime.SetFinalizer(ups, (*dnsOverTLS).Close)
runtime.SetFinalizer(tlsUps, (*dnsOverTLS).Close)
return ups, nil
return tlsUps, nil
}
// Address implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Address() string { return p.boot.URL.String() }
func (p *dnsOverTLS) Address() string { return p.addr.String() }
// Exchange implements the [Upstream] interface for *dnsOverTLS.
func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
conn, err := p.conn()
h, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
}
conn, err := p.conn(h)
if err != nil {
return nil, fmt.Errorf("getting conn to %s: %w", p.Address(), err)
return nil, fmt.Errorf("getting conn to %s: %w", p.addr, err)
}
reply, err = p.exchangeWithConn(conn, m)
......@@ -78,12 +106,17 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) {
// connection from pool may also be malformed, so dial a new one.
err = errors.WithDeferred(err, conn.Close())
log.Debug("dot upstream: bad conn from pool: %s", err)
log.Debug("dot %s: bad conn from pool: %s", p.addr, err)
// Retry.
conn, err = p.dial()
conn, err = tlsDial(h, "tcp", p.tlsConf.Clone())
if err != nil {
return nil, fmt.Errorf("dialing conn to %s: %w", p.Address(), err)
return nil, fmt.Errorf(
"dialing %s: connecting to %s: %w",
p.addr,
p.tlsConf.ServerName,
err,
)
}
reply, err = p.exchangeWithConn(conn, m)
......@@ -121,11 +154,12 @@ func (p *dnsOverTLS) Close() (err error) {
// conn returns the first available connection from the pool if there is any, or
// dials a new one otherwise.
func (p *dnsOverTLS) conn() (conn net.Conn, err error) {
func (p *dnsOverTLS) conn(h bootstrap.DialHandler) (conn net.Conn, err error) {
// Dial a new connection outside the lock, if needed.
defer func() {
if conn == nil {
conn, err = p.dial()
conn, err = tlsDial(h, "tcp", p.tlsConf.Clone())
err = errors.Annotate(err, "connecting to %s: %w", p.tlsConf.ServerName)
}
}()
......@@ -184,24 +218,13 @@ func (p *dnsOverTLS) exchangeWithConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg
return reply, err
}
// dial dials a new connection that may be stored in pool.
func (p *dnsOverTLS) dial() (conn net.Conn, err error) {
tlsConfig, dialContext, err := p.boot.get()
if err != nil {
return nil, err
}
conn, err = tlsDial(dialContext, "tcp", tlsConfig)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err)
}
return conn, nil
}
// tlsDial is basically the same as tls.DialWithDialer, but we will call our own
// dialContext function to get connection.
func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.Conn, error) {
func tlsDial(
dialContext bootstrap.DialHandler,
network string,
conf *tls.Config,
) (c *tls.Conn, err error) {
// We're using bootstrapped address instead of what's passed to the
// function.
rawConn, err := dialContext(context.Background(), network, "")
......@@ -211,7 +234,7 @@ func tlsDial(dialContext dialHandler, network string, config *tls.Config) (*tls.
// We want the timeout to cover the whole process: TCP connection and
// TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config)
conn := tls.Client(rawConn, conf)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
......
......@@ -159,7 +159,10 @@ func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
require.Len(t, p.conns, 1)
conn := p.conns[0]
usedConn, err := p.conn()
dialHandler, err := p.getDialer()
require.NoError(t, err)
usedConn, err := p.conn(dialHandler)
require.NoError(t, err)
require.Same(t, usedConn, conn)
......@@ -177,7 +180,7 @@ func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
require.Len(t, p.conns, 1)
conn = p.conns[0]
usedConn, err = p.conn()
usedConn, err = p.conn(dialHandler)
require.NoError(t, err)
require.Same(t, usedConn, conn)
......
package upstream
import (
"context"
"fmt"
"net/url"
"time"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// plainDNS is a struct that implements the Upstream interface for the regular
// DNS protocol.
// plainDNS implements the [Upstream] interface for the regular DNS protocol.
type plainDNS struct {
address string
timeout time.Duration
preferTCP bool
// addr is the DNS server URL. Scheme is always "udp" or "tcp".
addr *url.URL
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// timeout is the timeout for DNS requests.
timeout time.Duration
}
// type check
var _ Upstream = &plainDNS{}
// newPlain returns the plain DNS Upstream.
func newPlain(uu *url.URL, timeout time.Duration, preferTCP bool) (u *plainDNS) {
addPort(uu, defaultPortPlain)
func newPlain(addr *url.URL, opts *Options) (u *plainDNS, err error) {
addPort(addr, defaultPortPlain)
return &plainDNS{
address: uu.Host,
timeout: timeout,
preferTCP: preferTCP,
getDialer, err := newDialerInitializer(addr, opts)
if err != nil {
return nil, err
}
return &plainDNS{
addr: addr,
getDialer: getDialer,
timeout: opts.Timeout,
}, nil
}
// Address implements the Upstream interface for *plainDNS.
// Address implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Address() string {
if p.preferTCP {
return "tcp://" + p.address
if p.addr.Scheme == "udp" {
return p.addr.Host
}
return p.address
return p.addr.String()
}
// Exchange implements the Upstream interface for *plainDNS.
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
if p.preferTCP {
tcpClient := dns.Client{Net: "tcp", Timeout: p.timeout}
// dialExchange performs a DNS exchange with the specified dial handler.
// network must be either "udp" or "tcp".
func (p *plainDNS) dialExchange(
network string,
dial bootstrap.DialHandler,
m *dns.Msg,
) (resp *dns.Msg, err error) {
addr := p.Address()
client := &dns.Client{Timeout: p.timeout}
conn := &dns.Conn{}
if network == "udp" {
conn.UDPSize = dns.MinMsgSize
}
logBegin(p.Address(), m)
reply, _, tcpErr := tcpClient.Exchange(m, p.address)
logFinish(p.Address(), tcpErr)
logBegin(addr, m)
conn.Conn, err = dial(context.Background(), network, "")
if err != nil {
logFinish(addr, err)
return nil, fmt.Errorf("dialing %s over %s: %w", p.addr.Host, network, err)
}
resp, _, err = client.ExchangeWithConn(m, conn)
logFinish(addr, err)
return resp, err
}
return reply, tcpErr
// Exchange implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
dial, err := p.getDialer()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
client := dns.Client{Timeout: p.timeout, UDPSize: dns.MaxMsgSize}
addr := p.Address()
logBegin(p.Address(), m)
reply, _, err := client.Exchange(m, p.address)
logFinish(p.Address(), err)
resp, err = p.dialExchange(p.addr.Scheme, dial, m)
if p.addr.Scheme == "udp" {
if resp == nil || !resp.Truncated {
return resp, err
}
if reply != nil && reply.Truncated {
log.Tracef("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
tcpClient := dns.Client{Net: "tcp", Timeout: p.timeout}
log.Debug("plain %s: received truncated, falling back to tcp with %s", addr, &m.Question[0])
logBegin(p.Address(), m)
reply, _, err = tcpClient.Exchange(m, p.address)
logFinish(p.Address(), err)
resp, err = p.dialExchange("tcp", dial, m)
}
return reply, err
return resp, err
}
// Close implements the Upstream interface for *plainDNS.
// Close implements the [Upstream] interface for *plainDNS.
func (p *plainDNS) Close() (err error) {
// Nothing to close here.
return nil
}
......@@ -2,6 +2,7 @@ package upstream
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
......@@ -20,10 +21,12 @@ const (
// QUICCodeNoError is used when the connection or stream needs to be closed,
// but there is no error to signal.
QUICCodeNoError = quic.ApplicationErrorCode(0)
// QUICCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
QUICCodeInternalError = quic.ApplicationErrorCode(1)
// QUICKeepAlivePeriod is the value that we pass to *quic.Config and that
// controls the period with with keep-alive frames are being sent to the
// connection. We set it to 20s as it would be in the quic-go@v0.27.1 with
......@@ -32,53 +35,95 @@ const (
//
// TODO(ameshkov): Consider making it configurable.
QUICKeepAlivePeriod = time.Second * 20
// NextProtoDQ is the ALPN token for DoQ. During the connection establishment,
// DNS/QUIC support is indicated by selecting the ALPN token "doq" in the
// crypto handshake.
//
// See https://datatracker.ietf.org/doc/rfc9250.
NextProtoDQ = "doq"
)
// dnsOverQUIC is a struct that implements the Upstream interface for the
// DNS-over-QUIC protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
// compatProtoDQ is a list of ALPN tokens used by a QUIC connection.
// NextProtoDQ is the latest draft version supported by dnsproxy, but it also
// includes previous drafts.
var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}
// dnsOverQUIC implements the [Upstream] interface for the DNS-over-QUIC
// protocol (spec: https://www.rfc-editor.org/rfc/rfc9250.html).
type dnsOverQUIC struct {
// boot is a bootstrap DNS abstraction that is used to resolve the upstream
// server's address and open a network connection to it.
boot *bootstrapper
// getDialer either returns an initialized dial handler or creates a new
// one.
getDialer DialerInitializer
// addr is the DNS-over-QUIC server URL.
addr *url.URL
// tlsConf is the configuration of TLS.
tlsConf *tls.Config
// quicConfig is the QUIC configuration that is used for establishing
// connections to the upstream. This configuration includes the TokenStore
// that needs to be stored for the lifetime of dnsOverQUIC since we can
// re-create the connection.
quicConfig *quic.Config
quicConfigGuard sync.Mutex
quicConfig *quic.Config
// conn is the current active QUIC connection. It can be closed and
// re-opened when needed.
conn quic.Connection
connMu sync.RWMutex
conn quic.Connection
// bytesPool is a *sync.Pool we use to store byte buffers in. These byte
// buffers are used to read responses from the upstream.
bytesPool *sync.Pool
bytesPoolGuard sync.Mutex
bytesPool *sync.Pool
// quicConfigMu protects quicConfig.
quicConfigMu sync.Mutex
// connMu protects conn.
connMu sync.RWMutex
// bytesPoolGuard protects bytesPool.
bytesPoolMu sync.Mutex
// timeout is the timeout for the upstream connection.
timeout time.Duration
}
// type check
var _ Upstream = (*dnsOverQUIC)(nil)
// newDoQ returns the DNS-over-QUIC Upstream.
func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) {
addPort(uu, defaultPortDoQ)
func newDoQ(addr *url.URL, opts *Options) (u Upstream, err error) {
addPort(addr, defaultPortDoQ)
var b *bootstrapper
b, err = urlToBoot(uu, opts)
getDialer, err := newDialerInitializer(addr, opts)
if err != nil {
return nil, fmt.Errorf("creating quic bootstrapper: %w", err)
return nil, err
}
u = &dnsOverQUIC{
boot: b,
getDialer: getDialer,
addr: addr,
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
tlsConf: &tls.Config{
ServerName: addr.Hostname(),
RootCAs: RootCAs,
CipherSuites: CipherSuites,
// Use the default capacity for the LRU cache. It may be useful to
// store several caches since the user may be routed to different
// servers in case there's load balancing on the server-side.
ClientSessionCache: tls.NewLRUClientSessionCache(0),
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: opts.InsecureSkipVerify,
VerifyPeerCertificate: opts.VerifyServerCertificate,
VerifyConnection: opts.VerifyConnection,
NextProtos: compatProtoDQ,
},
timeout: opts.Timeout,
}
runtime.SetFinalizer(u, (*dnsOverQUIC).Close)
......@@ -86,10 +131,10 @@ func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) {
return u, nil
}
// Address implements the Upstream interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.boot.URL.String() }
// Address implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Address() string { return p.addr.String() }
// Exchange implements the Upstream interface for *dnsOverQUIC.
// Exchange implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
// When sending queries over a QUIC connection, the DNS Message ID MUST be
// set to zero.
......@@ -134,7 +179,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, err
}
// Close implements the Upstream interface for *dnsOverQUIC.
// Close implements the [Upstream] interface for *dnsOverQUIC.
func (p *dnsOverQUIC) Close() (err error) {
p.connMu.Lock()
defer p.connMu.Unlock()
......@@ -191,8 +236,8 @@ func (p *dnsOverQUIC) shouldRetry(err error) (ok bool) {
// getBytesPool returns (creates if needed) a pool we store byte buffers in.
func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
p.bytesPoolGuard.Lock()
defer p.bytesPoolGuard.Unlock()
p.bytesPoolMu.Lock()
defer p.bytesPoolMu.Unlock()
if p.bytesPool == nil {
p.bytesPool = &sync.Pool{
......@@ -250,8 +295,8 @@ func (p *dnsOverQUIC) hasConnection() (ok bool) {
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfigMu.Lock()
defer p.quicConfigMu.Unlock()
return p.quicConfig
}
......@@ -259,8 +304,8 @@ func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfigMu.Lock()
defer p.quicConfigMu.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
......@@ -268,7 +313,7 @@ func (p *dnsOverQUIC) resetQUICConfig() {
// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
ctx, cancel := p.boot.newContext()
ctx, cancel := p.withDeadline(context.Background())
defer cancel()
stream, err := conn.OpenStreamSync(ctx)
......@@ -288,7 +333,7 @@ func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
// openConnection opens a new QUIC connection.
func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
tlsConfig, dialContext, err := p.boot.get()
dialContext, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("failed to bootstrap QUIC connection: %w", err)
}
......@@ -305,17 +350,17 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("failed to open connection to %s", p.Address())
return nil, fmt.Errorf("failed to open connection to %s", p.addr)
}
addr := udpConn.RemoteAddr().String()
ctx, cancel := p.boot.newContext()
ctx, cancel := p.withDeadline(context.Background())
defer cancel()
conn, err = quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.getQUICConfig())
conn, err = quic.DialAddrEarlyContext(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig())
if err != nil {
return nil, fmt.Errorf("opening quic connection to %s: %w", p.Address(), err)
return nil, fmt.Errorf("opening quic connection to %s: %w", p.addr, err)
}
return conn, nil
......@@ -360,7 +405,7 @@ func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) {
respBuf := *bufPtr
n, err := stream.Read(respBuf)
if err != nil && n == 0 {
return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err)
return nil, fmt.Errorf("reading response from %s: %w", p.addr, err)
}
// All DNS messages (queries and responses) sent over DoQ connections MUST
......@@ -371,7 +416,7 @@ func (p *dnsOverQUIC) readMsg(stream quic.Stream) (m *dns.Msg, err error) {
m = new(dns.Msg)
err = m.Unpack(respBuf[2:])
if err != nil {
return nil, fmt.Errorf("unpacking response from %s: %w", p.Address(), err)
return nil, fmt.Errorf("unpacking response from %s: %w", p.addr, err)
}
return m, nil
......@@ -443,3 +488,14 @@ func isQUICRetryError(err error) (ok bool) {
return false
}
func (p *dnsOverQUIC) withDeadline(
parent context.Context,
) (ctx context.Context, cancel context.CancelFunc) {
ctx, cancel = parent, func() {}
if p.timeout > 0 {
ctx, cancel = context.WithDeadline(ctx, time.Now().Add(p.timeout))
}
return ctx, cancel
}
......@@ -12,6 +12,7 @@ import (
"io"
"math/big"
"net"
"net/netip"
"net/url"
"os"
"sync"
......@@ -19,7 +20,9 @@ import (
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/ameshkov/dnsstamps"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -249,14 +252,16 @@ func TestAddressToUpstream_bads(t *testing.T) {
addr: "asdf://1.1.1.1",
wantErrMsg: "unsupported url scheme: asdf",
}, {
addr: "12345.1.1.1:1234567",
wantErrMsg: "invalid address: 12345.1.1.1:1234567",
addr: "12345.1.1.1:1234567",
wantErrMsg: `invalid address 12345.1.1.1:1234567: ` +
`strconv.ParseUint: parsing "1234567": value out of range`,
}, {
addr: ":1234567",
wantErrMsg: "invalid address: :1234567",
addr: ":1234567",
wantErrMsg: `invalid address :1234567: ` +
`strconv.ParseUint: parsing "1234567": value out of range`,
}, {
addr: "host:",
wantErrMsg: "invalid address: host:",
wantErrMsg: `invalid address host:: strconv.ParseUint: parsing "": invalid syntax`,
}}
for _, tc := range testCases {
......@@ -347,43 +352,60 @@ func TestUpstreamsWithServerIP(t *testing.T) {
// use invalid bootstrap to make sure it fails if tries to use it
invalidBootstrap := []string{"1.2.3.4:55"}
h := func(w dns.ResponseWriter, m *dns.Msg) {
require.NoError(testutil.PanicT{}, w.WriteMsg(respondToTestMessage(m)))
}
dotSrv := startDoTServer(t, h)
dohSrv := startDoHServer(t, testDoHServerOptions{})
_, dohPort, err := net.SplitHostPort(dohSrv.addr)
require.NoError(t, err)
dotStamp := (&dnsstamps.ServerStamp{
ServerAddrStr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
Proto: dnsstamps.StampProtoTypeTLS,
ProviderName: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(dotSrv.port)).String(),
}).String()
dohStamp := (&dnsstamps.ServerStamp{
ServerAddrStr: dohSrv.addr,
Proto: dnsstamps.StampProtoTypeDoH,
ProviderName: dohSrv.addr,
Path: "/dns-query",
}).String()
upstreams := []struct {
name string
address string
serverIP net.IP
bootstrap []string
serverIPs []net.IP
}{{
address: "tls://dns.adguard.com",
serverIP: net.IP{94, 140, 14, 14},
bootstrap: invalidBootstrap,
}, {
address: "https://dns.adguard.com/dns-query",
serverIP: net.IP{94, 140, 14, 14},
bootstrap: invalidBootstrap,
}, {
// AdGuard DNS DoH with the IP address specified.
address: "sdns://AgcAAAAAAAAADzE3Ni4xMDMuMTMwLjEzMAAPZG5zLmFkZ3VhcmQuY29tCi9kbnMtcXVlcnk",
serverIP: nil,
bootstrap: invalidBootstrap,
}, {
// AdGuard DNS DoT with the IP address specified.
address: "sdns://AwAAAAAAAAAAEzE3Ni4xMDMuMTMwLjEzMDo4NTMAD2Rucy5hZGd1YXJkLmNvbQ",
serverIP: nil,
bootstrap: invalidBootstrap,
name: "dot",
address: fmt.Sprintf("tls://some.dns.server:%d", dotSrv.port),
serverIPs: []net.IP{netutil.IPv4Localhost().AsSlice()},
}, {
name: "doh",
address: fmt.Sprintf("https://some.dns.server:%s/dns-query", dohPort),
serverIPs: []net.IP{netutil.IPv4Localhost().AsSlice()},
}, {
name: "dot_stamp",
address: dotStamp,
serverIPs: nil,
}, {
name: "doh_stamp",
address: dohStamp,
serverIPs: nil,
}}
for _, tc := range upstreams {
opts := &Options{
Bootstrap: tc.bootstrap,
Timeout: timeout,
ServerIPAddrs: []net.IP{tc.serverIP},
}
u, err := AddressToUpstream(tc.address, opts)
if err != nil {
t.Fatalf("Failed to generate upstream from address %s: %s", tc.address, err)
}
testutil.CleanupAndRequireSuccess(t, u.Close)
t.Run(tc.name, func(t *testing.T) {
opts := &Options{
Bootstrap: invalidBootstrap,
Timeout: timeout,
ServerIPAddrs: tc.serverIPs,
InsecureSkipVerify: true,
}
u, err := AddressToUpstream(tc.address, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
t.Run(tc.address, func(t *testing.T) {
checkUpstream(t, u, tc.address)
})
}
......@@ -413,6 +435,11 @@ func TestAddPort(t *testing.T) {
}, {
name: "ipv6",
want: "[::1]:1",
host: "::1",
port: 1,
}, {
name: "ipv6_with_brackets",
want: "[::1]:1",
host: "[::1]",
port: 1,
}, {
......@@ -426,7 +453,7 @@ func TestAddPort(t *testing.T) {
host: "1.2.3.4:2",
port: 1,
}, {
name: "ipv6_with_port",
name: "ipv6_with_brackets_and_port",
want: "[::1]:2",
host: "[::1]:2",
port: 1,
......
......@@ -5,12 +5,11 @@ import (
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
......@@ -30,14 +29,15 @@ func NewResolver(resolverAddress string, opts *Options) (r Resolver, err error)
return &net.Resolver{}, nil
}
if opts == nil {
opts = &Options{}
upsOpts := &Options{
// Avoid recursion in case the bootstrap resolver is not valid.
Bootstrap: []string{""},
}
// TODO(ameshkov): Aren't other options needed here?
upsOpts := &Options{
Timeout: opts.Timeout,
VerifyServerCertificate: opts.VerifyServerCertificate,
if opts != nil {
upsOpts.Timeout = opts.Timeout
upsOpts.VerifyServerCertificate = opts.VerifyServerCertificate
}
ur := upstreamResolver{}
......@@ -49,71 +49,37 @@ func NewResolver(resolverAddress string, opts *Options) (r Resolver, err error)
return ur, err
}
// Validate the bootstrap resolver. It must be either a plain DNS resolver,
// or a DoT/DoH resolver defined by IP address (not a hostname).
if !isResolverValidBootstrap(ur.Upstream) {
if err = validateBootstrap(ur.Upstream); err != nil {
log.Error("upstream bootstrap %s: %s", resolverAddress, err)
ur.Upstream = nil
err = fmt.Errorf("resolver %q is not a valid bootstrap DNS server", resolverAddress)
log.Error("upstream bootstrap: %s", err)
return ur, err
}
return ur, err
}
// isResolverValidBootstrap checks if the upstream is eligible to be a bootstrap
// DNS server DNSCrypt and plain DNS resolvers are okay DoH and DoT are okay
// only in the case if an IP address is used in the IP address.
//
// TODO(e.burkov): Refactor using the actual upstream types instead of parsing
// their addresses.
func isResolverValidBootstrap(upstream Upstream) bool {
if u, ok := upstream.(*dnsOverTLS); ok {
urlAddr, err := url.Parse(u.Address())
if err != nil {
return false
}
host, _, err := net.SplitHostPort(urlAddr.Host)
if err != nil {
return false
}
if ip := net.ParseIP(host); ip != nil {
return true
}
return false
}
if u, ok := upstream.(*dnsOverHTTPS); ok {
urlAddr, err := url.Parse(u.Address())
if err != nil {
return false
}
host, _, err := net.SplitHostPort(urlAddr.Host)
if err != nil {
host = urlAddr.Host
}
if ip := net.ParseIP(host); ip != nil {
return true
}
return false
}
a := upstream.Address()
if strings.HasPrefix(a, "sdns://") {
return true
}
a = strings.TrimPrefix(a, "tcp://")
host, _, err := net.SplitHostPort(a)
if err != nil {
return false
// validateBootstrap returns error if the upstream is not eligible to be a
// bootstrap DNS server. DNSCrypt is always okay. Plain DNS, DNS-over-TLS,
// DNS-over-HTTPS, and DNS-over-QUIC are okay only if those are defined by IP.
func validateBootstrap(upstream Upstream) (err error) {
switch upstream := upstream.(type) {
case *dnsCrypt:
return nil
case *dnsOverTLS:
_, err = netip.ParseAddr(upstream.addr.Hostname())
case *dnsOverHTTPS:
_, err = netip.ParseAddr(upstream.addr.Hostname())
case *dnsOverQUIC:
_, err = netip.ParseAddr(upstream.addr.Hostname())
case *plainDNS:
_, err = netip.ParseAddr(upstream.addr.Hostname())
default:
err = fmt.Errorf("unknown upstream type: %T", upstream)
}
ip := net.ParseIP(host)
return ip != nil
return errors.Annotate(err, "bootstrap %s: %w", upstream.Address())
}
// upstreamResolver is a wrapper around Upstream that implements the
......@@ -129,51 +95,57 @@ type upstreamResolver struct {
var _ Resolver = upstreamResolver{}
// LookupNetIP implements the [Resolver] interface for upstreamResolver.
//
// TODO(e.burkov): Do not look up concurrently for "ip4" and "ip6" networks.
func (r upstreamResolver) LookupNetIP(
ctx context.Context,
network string,
host string,
) (ipAddrs []netip.Addr, err error) {
// TODO(e.burkov): Investigate when r.ups is nil and why.
// TODO(e.burkov): Investigate when [r.Upstream] is nil and why.
if r.Upstream == nil || host == "" {
return []netip.Addr{}, nil
}
host = dns.Fqdn(host)
var resCh chan *resolveResult
n := 1
answers := make([][]dns.RR, 1, 2)
var errs []error
switch network {
case "ip4":
resCh = make(chan *resolveResult, n)
case "ip4", "ip6":
qtype := dns.TypeA
if network == "ip6" {
qtype = dns.TypeAAAA
}
go r.resolveAsync(host, dns.TypeA, resCh)
case "ip6":
resCh = make(chan *resolveResult, n)
var resp *dns.Msg
resp, err = r.resolve(host, qtype)
if err != nil {
return []netip.Addr{}, err
}
go r.resolveAsync(host, dns.TypeAAAA, resCh)
answers[0] = resp.Answer
case "ip":
n = 2
resCh = make(chan *resolveResult, n)
resCh := make(chan *resolveResult, 2)
go r.resolveAsync(host, dns.TypeA, resCh)
go r.resolveAsync(host, dns.TypeAAAA, resCh)
default:
return []netip.Addr{}, fmt.Errorf("unsupported network: %s", network)
}
go r.resolveAsync(resCh, host, dns.TypeA)
go r.resolveAsync(resCh, host, dns.TypeAAAA)
var errs []error
for ; n > 0; n-- {
re := <-resCh
if re.err != nil {
errs = append(errs, re.err)
answers = answers[:0:cap(answers)]
for i := 0; i < 2; i++ {
res := <-resCh
if res.err != nil {
errs = append(errs, res.err)
continue
continue
}
answers = append(answers, res.resp.Answer)
}
default:
return []netip.Addr{}, fmt.Errorf("unsupported network %s", network)
}
for _, rr := range re.resp.Answer {
for _, ans := range answers {
for _, rr := range ans {
if addr, ok := netip.AddrFromSlice(proxyutil.IPFromRR(rr)); ok {
ipAddrs = append(ipAddrs, addr)
}
......@@ -187,8 +159,8 @@ func (r upstreamResolver) LookupNetIP(
// Use the previous dnsproxy behavior: prefer IPv4 by default.
//
// TODO(a.garipov): Consider unexporting this entire method or documenting
// that the order of addrs is undefined.
// TODO(a.garipov): Consider unexporting this entire method or
// documenting that the order of addrs is undefined.
proxynetutil.SortNetIPAddrs(ipAddrs, false)
return ipAddrs, nil
......@@ -212,14 +184,18 @@ func (r upstreamResolver) resolve(host string, qtype uint16) (resp *dns.Msg, err
}
// resolveResult is the result of a single concurrent lookup.
type resolveResult struct {
type resolveResult = struct {
resp *dns.Msg
err error
}
// resolveAsync performs a single DNS lookup and sends the result to ch. It's
// intended to be used as a goroutine.
func (r upstreamResolver) resolveAsync(host string, qtype uint16, ch chan *resolveResult) {
func (r upstreamResolver) resolveAsync(
resCh chan<- *resolveResult,
host string,
qtype uint16,
) {
resp, err := r.resolve(host, qtype)
ch <- &resolveResult{resp, err}
resCh <- &resolveResult{resp: resp, err: err}
}
......@@ -49,23 +49,23 @@ func TestNewResolver_validity(t *testing.T) {
}, {
name: "invalid_tls",
addr: "tls://dns.adguard.com",
wantErrMsg: `resolver "tls://dns.adguard.com" is not a valid ` +
`bootstrap DNS server`,
wantErrMsg: `bootstrap tls://dns.adguard.com:853: ` +
`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_https",
addr: "https://dns.adguard.com/dns-query",
wantErrMsg: `resolver "https://dns.adguard.com/dns-query" is not a ` +
`valid bootstrap DNS server`,
wantErrMsg: `bootstrap https://dns.adguard.com:443/dns-query: ` +
`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_tcp",
addr: "tcp://dns.adguard.com",
wantErrMsg: `resolver "tcp://dns.adguard.com" is not a valid ` +
`bootstrap DNS server`,
wantErrMsg: `bootstrap tcp://dns.adguard.com:53: ` +
`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
}, {
name: "invalid_no_scheme",
addr: "dns.adguard.com",
wantErrMsg: `resolver "dns.adguard.com" is not a valid bootstrap ` +
`DNS server`,
wantErrMsg: `bootstrap dns.adguard.com:53: ` +
`ParseAddr("dns.adguard.com"): unexpected character (at "dns.adguard.com")`,
}}
for _, tc := range testCases {
......