Introduce HostLookup and corresponding tests
This commit is contained in:
parent
c2e7b6f2d2
commit
24b7ea6058
|
@ -1,177 +0,0 @@
|
||||||
package backend
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
c "github.com/pboehm/ddns/config"
|
|
||||||
h "github.com/pboehm/ddns/hosts"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testHostBackend struct {
|
|
||||||
hosts map[string]*h.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) {
|
|
||||||
host, ok := b.hosts[hostname]
|
|
||||||
if ok {
|
|
||||||
return host, nil
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("Host not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *testHostBackend) SetHost(host *h.Host) error {
|
|
||||||
b.hosts[host.Hostname] = host
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildBackend(domain string) (*c.Config, *testHostBackend, *PowerDnsBackend) {
|
|
||||||
config := &c.Config{
|
|
||||||
Verbose: false,
|
|
||||||
Domain: domain,
|
|
||||||
SOAFqdn: "dns" + domain,
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts := &testHostBackend{
|
|
||||||
hosts: map[string]*h.Host{
|
|
||||||
"www": {
|
|
||||||
Hostname: "www",
|
|
||||||
Ip: "10.11.12.13",
|
|
||||||
Token: "abcdef",
|
|
||||||
},
|
|
||||||
"v4": {
|
|
||||||
Hostname: "v4",
|
|
||||||
Ip: "10.10.10.10",
|
|
||||||
Token: "ghijkl",
|
|
||||||
},
|
|
||||||
"v6": {
|
|
||||||
Hostname: "v6",
|
|
||||||
Ip: "2001:db8:85a3::8a2e:370:7334",
|
|
||||||
Token: "ghijkl",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, hosts, NewPowerDnsBackend(config, hosts, os.Stdin, os.Stdout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildRequest(queryName, queryType string) *backendRequest {
|
|
||||||
return &backendRequest{
|
|
||||||
query: "Q",
|
|
||||||
queryName: queryName,
|
|
||||||
queryClass: "IN",
|
|
||||||
queryType: queryType,
|
|
||||||
queryId: "-1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func readResponse(t *testing.T, responses chan backendResponse) backendResponse {
|
|
||||||
select {
|
|
||||||
case response, ok := <-responses:
|
|
||||||
assert.True(t, ok)
|
|
||||||
return response
|
|
||||||
default:
|
|
||||||
assert.FailNow(t, "Couldn't read response because it is not available ...")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRequest(t *testing.T) {
|
|
||||||
_, _, backend := buildBackend(".example.org")
|
|
||||||
|
|
||||||
reader := bufio.NewReader(bytes.NewBufferString("Q\twww.example.org\tIN\tCNAME\t-1\t203.0.113.210\n"))
|
|
||||||
request, err := backend.parseRequest(reader)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, buildRequest("www.example.org", "CNAME"), request)
|
|
||||||
|
|
||||||
reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\tIN\tSOA\t-1\t203.0.113.210\n"))
|
|
||||||
request, err = backend.parseRequest(reader)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, buildRequest("example.org", "SOA"), request)
|
|
||||||
|
|
||||||
reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\n"))
|
|
||||||
request, err = backend.parseRequest(reader)
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.Nil(t, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestHandling(t *testing.T) {
|
|
||||||
_, _, backend := buildBackend(".example.org")
|
|
||||||
|
|
||||||
responses := make(chan backendResponse, 2)
|
|
||||||
err := backend.handleRequest(buildRequest("example.org", "SOA"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
soaResponse := readResponse(t, responses)
|
|
||||||
assert.Equal(t, newResponse("DATA", "example.org", "IN", "SOA", "10", "-1"), soaResponse[0:6])
|
|
||||||
assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", soaResponse[6])
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("example.org", "NS"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "example.org", "IN", "NS", "10", "-1", "dns.example.org."), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("www.example.org", "ANY"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("www.example.org", "A"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
// Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("wWW.eXaMPlE.oRg", "A"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "wWW.eXaMPlE.oRg", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 1)
|
|
||||||
err = backend.handleRequest(buildRequest("notexisting.example.org", "A"), responses)
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
// Correct Handling of IPv4/IPv6 and ANY/A/AAAA
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("v4.example.org", "ANY"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("v4.example.org", "A"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 1)
|
|
||||||
err = backend.handleRequest(buildRequest("v4.example.org", "AAAA"), responses)
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("v6.example.org", "ANY"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 2)
|
|
||||||
err = backend.handleRequest(buildRequest("v6.example.org", "AAAA"), responses)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses))
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
|
|
||||||
responses = make(chan backendResponse, 1)
|
|
||||||
err = backend.handleRequest(buildRequest("v6.example.org", "A"), responses)
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.Equal(t, endResponse, readResponse(t, responses))
|
|
||||||
}
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/pboehm/ddns/config"
|
||||||
|
"github.com/pboehm/ddns/hosts"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
QType string
|
||||||
|
QName string
|
||||||
|
Remote string
|
||||||
|
Local string
|
||||||
|
RealRemote string
|
||||||
|
ZoneId string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
QType string
|
||||||
|
QName string
|
||||||
|
Content string
|
||||||
|
TTL int
|
||||||
|
}
|
||||||
|
|
||||||
|
type HostLookup struct {
|
||||||
|
config *config.Config
|
||||||
|
hosts hosts.HostBackend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *HostLookup) Lookup(request *Request) (*Response, error) {
|
||||||
|
responseRecord := request.QType
|
||||||
|
responseContent := ""
|
||||||
|
|
||||||
|
switch request.QType {
|
||||||
|
case "SOA":
|
||||||
|
responseContent = fmt.Sprintf("%s. hostmaster%s. %d 1800 3600 7200 5",
|
||||||
|
l.config.SOAFqdn, l.config.Domain, l.currentSOASerial())
|
||||||
|
|
||||||
|
case "NS":
|
||||||
|
responseContent = l.config.SOAFqdn
|
||||||
|
|
||||||
|
case "A", "AAAA", "ANY":
|
||||||
|
hostname, err := l.extractHostname(request.QName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var host *hosts.Host
|
||||||
|
if host, err = l.hosts.GetHost(hostname); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
responseContent = host.Ip
|
||||||
|
|
||||||
|
responseRecord = "A"
|
||||||
|
if !host.IsIPv4() {
|
||||||
|
responseRecord = "AAAA"
|
||||||
|
}
|
||||||
|
|
||||||
|
if (request.QType == "A" || request.QType == "AAAA") && request.QType != responseRecord {
|
||||||
|
return nil, errors.New("IP address is not valid for requested record")
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("Invalid request")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Response{QType: responseRecord, QName: request.QName, Content: responseContent, TTL: 5}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractHostname extract the host part of the fqdn: pi.d.example.org -> pi
|
||||||
|
func (l *HostLookup) extractHostname(rawQueryName string) (string, error) {
|
||||||
|
queryName := strings.ToLower(rawQueryName)
|
||||||
|
|
||||||
|
hostname := ""
|
||||||
|
if strings.HasSuffix(queryName, l.config.Domain) {
|
||||||
|
hostname = queryName[:len(queryName)-len(l.config.Domain)]
|
||||||
|
}
|
||||||
|
|
||||||
|
if hostname == "" {
|
||||||
|
return "", errors.New("Query name does not correspond to our domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostname, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentSOASerial get the current SOA serial by returning the current time in seconds
|
||||||
|
func (l *HostLookup) currentSOASerial() int64 {
|
||||||
|
return time.Now().Unix()
|
||||||
|
}
|
|
@ -0,0 +1,154 @@
|
||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
c "github.com/pboehm/ddns/config"
|
||||||
|
h "github.com/pboehm/ddns/hosts"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testHostBackend struct {
|
||||||
|
hosts map[string]*h.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) {
|
||||||
|
host, ok := b.hosts[hostname]
|
||||||
|
if ok {
|
||||||
|
return host, nil
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("Host not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testHostBackend) SetHost(host *h.Host) error {
|
||||||
|
b.hosts[host.Hostname] = host
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLookup(domain string) (*c.Config, *testHostBackend, *HostLookup) {
|
||||||
|
config := &c.Config{
|
||||||
|
Verbose: false,
|
||||||
|
Domain: domain,
|
||||||
|
SOAFqdn: "dns" + domain,
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts := &testHostBackend{
|
||||||
|
hosts: map[string]*h.Host{
|
||||||
|
"www": {
|
||||||
|
Hostname: "www",
|
||||||
|
Ip: "10.11.12.13",
|
||||||
|
Token: "abcdef",
|
||||||
|
},
|
||||||
|
"v4": {
|
||||||
|
Hostname: "v4",
|
||||||
|
Ip: "10.10.10.10",
|
||||||
|
Token: "ghijkl",
|
||||||
|
},
|
||||||
|
"v6": {
|
||||||
|
Hostname: "v6",
|
||||||
|
Ip: "2001:db8:85a3::8a2e:370:7334",
|
||||||
|
Token: "ghijkl",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, hosts, &HostLookup{config, hosts}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRequest(queryName, queryType string) *Request {
|
||||||
|
return &Request{
|
||||||
|
QType: queryType,
|
||||||
|
QName: queryName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestHandling(t *testing.T) {
|
||||||
|
_, _, lookup := buildLookup(".example.org")
|
||||||
|
|
||||||
|
response, err := lookup.Lookup(buildRequest("example.org", "SOA"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "SOA", response.QType)
|
||||||
|
assert.Equal(t, "example.org", response.QName)
|
||||||
|
assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("example.org", "NS"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "NS", response.QType)
|
||||||
|
assert.Equal(t, "example.org", response.QName)
|
||||||
|
assert.Equal(t, "dns.example.org", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("www.example.org", "ANY"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "A", response.QType)
|
||||||
|
assert.Equal(t, "www.example.org", response.QName)
|
||||||
|
assert.Equal(t, "10.11.12.13", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("www.example.org", "A"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "A", response.QType)
|
||||||
|
assert.Equal(t, "www.example.org", response.QName)
|
||||||
|
assert.Equal(t, "10.11.12.13", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
// Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security
|
||||||
|
response, err = lookup.Lookup(buildRequest("wWW.eXaMPlE.oRg", "A"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "A", response.QType)
|
||||||
|
assert.Equal(t, "wWW.eXaMPlE.oRg", response.QName)
|
||||||
|
assert.Equal(t, "10.11.12.13", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("notexisting.example.org", "A"))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, response)
|
||||||
|
|
||||||
|
// Correct Handling of IPv4/IPv6 and ANY/A/AAAA
|
||||||
|
response, err = lookup.Lookup(buildRequest("v4.example.org", "ANY"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "A", response.QType)
|
||||||
|
assert.Equal(t, "v4.example.org", response.QName)
|
||||||
|
assert.Equal(t, "10.10.10.10", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("v4.example.org", "A"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "A", response.QType)
|
||||||
|
assert.Equal(t, "v4.example.org", response.QName)
|
||||||
|
assert.Equal(t, "10.10.10.10", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("v4.example.org", "AAAA"))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, response)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("v6.example.org", "ANY"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "AAAA", response.QType)
|
||||||
|
assert.Equal(t, "v6.example.org", response.QName)
|
||||||
|
assert.Equal(t, "2001:db8:85a3::8a2e:370:7334", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("v6.example.org", "AAAA"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
assert.Equal(t, "AAAA", response.QType)
|
||||||
|
assert.Equal(t, "v6.example.org", response.QName)
|
||||||
|
assert.Equal(t, "2001:db8:85a3::8a2e:370:7334", response.Content)
|
||||||
|
assert.Equal(t, 5, response.TTL)
|
||||||
|
|
||||||
|
response, err = lookup.Lookup(buildRequest("v6.example.org", "A"))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Nil(t, response)
|
||||||
|
}
|
Loading…
Reference in New Issue