Skip to content
Commits on Source (1)
  • Andrey Meshkov's avatar
    Pull request: proxy: return ServeHTTP to the Proxy instance · 1250a0b7
    Andrey Meshkov authored
    Merge in DNS/dnsproxy from fix-serve-http to master
    
    Squashed commit of the following:
    
    commit 901fe43e51c1854f43bcb5034e488e2f560a3d67
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Fri Sep 30 16:02:30 2022 +0300
    
        stabilize race test
    
    commit 302b6609658b51bb3f7b45ec1cc79fde41486032
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Fri Sep 30 15:48:00 2022 +0300
    
        improve tests
    
    commit b24c0a05dd9c2f85a94a9c3b21e41a39f28d2a21
    Author: Andrey Meshkov <am@adguard.com>
    Date:   Fri Sep 30 15:41:27 2022 +0300
    
        proxy: return ServeHTTP to the Proxy instance
    1250a0b7
......@@ -56,20 +56,14 @@ func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) {
// createHTTPSListeners creates TCP/UDP listeners and HTTP/H3 servers.
func (p *Proxy) createHTTPSListeners() (err error) {
p.httpsServer = &http.Server{
Handler: &proxyHTTPHandler{
proxy: p,
h3: false,
},
Handler: p,
ReadHeaderTimeout: defaultTimeout,
WriteTimeout: defaultTimeout,
}
if p.HTTP3 {
p.h3Server = &http3.Server{
Handler: &proxyHTTPHandler{
proxy: p,
h3: true,
},
Handler: p,
}
}
......@@ -95,16 +89,6 @@ func (p *Proxy) createHTTPSListeners() (err error) {
return nil
}
// proxyHTTPHandler implements http.Handler and processes DoH queries.
type proxyHTTPHandler struct {
// h3 is true if this is an HTTP/3 requests handler.
h3 bool
proxy *Proxy
}
// type check
var _ http.Handler = &proxyHTTPHandler{}
// ServeHTTP is the http.Handler implementation that handles DoH queries.
// Here is what it returns:
//
......@@ -112,7 +96,7 @@ var _ http.Handler = &proxyHTTPHandler{}
// - http.StatusUnsupportedMediaType if request content type is not
// "application/dns-message";
// - http.StatusMethodNotAllowed if request method is not GET or POST.
func (h *proxyHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Tracef("Incoming HTTPS request on %s", r.URL)
var buf []byte
......@@ -155,12 +139,12 @@ func (h *proxyHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
addr, prx, err := remoteAddr(r, h.h3)
addr, prx, err := remoteAddr(r)
if err != nil {
log.Debug("warning: getting real ip: %s", err)
}
d := h.proxy.newDNSContext(ProtoHTTPS, req)
d := p.newDNSContext(ProtoHTTPS, req)
d.Addr = addr
d.HTTPRequest = r
d.HTTPResponseWriter = w
......@@ -168,13 +152,13 @@ func (h *proxyHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if prx != nil {
ip, _ := netutil.IPAndPortFromAddr(prx)
log.Debug("request came from proxy server %s", prx)
if !h.proxy.proxyVerifier.Contains(ip) {
if !p.proxyVerifier.Contains(ip) {
log.Debug("proxy %s is not trusted, using original remote addr", ip)
d.Addr = prx
}
}
err = h.proxy.handleDNSRequest(d)
err = p.handleDNSRequest(d)
if err != nil {
log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
}
......@@ -239,7 +223,7 @@ func realIPFromHdrs(r *http.Request) (realIP net.IP) {
// remoteAddr returns the real client's address and the IP address of the latest
// proxy server if any.
func remoteAddr(r *http.Request, h3 bool) (addr, prx net.Addr, err error) {
func remoteAddr(r *http.Request) (addr, prx net.Addr, err error) {
var hostStr, portStr string
if hostStr, portStr, err = net.SplitHostPort(r.RemoteAddr); err != nil {
return nil, nil, err
......@@ -255,6 +239,8 @@ func remoteAddr(r *http.Request, h3 bool) (addr, prx net.Addr, err error) {
return nil, nil, fmt.Errorf("invalid ip: %s", hostStr)
}
h3 := r.Context().Value(http3.ServerContextKey) != nil
if realIP := realIPFromHdrs(r); realIP != nil {
log.Tracef("Using IP address from HTTP request: %s", realIP)
......@@ -271,5 +257,9 @@ func remoteAddr(r *http.Request, h3 bool) (addr, prx net.Addr, err error) {
return addr, prx, nil
}
if h3 {
return &net.UDPAddr{IP: host, Port: port}, nil, nil
}
return &net.TCPAddr{IP: host, Port: port}, nil, nil
}
......@@ -292,7 +292,7 @@ func TestRemoteAddr(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
addr, prx, err := remoteAddr(r, false)
addr, prx, err := remoteAddr(r)
if tc.wantErr != "" {
assert.Equal(t, tc.wantErr, err.Error())
......
......@@ -103,13 +103,51 @@ func TestUpstreamDoH(t *testing.T) {
require.True(t, lastState.DidResume)
})
}
}
func TestUpstreamDoH_raceReconnect(t *testing.T) {
testCases := []struct {
name string
http3Enabled bool
httpVersions []HTTPVersion
delayHandshakeH3 time.Duration
delayHandshakeH2 time.Duration
expectedProtocol HTTPVersion
}{{
name: "http1.1_h2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion11, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "fallback_to_http2",
http3Enabled: false,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
expectedProtocol: HTTPVersion2,
}, {
name: "http3",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3},
expectedProtocol: HTTPVersion3,
}, {
name: "race_http3_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH2: time.Second,
expectedProtocol: HTTPVersion3,
}, {
name: "race_http2_faster",
http3Enabled: true,
httpVersions: []HTTPVersion{HTTPVersion3, HTTPVersion2},
delayHandshakeH3: time.Second,
expectedProtocol: HTTPVersion2,
}}
// This is a different set of tests that are supposed to be run with -race.
// The difference is that the HTTP handler here adds additional time.Sleep
// call. This call would trigger the HTTP client re-connection which is
// important to test for race conditions.
for _, tc := range testCases {
t.Run(tc.name+"_race", func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
const timeout = time.Millisecond * 100
var requestsCount int32
......@@ -142,7 +180,7 @@ func TestUpstreamDoH(t *testing.T) {
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
checkRaceCondition(t, u, address)
checkRaceCondition(u)
})
}
}
......
......@@ -63,7 +63,7 @@ func TestUpstreamDoQ(t *testing.T) {
u, err = AddressToUpstream(address, opts)
require.NoError(t, err)
checkRaceCondition(t, u, address)
checkRaceCondition(u)
}
func TestUpstreamDoQ_serverRestart(t *testing.T) {
......
......@@ -490,7 +490,7 @@ func checkUpstream(t *testing.T, u Upstream, addr string) {
// checkRaceCondition runs several goroutines in parallel and each of them calls
// checkUpstream several times.
func checkRaceCondition(t *testing.T, u Upstream, addr string) {
func checkRaceCondition(u Upstream) {
wg := sync.WaitGroup{}
// The number of requests to run in every goroutine.
......@@ -501,7 +501,9 @@ func checkRaceCondition(t *testing.T, u Upstream, addr string) {
makeRequests := func() {
defer wg.Done()
for i := 0; i < reqCount; i++ {
checkUpstream(t, u, addr)
req := createTestMessage()
// Ignore exchange errors here, the point is to check for races.
_, _ = u.Exchange(req)
}
}
......