Skip to content
Commits on Source (5)
......@@ -6,69 +6,72 @@ import (
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
const ipv4OnlyHost = "ipv4only.arpa"
// Valid NAT-64 prefix for 2001:67c:27e4:15::64 server
var testNAT64Prefix = []byte{32, 1, 6, 124, 39, 228, 16, 100, 0, 0, 0, 0} //nolint
// Valid NAT-64 prefix for 2001:67c:27e4:15::64 server.
var testNAT64Prefix = []byte{32, 1, 6, 124, 39, 228, 16, 100, 0, 0, 0, 0}
func TestProxyWithDNS64(t *testing.T) {
// Create test proxy and manually set NAT64 prefix
// Create test proxy and manually set NAT64 prefix.
dnsProxy := createTestProxy(t, nil)
dnsProxy.SetNAT64Prefix(testNAT64Prefix)
err := dnsProxy.Start()
if err != nil {
t.Fatalf("Failed to start dns proxy")
}
require.NoError(t, err)
// Let's create test A request to ipv4OnlyHost and exchange it with test proxy
// Let's create test A request to ipv4OnlyHost and exchange it with the
// test proxy.
req := createHostTestMessage(ipv4OnlyHost)
resp, _, err := dnsProxy.exchange(req, dnsProxy.UpstreamConfig.Upstreams)
if err != nil {
t.Fatalf("Can not exchange test message for %s cause: %s", ipv4OnlyHost, err)
}
require.NoError(t, err)
require.Len(t, resp.Answer, 2)
a, ok := resp.Answer[0].(*dns.A)
if !ok {
t.Fatalf("Answer for %s is not an A record!", ipv4OnlyHost)
}
var mappedIPs []net.IP
for _, rr := range resp.Answer {
a, ok := rr.(*dns.A)
require.True(t, ok)
// Let's manually add NAT64 prefix to IPv4 response
mappedIP := make(net.IP, net.IPv6len)
copy(mappedIP, testNAT64Prefix)
for index, b := range a.A {
mappedIP[NAT64PrefixLength+index] = b
// Let's manually add NAT64 prefix to IPv4 response.
mappedIP := make(net.IP, net.IPv6len)
copy(mappedIP, testNAT64Prefix)
for index, b := range a.A {
mappedIP[NAT64PrefixLength+index] = b
}
mappedIPs = append(mappedIPs, mappedIP)
}
// Create test context with AAAA request to ipv4OnlyHost and resolve it
// Create test context with AAAA request to ipv4OnlyHost and resolve it.
testDNSContext := createTestDNSContext(ipv4OnlyHost)
err = dnsProxy.Resolve(testDNSContext)
if err != nil {
t.Fatalf("Error while DNSContext resolve: %s", err)
}
require.NoError(t, err)
// Response should be AAAA answer
// Response should be AAAA answer.
res := testDNSContext.Res
if res == nil {
t.Fatalf("No response")
}
require.NotNil(t, res)
ans, ok := res.Answer[0].(*dns.AAAA)
if !ok {
t.Fatalf("Answer for %s is not AAAA record", ipv4OnlyHost)
}
for _, rr := range res.Answer {
aaaa, ok := rr.(*dns.AAAA)
require.True(t, ok)
// Compare manually mapped IP with IP that was resolved by dnsproxy
// with calculated NAT64 prefix.
found := false
for _, mappedIP := range mappedIPs {
if aaaa.AAAA.Equal(mappedIP) {
found = true
break
}
}
// Compare manually mapped IP with IP that was resolved by dnsproxy with calculated NAT64 prefix
if !ans.AAAA.Equal(mappedIP) {
t.Fatalf("Manually mapped IP %s not equlas to response %s", mappedIP.String(), ans.AAAA.String())
require.True(t, found)
}
err = dnsProxy.Stop()
if err != nil {
t.Fatalf("Failed to stop dns proxy")
}
require.NoError(t, err)
}
func TestDNS64Race(t *testing.T) {
......@@ -76,26 +79,20 @@ func TestDNS64Race(t *testing.T) {
dnsProxy.SetNAT64Prefix(testNAT64Prefix)
dnsProxy.UpstreamConfig.Upstreams = append(dnsProxy.UpstreamConfig.Upstreams, dnsProxy.UpstreamConfig.Upstreams[0])
// Start listening
// Start listening.
err := dnsProxy.Start()
if err != nil {
t.Fatalf("cannot start the DNS proxy: %s", err)
}
require.NoError(t, err)
// Create a DNS-over-UDP client connection
// Create a DNS-over-UDP client connection.
addr := dnsProxy.Addr(ProtoUDP)
conn, err := dns.Dial("udp", addr.String())
if err != nil {
t.Fatalf("cannot connect to the proxy: %s", err)
}
require.NoError(t, err)
sendTestAAAAMessagesAsync(t, conn)
// Stop the proxy
// Stop the proxy.
err = dnsProxy.Stop()
if err != nil {
t.Fatalf("cannot stop the DNS proxy: %s", err)
}
require.NoError(t, err)
}
func sendTestAAAAMessagesAsync(t *testing.T, conn *dns.Conn) {
......@@ -116,31 +113,14 @@ func sendTestAAAAMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup, h
req := createAAAATestMessage(host)
err := conn.WriteMsg(req)
if err != nil {
t.Errorf("cannot write message: %s", err)
return
}
require.NoError(t, err)
res, err := conn.ReadMsg()
if err != nil {
t.Errorf("cannot read response to message: %s", err)
return
}
if len(res.Answer) == 0 {
t.Errorf("No answers!")
return
}
require.NoError(t, err)
require.True(t, len(res.Answer) > 0)
_, ok := res.Answer[0].(*dns.AAAA)
if !ok {
t.Errorf("Answer for %s is not AAAA record!", host)
return
}
require.True(t, ok)
}
func createAAAATestMessage(host string) *dns.Msg {
......
......@@ -50,7 +50,7 @@ func TestFilteringHandler(t *testing.T) {
if err != nil {
t.Fatalf("error in the first request: %s", err)
}
assertResponse(t, r)
requireResponse(t, req, r)
// Now send the second and make sure it is blocked
m.Lock()
......
......@@ -385,7 +385,7 @@ func TestExchangeWithReservedDomains(t *testing.T) {
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
assertResponse(t, res)
requireResponse(t, req, res)
// create adguard.com test message
req = createHostTestMessage("adguard.com")
......@@ -494,7 +494,7 @@ func TestOneByOneUpstreamsExchange(t *testing.T) {
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
assertResponse(t, res)
requireResponse(t, req, res)
elapsed := time.Since(start)
if elapsed > 3*timeOut {
......@@ -559,7 +559,7 @@ func TestFallback(t *testing.T) {
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
assertResponse(t, res)
requireResponse(t, req, res)
elapsed := time.Since(start)
if elapsed > 3*timeout {
......@@ -626,7 +626,7 @@ func TestFallbackFromInvalidBootstrap(t *testing.T) {
if err != nil {
t.Fatalf("cannot read response to message: %s", err)
}
assertResponse(t, res)
requireResponse(t, req, res)
elapsed := time.Since(start)
if elapsed > 3*timeout {
......@@ -1144,20 +1144,19 @@ func sendTestMessageAsync(t *testing.T, conn *dns.Conn, g *sync.WaitGroup) {
req := createTestMessage()
err := conn.WriteMsg(req)
if err != nil {
t.Errorf("cannot write message: %s", err)
return
}
require.NoError(t, err)
res, err := conn.ReadMsg()
if err != nil {
t.Errorf("cannot read response to message: %s", err)
require.NoError(t, err)
return
}
// We do not check if msg IDs match because the order of responses may
// be different.
assertResponse(t, res)
require.NotNil(t, res)
require.Lenf(t, res.Answer, 1, "wrong number of answers: %d", len(res.Answer))
a, ok := res.Answer[0].(*dns.A)
require.Truef(t, ok, "wrong answer type: %v", res.Answer[0])
require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}
// sendTestMessagesAsync sends messages in parallel
......@@ -1185,7 +1184,7 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) {
if err != nil {
t.Fatalf("cannot read response to message #%d: %s", i, err)
}
assertResponse(t, res)
requireResponse(t, req, res)
}
}
......@@ -1204,19 +1203,15 @@ func createHostTestMessage(host string) *dns.Msg {
return &req
}
func assertResponse(t *testing.T, reply *dns.Msg) {
t.Helper()
func requireResponse(t *testing.T, req, reply *dns.Msg) {
require.NotNil(t, reply)
require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
require.Equal(t, req.Id, reply.Id)
if len(reply.Answer) != 1 {
t.Fatalf("DNS upstream returned reply with wrong number of answers - %d", len(reply.Answer))
}
if a, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(a.A) {
t.Fatalf("DNS upstream returned wrong answer instead of 8.8.8.8: %v", a.A)
}
} else {
t.Fatalf("DNS upstream returned wrong answer type instead of A: %v", reply.Answer[0])
}
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])
require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}
func createServerTLSConfig(t *testing.T) (*tls.Config, []byte) {
......
......@@ -30,7 +30,7 @@ func TestRatelimitingProxy(t *testing.T) {
if err != nil {
t.Fatalf("error in the first request: %s", err)
}
assertResponse(t, r)
requireResponse(t, req, r)
// Send the second message (blocked)
req = createTestMessage()
......
......@@ -46,5 +46,5 @@ func checkDNSCryptProxy(t *testing.T, proto string, stamp dnsstamps.ServerStamp)
msg := createTestMessage()
reply, err := c.Exchange(msg, ri)
assert.Nil(t, err)
assertResponse(t, reply)
requireResponse(t, msg, reply)
}
......@@ -57,7 +57,7 @@ func TestHttpsProxy(t *testing.T) {
clientIP, proxyIP := net.IP{1, 2, 3, 4}, net.IP{127, 0, 0, 1}
msg := createTestMessage()
doRequest := func(t *testing.T, proxyAddr string) (reply *dns.Msg) {
doRequest := func(t *testing.T, proxyAddr string) {
dnsProxy.TrustedProxies = []string{proxyAddr}
// Start listening.
......@@ -95,25 +95,22 @@ func TestHttpsProxy(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
reply = &dns.Msg{}
reply := &dns.Msg{}
err = reply.Unpack(body)
require.NoError(t, err)
return reply
requireResponse(t, msg, reply)
}
t.Run("success", func(t *testing.T) {
reply := doRequest(t, proxyIP.String())
doRequest(t, proxyIP.String())
assertResponse(t, reply)
ip, _ := netutil.IPAndPortFromAddr(gotAddr)
assert.True(t, ip.Equal(clientIP))
})
t.Run("not_in_trusted", func(t *testing.T) {
reply := doRequest(t, "127.0.0.2")
doRequest(t, "127.0.0.2")
assertResponse(t, reply)
ip, _ := netutil.IPAndPortFromAddr(gotAddr)
assert.True(t, ip.Equal(proxyIP))
})
......
......@@ -96,5 +96,5 @@ func sendTestQUICMessage(t *testing.T, conn quic.Connection, doqVersion DoQVersi
require.NoError(t, err)
// Check the response
assertResponse(t, reply)
requireResponse(t, msg, reply)
}
......@@ -39,7 +39,7 @@ func TestExchangeParallel(t *testing.T) {
t.Fatalf("shouldn't happen. This upstream can't resolve DNS request: %s", u.Address())
}
assertResponse(t, resp)
requireResponse(t, req, resp)
elapsed := time.Since(start)
if elapsed > timeout {
t.Fatalf("exchange took more time than the configured timeout: %v", elapsed)
......
......@@ -23,7 +23,7 @@ func TestTLSPoolReconnect(t *testing.T) {
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
assertResponse(t, reply)
requireResponse(t, req, reply)
// Now let's close the pooled connection and return it back to the pool
p := u.(*dnsOverTLS)
......@@ -37,7 +37,7 @@ func TestTLSPoolReconnect(t *testing.T) {
if err != nil {
t.Fatalf("second DNS message failed: %s", err)
}
assertResponse(t, reply)
requireResponse(t, req, reply)
// Now assert that the number of connections in the pool is not changed
if len(p.pool.conns) != 1 {
......@@ -64,7 +64,7 @@ func TestTLSPoolDeadLine(t *testing.T) {
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
assertResponse(t, response)
requireResponse(t, req, response)
p := u.(*dnsOverTLS)
......@@ -77,7 +77,7 @@ func TestTLSPoolDeadLine(t *testing.T) {
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
assertResponse(t, response)
requireResponse(t, req, response)
// Update connection's deadLine and put it back to the pool
err = conn.SetDeadline(time.Now().Add(10 * time.Hour))
......@@ -95,7 +95,7 @@ func TestTLSPoolDeadLine(t *testing.T) {
if err != nil {
t.Fatalf("first DNS message failed: %s", err)
}
assertResponse(t, response)
requireResponse(t, req, response)
// Set connection's deadLine to the past and try to reuse it
err = conn.SetDeadline(time.Now().Add(-10 * time.Hour))
......
......@@ -73,13 +73,12 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (res *dns.Msg, err error) {
// When sending queries over a QUIC connection, the DNS Message ID MUST be
// set to zero.
id := m.Id
var reply *dns.Msg
m.Id = 0
defer func() {
// Restore the original ID to not break compatibility with proxies
m.Id = id
if reply != nil {
reply.Id = id
if res != nil {
res.Id = id
}
}()
......
......@@ -97,7 +97,7 @@ func TestUpstreamRace(t *testing.T) {
abort <- fmt.Sprintf("%s failed to resolve: %v", u.Address(), err)
return
}
assertResponse(t, res)
requireResponse(t, req, res)
t.Logf("Finished %d", idx)
ch <- idx
}(i)
......@@ -412,7 +412,7 @@ func checkUpstream(t *testing.T, u Upstream, addr string) {
reply, err := u.Exchange(req)
require.NoErrorf(t, err, "couldn't talk to upstream %s", addr)
assertResponse(t, reply)
requireResponse(t, req, reply)
}
func createTestMessage() *dns.Msg {
......@@ -433,14 +433,15 @@ func createHostTestMessage(host string) (req *dns.Msg) {
}
}
func assertResponse(t *testing.T, reply *dns.Msg) {
func requireResponse(t *testing.T, req, reply *dns.Msg) {
require.NotNil(t, reply)
require.Lenf(t, reply.Answer, 1, "wrong number of answers: %d", len(reply.Answer))
require.Equal(t, req.Id, reply.Id)
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "wrong answer type: %v", reply.Answer[0])
assert.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
require.Equalf(t, net.IPv4(8, 8, 8, 8), a.A.To16(), "wrong answer: %v", a.A)
}
func TestAddPort(t *testing.T) {
......