Skip to content
Commits on Source (2)
  • Andrey Meshkov's avatar
    Pull request: upstream: enable 0-RTT for DoH3, fix #273 · 29e83b15
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from doh3-0rtt to master
    
    Squashed commit of the following:
    
    commit fa70069464535bd712c7872b9dcc90068efe7e11
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 22 13:07:06 2022 +0300
    
        fix review comments
    
    commit 7c6470b4c35d1ab30b7be44995156885cdb98b93
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 22 10:51:19 2022 +0300
    
        upstream: enable 0-RTT for DoH3, fix #273
    29e83b15
  • Andrey Meshkov's avatar
    Pull request: Fixed a race condition in the DoH upstream · bde9a138
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-279-doh-race to master
    
    Squashed commit of the following:
    
    commit b036e970fbd8c80d1ed73ddbb6f5b07b880a0107
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 29 15:18:33 2022 +0300
    
        better comment for getQUICConfig
    
    commit 59f22fa37a0ded8f9996717b214a4a27a59a6750
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 29 15:13:30 2022 +0300
    
        fix build with go1.18
    
    commit 8bc66ad87db06314981ed3eae2d7703a8370d74b
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 29 15:02:08 2022 +0300
    
        Add a test that specifically covers #279, fixed one more race
    
    commit de0dfe9d9d707dbc1301bf12850e753bb226b76e
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Thu Sep 29 10:18:56 2022 +0300
    
        Fixed a race condition in the DoH upstream
    
        Also, added a test for races in DoH and DoQ upstreams.
        We should also add it in other DNS implementations but only
        after the tests there made to not depend on external services.
    
        Closes #279
    bde9a138
