Skip to content
Commits on Source (5)
  • Andrey Meshkov's avatar
    Pull request: proxy: added a test that checks http version · 0357af0b
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from doh2-test to master
    
    Squashed commit of the following:
    
    commit 1a88d2cea11fe9cf905a1a5d983bc58d4d7b2bac
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Wed Oct 19 18:23:21 2022 +0300
    
        proxy: added a test that checks http version
    0357af0b
  • Andrey Meshkov's avatar
    Pull request: upstream: yet another attempt to fix races · 515fec8d
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-doh3-races to master
    
    Squashed commit of the following:
    
    commit 9b1f2b2b768041dd315a2ff49e3658e4574124ee
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 13:22:44 2022 +0300
    
        fix review comments
    
    commit b2e11d956760847cb4e89dd904a1212da7846737
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 12:50:58 2022 +0300
    
        fix review comments
    
    commit 84a7f30ef69020099be6247a466491c3e36cd7ec
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 11:00:32 2022 +0300
    
        upstream: yet another attempt to fix races
    515fec8d
  • Andrey Meshkov's avatar
    Pull request: Fix some unit tests · 1d45bb99
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-unit-tests to master
    
    Squashed commit of the following:
    
    commit 1f287f124348739d946b270988750f7250a9a1ee
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 14:56:08 2022 +0300
    
        run linter in a separate stage
    
    commit 520743c54abd525b2175e589d055f02cff4bb173
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 14:49:02 2022 +0300
    
        fix review comments
    
    commit 9ebafc2821ecfe114d1ddadc85019f95d89e6cc1
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 14:29:26 2022 +0300
    
        fix review comments and fix flaky TestHttpsProxyTrustedProxies
    
    commit 0ae8960b8e2181da5cc41dca28290c268dbf64ee
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 14:05:49 2022 +0300
    
        stabilize bootstrap timeout test on GH actions
    
    commit e30c4c7989495e92edeebf098fc729dd151582da
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 13:57:27 2022 +0300
    
        added a comment about rootCAs
    
    commit 60a8c02e8b93c7034f48b3e348837a0d225d4ab8
    Merge: 99a3986 515fec8d
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 13:56:05 2022 +0300
    
        upstream: refactor unit-tests
    
    commit 99a3986969b81aa9b10320937ea719f5a274eec1
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 12:32:52 2022 +0300
    
        upstream: refactor upstream plain/dot unit-tests
    
    commit 84a7f30ef69020099be6247a466491c3e36cd7ec
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Oct 20 11:00:32 2022 +0300
    
        upstream: yet another attempt to fix races
    1d45bb99
  • Andrey Meshkov's avatar
    Pull request: Add unit-tests checking that 0-RTT is used for QUIC connections · cc4140d3
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-236 to master
    
    Squashed commit of the following:
    
    commit 0a10e29305272baf478cf0e6c7626fbe23ab7f14
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Mon Oct 24 10:37:28 2022 +0300
    
        fix review comments
    
    commit 67e35bd56f6e91050cac3f643d44ff21c3c0eec4
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Mon Oct 24 10:10:02 2022 +0300
    
        fix review comments
    
    commit f5ca4b40
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Sun Oct 23 20:41:44 2022 +0300
    
        Add unit-tests checking that 0-RTT is used for QUIC connections
    
        This test utilizes quic-go's logging.Tracer to record detailed information
        about QUIC connections and check that 0-RTT packets are being when the upstream
        reconnects to the server.
    
        Closes #236
    cc4140d3
  • Eugene Burkov's avatar
    Pull request: close-upstream-conf · ad02cded
    Eugene Burkov authored
    Merge in DNS/dnsproxy from close-upstream-conf to master
    
    Squashed commit of the following:
    
    commit 69175c34643cb8e55e97bb0c2a56746730ff2c97
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Oct 25 17:23:12 2022 +0300
    
        proxy: add todo
    
    commit d5d1996fecbd65ea109af3af7c2b913fd2248afd
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Oct 25 17:06:59 2022 +0300
    
        proxy: fix docs
    
    commit d8156796c944b01633a635194c3fe52896361985
    Author: Eugene Burkov <E.Burkov@AdGuard.COM>
    Date:   Tue Oct 25 17:05:04 2022 +0300
    
        proxy: close upstream conf beter
    ad02cded
