From 24b7ea605819cd901f9c4a1b65f52444021322b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20B=C3=B6hm?= Date: Tue, 28 Nov 2017 23:54:05 +0100 Subject: [PATCH] Introduce HostLookup and corresponding tests --- backend/backend_test.go | 177 ---------------------------------------- backend/lookup.go | 93 +++++++++++++++++++++ backend/lookup_test.go | 154 ++++++++++++++++++++++++++++++++++ 3 files changed, 247 insertions(+), 177 deletions(-) delete mode 100644 backend/backend_test.go create mode 100644 backend/lookup.go create mode 100644 backend/lookup_test.go diff --git a/backend/backend_test.go b/backend/backend_test.go deleted file mode 100644 index 2a155c1..0000000 --- a/backend/backend_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package backend - -import ( - "bufio" - "bytes" - "errors" - c "github.com/pboehm/ddns/config" - h "github.com/pboehm/ddns/hosts" - "github.com/stretchr/testify/assert" - "os" - "testing" -) - -type testHostBackend struct { - hosts map[string]*h.Host -} - -func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) { - host, ok := b.hosts[hostname] - if ok { - return host, nil - } else { - return nil, errors.New("Host not found") - } -} - -func (b *testHostBackend) SetHost(host *h.Host) error { - b.hosts[host.Hostname] = host - return nil -} - -func buildBackend(domain string) (*c.Config, *testHostBackend, *PowerDnsBackend) { - config := &c.Config{ - Verbose: false, - Domain: domain, - SOAFqdn: "dns" + domain, - } - - hosts := &testHostBackend{ - hosts: map[string]*h.Host{ - "www": { - Hostname: "www", - Ip: "10.11.12.13", - Token: "abcdef", - }, - "v4": { - Hostname: "v4", - Ip: "10.10.10.10", - Token: "ghijkl", - }, - "v6": { - Hostname: "v6", - Ip: "2001:db8:85a3::8a2e:370:7334", - Token: "ghijkl", - }, - }, - } - - return config, hosts, NewPowerDnsBackend(config, hosts, os.Stdin, os.Stdout) -} - -func buildRequest(queryName, queryType string) *backendRequest { - return &backendRequest{ - query: "Q", - queryName: queryName, - queryClass: "IN", - queryType: queryType, - queryId: "-1", - } -} - -func readResponse(t *testing.T, responses chan backendResponse) backendResponse { - select { - case response, ok := <-responses: - assert.True(t, ok) - return response - default: - assert.FailNow(t, "Couldn't read response because it is not available ...") - return nil - } -} - -func TestParseRequest(t *testing.T) { - _, _, backend := buildBackend(".example.org") - - reader := bufio.NewReader(bytes.NewBufferString("Q\twww.example.org\tIN\tCNAME\t-1\t203.0.113.210\n")) - request, err := backend.parseRequest(reader) - assert.Nil(t, err) - assert.Equal(t, buildRequest("www.example.org", "CNAME"), request) - - reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\tIN\tSOA\t-1\t203.0.113.210\n")) - request, err = backend.parseRequest(reader) - assert.Nil(t, err) - assert.Equal(t, buildRequest("example.org", "SOA"), request) - - reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\n")) - request, err = backend.parseRequest(reader) - assert.NotNil(t, err) - assert.Nil(t, request) -} - -func TestRequestHandling(t *testing.T) { - _, _, backend := buildBackend(".example.org") - - responses := make(chan backendResponse, 2) - err := backend.handleRequest(buildRequest("example.org", "SOA"), responses) - assert.Nil(t, err) - soaResponse := readResponse(t, responses) - assert.Equal(t, newResponse("DATA", "example.org", "IN", "SOA", "10", "-1"), soaResponse[0:6]) - assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", soaResponse[6]) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("example.org", "NS"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "example.org", "IN", "NS", "10", "-1", "dns.example.org."), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("www.example.org", "ANY"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("www.example.org", "A"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - // Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("wWW.eXaMPlE.oRg", "A"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "wWW.eXaMPlE.oRg", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 1) - err = backend.handleRequest(buildRequest("notexisting.example.org", "A"), responses) - assert.NotNil(t, err) - assert.Equal(t, endResponse, readResponse(t, responses)) - - // Correct Handling of IPv4/IPv6 and ANY/A/AAAA - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("v4.example.org", "ANY"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("v4.example.org", "A"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 1) - err = backend.handleRequest(buildRequest("v4.example.org", "AAAA"), responses) - assert.NotNil(t, err) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("v6.example.org", "ANY"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 2) - err = backend.handleRequest(buildRequest("v6.example.org", "AAAA"), responses) - assert.Nil(t, err) - assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses)) - assert.Equal(t, endResponse, readResponse(t, responses)) - - responses = make(chan backendResponse, 1) - err = backend.handleRequest(buildRequest("v6.example.org", "A"), responses) - assert.NotNil(t, err) - assert.Equal(t, endResponse, readResponse(t, responses)) -} diff --git a/backend/lookup.go b/backend/lookup.go new file mode 100644 index 0000000..51d8c20 --- /dev/null +++ b/backend/lookup.go @@ -0,0 +1,93 @@ +package backend + +import ( + "errors" + "fmt" + "github.com/pboehm/ddns/config" + "github.com/pboehm/ddns/hosts" + "strings" + "time" +) + +type Request struct { + QType string + QName string + Remote string + Local string + RealRemote string + ZoneId string +} + +type Response struct { + QType string + QName string + Content string + TTL int +} + +type HostLookup struct { + config *config.Config + hosts hosts.HostBackend +} + +func (l *HostLookup) Lookup(request *Request) (*Response, error) { + responseRecord := request.QType + responseContent := "" + + switch request.QType { + case "SOA": + responseContent = fmt.Sprintf("%s. hostmaster%s. %d 1800 3600 7200 5", + l.config.SOAFqdn, l.config.Domain, l.currentSOASerial()) + + case "NS": + responseContent = l.config.SOAFqdn + + case "A", "AAAA", "ANY": + hostname, err := l.extractHostname(request.QName) + if err != nil { + return nil, err + } + + var host *hosts.Host + if host, err = l.hosts.GetHost(hostname); err != nil { + return nil, err + } + + responseContent = host.Ip + + responseRecord = "A" + if !host.IsIPv4() { + responseRecord = "AAAA" + } + + if (request.QType == "A" || request.QType == "AAAA") && request.QType != responseRecord { + return nil, errors.New("IP address is not valid for requested record") + } + + default: + return nil, errors.New("Invalid request") + } + + return &Response{QType: responseRecord, QName: request.QName, Content: responseContent, TTL: 5}, nil +} + +// extractHostname extract the host part of the fqdn: pi.d.example.org -> pi +func (l *HostLookup) extractHostname(rawQueryName string) (string, error) { + queryName := strings.ToLower(rawQueryName) + + hostname := "" + if strings.HasSuffix(queryName, l.config.Domain) { + hostname = queryName[:len(queryName)-len(l.config.Domain)] + } + + if hostname == "" { + return "", errors.New("Query name does not correspond to our domain") + } + + return hostname, nil +} + +// currentSOASerial get the current SOA serial by returning the current time in seconds +func (l *HostLookup) currentSOASerial() int64 { + return time.Now().Unix() +} diff --git a/backend/lookup_test.go b/backend/lookup_test.go new file mode 100644 index 0000000..0ec7ed7 --- /dev/null +++ b/backend/lookup_test.go @@ -0,0 +1,154 @@ +package backend + +import ( + "errors" + c "github.com/pboehm/ddns/config" + h "github.com/pboehm/ddns/hosts" + "github.com/stretchr/testify/assert" + "testing" +) + +type testHostBackend struct { + hosts map[string]*h.Host +} + +func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) { + host, ok := b.hosts[hostname] + if ok { + return host, nil + } else { + return nil, errors.New("Host not found") + } +} + +func (b *testHostBackend) SetHost(host *h.Host) error { + b.hosts[host.Hostname] = host + return nil +} + +func buildLookup(domain string) (*c.Config, *testHostBackend, *HostLookup) { + config := &c.Config{ + Verbose: false, + Domain: domain, + SOAFqdn: "dns" + domain, + } + + hosts := &testHostBackend{ + hosts: map[string]*h.Host{ + "www": { + Hostname: "www", + Ip: "10.11.12.13", + Token: "abcdef", + }, + "v4": { + Hostname: "v4", + Ip: "10.10.10.10", + Token: "ghijkl", + }, + "v6": { + Hostname: "v6", + Ip: "2001:db8:85a3::8a2e:370:7334", + Token: "ghijkl", + }, + }, + } + + return config, hosts, &HostLookup{config, hosts} +} + +func buildRequest(queryName, queryType string) *Request { + return &Request{ + QType: queryType, + QName: queryName, + } +} + +func TestRequestHandling(t *testing.T) { + _, _, lookup := buildLookup(".example.org") + + response, err := lookup.Lookup(buildRequest("example.org", "SOA")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "SOA", response.QType) + assert.Equal(t, "example.org", response.QName) + assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("example.org", "NS")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "NS", response.QType) + assert.Equal(t, "example.org", response.QName) + assert.Equal(t, "dns.example.org", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("www.example.org", "ANY")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "A", response.QType) + assert.Equal(t, "www.example.org", response.QName) + assert.Equal(t, "10.11.12.13", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("www.example.org", "A")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "A", response.QType) + assert.Equal(t, "www.example.org", response.QName) + assert.Equal(t, "10.11.12.13", response.Content) + assert.Equal(t, 5, response.TTL) + + // Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security + response, err = lookup.Lookup(buildRequest("wWW.eXaMPlE.oRg", "A")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "A", response.QType) + assert.Equal(t, "wWW.eXaMPlE.oRg", response.QName) + assert.Equal(t, "10.11.12.13", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("notexisting.example.org", "A")) + assert.NotNil(t, err) + assert.Nil(t, response) + + // Correct Handling of IPv4/IPv6 and ANY/A/AAAA + response, err = lookup.Lookup(buildRequest("v4.example.org", "ANY")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "A", response.QType) + assert.Equal(t, "v4.example.org", response.QName) + assert.Equal(t, "10.10.10.10", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("v4.example.org", "A")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "A", response.QType) + assert.Equal(t, "v4.example.org", response.QName) + assert.Equal(t, "10.10.10.10", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("v4.example.org", "AAAA")) + assert.NotNil(t, err) + assert.Nil(t, response) + + response, err = lookup.Lookup(buildRequest("v6.example.org", "ANY")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "AAAA", response.QType) + assert.Equal(t, "v6.example.org", response.QName) + assert.Equal(t, "2001:db8:85a3::8a2e:370:7334", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("v6.example.org", "AAAA")) + assert.Nil(t, err) + assert.NotNil(t, response) + assert.Equal(t, "AAAA", response.QType) + assert.Equal(t, "v6.example.org", response.QName) + assert.Equal(t, "2001:db8:85a3::8a2e:370:7334", response.Content) + assert.Equal(t, 5, response.TTL) + + response, err = lookup.Lookup(buildRequest("v6.example.org", "A")) + assert.NotNil(t, err) + assert.Nil(t, response) +}