refactor ddns into several packages and add tests for most important

components

WIP

restructure code into separate packages + add tests for request handling

more documentation for methods

remove useless comment
This commit is contained in:
Philipp Böhm 2017-01-29 19:02:14 +01:00
parent 5f52d85cf3
commit 6af394aa2f
10 changed files with 553 additions and 257 deletions

View File

@ -1,99 +0,0 @@
package main
import (
"bufio"
"fmt"
"os"
"strings"
"time"
)
type responder func()
func respondWithFAIL() {
fmt.Printf("FAIL\n")
}
func respondWithEND() {
fmt.Printf("END\n")
}
// This function implements the PowerDNS-Pipe-Backend protocol and generates
// the response data it possible
func RunBackend(conn *RedisConnection) {
bio := bufio.NewReader(os.Stdin)
// handshake with PowerDNS
_, _, _ = bio.ReadLine()
fmt.Printf("OK\tDDNS Go Backend\n")
for {
line, _, err := bio.ReadLine()
if err != nil {
respondWithFAIL()
continue
}
HandleRequest(string(line), conn)()
}
}
func HandleRequest(line string, conn *RedisConnection) responder {
if Verbose {
fmt.Printf("LOG\t'%s'\n", line)
}
parts := strings.Split(line, "\t")
if len(parts) != 6 {
return respondWithFAIL
}
query_name := parts[1]
query_class := parts[2]
query_type := parts[3]
query_id := parts[4]
var response, record string
record = query_type
switch query_type {
case "SOA":
response = fmt.Sprintf("%s. hostmaster.example.com. %d 1800 3600 7200 5",
DdnsSoaFqdn, getSoaSerial())
case "NS":
response = fmt.Sprintf("%s.", DdnsSoaFqdn)
case "A":
case "ANY":
// get the host part of the fqdn: pi.d.example.org -> pi
hostname := ""
if strings.HasSuffix(query_name, DdnsDomain) {
hostname = query_name[:len(query_name)-len(DdnsDomain)]
}
if hostname == "" || !conn.HostExist(hostname) {
return respondWithFAIL
}
host := conn.GetHost(hostname)
response = host.Ip
record = "A"
if !host.IsIPv4() {
record = "AAAA"
}
default:
return respondWithFAIL
}
fmt.Printf("DATA\t%s\t%s\t%s\t10\t%s\t%s\n",
query_name, query_class, record, query_id, response)
return respondWithEND
}
func getSoaSerial() int64 {
// return current time in milliseconds
return time.Now().UnixNano()
}

170
backend/backend.go Normal file
View File