......@@ -26,6 +26,8 @@ jobs:
with:
go-version: '${{ env.GO_VERSION }}'
- name: Run tests
env:
CI: "1"
run: |-
make test
- name: Upload coverage
......
......@@ -29,7 +29,7 @@ jobs:
with:
# This field is required. Dont set the patch version to always use
# the latest patch version.
version: v1.48.0
version: v1.50.0
notify:
needs:
- golangci
......
......@@ -12,7 +12,6 @@ run:
# autogenerated files. If it's not please let us know.
skip-files:
- ".*generated.*"
- ".*_test.go"
# all available settings of specific linters
linters-settings:
......
......@@ -6,14 +6,41 @@ plan:
name: dnsproxy - Build and run tests
variables:
dockerGo: adguard/golang-ubuntu:5.0
dockerLint: golangci/golangci-lint:v1.50.0
stages:
- Lint:
manual: false
final: false
jobs:
- Lint
- Tests:
manual: false
final: false
jobs:
- Test
Lint:
docker:
image: ${bamboo.dockerLint}
volumes:
${system.GO_CACHE_DIR}: "${bamboo.cacheGo}"
${system.GO_PKG_CACHE_DIR}: "${bamboo.cacheGoPkg}"
${bamboo.build.working.directory}: "/app"
key: LINT
other:
clean-working-dir: true
tasks:
- checkout:
force-clean-build: 'true'
- script:
interpreter: SHELL
scripts:
- |-
golangci-lint run -v
requirements:
- adg-docker: 'true'
Test:
docker:
image: ${bamboo.dockerGo}
......@@ -27,16 +54,12 @@ Test:
force-clean-build: 'true'
- script:
interpreter: SHELL
environment: GOFLAGS="-buildvcs=false"
environment: GOFLAGS="-buildvcs=false" CI="1"
scripts:
- |-
set -e -f -u -x
go version
golangci-lint --version
# Run linter.
golangci-lint run
# Run tests.
make VERBOSE=1 test
......
......@@ -98,7 +98,7 @@ func TestFastestAddr_PingAll_cache(t *testing.T) {
})
t.Run("not_cached", func(t *testing.T) {
listener, err := net.Listen("tcp", ":0")
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, listener.Close)
......@@ -225,7 +225,7 @@ func TestFastestAddr_PingAll(t *testing.T) {
func getFreePort(t *testing.T) (port uint) {
t.Helper()
l, err := net.Listen("tcp", ":0")
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port = uint(l.Addr().(*net.TCPAddr).Port)
......
......@@ -4,9 +4,9 @@ import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLookupIPAddr(t *testing.T) {
......@@ -23,7 +23,8 @@ func TestLookupIPAddr(t *testing.T) {
p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)
// Init the proxy
p.Init()
err = p.Init()
require.NoError(t, err)
// Now let's try doing some lookups
addrs, err := p.LookupIPAddr("dns.google")
......
......@@ -203,15 +203,16 @@ func (p *Proxy) Start() (err error) {
return nil
}
// closeAll closes all elements in the toClose slice and if there's any error
// appends it to the errs slice.
func closeAll[T io.Closer](toClose []T, errs *[]error) {
for _, c := range toClose {
// closeAll closes all closers and appends the occurred errors to errs.
func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
for _, c := range closers {
err := c.Close()
if err != nil {
*errs = append(*errs, err)
errs = append(errs, err)
}
}
return errs
}
// Stop stops the proxy server including all its listeners
......@@ -225,19 +226,17 @@ func (p *Proxy) Stop() error {
return nil
}
errs := []error{}
closeAll(p.tcpListen, &errs)
errs := closeAll(nil, p.tcpListen...)
p.tcpListen = nil
closeAll(p.udpListen, &errs)
errs = closeAll(errs, p.udpListen...)
p.udpListen = nil
closeAll(p.tlsListen, &errs)
errs = closeAll(errs, p.tlsListen...)
p.tlsListen = nil
if p.httpsServer != nil {
closeAll([]io.Closer{p.httpsServer}, &errs)
errs = closeAll(errs, p.httpsServer)
p.httpsServer = nil
// No need to close these since they're closed by httpsServer.Close().
......@@ -245,24 +244,24 @@ func (p *Proxy) Stop() error {
}
if p.h3Server != nil {
closeAll([]io.Closer{p.h3Server}, &errs)
errs = closeAll(errs, p.h3Server)
p.h3Server = nil
}
closeAll(p.h3Listen, &errs)
errs = closeAll(errs, p.h3Listen...)
p.h3Listen = nil
closeAll(p.quicListen, &errs)
errs = closeAll(errs, p.quicListen...)
p.quicListen = nil
closeAll(p.dnsCryptUDPListen, &errs)
errs = closeAll(errs, p.dnsCryptUDPListen...)
p.dnsCryptUDPListen = nil
closeAll(p.dnsCryptTCPListen, &errs)
errs = closeAll(errs, p.dnsCryptTCPListen...)
p.dnsCryptTCPListen = nil
if p.UpstreamConfig != nil {
closeAll([]io.Closer{p.UpstreamConfig}, &errs)
errs = closeAll(errs, p.UpstreamConfig)
}
p.started = false
......
......@@ -785,7 +785,7 @@ func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
hdr := dns.RR_Header{
Name: m.Question[0].Name,
Class: dns.ClassINET,
Rrtype: uint16(dns.TypeA),
Rrtype: dns.TypeA,
}
resp.Answer = append(resp.Answer, &dns.A{
Hdr: hdr,
......@@ -1083,7 +1083,7 @@ func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) {
}
func getFreePort() uint {
l, _ := net.Listen("tcp", ":0")
l, _ := net.Listen("tcp", "127.0.0.1:0")
port := uint(l.Addr().(*net.TCPAddr).Port)
// stop listening immediately
......@@ -1198,7 +1198,9 @@ func createHostTestMessage(host string) *dns.Msg {
return &req
}
func requireResponse(t *testing.T, req, reply *dns.Msg) {
func requireResponse(t testing.TB, req, reply *dns.Msg) {
t.Helper()
require.NotNil(t, reply)
require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
require.Equal(t, req.Id, reply.Id)
......
......@@ -60,24 +60,25 @@ func TestHttpsProxy(t *testing.T) {
}
}
func TestHttpsProxyTrustedProxies(t *testing.T) {
// Prepare the proxy server.
tlsConf, caPem := createServerTLSConfig(t)
dnsProxy := createTestProxy(t, tlsConf)
func TestProxy_trustedProxies(t *testing.T) {
clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1}
var gotAddr net.Addr
dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
gotAddr = d.Addr
doRequest := func(t *testing.T, proxyAddr string, expectedClientIP net.IP) {
// Prepare the proxy server.
tlsConf, caPem := createServerTLSConfig(t)
dnsProxy := createTestProxy(t, tlsConf)
return dnsProxy.Resolve(d)
}
var gotAddr net.Addr
dnsProxy.RequestHandler = func(_ *Proxy, d *DNSContext) (err error) {
gotAddr = d.Addr
client := createTestHTTPClient(dnsProxy, caPem, false)
return dnsProxy.Resolve(d)
}
clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1}
msg := createTestMessage()
client := createTestHTTPClient(dnsProxy, caPem, false)
msg := createTestMessage()
doRequest := func(t *testing.T, proxyAddr string) {
dnsProxy.TrustedProxies = []string{proxyAddr}
// Start listening.
......@@ -91,20 +92,17 @@ func TestHttpsProxyTrustedProxies(t *testing.T) {
resp := sendTestDoHMessage(t, client, msg, hdrs)
requireResponse(t, msg, resp)
ip, _ := netutil.IPAndPortFromAddr(gotAddr)
require.True(t, ip.Equal(expectedClientIP))
}
t.Run("success", func(t *testing.T) {
doRequest(t, proxyIP.String())
ip, _ := netutil.IPAndPortFromAddr(gotAddr)
assert.True(t, ip.Equal(clientIP))
doRequest(t, proxyIP.String(), clientIP)
})
t.Run("not_in_trusted", func(t *testing.T) {
doRequest(t, "127.0.0.2")
ip, _ := netutil.IPAndPortFromAddr(gotAddr)
assert.True(t, ip.Equal(proxyIP))
doRequest(t, "127.0.0.2", proxyIP)
})
}
......@@ -344,10 +342,17 @@ func sendTestDoHMessage(
req.Header.Set(k, v)
}
httpResp, err := client.Do(req)
httpResp, err := client.Do(req) // nolint:bodyclose
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, httpResp.Body.Close)
require.True(
t,
httpResp.ProtoAtLeast(2, 0),
"the proto is too old: %s",
httpResp.Proto,
)
body, err := io.ReadAll(httpResp.Body)
require.NoError(t, err)
......@@ -372,6 +377,8 @@ func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (cli
var transport http.RoundTripper
if http3Enabled {
tlsClientConfig.NextProtos = []string{"h3"}
transport = &http3.RoundTripper{
Dial: func(
ctx context.Context,
......@@ -395,10 +402,12 @@ func createTestHTTPClient(dnsProxy *Proxy, caPem []byte, http3Enabled bool) (cli
return dialer.DialContext(ctx, network, dnsProxy.Addr(ProtoHTTPS).String())
}
tlsClientConfig.NextProtos = []string{"h2", "http/1.1"}
transport = &http.Transport{
TLSClientConfig: tlsClientConfig,
DisableCompression: true,
DialContext: dialContext,
ForceAttemptHTTP2: true,
}
}
......
......@@ -36,7 +36,10 @@ func TestQuicProxy(t *testing.T) {
// Open QUIC connection.
conn, err := quic.DialAddrEarly(addr.String(), tlsConfig, nil)
require.NoError(t, err)
defer conn.CloseWithError(DoQCodeNoError, "")
defer func() {
// TODO(ameshkov): check the error here.
_ = conn.CloseWithError(DoQCodeNoError, "")
}()
// Send several test messages.
for i := 0; i < 10; i++ {
......
......@@ -3,6 +3,7 @@ package proxy
import (
"fmt"
"io"
"sort"
"strings"
"github.com/AdguardTeam/dnsproxy/upstream"
......@@ -224,8 +225,23 @@ func (uc *UpstreamConfig) getUpstreamsForDomain(host string) (ups []upstream.Ups
// Close implements the io.Closer interface for *UpstreamConfig.
func (uc *UpstreamConfig) Close() (err error) {
closeErrs := []error{}
closeAll(uc.Upstreams, &closeErrs)
closeErrs := closeAll(nil, uc.Upstreams...)
for _, specUps := range []map[string][]upstream.Upstream{
uc.DomainReservedUpstreams,
uc.SpecifiedDomainUpstreams,
} {
domains := make([]string, 0, len(specUps))
for domain := range specUps {
domains = append(domains, domain)
}
// TODO(e.burkov): Use functions from golang.org/x/exp.
sort.Stable(sort.StringSlice(domains))
for _, domain := range domains {
closeErrs = closeAll(closeErrs, specUps[domain]...)
}
}
if len(closeErrs) > 0 {
return errors.List("failed to close some upstreams", closeErrs...)
......
......@@ -28,6 +28,7 @@ var compatProtoDQ = []string{NextProtoDQ, "doq-i00", "dq", "doq-i02"}
// RootCAs is the CertPool that must be used by all upstreams. Redefining
// RootCAs makes sense on iOS to overcome the 15MB memory limit of the
// NEPacketTunnelProvider.
// TODO(ameshkov): remove this and replace with an upstream option.
var RootCAs *x509.CertPool
// CipherSuites is a custom list of TLSv1.2 ciphers.
......
......@@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/ameshkov/dnscrypt/v2"
"github.com/ameshkov/dnsstamps"
"github.com/lucas-clemente/quic-go/logging"
"github.com/miekg/dns"
)
......@@ -66,6 +67,10 @@ type Options struct {
// will be passed to. It's called in dnsCrypt.exchangeDNSCrypt.
// Upstream.Exchange method returns any error caused by it.
VerifyDNSCryptCertificate func(cert *dnscrypt.Cert) error
// QUICTracer is an optional object that allows tracing every QUIC
// connection and logging every packet that goes through.
QUICTracer logging.Tracer
}
// Clone copies o to a new struct. Note, that this is not a deep clone.
......
......@@ -49,8 +49,15 @@ func TestDNSCryptTruncated(t *testing.T) {
testutil.CleanupAndRequireSuccess(t, udpConn.Close)
// Start the server
go s.ServeUDP(udpConn)
go s.ServeTCP(tcpConn)
go func() {
// TODO(ameshkov): check the error here.
_ = s.ServeUDP(udpConn)
}()
go func() {
// TODO(ameshkov): check the error here.
_ = s.ServeTCP(tcpConn)
}()
// Now prepare a client for this test server
stamp, err := rc.CreateStamp(udpConn.LocalAddr().String())
......
......@@ -76,6 +76,7 @@ func newDoH(uu *url.URL, opts *Options) (u Upstream, err error) {
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
}
......@@ -156,8 +157,8 @@ func (p *dnsOverHTTPS) Close() (err error) {
// this point it should only be done for HTTP/3 as it may leak due to keep-alive
// connections.
func (p *dnsOverHTTPS) closeClient(client *http.Client) (err error) {
if t, ok := client.Transport.(*http3.RoundTripper); ok {
return t.Close()
if isHTTP3(client) {
return client.Transport.(io.Closer).Close()
}
return nil
......@@ -186,7 +187,7 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(
// It appears, that GET requests are more memory-efficient with Golang
// implementation of HTTP/2.
method := http.MethodGet
if _, ok := client.Transport.(*http3.RoundTripper); ok {
if isHTTP3(client) {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
......@@ -414,6 +415,53 @@ func (p *dnsOverHTTPS) createTransport() (t http.RoundTripper, err error) {
return transport, nil
}
// http3Transport is a wrapper over *http3.RoundTripper that tries to optimize
// its behavior. The main thing that it does is trying to force use a single
// connection to a host instead of creating a new one all the time. It also
// helps mitigate race issues with quic-go.
type http3Transport struct {
baseTransport *http3.RoundTripper
closed bool
mu sync.RWMutex
}
// type check
var _ http.RoundTripper = (*http3Transport)(nil)
// RoundTrip implements the http.RoundTripper interface for *http3Transport.
func (h *http3Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
h.mu.RLock()
defer h.mu.RUnlock()
if h.closed {
return nil, net.ErrClosed
}
// Try to use cached connection to the target host if it's available.
resp, err = h.baseTransport.RoundTripOpt(req, http3.RoundTripOpt{OnlyCachedConn: true})
if errors.Is(err, http3.ErrNoCachedConn) {
// If there are no cached connection, trigger creating a new one.
resp, err = h.baseTransport.RoundTrip(req)
}
return resp, err
}
// type check
var _ io.Closer = (*http3Transport)(nil)
// Close implements the io.Closer interface for *http3Transport.
func (h *http3Transport) Close() (err error) {
h.mu.Lock()
defer h.mu.Unlock()
h.closed = true
return h.baseTransport.Close()
}
// createTransportH3 tries to create an HTTP/3 transport for this upstream.
// We should be able to fall back to H1/H2 in case if HTTP/3 is unavailable or
// if it is too slow. In order to do that, this method will run two probes
......@@ -450,7 +498,7 @@ func (p *dnsOverHTTPS) createTransportH3(
QuicConfig: p.getQUICConfig(),
}
return rt, nil
return &http3Transport{baseTransport: rt}, nil
}
// probeH3 runs a test to check whether QUIC is faster than TLS for this
......@@ -599,3 +647,10 @@ func (p *dnsOverHTTPS) supportedHTTPVersions() (v []HTTPVersion) {
return v
}
// isHTTP3 checks if the *http.Client is an HTTP/3 client.
func isHTTP3(client *http.Client) (ok bool) {
_, ok = client.Transport.(*http3Transport)
return ok
}
......@@ -3,10 +3,12 @@ package upstream
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
"net/http"
"os"
"strconv"
"sync/atomic"
"testing"
......@@ -108,8 +110,10 @@ func TestUpstreamDoH(t *testing.T) {
}
func TestUpstreamDoH_raceReconnect(t *testing.T) {
// TODO(ameshkov): report or fix races in quic-go and enable this back.
t.Skip("Disable this test temporarily until races are fixed in quic-go")
// TODO(ameshkov): fix other races before removing this.
if os.Getenv("CI") == "1" {
t.Skip("Skipping this test on CI until all races are fixed")
}
testCases := []struct {
name string
......@@ -233,6 +237,7 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
_, portStr, err := net.SplitHostPort(srv.addr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
// Shutdown the first server.
srv.Shutdown()
......@@ -258,6 +263,7 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
http3Enabled: true,
port: port,
})
defer srv.Shutdown()
// Check that everything works after the second restart.
checkUpstream(t, u, address)
......@@ -265,6 +271,58 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
}
}
func TestUpstreamDoH_0RTT(t *testing.T) {
// Run the first server instance.
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
})
t.Cleanup(srv.Shutdown)
// Create a DNS-over-HTTPS upstream.
tracer := &quicTracer{}
address := fmt.Sprintf("h3://%s/dns-query", srv.addr)
u, err := AddressToUpstream(address, &Options{
InsecureSkipVerify: true,
QUICTracer: tracer,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
uh := u.(*dnsOverHTTPS)
req := createTestMessage()
// Trigger connection to a DoH3 server.
resp, err := uh.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Close the active connection to make sure we'll reconnect.
func() {
uh.clientMu.Lock()
defer uh.clientMu.Unlock()
err = uh.closeClient(uh.client)
require.NoError(t, err)
uh.client = nil
}()
// Trigger second connection.
resp, err = uh.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
// Check traced connections info.
conns := tracer.getConnectionsInfo()
require.Len(t, conns, 2)
// Examine the first connection (no 0-Rtt there).
require.False(t, conns[0].is0RTT())
// Examine the second connection (the one that used 0-Rtt).
require.True(t, conns[1].is0RTT())
}
// testDoHServerOptions allows customizing testDoHServer behavior.
type testDoHServerOptions struct {
http3Enabled bool
......@@ -282,6 +340,9 @@ type testDoHServer struct {
// tlsConfig is the TLS configuration that is used for this server.
tlsConfig *tls.Config
// rootCAs is the pool with root certificates used by the test server.
rootCAs *x509.CertPool
// server is an HTTP/1.1 and HTTP/2 server.
server *http.Server
......@@ -311,7 +372,7 @@ func startDoHServer(
t *testing.T,
opts testDoHServerOptions,
) (s *testDoHServer) {
tlsConfig := createServerTLSConfig(t, "127.0.0.1")
tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
handler := opts.handler
if handler == nil {
handler = createDoHHandler()
......@@ -320,7 +381,8 @@ func startDoHServer(
// Step one is to create a regular HTTP server, we'll always have it
// running.
server := &http.Server{
Handler: handler,
Handler: handler,
ReadTimeout: time.Second,
}
// Listen TCP first.
......@@ -342,7 +404,10 @@ func startDoHServer(
tlsListen := tls.NewListener(tcpListen, tlsConfigH2)
// Run the H1/H2 server.
go server.Serve(tlsListen)
go func() {
// TODO(ameshkov): check the error here.
_ = server.Serve(tlsListen)
}()
// Get the real address that the listener now listens to.
tcpAddr = tcpListen.Addr().(*net.TCPAddr)
......@@ -373,11 +438,15 @@ func startDoHServer(
require.NoError(t, err)
// Run the H3 server.
go serverH3.ServeListener(listenerH3)
go func() {
// TODO(ameshkov): check the error here.
_ = serverH3.ServeListener(listenerH3)
}()
}
return &testDoHServer{
tlsConfig: tlsConfig,
rootCAs: rootCAs,
server: server,
serverH3: serverH3,
listenerH3: listenerH3,
......@@ -425,7 +494,11 @@ func createDoHHandlerFunc() (f http.HandlerFunc) {
}
w.Header().Set("Content-Type", "application/dns-message")
_, err = w.Write(buf)
if err != nil {
panic(fmt.Errorf("unexpected error on writing response: %w", err))
}
}
}
......
......@@ -2,23 +2,105 @@ package upstream
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
// TODO(ameshkov): make it not depend on external servers.
func TestTLSPoolReconnect(t *testing.T) {
func TestUpstream_dnsOverTLS(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
// Create a DoT upstream that we'll be testing.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{InsecureSkipVerify: true})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Test that it responds properly.
for i := 0; i < 10; i++ {
checkUpstream(t, u, addr)
}
}
func TestUpstream_dnsOverTLS_race(t *testing.T) {
const count = 10
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
// Creating a DoT upstream that we will be testing.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{InsecureSkipVerify: true})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Use this upstream from multiple goroutines in parallel.
wg := sync.WaitGroup{}
for i := 0; i < count; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pt := testutil.PanicT{}
req := createTestMessage()
resp, err := u.Exchange(req)
require.NoError(pt, err)
requireResponse(pt, req, resp)
}()
}
wg.Wait()
}
func TestUpstream_dnsOverTLS_poolReconnect(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
// This var is used to store the last connection state in order to check
// if session resumption works as expected.
var lastState tls.ConnectionState
// Init the upstream to the test DoT server that also keeps track of the
// session resumptions.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(
"tls://one.one.one.one",
addr,
&Options{
Bootstrap: []string{"8.8.8.8:53"},
Timeout: timeout,
InsecureSkipVerify: true,
VerifyConnection: func(state tls.ConnectionState) error {
lastState = state
return nil
},
},
......@@ -32,10 +114,12 @@ func TestTLSPoolReconnect(t *testing.T) {
require.NoError(t, err)
requireResponse(t, req, reply)
// Now let's close the pooled connection and return it back to the pool.
// Now let's close the pooled connection.
p := u.(*dnsOverTLS)
conn, _ := p.pool.Get()
conn.Close()
// And return it back to the pool.
p.pool.Put(conn)
// Send the second test message.
......@@ -51,17 +135,21 @@ func TestTLSPoolReconnect(t *testing.T) {
require.True(t, lastState.DidResume)
}
// TODO(ameshkov): make it not depend on external servers.
func TestTLSPoolDeadLine(t *testing.T) {
u, err := AddressToUpstream(
"tls://one.one.one.one",
&Options{
Bootstrap: []string{"8.8.8.8:53"},
Timeout: timeout,
},
)
func TestUpstream_dnsOverTLS_poolDeadline(t *testing.T) {
srv := startDoTServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
// Create a DoT upstream that we'll be testing.
addr := fmt.Sprintf("tls://127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{InsecureSkipVerify: true})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// Send the first test message.
req := createTestMessage()
......@@ -71,7 +159,7 @@ func TestTLSPoolDeadLine(t *testing.T) {
p := u.(*dnsOverTLS)
// Now let's get connection from the pool and use it.
// Now let's get connection from the pool and use it again.
conn, err := p.pool.Get()
require.NoError(t, err)
......@@ -79,9 +167,11 @@ func TestTLSPoolDeadLine(t *testing.T) {
require.NoError(t, err)
requireResponse(t, req, response)
// Update connection's deadLine and put it back to the pool.
// Update the connection's deadLine.
err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
require.NoError(t, err)
// And put it back to the pool.
p.pool.Put(conn)
// Get connection from the pool and reuse it.
......@@ -99,4 +189,58 @@ func TestTLSPoolDeadLine(t *testing.T) {
// Connection with expired deadLine can't be used.
response, err = p.exchangeConn(conn, req)
require.Error(t, err)
require.Nil(t, response)
}
// testDoTServer is a test DNS-over-TLS server that can be used in unit-tests.
type testDoTServer struct {
// srv is the *dns.Server instance that listens for DoT requests.
srv *dns.Server
// tlsConfig is the TLS configuration that is used for this server.
tlsConfig *tls.Config
// rootCAs is the pool with root certificates used by the test server.
rootCAs *x509.CertPool
// port to which the server listens to.
port int
}
// type check
var _ io.Closer = (*testDoTServer)(nil)
// startDoTServer starts *testDoTServer on a random port.
func startDoTServer(t *testing.T, handler dns.HandlerFunc) (s *testDoTServer) {
t.Helper()
tcpListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tlsConfig, rootCAs := createServerTLSConfig(t, "127.0.0.1")
tlsListener := tls.NewListener(tcpListener, tlsConfig)
srv := &dns.Server{
Listener: tlsListener,
TLSConfig: tlsConfig,
Net: "tls",
Handler: handler,
}
go func() {
pt := testutil.PanicT{}
require.NoError(pt, srv.ActivateAndServe())
}()
return &testDoTServer{
srv: srv,
tlsConfig: tlsConfig,
rootCAs: rootCAs,
port: tcpListener.Addr().(*net.TCPAddr).Port,
}
}
// Close implements the io.Closer interface for *testDoTServer.
func (s *testDoTServer) Close() error {
return s.srv.Shutdown()
}
package upstream
import (
"fmt"
"io"
"net"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
// TODO(ameshkov): make this test not depend on external resources.
func TestDNSTruncated(t *testing.T) {
// AdGuard DNS
address := "94.140.14.14:53"
func TestUpstream_plainDNS(t *testing.T) {
srv := startDNSServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
u, err := AddressToUpstream(address, &Options{Timeout: timeout})
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
req := new(dns.Msg)
req.SetQuestion("unit-test2.dns.adguard.com.", dns.TypeTXT)
req.RecursionDesired = true
for i := 0; i < 10; i++ {
checkUpstream(t, u, addr)
}
}
func TestUpstream_plainDNS_truncatedResponse(t *testing.T) {
srv := startDNSServer(t, func(w dns.ResponseWriter, req *dns.Msg) {
resp := respondToTestMessage(req)
if w.LocalAddr().Network() == "udp" {
// Make sure the response is truncated.
resp.Truncated = true
resp.Answer = nil
}
err := w.WriteMsg(resp)
pt := testutil.PanicT{}
require.NoError(pt, err)
})
testutil.CleanupAndRequireSuccess(t, srv.Close)
addr := fmt.Sprintf("127.0.0.1:%d", srv.port)
u, err := AddressToUpstream(addr, &Options{})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
// The plain DNS upstream must know how to fall back to TCP so even though
// the response over UDP is truncated, it should re-request it over TCP and
// get the full response.
req := createTestMessage()
resp, err := u.Exchange(req)
require.NoError(t, err)
requireResponse(t, req, resp)
}
// testDNSServer is a simple DNS server that can be used in unit-tests.
type testDNSServer struct {
udpSrv *dns.Server
tcpSrv *dns.Server
port int
udpListener net.PacketConn
tcpListener net.Listener
}
// type check
var _ io.Closer = (*testDNSServer)(nil)
// startDNSServer a test DNS server.
func startDNSServer(t *testing.T, handler dns.HandlerFunc) (s *testDNSServer) {
t.Helper()
s = &testDNSServer{}
res, err := u.Exchange(req)
udpListener, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
require.False(t, res.Truncated)
s.port = udpListener.LocalAddr().(*net.UDPAddr).Port
s.udpListener = udpListener
s.tcpListener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.port))
require.NoError(t, err)
s.udpSrv = &dns.Server{
PacketConn: s.udpListener,
Handler: handler,
}
s.tcpSrv = &dns.Server{
Listener: s.tcpListener,
Handler: handler,
}
go func() {
pt := testutil.PanicT{}
require.NoError(pt, s.udpSrv.ActivateAndServe())
}()
go func() {
pt := testutil.PanicT{}
require.NoError(pt, s.tcpSrv.ActivateAndServe())
}()
return s
}
// Close implements the io.Closer interface for *testDNSServer.
func (s *testDNSServer) Close() (err error) {
udpErr := s.udpSrv.Shutdown()
tcpErr := s.tcpSrv.Shutdown()
return errors.WithDeferred(udpErr, tcpErr)
}
......@@ -77,6 +77,7 @@ func newDoQ(uu *url.URL, opts *Options) (u Upstream, err error) {
quicConfig: &quic.Config{
KeepAlivePeriod: QUICKeepAlivePeriod,
TokenStore: newQUICTokenStore(),
Tracer: opts.QUICTracer,
},
}
......