......@@ -23,14 +23,11 @@ linters-settings:
linters:
enable:
- deadcode
- errcheck
- govet
- ineffassign
- staticcheck
- structcheck
- unused
- varcheck
- bodyclose
- depguard
- dupl
......
package proxy
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"testing"
......@@ -320,9 +321,20 @@ func sendTestDoHMessage(
packed, err := m.Pack()
require.NoError(t, err)
b := bytes.NewBuffer(packed)
u := fmt.Sprintf("https://%s/dns-query", tlsServerName)
req, err := http.NewRequest(http.MethodPost, u, b)
u := url.URL{
Scheme: "https",
Host: tlsServerName,
Path: "/dns-query",
RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(packed)),
}
method := http.MethodGet
if _, ok := client.Transport.(*http3.RoundTripper); ok {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
req, err := http.NewRequest(method, u.String(), nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/dns-message")
......
......@@ -52,7 +52,8 @@ type dnsOverHTTPS struct {
// quicConfig is the QUIC configuration that is used if HTTP/3 is enabled
// for this upstream.
quicConfig *quic.Config
quicConfig *quic.Config
quicConfigGuard sync.Mutex
}
// type check
......@@ -113,21 +114,15 @@ func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
for i := 0; hasClient && p.shouldRetry(err) && i < 2; i++ {
log.Debug("re-creating the HTTP client and retrying due to %v", err)
p.clientGuard.Lock()
p.client = nil
// Re-create the token store to make sure we're not trying to use invalid
// tokens for 0-RTT.
p.quicConfig.TokenStore = newQUICTokenStore()
p.clientGuard.Unlock()
// Make sure we reset the client here.
p.resetClient(err)
resp, err = p.exchangeHTTPS(m)
}
if err != nil {
// If the request failed anyway, make sure we don't use this client.
p.clientGuard.Lock()
p.client = nil
p.clientGuard.Unlock()
p.resetClient(err)
}
return resp, err
......@@ -157,8 +152,20 @@ func (p *dnsOverHTTPS) exchangeHTTPSClient(m *dns.Msg, client *http.Client) (*dn
// It appears, that GET requests are more memory-efficient with Golang
// implementation of HTTP/2.
requestURL := p.Address() + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest("GET", requestURL, nil)
method := http.MethodGet
if _, ok := client.Transport.(*http3.RoundTripper); ok {
// If we're using HTTP/3, use http3.MethodGet0RTT to force using 0-RTT.
method = http3.MethodGet0RTT
}
u := url.URL{
Scheme: p.boot.URL.Scheme,
Host: p.boot.URL.Host,
Path: p.boot.URL.Path,
RawQuery: fmt.Sprintf("dns=%s", base64.RawURLEncoding.EncodeToString(buf)),
}
req, err := http.NewRequest(method, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", p.boot.URL, err)
}
......@@ -228,6 +235,40 @@ func (p *dnsOverHTTPS) shouldRetry(err error) (ok bool) {
return false
}
// resetClient triggers re-creation of the *http.Client that is used by this
// upstream. This method accepts the error that caused resetting client as
// depending on the error we may also reset the QUIC config.
func (p *dnsOverHTTPS) resetClient(err error) {
p.clientGuard.Lock()
defer p.clientGuard.Unlock()
p.client = nil
if errors.Is(err, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverHTTPS) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
return p.quicConfig
}
// resetQUICConfig Re-create the token store to make sure we're not trying to
// use invalid for 0-RTT.
func (p *dnsOverHTTPS) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
}
// getClient gets or lazily initializes an HTTP client (and transport) that will
// be used for this DoH resolver.
func (p *dnsOverHTTPS) getClient() (c *http.Client, err error) {
......@@ -268,6 +309,7 @@ func (p *dnsOverHTTPS) createClient() (*http.Client, error) {
}
p.client = client
return p.client, nil
}
......@@ -359,7 +401,7 @@ func (p *dnsOverHTTPS) createTransportH3(
},
DisableCompression: true,
TLSClientConfig: tlsConfig,
QuicConfig: p.quicConfig,
QuicConfig: p.getQUICConfig(),
}
return rt, nil
......@@ -445,7 +487,7 @@ func (p *dnsOverHTTPS) probeQUIC(addr string, tlsConfig *tls.Config, ch chan err
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
defer cancel()
conn, err := quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.quicConfig)
conn, err := quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.getQUICConfig())
if err != nil {
ch <- fmt.Errorf("opening QUIC connection to %s: %w", p.Address(), err)
return
......
......@@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strconv"
"sync/atomic"
"testing"
"time"
......@@ -67,24 +68,22 @@ func TestUpstreamDoH(t *testing.T) {
address := fmt.Sprintf("https://%s/dns-query", srv.addr)
var lastState tls.ConnectionState
u, err := AddressToUpstream(
address,
&Options{
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
VerifyConnection: func(state tls.ConnectionState) (err error) {
if state.NegotiatedProtocol != string(tc.expectedProtocol) {
return fmt.Errorf(
"expected %s, got %s",
tc.expectedProtocol,
state.NegotiatedProtocol,
)
}
lastState = state
return nil
},
opts := &Options{
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
VerifyConnection: func(state tls.ConnectionState) (err error) {
if state.NegotiatedProtocol != string(tc.expectedProtocol) {
return fmt.Errorf(
"expected %s, got %s",
tc.expectedProtocol,
state.NegotiatedProtocol,
)
}
lastState = state
return nil
},
)
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
// Test that it responds properly.
......@@ -104,6 +103,48 @@ func TestUpstreamDoH(t *testing.T) {
require.True(t, lastState.DidResume)
})
}
// This is a different set of tests that are supposed to be run with -race.
// The difference is that the HTTP handler here adds additional time.Sleep
// call. This call would trigger the HTTP client re-connection which is
// important to test for race conditions.
for _, tc := range testCases {
t.Run(tc.name+"_race", func(t *testing.T) {
const timeout = time.Millisecond * 100
var requestsCount int32
handlerFunc := createDoHHandlerFunc()
mux := http.NewServeMux()
mux.HandleFunc("/dns-query", func(w http.ResponseWriter, r *http.Request) {
newVal := atomic.AddInt32(&requestsCount, 1)
if newVal%10 == 0 {
time.Sleep(timeout * 2)
}
handlerFunc(w, r)
})
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: tc.http3Enabled,
delayHandshakeH2: tc.delayHandshakeH2,
delayHandshakeH3: tc.delayHandshakeH3,
handler: mux,
})
t.Cleanup(srv.Shutdown)
// Create a DNS-over-HTTPS upstream that will be used for the
// race test.
address := fmt.Sprintf("https://%s/dns-query", srv.addr)
opts := &Options{
InsecureSkipVerify: true,
HTTPVersions: tc.httpVersions,
Timeout: timeout,
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
checkRaceCondition(t, u, address)
})
}
}
func TestUpstreamDoH_serverRestart(t *testing.T) {
......@@ -185,6 +226,7 @@ type testDoHServerOptions struct {
delayHandshakeH2 time.Duration
delayHandshakeH3 time.Duration
port int
handler http.Handler
}
// testDoHServer is an instance of a test DNS-over-HTTPS server.
......@@ -225,7 +267,10 @@ func startDoHServer(
opts testDoHServerOptions,
) (s *testDoHServer) {
tlsConfig := createServerTLSConfig(t, "127.0.0.1")
handler := createDoHHandler()
handler := opts.handler
if handler == nil {
handler = createDoHHandler()
}
// Step one is to create a regular HTTP server, we'll always have it
// running.
......@@ -296,11 +341,10 @@ func startDoHServer(
}
}
// createDoHHandler returns a very simple http.Handler that reads the incoming
// request and returns with a test message.
func createDoHHandler() (h http.Handler) {
mux := http.NewServeMux()
mux.HandleFunc("/dns-query", func(w http.ResponseWriter, r *http.Request) {
// createDoHHandlerFunc creates a simple http.HandlerFunc that reads the
// incoming DNS message and returns the test response.
func createDoHHandlerFunc() (f http.HandlerFunc) {
return func(w http.ResponseWriter, r *http.Request) {
dnsParam := r.URL.Query().Get("dns")
buf, err := base64.RawURLEncoding.DecodeString(dnsParam)
if err != nil {
......@@ -337,7 +381,14 @@ func createDoHHandler() (h http.Handler) {
w.Header().Set("Content-Type", "application/dns-message")
_, err = w.Write(buf)
})
}
}
// createDoHHandler returns a very simple http.Handler that reads the incoming
// request and returns with a test message.
func createDoHHandler() (h http.Handler) {
mux := http.NewServeMux()
mux.HandleFunc("/dns-query", createDoHHandlerFunc())
return mux
}
......@@ -39,16 +39,19 @@ type dnsOverQUIC struct {
// boot is a bootstrap DNS abstraction that is used to resolve the upstream
// server's address and open a network connection to it.
boot *bootstrapper
// quicConfig is the QUIC configuration that is used for establishing
// connections to the upstream. This configuration includes the TokenStore
// that needs to be stored for the lifetime of dnsOverQUIC since we can
// re-create the connection.
quicConfig *quic.Config
quicConfig *quic.Config
quicConfigGuard sync.Mutex
// conn is the current active QUIC connection. It can be closed and
// re-opened when needed.
conn quic.Connection
// connGuard protects conn and quicConfig.
conn quic.Connection
connGuard sync.RWMutex
// bytesPool is a *sync.Pool we use to store byte buffers in. These byte
// buffers are used to read responses from the upstream.
bytesPool *sync.Pool
......@@ -110,7 +113,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
log.Debug("re-creating the QUIC connection and retrying due to %v", err)
// Close the active connection to make sure we'll try to re-connect.
p.closeConnWithError(QUICCodeNoError)
p.closeConnWithError(err)
// Retry sending the request.
resp, err = p.exchangeQUIC(m)
......@@ -119,7 +122,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
if err != nil {
// If we're unable to exchange messages, make sure the connection is
// closed and signal about an internal error.
p.closeConnWithError(QUICCodeInternalError)
p.closeConnWithError(err)
}
return resp, err
......@@ -224,6 +227,25 @@ func (p *dnsOverQUIC) hasConnection() (ok bool) {
return p.conn != nil
}
// getQUICConfig returns the QUIC config in a thread-safe manner. Note, that
// this method returns a pointer, it is forbidden to change its properties.
func (p *dnsOverQUIC) getQUICConfig() (c *quic.Config) {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
return p.quicConfig
}
// resetQUICConfig re-creates the tokens store as we may need to use a new one
// if we failed to connect.
func (p *dnsOverQUIC) resetQUICConfig() {
p.quicConfigGuard.Lock()
defer p.quicConfigGuard.Unlock()
p.quicConfig = p.quicConfig.Clone()
p.quicConfig.TokenStore = newQUICTokenStore()
}
// openStream opens a new QUIC stream for the specified connection.
func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
ctx, cancel := p.boot.newContext()
......@@ -271,7 +293,7 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
ctx, cancel := p.boot.newContext()
defer cancel()
conn, err = quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.quicConfig)
conn, err = quic.DialAddrEarlyContext(ctx, addr, tlsConfig, p.getQUICConfig())
if err != nil {
return nil, fmt.Errorf("opening quic connection to %s: %w", p.Address(), err)
}
......@@ -280,9 +302,9 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
}
// closeConnWithError closes the active connection with error to make sure that
// new queries were processed in another connection. We can do that in the case
// new queries were processed in another connection. We can do that in the case
// of a fatal error.
func (p *dnsOverQUIC) closeConnWithError(code quic.ApplicationErrorCode) {
func (p *dnsOverQUIC) closeConnWithError(err error) {
p.connGuard.Lock()
defer p.connGuard.Unlock()
......@@ -291,15 +313,21 @@ func (p *dnsOverQUIC) closeConnWithError(code quic.ApplicationErrorCode) {
return
}
err := p.conn.CloseWithError(code, "")
code := QUICCodeNoError
if err != nil {
code = QUICCodeInternalError
}
if errors.Is(err, quic.Err0RTTRejected) {
// Reset the TokenStore only if 0-RTT was rejected.
p.resetQUICConfig()
}
err = p.conn.CloseWithError(code, "")
if err != nil {
log.Error("failed to close the conn: %v", err)
}
p.conn = nil
// Re-create the token store to make sure we're not trying to use invalid
// tokens for 0-RTT.
p.quicConfig.TokenStore = newQUICTokenStore()
}
// readMsg reads the incoming DNS message from the QUIC stream.
......
......@@ -24,16 +24,14 @@ func TestUpstreamDoQ(t *testing.T) {
address := fmt.Sprintf("quic://%s", srv.addr)
var lastState tls.ConnectionState
u, err := AddressToUpstream(
address,
&Options{
InsecureSkipVerify: true,
VerifyConnection: func(state tls.ConnectionState) error {
lastState = state
return nil
},
opts := &Options{
InsecureSkipVerify: true,
VerifyConnection: func(state tls.ConnectionState) error {
lastState = state
return nil
},
)
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
uq := u.(*dnsOverQUIC)
......@@ -59,6 +57,13 @@ func TestUpstreamDoQ(t *testing.T) {
// Make sure that the session has been resumed.
require.True(t, lastState.DidResume)
// Re-create the upstream to make the test check initialization and
// check it for race conditions.
u, err = AddressToUpstream(address, opts)
require.NoError(t, err)
checkRaceCondition(t, u, address)
}
func TestUpstreamDoQ_serverRestart(t *testing.T) {
......
......@@ -14,6 +14,7 @@ import (
"net"
"net/url"
"os"
"sync"
"testing"
"time"
......@@ -476,6 +477,7 @@ func TestAddPort(t *testing.T) {
}
}
// checkUpstream sends a test message to the upstream and checks the result.
func checkUpstream(t *testing.T, u Upstream, addr string) {
t.Helper()
......@@ -486,6 +488,31 @@ func checkUpstream(t *testing.T, u Upstream, addr string) {
requireResponse(t, req, reply)
}
// checkRaceCondition runs several goroutines in parallel and each of them calls
// checkUpstream several times.
func checkRaceCondition(t *testing.T, u Upstream, addr string) {
wg := sync.WaitGroup{}
// The number of requests to run in every goroutine.
reqCount := 10
// The overall number of goroutines to run.
goroutinesCount := 3
makeRequests := func() {
defer wg.Done()
for i := 0; i < reqCount; i++ {
checkUpstream(t, u, addr)
}
}
wg.Add(goroutinesCount)
for i := 0; i < goroutinesCount; i++ {
go makeRequests()
}
wg.Wait()
}
func createTestMessage() (m *dns.Msg) {
return createHostTestMessage("google-public-dns-a.google.com")
}
......