@ -0,0 +1,170 @@
package backend
import (
"../config"
"../hosts"
"bufio"
"errors"
"fmt"
"io"
"strings"
"time"
)
// PowerDnsBackend implements the PowerDNS-Pipe-Backend protocol ABI-Version 1
type PowerDnsBackend struct {
config *config.Config
hosts hosts.HostBackend
in io.Reader
out io.Writer
}
// NewPowerDnsBackend creates an instance of a PowerDNS-Pipe-Backend using the supplied parameters
func NewPowerDnsBackend(config *config.Config, backend hosts.HostBackend, in io.Reader, out io.Writer) *PowerDnsBackend {
return &PowerDnsBackend{
config: config,
hosts: backend,
in: in,
out: out,
}
}
// Run reads requests from an input (normally STDIN) and prints response messages to an output (normally STDOUT)
func (b *PowerDnsBackend) Run() {
responses := make(chan backendResponse, 5)
go func() {
for response := range responses {
fmt.Fprintln(b.out, strings.Join(response, "\t"))
}
}()
// handshake with PowerDNS
bio := bufio.NewReader(b.in)
_, _, _ = bio.ReadLine()
responses <- handshakeResponse
for {
request, err := b.parseRequest(bio)
if err != nil {
responses <- failResponse
continue
}
if err = b.handleRequest(request, responses); err != nil {
responses <- newResponse("LOG", err.Error())
}
}
}
// handleRequest handles the supplied request by sending response messages on the supplied responses channel
func (b *PowerDnsBackend) handleRequest(request *backendRequest, responses chan backendResponse) error {
defer b.commitRequest(responses)
responseRecord := request.queryType
var response string
switch request.queryType {
case "SOA":
response = fmt.Sprintf("%s. hostmaster%s. %d 1800 3600 7200 5",
b.config.SOAFqdn, b.config.Domain, b.currentSOASerial())
case "NS":
response = fmt.Sprintf("%s.", b.config.SOAFqdn)
case "A", "AAAA", "ANY":
hostname, err := b.extractHostname(request.queryName)
if err != nil {
return err
}
var host *hosts.Host
if host, err = b.hosts.GetHost(hostname); err != nil {
return err
}
response = host.Ip
responseRecord = "A"
if !host.IsIPv4() {
responseRecord = "AAAA"
}
if (request.queryType == "A" || request.queryType == "AAAA") && request.queryType != responseRecord {
return errors.New("IP address is not valid for requested record")
}
default:
return errors.New("Unsupported query type")
}
responses <- newResponse("DATA", request.queryName, request.queryClass, responseRecord, "10", request.queryId, response)
return nil
}
func (b *PowerDnsBackend) commitRequest(responses chan backendResponse) {
responses <- endResponse
}
// parseRequest reads a line from input and tries to build a request structure from it
func (b *PowerDnsBackend) parseRequest(input *bufio.Reader) (*backendRequest, error) {
line, _, err := input.ReadLine()
if err != nil {
return nil, err
}
parts := strings.Split(string(line), "\t")
if len(parts) != 6 {
return nil, errors.New("Invalid line")
}
return &backendRequest{
query: parts[0],
queryName: parts[1],
queryClass: parts[2],
queryType: parts[3],
queryId: parts[4],
}, nil
}
// extractHostname extract the host part of the fqdn: pi.d.example.org -> pi
func (b *PowerDnsBackend) extractHostname(rawQueryName string) (string, error) {
queryName := strings.ToLower(rawQueryName)
hostname := ""
if strings.HasSuffix(queryName, b.config.Domain) {
hostname = queryName[:len(queryName)-len(b.config.Domain)]
}
if hostname == "" {
return "", errors.New("Query name does not correspond to our domain")
}
return hostname, nil
}
// currentSOASerial get the current SOA serial by returning the current time in seconds
func (b *PowerDnsBackend) currentSOASerial() int64 {
return time.Now().Unix()
}
type backendRequest struct {
query string
queryName string
queryClass string
queryType string
queryId string
}
type backendResponse []string
func newResponse(values ...string) backendResponse {
return values
}
var (
handshakeResponse backendResponse = []string{"OK", "DDNS Backend"}
endResponse backendResponse = []string{"END"}
failResponse backendResponse = []string{"FAIL"}
)

177
backend/backend_test.go Normal file
View File

