diff --git a/backend.go b/backend.go deleted file mode 100644 index b5b0620..0000000 --- a/backend.go +++ /dev/null @@ -1,99 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "os" - "strings" - "time" -) - -type responder func() - -func respondWithFAIL() { - fmt.Printf("FAIL\n") -} - -func respondWithEND() { - fmt.Printf("END\n") -} - -// This function implements the PowerDNS-Pipe-Backend protocol and generates -// the response data it possible -func RunBackend(conn *RedisConnection) { - bio := bufio.NewReader(os.Stdin) - - // handshake with PowerDNS - _, _, _ = bio.ReadLine() - fmt.Printf("OK\tDDNS Go Backend\n") - - for { - line, _, err := bio.ReadLine() - if err != nil { - respondWithFAIL() - continue - } - - HandleRequest(string(line), conn)() - } -} - -func HandleRequest(line string, conn *RedisConnection) responder { - if Verbose { - fmt.Printf("LOG\t'%s'\n", line) - } - - parts := strings.Split(line, "\t") - if len(parts) != 6 { - return respondWithFAIL - } - - query_name := parts[1] - query_class := parts[2] - query_type := parts[3] - query_id := parts[4] - - var response, record string - record = query_type - - switch query_type { - case "SOA": - response = fmt.Sprintf("%s. hostmaster.example.com. %d 1800 3600 7200 5", - DdnsSoaFqdn, getSoaSerial()) - - case "NS": - response = fmt.Sprintf("%s.", DdnsSoaFqdn) - - case "A": - case "ANY": - // get the host part of the fqdn: pi.d.example.org -> pi - hostname := "" - if strings.HasSuffix(query_name, DdnsDomain) { - hostname = query_name[:len(query_name)-len(DdnsDomain)] - } - - if hostname == "" || !conn.HostExist(hostname) { - return respondWithFAIL - } - - host := conn.GetHost(hostname) - response = host.Ip - - record = "A" - if !host.IsIPv4() { - record = "AAAA" - } - - default: - return respondWithFAIL - } - - fmt.Printf("DATA\t%s\t%s\t%s\t10\t%s\t%s\n", - query_name, query_class, record, query_id, response) - return respondWithEND -} - -func getSoaSerial() int64 { - // return current time in milliseconds - return time.Now().UnixNano() -} diff --git a/backend/backend.go b/backend/backend.go new file mode 100644 index 0000000..9e1b340 --- /dev/null +++ b/backend/backend.go @@ -0,0 +1,170 @@ +package backend + +import ( + "../config" + "../hosts" + "bufio" + "errors" + "fmt" + "io" + "strings" + "time" +) + +// PowerDnsBackend implements the PowerDNS-Pipe-Backend protocol ABI-Version 1 +type PowerDnsBackend struct { + config *config.Config + hosts hosts.HostBackend + in io.Reader + out io.Writer +} + +// NewPowerDnsBackend creates an instance of a PowerDNS-Pipe-Backend using the supplied parameters +func NewPowerDnsBackend(config *config.Config, backend hosts.HostBackend, in io.Reader, out io.Writer) *PowerDnsBackend { + return &PowerDnsBackend{ + config: config, + hosts: backend, + in: in, + out: out, + } +} + +// Run reads requests from an input (normally STDIN) and prints response messages to an output (normally STDOUT) +func (b *PowerDnsBackend) Run() { + responses := make(chan backendResponse, 5) + + go func() { + for response := range responses { + fmt.Fprintln(b.out, strings.Join(response, "\t")) + } + }() + + // handshake with PowerDNS + bio := bufio.NewReader(b.in) + _, _, _ = bio.ReadLine() + responses <- handshakeResponse + + for { + request, err := b.parseRequest(bio) + if err != nil { + responses <- failResponse + continue + } + + if err = b.handleRequest(request, responses); err != nil { + responses <- newResponse("LOG", err.Error()) + } + } +} + +// handleRequest handles the supplied request by sending response messages on the supplied responses channel +func (b *PowerDnsBackend) handleRequest(request *backendRequest, responses chan backendResponse) error { + defer b.commitRequest(responses) + + responseRecord := request.queryType + var response string + + switch request.queryType { + case "SOA": + response = fmt.Sprintf("%s. hostmaster%s. %d 1800 3600 7200 5", + b.config.SOAFqdn, b.config.Domain, b.currentSOASerial()) + + case "NS": + response = fmt.Sprintf("%s.", b.config.SOAFqdn) + + case "A", "AAAA", "ANY": + hostname, err := b.extractHostname(request.queryName) + if err != nil { + return err + } + + var host *hosts.Host + if host, err = b.hosts.GetHost(hostname); err != nil { + return err + } + + response = host.Ip + + responseRecord = "A" + if !host.IsIPv4() { + responseRecord = "AAAA" + } + + if (request.queryType == "A" || request.queryType == "AAAA") && request.queryType != responseRecord { + return errors.New("IP address is not valid for requested record") + } + + default: + return errors.New("Unsupported query type") + } + + responses <- newResponse("DATA", request.queryName, request.queryClass, responseRecord, "10", request.queryId, response) + + return nil +} + +func (b *PowerDnsBackend) commitRequest(responses chan backendResponse) { + responses <- endResponse +} + +// parseRequest reads a line from input and tries to build a request structure from it +func (b *PowerDnsBackend) parseRequest(input *bufio.Reader) (*backendRequest, error) { + line, _, err := input.ReadLine() + if err != nil { + return nil, err + } + + parts := strings.Split(string(line), "\t") + if len(parts) != 6 { + return nil, errors.New("Invalid line") + } + + return &backendRequest{ + query: parts[0], + queryName: parts[1], + queryClass: parts[2], + queryType: parts[3], + queryId: parts[4], + }, nil +} + +// extractHostname extract the host part of the fqdn: pi.d.example.org -> pi +func (b *PowerDnsBackend) extractHostname(rawQueryName string) (string, error) { + queryName := strings.ToLower(rawQueryName) + + hostname := "" + if strings.HasSuffix(queryName, b.config.Domain) { + hostname = queryName[:len(queryName)-len(b.config.Domain)] + } + + if hostname == "" { + return "", errors.New("Query name does not correspond to our domain") + } + + return hostname, nil +} + +// currentSOASerial get the current SOA serial by returning the current time in seconds +func (b *PowerDnsBackend) currentSOASerial() int64 { + return time.Now().Unix() +} + +type backendRequest struct { + query string + queryName string + queryClass string + queryType string + queryId string +} + +type backendResponse []string + +func newResponse(values ...string) backendResponse { + return values +} + +var ( + handshakeResponse backendResponse = []string{"OK", "DDNS Backend"} + endResponse backendResponse = []string{"END"} + failResponse backendResponse = []string{"FAIL"} +) diff --git a/backend/backend_test.go b/backend/backend_test.go new file mode 100644 index 0000000..dfd2c55 --- /dev/null +++ b/backend/backend_test.go @@ -0,0 +1,177 @@ +package backend + +import ( + c "../config" + h "../hosts" + "bufio" + "bytes" + "errors" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +type testHostBackend struct { + hosts map[string]*h.Host +} + +func (b *testHostBackend) GetHost(hostname string) (*h.Host, error) { + host, ok := b.hosts[hostname] + if ok { + return host, nil + } else { + return nil, errors.New("Host not found") + } +} + +func (b *testHostBackend) SetHost(host *h.Host) error { + b.hosts[host.Hostname] = host + return nil +} + +func buildBackend(domain string) (*c.Config, *testHostBackend, *PowerDnsBackend) { + config := &c.Config{ + Verbose: false, + Domain: domain, + SOAFqdn: "dns" + domain, + } + + hosts := &testHostBackend{ + hosts: map[string]*h.Host{ + "www": { + Hostname: "www", + Ip: "10.11.12.13", + Token: "abcdef", + }, + "v4": { + Hostname: "v4", + Ip: "10.10.10.10", + Token: "ghijkl", + }, + "v6": { + Hostname: "v6", + Ip: "2001:db8:85a3::8a2e:370:7334", + Token: "ghijkl", + }, + }, + } + + return config, hosts, NewPowerDnsBackend(config, hosts, os.Stdin, os.Stdout) +} + +func buildRequest(queryName, queryType string) *backendRequest { + return &backendRequest{ + query: "Q", + queryName: queryName, + queryClass: "IN", + queryType: queryType, + queryId: "-1", + } +} + +func readResponse(t *testing.T, responses chan backendResponse) backendResponse { + select { + case response, ok := <-responses: + assert.True(t, ok) + return response + default: + assert.FailNow(t, "Couldn't read response because it is not available ...") + return nil + } +} + +func TestParseRequest(t *testing.T) { + _, _, backend := buildBackend(".example.org") + + reader := bufio.NewReader(bytes.NewBufferString("Q\twww.example.org\tIN\tCNAME\t-1\t203.0.113.210\n")) + request, err := backend.parseRequest(reader) + assert.Nil(t, err) + assert.Equal(t, buildRequest("www.example.org", "CNAME"), request) + + reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\tIN\tSOA\t-1\t203.0.113.210\n")) + request, err = backend.parseRequest(reader) + assert.Nil(t, err) + assert.Equal(t, buildRequest("example.org", "SOA"), request) + + reader = bufio.NewReader(bytes.NewBufferString("Q\texample.org\n")) + request, err = backend.parseRequest(reader) + assert.NotNil(t, err) + assert.Nil(t, request) +} + +func TestRequestHandling(t *testing.T) { + _, _, backend := buildBackend(".example.org") + + responses := make(chan backendResponse, 2) + err := backend.handleRequest(buildRequest("example.org", "SOA"), responses) + assert.Nil(t, err) + soaResponse := readResponse(t, responses) + assert.Equal(t, newResponse("DATA", "example.org", "IN", "SOA", "10", "-1"), soaResponse[0:6]) + assert.Regexp(t, "dns\\.example\\.org\\. hostmaster\\.example.org\\. \\d+ 1800 3600 7200 5", soaResponse[6]) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("example.org", "NS"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "example.org", "IN", "NS", "10", "-1", "dns.example.org."), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("www.example.org", "ANY"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("www.example.org", "A"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "www.example.org", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + // Allow hostname to be mixed case which is used by Let's Encrypt for a little bit more security + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("wWW.eXaMPlE.oRg", "A"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "wWW.eXaMPlE.oRg", "IN", "A", "10", "-1", "10.11.12.13"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 1) + err = backend.handleRequest(buildRequest("notexisting.example.org", "A"), responses) + assert.NotNil(t, err) + assert.Equal(t, endResponse, readResponse(t, responses)) + + // Correct Handling of IPv4/IPv6 and ANY/A/AAAA + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("v4.example.org", "ANY"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("v4.example.org", "A"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "v4.example.org", "IN", "A", "10", "-1", "10.10.10.10"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 1) + err = backend.handleRequest(buildRequest("v4.example.org", "AAAA"), responses) + assert.NotNil(t, err) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("v6.example.org", "ANY"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 2) + err = backend.handleRequest(buildRequest("v6.example.org", "AAAA"), responses) + assert.Nil(t, err) + assert.Equal(t, newResponse("DATA", "v6.example.org", "IN", "AAAA", "10", "-1", "2001:db8:85a3::8a2e:370:7334"), readResponse(t, responses)) + assert.Equal(t, endResponse, readResponse(t, responses)) + + responses = make(chan backendResponse, 1) + err = backend.handleRequest(buildRequest("v6.example.org", "A"), responses) + assert.NotNil(t, err) + assert.Equal(t, endResponse, readResponse(t, responses)) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..2532421 --- /dev/null +++ b/config/config.go @@ -0,0 +1,10 @@ +package config + +type Config struct { + Verbose bool + Domain string + SOAFqdn string + HostExpirationDays int + Listen string + RedisHost string +} diff --git a/ddns.go b/ddns.go index 515d64b..11e2e57 100644 --- a/ddns.go +++ b/ddns.go @@ -1,63 +1,63 @@ package main import ( + "./backend" + "./config" + "./hosts" + "./web" "flag" "log" + "os" "strings" ) -func HandleErr(err error) { - if err != nil { - log.Fatal(err) - } -} - const ( CmdBackend string = "backend" CmdWeb string = "web" ) -var ( - DdnsDomain string - DdnsWebListenSocket string - DdnsRedisHost string - DdnsSoaFqdn string - Verbose bool -) +var serviceConfig *config.Config func init() { - flag.StringVar(&DdnsDomain, "domain", "", + flag.StringVar(&serviceConfig.Domain, "domain", "", "The subdomain which should be handled by DDNS") - flag.StringVar(&DdnsWebListenSocket, "listen", ":8080", + flag.StringVar(&serviceConfig.Listen, "listen", ":8080", "Which socket should the web service use to bind itself") - flag.StringVar(&DdnsRedisHost, "redis", ":6379", + flag.StringVar(&serviceConfig.RedisHost, "redis", ":6379", "The Redis socket that should be used") - flag.StringVar(&DdnsSoaFqdn, "soa_fqdn", "", + flag.StringVar(&serviceConfig.SOAFqdn, "soa_fqdn", "", "The FQDN of the DNS server which is returned as a SOA record") - flag.BoolVar(&Verbose, "verbose", false, + flag.IntVar(&serviceConfig.HostExpirationDays, "expiration-days", 10, + "The number of days after a host is released when it is not updated") + + flag.BoolVar(&serviceConfig.Verbose, "verbose", false, "Be more verbose") } -func ValidateCommandArgs(cmd string) { - if DdnsDomain == "" { +func usage() { + log.Fatal("Usage: ./ddns [backend|web]") +} + +func validateCommandArgs(cmd string) { + if serviceConfig.Domain == "" { log.Fatal("You have to supply the domain via --domain=DOMAIN") - } else if !strings.HasPrefix(DdnsDomain, ".") { + } else if !strings.HasPrefix(serviceConfig.Domain, ".") { // get the domain in the right format - DdnsDomain = "." + DdnsDomain + serviceConfig.Domain = "." + serviceConfig.Domain } if cmd == CmdBackend { - if DdnsSoaFqdn == "" { + if serviceConfig.SOAFqdn == "" { log.Fatal("You have to supply the server FQDN via --soa_fqdn=FQDN") } } } -func PrepareForExecution() string { +func prepareForExecution() string { flag.Parse() if len(flag.Args()) != 1 { @@ -65,28 +65,24 @@ func PrepareForExecution() string { } cmd := flag.Args()[0] - ValidateCommandArgs(cmd) + validateCommandArgs(cmd) return cmd } func main() { - cmd := PrepareForExecution() + cmd := prepareForExecution() - conn := OpenConnection(DdnsRedisHost) - defer conn.Close() + redis := hosts.NewRedisBackend(serviceConfig) + defer redis.Close() switch cmd { case CmdBackend: - log.Printf("Starting PDNS Backend\n") - RunBackend(conn) + backend.NewPowerDnsBackend(serviceConfig, redis, os.Stdin, os.Stdout).Run() + case CmdWeb: - log.Printf("Starting Web Service\n") - RunWebService(conn) + web.NewWebService(serviceConfig, redis).Run() + default: usage() } } - -func usage() { - log.Fatal("Usage: ./ddns [backend|web]") -} diff --git a/hosts/hosts.go b/hosts/hosts.go new file mode 100644 index 0000000..6d55d91 --- /dev/null +++ b/hosts/hosts.go @@ -0,0 +1,36 @@ +package hosts + +import ( + "crypto/sha1" + "fmt" + "strings" + "time" +) + +type Host struct { + Hostname string `redis:"-"` + Ip string `redis:"ip"` + Token string `redis:"token"` +} + +func (h *Host) GenerateAndSetToken() { + hash := sha1.New() + hash.Write([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) + hash.Write([]byte(h.Hostname)) + + h.Token = fmt.Sprintf("%x", hash.Sum(nil)) +} + +func (h *Host) IsIPv4() bool { + if strings.Contains(h.Ip, ".") { + return true + } + + return false +} + +type HostBackend interface { + GetHost(string) (*Host, error) + + SetHost(*Host) error +} diff --git a/hosts/redis.go b/hosts/redis.go new file mode 100644 index 0000000..0ca1919 --- /dev/null +++ b/hosts/redis.go @@ -0,0 +1,76 @@ +package hosts + +import ( + "../config" + "github.com/garyburd/redigo/redis" + "time" +) + +type RedisBackend struct { + expirationSeconds int + pool *redis.Pool +} + +func NewRedisBackend(config *config.Config) *RedisBackend { + return &RedisBackend{ + expirationSeconds: config.HostExpirationDays * 24 * 60 * 60, + pool: &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 240 * time.Second, + + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", config.RedisHost) + if err != nil { + return nil, err + } + return c, err + }, + + TestOnBorrow: func(c redis.Conn, t time.Time) error { + _, err := c.Do("PING") + return err + }, + }, + } +} + +func (r *RedisBackend) Close() { + r.pool.Close() +} + +func (r *RedisBackend) GetHost(name string) (*Host, error) { + conn := r.pool.Get() + defer conn.Close() + + host := Host{Hostname: name} + + var err error + var data []interface{} + + if data, err = redis.Values(conn.Do("HGETALL", host.Hostname)); err != nil { + return nil, err + } + + if err = redis.ScanStruct(data, &host); err != nil { + return nil, err + } + + return &host, nil +} + +func (r *RedisBackend) SetHost(host *Host) error { + conn := r.pool.Get() + defer conn.Close() + + var err error + + if _, err = conn.Do("HMSET", redis.Args{}.Add(host.Hostname).AddFlat(host)...); err != nil { + return err + } + + if _, err = conn.Do("EXPIRE", host.Hostname, r.expirationSeconds); err != nil { + return err + } + + return nil +} diff --git a/redis.go b/redis.go deleted file mode 100644 index 5f2c89f..0000000 --- a/redis.go +++ /dev/null @@ -1,100 +0,0 @@ -package main - -import ( - "crypto/sha1" - "fmt" - "github.com/garyburd/redigo/redis" - "strings" - "time" -) - -const HostExpirationSeconds int = 10 * 24 * 60 * 60 // 10 Days - -type RedisConnection struct { - *redis.Pool -} - -func OpenConnection(server string) *RedisConnection { - return &RedisConnection{newPool(server)} -} - -func newPool(server string) *redis.Pool { - return &redis.Pool{ - MaxIdle: 3, - - IdleTimeout: 240 * time.Second, - - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", server) - if err != nil { - return nil, err - } - return c, err - }, - - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } -} - -func (self *RedisConnection) GetHost(name string) *Host { - conn := self.Get() - defer conn.Close() - - host := Host{Hostname: name} - - if self.HostExist(name) { - data, err := redis.Values(conn.Do("HGETALL", host.Hostname)) - HandleErr(err) - - HandleErr(redis.ScanStruct(data, &host)) - } - - return &host -} - -func (self *RedisConnection) SaveHost(host *Host) { - conn := self.Get() - defer conn.Close() - - _, err := conn.Do("HMSET", redis.Args{}.Add(host.Hostname).AddFlat(host)...) - HandleErr(err) - - _, err = conn.Do("EXPIRE", host.Hostname, HostExpirationSeconds) - HandleErr(err) -} - -func (self *RedisConnection) HostExist(name string) bool { - conn := self.Get() - defer conn.Close() - - exists, err := redis.Bool(conn.Do("EXISTS", name)) - HandleErr(err) - - return exists -} - -type Host struct { - Hostname string `redis:"-"` - Ip string `redis:"ip"` - Token string `redis:"token"` -} - -func (self *Host) GenerateAndSetToken() { - hash := sha1.New() - hash.Write([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) - hash.Write([]byte(self.Hostname)) - - self.Token = fmt.Sprintf("%x", hash.Sum(nil)) -} - -// Returns true when this host has a IPv4 Address and false if IPv6 -func (self *Host) IsIPv4() bool { - if strings.Contains(self.Ip, ".") { - return true - } - - return false -} diff --git a/template.go b/web/template.go similarity index 99% rename from template.go rename to web/template.go index 117402f..bb2b1cd 100644 --- a/template.go +++ b/web/template.go @@ -1,4 +1,4 @@ -package main +package web const indexTemplate string = ` diff --git a/web.go b/web/web.go similarity index 57% rename from web.go rename to web/web.go index fe9c18a..66df140 100644 --- a/web.go +++ b/web/web.go @@ -1,49 +1,74 @@ -package main +package web import ( + "../config" + "../hosts" "fmt" "github.com/gin-gonic/gin" "html/template" + "log" "net" "net/http" "regexp" ) -func RunWebService(conn *RedisConnection) { +type WebService struct { + config *config.Config + hosts hosts.HostBackend +} + +func NewWebService(config *config.Config, hosts hosts.HostBackend) *WebService { + return &WebService{ + config: config, + hosts: hosts, + } +} + +func (w *WebService) Run() { r := gin.Default() - r.SetHTMLTemplate(BuildTemplate()) + r.SetHTMLTemplate(buildTemplate()) r.GET("/", func(g *gin.Context) { - g.HTML(200, "index.html", gin.H{"domain": DdnsDomain}) + g.HTML(200, "index.html", gin.H{"domain": w.config.Domain}) }) r.GET("/available/:hostname", func(c *gin.Context) { - hostname, valid := ValidHostname(c.Params.ByName("hostname")) + hostname, valid := isValidHostname(c.Params.ByName("hostname")) + + if valid { + _, err := w.hosts.GetHost(hostname) + valid = err == nil + } c.JSON(200, gin.H{ - "available": valid && !conn.HostExist(hostname), + "available": valid, }) }) r.GET("/new/:hostname", func(c *gin.Context) { - hostname, valid := ValidHostname(c.Params.ByName("hostname")) + hostname, valid := isValidHostname(c.Params.ByName("hostname")) if !valid { c.JSON(404, gin.H{"error": "This hostname is not valid"}) return } - if conn.HostExist(hostname) { + var err error + + if _, err = w.hosts.GetHost(hostname); err == nil { c.JSON(403, gin.H{ "error": "This hostname has already been registered.", }) return } - host := &Host{Hostname: hostname, Ip: "127.0.0.1"} + host := &hosts.Host{Hostname: hostname, Ip: "127.0.0.1"} host.GenerateAndSetToken() - conn.SaveHost(host) + if err = w.hosts.SetHost(host); err != nil { + c.JSON(400, gin.H{"error": "Could not register host."}) + return + } c.JSON(200, gin.H{ "hostname": host.Hostname, @@ -53,7 +78,7 @@ func RunWebService(conn *RedisConnection) { }) r.GET("/update/:hostname/:token", func(c *gin.Context) { - hostname, valid := ValidHostname(c.Params.ByName("hostname")) + hostname, valid := isValidHostname(c.Params.ByName("hostname")) token := c.Params.ByName("token") if !valid { @@ -61,15 +86,14 @@ func RunWebService(conn *RedisConnection) { return } - if !conn.HostExist(hostname) { + host, err := w.hosts.GetHost(hostname) + if err != nil { c.JSON(404, gin.H{ "error": "This hostname has not been registered or is expired.", }) return } - host := conn.GetHost(hostname) - if host.Token != token { c.JSON(403, gin.H{ "error": "You have supplied the wrong token to manipulate this host", @@ -77,7 +101,7 @@ func RunWebService(conn *RedisConnection) { return } - ip, err := GetRemoteAddr(c.Request) + ip, err := extractRemoteAddr(c.Request) if err != nil { c.JSON(400, gin.H{ "error": "Your sender IP address is not in the right format", @@ -86,7 +110,11 @@ func RunWebService(conn *RedisConnection) { } host.Ip = ip - conn.SaveHost(host) + if err = w.hosts.SetHost(host); err != nil { + c.JSON(400, gin.H{ + "error": "Could not update registered IP address", + }) + } c.JSON(200, gin.H{ "current_ip": ip, @@ -94,13 +122,13 @@ func RunWebService(conn *RedisConnection) { }) }) - r.Run(DdnsWebListenSocket) + r.Run(w.config.Listen) } // Get the Remote Address of the client. At First we try to get the // X-Forwarded-For Header which holds the IP if we are behind a proxy, // otherwise the RemoteAddr is used -func GetRemoteAddr(req *http.Request) (string, error) { +func extractRemoteAddr(req *http.Request) (string, error) { header_data, ok := req.Header["X-Forwarded-For"] if ok { @@ -112,14 +140,16 @@ func GetRemoteAddr(req *http.Request) (string, error) { } // Get index template from bindata -func BuildTemplate() *template.Template { +func buildTemplate() *template.Template { html, err := template.New("index.html").Parse(indexTemplate) - HandleErr(err) + if err != nil { + log.Fatal(err) + } return html } -func ValidHostname(host string) (string, bool) { +func isValidHostname(host string) (string, bool) { valid, _ := regexp.Match("^[a-z0-9]{1,32}$", []byte(host)) return host, valid