@ -0,0 +1,177 @@
package backend
import (
c "../config"
h "../hosts"
"bufio"
"bytes"
"errors"
"github.com/stretchr/testify/assert"
"os"
"testing"
)
type testHostBackend struct {
hosts map[string]*h.Host
}
func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) {
host, ok := b.hosts[hostname]
if ok {
return host, nil
} else {
return nil, errors.New("Host not found")
}
}
func (b *testHostBackend) SetHost(host *h.Host) error {
b.hosts[host.Hostname] = host
return nil
}
func buildBackend(domain string) (*c.Config, *testHostBackend, *PowerDnsBackend) {
config := &c.Config{
Verbose: false,
Domain: domain,
SOAFqdn: "dns" + domain,
}
hosts := &testHostBackend{
hosts: map[string]*h.Host{
"www": {
Hostname: "www",
Ip: "10.11.12.13",
Token: "abcdef",
},
"v4": {
Hostname: "v4",
Ip: "10.10.10.10",
Token: "ghijkl",
},
"v6": {
Hostname: "v6",
Ip: "2001:db8:85a3::8a2e:370:7334",
Token: "ghijkl",
},
},
}
return config, hosts, NewPowerDnsBackend(config, hosts, os.Stdin, os.Stdout)
}
func buildRequest(queryName, queryType string) *backendRequest {
return &backendRequest{
query: "Q",
queryName: queryName,
queryClass: "IN",
queryType: queryType,
queryId: "-1",
}
}
func readResponse(t *testing.T, responses chan backendResponse) backendResponse {
select {
case response, ok := <-responses:
assert.True(t, ok)
return response
default:
assert.FailNow(t, "Couldn't read response because it is not available ...")
return nil
}
}
func TestParseRequest(t *testing.T) {
_, _, backend := buildBackend(".example.org")
reader := bufio.NewReader(bytes.NewBufferString("Q\twww.example.org\tIN\tCNAME\t-1\t203.0.113.210\n"))
request, err := backend.parseRequest(reader)
assert.Nil(t, err)
assert.Equal(t, buildRequest("www.example.org", "CNAME"), request)
reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\tIN\tSOA\t-1\t203.0.113.210\n"))
request, err = backend.parseRequest(reader)
assert.Nil(t, err)
assert.Equal(t, buildRequest("example.org", "SOA"), request)
reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\n"))
request, err = backend.parseRequest(reader)
assert.NotNil(t, err)
assert.Nil(t, request)
}
func TestRequestHandling(t *testing.T) {
_, _, backend := buildBackend(".example.org")
responses := make(chan backendResponse, 2)
err := backend.handleRequest(buildRequest("example.org", "SOA"), responses)
assert.Nil(t, err)
soaResponse := readResponse(t, responses)
assert.Equal(t, newResponse("DATA", "example.org", "IN", "SOA", "10", "-1"), soaResponse[0:6])
assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", soaResponse[6])
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("example.org", "NS"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "example.org", "IN", "NS", "10", "-1", "dns.example.org."), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("www.example.org", "ANY"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("www.example.org", "A"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
// Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("wWW.eXaMPlE.oRg", "A"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "wWW.eXaMPlE.oRg", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 1)
err = backend.handleRequest(buildRequest("notexisting.example.org", "A"), responses)
assert.NotNil(t, err)
assert.Equal(t, endResponse, readResponse(t, responses))
// Correct Handling of IPv4/IPv6 and ANY/A/AAAA
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("v4.example.org", "ANY"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("v4.example.org", "A"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 1)
err = backend.handleRequest(buildRequest("v4.example.org", "AAAA"), responses)
assert.NotNil(t, err)
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("v6.example.org", "ANY"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 2)
err = backend.handleRequest(buildRequest("v6.example.org", "AAAA"), responses)
assert.Nil(t, err)
assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses))
assert.Equal(t, endResponse, readResponse(t, responses))
responses = make(chan backendResponse, 1)
err = backend.handleRequest(buildRequest("v6.example.org", "A"), responses)
assert.NotNil(t, err)
assert.Equal(t, endResponse, readResponse(t, responses))
}

10
config/config.go Normal file
View File

@ -0,0 +1,10 @@
package config
type Config struct {
Verbose bool
Domain string
SOAFqdn string
HostExpirationDays int
Listen string
RedisHost string
}

68
ddns.go
View File

@ -1,63 +1,63 @@
package main package main
import ( import (
"./backend"
"./config"
"./hosts"
"./web"
"flag" "flag"
"log" "log"
"os"
"strings" "strings"
) )
func HandleErr(err error) {
if err != nil {
log.Fatal(err)
}
}
const ( const (
CmdBackend string = "backend" CmdBackend string = "backend"
CmdWeb string = "web" CmdWeb string = "web"
) )
var ( var serviceConfig *config.Config
DdnsDomain string
DdnsWebListenSocket string
DdnsRedisHost string
DdnsSoaFqdn string
Verbose bool
)
func init() { func init() {
flag.StringVar(&DdnsDomain, "domain", "", flag.StringVar(&serviceConfig.Domain, "domain", "",
"The subdomain which should be handled by DDNS") "The subdomain which should be handled by DDNS")
flag.StringVar(&DdnsWebListenSocket, "listen", ":8080", flag.StringVar(&serviceConfig.Listen, "listen", ":8080",
"Which socket should the web service use to bind itself") "Which socket should the web service use to bind itself")
flag.StringVar(&DdnsRedisHost, "redis", ":6379", flag.StringVar(&serviceConfig.RedisHost, "redis", ":6379",
"The Redis socket that should be used") "The Redis socket that should be used")
flag.StringVar(&DdnsSoaFqdn, "soa_fqdn", "", flag.StringVar(&serviceConfig.SOAFqdn, "soa_fqdn", "",
"The FQDN of the DNS server which is returned as a SOA record") "The FQDN of the DNS server which is returned as a SOA record")
flag.BoolVar(&Verbose, "verbose", false, flag.IntVar(&serviceConfig.HostExpirationDays, "expiration-days", 10,
"The number of days after a host is released when it is not updated")
flag.BoolVar(&serviceConfig.Verbose, "verbose", false,
"Be more verbose") "Be more verbose")
} }
func ValidateCommandArgs(cmd string) { func usage() {
if DdnsDomain == "" { log.Fatal("Usage: ./ddns [backend|web]")
}
func validateCommandArgs(cmd string) {
if serviceConfig.Domain == "" {
log.Fatal("You have to supply the domain via --domain=DOMAIN") log.Fatal("You have to supply the domain via --domain=DOMAIN")
} else if !strings.HasPrefix(DdnsDomain, ".") { } else if !strings.HasPrefix(serviceConfig.Domain, ".") {
// get the domain in the right format // get the domain in the right format
DdnsDomain = "." + DdnsDomain serviceConfig.Domain = "." + serviceConfig.Domain
} }
if cmd == CmdBackend { if cmd == CmdBackend {
if DdnsSoaFqdn == "" { if serviceConfig.SOAFqdn == "" {
log.Fatal("You have to supply the server FQDN via --soa_fqdn=FQDN") log.Fatal("You have to supply the server FQDN via --soa_fqdn=FQDN")
} }
} }
} }
func PrepareForExecution() string { func prepareForExecution() string {
flag.Parse() flag.Parse()
if len(flag.Args()) != 1 { if len(flag.Args()) != 1 {
@ -65,28 +65,24 @@ func PrepareForExecution() string {
} }
cmd := flag.Args()[0] cmd := flag.Args()[0]
ValidateCommandArgs(cmd) validateCommandArgs(cmd)
return cmd return cmd
} }
func main() { func main() {
cmd := PrepareForExecution() cmd := prepareForExecution()
conn := OpenConnection(DdnsRedisHost) redis := hosts.NewRedisBackend(serviceConfig)
defer conn.Close() defer redis.Close()
switch cmd { switch cmd {
case CmdBackend: case CmdBackend:
log.Printf("Starting PDNS Backend\n") backend.NewPowerDnsBackend(serviceConfig, redis, os.Stdin, os.Stdout).Run()
RunBackend(conn)
case CmdWeb: case CmdWeb:
log.Printf("Starting Web Service\n") web.NewWebService(serviceConfig, redis).Run()
RunWebService(conn)
default: default:
usage() usage()
} }
} }
func usage() {
log.Fatal("Usage: ./ddns [backend|web]")
}

36
hosts/hosts.go Normal file
View File

@ -0,0 +1,36 @@
package hosts
import (
"crypto/sha1"
"fmt"
"strings"
"time"
)
type Host struct {
Hostname string `redis:"-"`
Ip string `redis:"ip"`
Token string `redis:"token"`
}
func (h *Host) GenerateAndSetToken() {
hash := sha1.New()
hash.Write([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
hash.Write([]byte(h.Hostname))
h.Token = fmt.Sprintf("%x", hash.Sum(nil))
}
func (h *Host) IsIPv4() bool {
if strings.Contains(h.Ip, ".") {
return true
}
return false
}
type HostBackend interface {
GetHost(string) (*Host, error)
SetHost(*Host) error
}

76
hosts/redis.go Normal file
View File

@ -0,0 +1,76 @@
package hosts
import (
"../config"
"github.com/garyburd/redigo/redis"
"time"
)
type RedisBackend struct {
expirationSeconds int
pool *redis.Pool
}
func NewRedisBackend(config *config.Config) *RedisBackend {
return &RedisBackend{
expirationSeconds: config.HostExpirationDays * 24 * 60 * 60,
pool: &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", config.RedisHost)
if err != nil {
return nil, err
}
return c, err
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
},
}
}
func (r *RedisBackend) Close() {
r.pool.Close()
}
func (r *RedisBackend) GetHost(name string) (*Host, error) {
conn := r.pool.Get()
defer conn.Close()
host := Host{Hostname: name}
var err error
var data []interface{}
if data, err = redis.Values(conn.Do("HGETALL", host.Hostname)); err != nil {
return nil, err
}
if err = redis.ScanStruct(data, &host); err != nil {
return nil, err
}
return &host, nil
}
func (r *RedisBackend) SetHost(host *Host) error {
conn := r.pool.Get()
defer conn.Close()
var err error
if _, err = conn.Do("HMSET", redis.Args{}.Add(host.Hostname).AddFlat(host)...); err != nil {
return err
}
if _, err = conn.Do("EXPIRE", host.Hostname, r.expirationSeconds); err != nil {
return err
}
return nil
}

100
redis.go
View File

@ -1,100 +0,0 @@
package main
import (
"crypto/sha1"
"fmt"
"github.com/garyburd/redigo/redis"
"strings"
"time"
)
const HostExpirationSeconds int = 10 * 24 * 60 * 60 // 10 Days
type RedisConnection struct {
*redis.Pool
}
func OpenConnection(server string) *RedisConnection {
return &RedisConnection{newPool(server)}
}
func newPool(server string) *redis.Pool {
return &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", server)
if err != nil {
return nil, err
}
return c, err
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
}
func (self *RedisConnection) GetHost(name string) *Host {
conn := self.Get()
defer conn.Close()
host := Host{Hostname: name}
if self.HostExist(name) {
data, err := redis.Values(conn.Do("HGETALL", host.Hostname))
HandleErr(err)
HandleErr(redis.ScanStruct(data, &host))
}
return &host
}
func (self *RedisConnection) SaveHost(host *Host) {
conn := self.Get()
defer conn.Close()
_, err := conn.Do("HMSET", redis.Args{}.Add(host.Hostname).AddFlat(host)...)
HandleErr(err)
_, err = conn.Do("EXPIRE", host.Hostname, HostExpirationSeconds)
HandleErr(err)
}
func (self *RedisConnection) HostExist(name string) bool {
conn := self.Get()
defer conn.Close()
exists, err := redis.Bool(conn.Do("EXISTS", name))
HandleErr(err)
return exists
}
type Host struct {
Hostname string `redis:"-"`
Ip string `redis:"ip"`
Token string `redis:"token"`
}
func (self *Host) GenerateAndSetToken() {
hash := sha1.New()
hash.Write([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
hash.Write([]byte(self.Hostname))
self.Token = fmt.Sprintf("%x", hash.Sum(nil))
}
// Returns true when this host has a IPv4 Address and false if IPv6
func (self *Host) IsIPv4() bool {
if strings.Contains(self.Ip, ".") {
return true
}
return false
}

View File

@ -1,4 +1,4 @@
package main package web
const indexTemplate string = ` const indexTemplate string = `
<!DOCTYPE html> <!DOCTYPE html>

View File

@ -1,49 +1,74 @@
package main package web
import ( import (
"../config"
"../hosts"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"html/template" "html/template"
"log"
"net" "net"
"net/http" "net/http"
"regexp" "regexp"
) )
func RunWebService(conn *RedisConnection) { type WebService struct {
config *config.Config
hosts hosts.HostBackend
}
func NewWebService(config *config.Config, hosts hosts.HostBackend) *WebService {
return &WebService{
config: config,
hosts: hosts,
}
}
func (w *WebService) Run() {
r := gin.Default() r := gin.Default()
r.SetHTMLTemplate(BuildTemplate()) r.SetHTMLTemplate(buildTemplate())
r.GET("/", func(g *gin.Context) { r.GET("/", func(g *gin.Context) {
g.HTML(200, "index.html", gin.H{"domain": DdnsDomain}) g.HTML(200, "index.html", gin.H{"domain": w.config.Domain})
}) })
r.GET("/available/:hostname", func(c *gin.Context) { r.GET("/available/:hostname", func(c *gin.Context) {
hostname, valid := ValidHostname(c.Params.ByName("hostname")) hostname, valid := isValidHostname(c.Params.ByName("hostname"))
if valid {
_, err := w.hosts.GetHost(hostname)
valid = err == nil
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"available": valid && !conn.HostExist(hostname), "available": valid,
}) })
}) })
r.GET("/new/:hostname", func(c *gin.Context) { r.GET("/new/:hostname", func(c *gin.Context) {
hostname, valid := ValidHostname(c.Params.ByName("hostname")) hostname, valid := isValidHostname(c.Params.ByName("hostname"))
if !valid { if !valid {
c.JSON(404, gin.H{"error": "This hostname is not valid"}) c.JSON(404, gin.H{"error": "This hostname is not valid"})
return return
} }
if conn.HostExist(hostname) { var err error
if _, err = w.hosts.GetHost(hostname); err == nil {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"error": "This hostname has already been registered.", "error": "This hostname has already been registered.",
}) })
return return
} }
host := &Host{Hostname: hostname, Ip: "127.0.0.1"} host := &hosts.Host{Hostname: hostname, Ip: "127.0.0.1"}
host.GenerateAndSetToken() host.GenerateAndSetToken()
conn.SaveHost(host) if err = w.hosts.SetHost(host); err != nil {
c.JSON(400, gin.H{"error": "Could not register host."})
return
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"hostname": host.Hostname, "hostname": host.Hostname,
@ -53,7 +78,7 @@ func RunWebService(conn *RedisConnection) {
}) })
r.GET("/update/:hostname/:token", func(c *gin.Context) { r.GET("/update/:hostname/:token", func(c *gin.Context) {
hostname, valid := ValidHostname(c.Params.ByName("hostname")) hostname, valid := isValidHostname(c.Params.ByName("hostname"))
token := c.Params.ByName("token") token := c.Params.ByName("token")
if !valid { if !valid {
@ -61,15 +86,14 @@ func RunWebService(conn *RedisConnection) {
return return
} }
if !conn.HostExist(hostname) { host, err := w.hosts.GetHost(hostname)
if err != nil {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"error": "This hostname has not been registered or is expired.", "error": "This hostname has not been registered or is expired.",
}) })
return return
} }
host := conn.GetHost(hostname)
if host.Token != token { if host.Token != token {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"error": "You have supplied the wrong token to manipulate this host", "error": "You have supplied the wrong token to manipulate this host",
@ -77,7 +101,7 @@ func RunWebService(conn *RedisConnection) {
return return
} }
ip, err := GetRemoteAddr(c.Request) ip, err := extractRemoteAddr(c.Request)
if err != nil { if err != nil {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "Your sender IP address is not in the right format", "error": "Your sender IP address is not in the right format",
@ -86,7 +110,11 @@ func RunWebService(conn *RedisConnection) {
} }
host.Ip = ip host.Ip = ip
conn.SaveHost(host) if err = w.hosts.SetHost(host); err != nil {
c.JSON(400, gin.H{
"error": "Could not update registered IP address",
})
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"current_ip": ip, "current_ip": ip,
@ -94,13 +122,13 @@ func RunWebService(conn *RedisConnection) {
}) })
}) })
r.Run(DdnsWebListenSocket) r.Run(w.config.Listen)
} }
// Get the Remote Address of the client. At First we try to get the // Get the Remote Address of the client. At First we try to get the
// X-Forwarded-For Header which holds the IP if we are behind a proxy, // X-Forwarded-For Header which holds the IP if we are behind a proxy,
// otherwise the RemoteAddr is used // otherwise the RemoteAddr is used
func GetRemoteAddr(req *http.Request) (string, error) { func extractRemoteAddr(req *http.Request) (string, error) {
header_data, ok := req.Header["X-Forwarded-For"] header_data, ok := req.Header["X-Forwarded-For"]
if ok { if ok {
@ -112,14 +140,16 @@ func GetRemoteAddr(req *http.Request) (string, error) {
} }
// Get index template from bindata // Get index template from bindata
func BuildTemplate() *template.Template { func buildTemplate() *template.Template {
html, err := template.New("index.html").Parse(indexTemplate) html, err := template.New("index.html").Parse(indexTemplate)
HandleErr(err) if err != nil {
log.Fatal(err)
}
return html return html
} }
func ValidHostname(host string) (string, bool) { func isValidHostname(host string) (string, bool) {
valid, _ := regexp.Match("^[a-z0-9]{1,32}$", []byte(host)) valid, _ := regexp.Match("^[a-z0-9]{1,32}$", []byte(host))
return host, valid return host, valid