diff --git a/client/discover.go b/client/discover.go index bfd7aec93f5..442e35fe543 100644 --- a/client/discover.go +++ b/client/discover.go @@ -14,8 +14,27 @@ package client +import ( + "github.com/coreos/etcd/pkg/srv" +) + // Discoverer is an interface that wraps the Discover method. type Discoverer interface { // Discover looks up the etcd servers for the domain. Discover(domain string) ([]string, error) } + +type srvDiscover struct{} + +// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records. +func NewSRVDiscover() Discoverer { + return &srvDiscover{} +} + +func (d *srvDiscover) Discover(domain string) ([]string, error) { + srvs, err := srv.GetClient("etcd-client", domain) + if err != nil { + return nil, err + } + return srvs.Endpoints, nil +} diff --git a/client/srv.go b/client/srv.go deleted file mode 100644 index fdfa3435921..00000000000 --- a/client/srv.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2015 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "fmt" - "net" - "net/url" -) - -var ( - // indirection for testing - lookupSRV = net.LookupSRV -) - -type srvDiscover struct{} - -// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records. -func NewSRVDiscover() Discoverer { - return &srvDiscover{} -} - -// Discover looks up the etcd servers for the domain. -func (d *srvDiscover) Discover(domain string) ([]string, error) { - var urls []*url.URL - - updateURLs := func(service, scheme string) error { - _, addrs, err := lookupSRV(service, "tcp", domain) - if err != nil { - return err - } - for _, srv := range addrs { - urls = append(urls, &url.URL{ - Scheme: scheme, - Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), - }) - } - return nil - } - - errHTTPS := updateURLs("etcd-client-ssl", "https") - errHTTP := updateURLs("etcd-client", "http") - - if errHTTPS != nil && errHTTP != nil { - return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) - } - - endpoints := make([]string, len(urls)) - for i := range urls { - endpoints[i] = urls[i].String() - } - return endpoints, nil -} diff --git a/client/srv_test.go b/client/srv_test.go deleted file mode 100644 index 64cf6032322..00000000000 --- a/client/srv_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2015 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "errors" - "net" - "reflect" - "testing" -) - -func TestSRVDiscover(t *testing.T) { - defer func() { lookupSRV = net.LookupSRV }() - - tests := []struct { - withSSL []*net.SRV - withoutSSL []*net.SRV - expected []string - }{ - { - []*net.SRV{}, - []*net.SRV{}, - []string{}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{}, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 7001}, - }, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 7001}, - }, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, - }, - { - []*net.SRV{ - {Target: "a.example.com", Port: 2480}, - {Target: "b.example.com", Port: 2480}, - {Target: "c.example.com", Port: 2480}, - }, - []*net.SRV{}, - []string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"}, - }, - } - - for i, tt := range tests { - lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { - if service == "etcd-client-ssl" { - return "", tt.withSSL, nil - } - if service == "etcd-client" { - return "", tt.withoutSSL, nil - } - return "", nil, errors.New("Unknown service in mock") - } - - d := NewSRVDiscover() - - endpoints, err := d.Discover("example.com") - if err != nil { - t.Fatalf("%d: err: %#v", i, err) - } - - if !reflect.DeepEqual(endpoints, tt.expected) { - t.Errorf("#%d: endpoints = %v, want %v", i, endpoints, tt.expected) - } - - } -} diff --git a/embed/config.go b/embed/config.go index 93431d1c672..e3926f66cb4 100644 --- a/embed/config.go +++ b/embed/config.go @@ -22,10 +22,10 @@ import ( "net/url" "strings" - "github.com/coreos/etcd/discovery" "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/pkg/cors" "github.com/coreos/etcd/pkg/netutil" + "github.com/coreos/etcd/pkg/srv" "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" @@ -321,11 +321,15 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok urlsmap[cfg.Name] = cfg.APUrls token = cfg.Durl case cfg.DNSCluster != "": - var clusterStr string - clusterStr, err = discovery.SRVGetCluster(cfg.Name, cfg.DNSCluster, cfg.APUrls) - if err != nil { - return nil, "", err + clusterStrs, cerr := srv.GetCluster("etcd-server", cfg.Name, cfg.DNSCluster, cfg.APUrls) + if cerr != nil { + plog.Errorf("couldn't resolve during SRV discovery (%v)", cerr) + return nil, "", cerr + } + for _, s := range clusterStrs { + plog.Noticef("got bootstrap from DNS for etcd-server at %s", s) } + clusterStr := strings.Join(clusterStrs, ",") if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" { cfg.PeerTLSInfo.ServerName = cfg.DNSCluster } diff --git a/etcdmain/util.go b/etcdmain/util.go index 23e19b44057..5de07275b5b 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -18,7 +18,7 @@ import ( "fmt" "os" - "github.com/coreos/etcd/client" + "github.com/coreos/etcd/pkg/srv" "github.com/coreos/etcd/pkg/transport" ) @@ -26,11 +26,12 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string if dns == "" { return nil } - endpoints, err := client.NewSRVDiscover().Discover(dns) + srvs, err := srv.GetClient("etcd-client", dns) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } + endpoints = srvs.Endpoints plog.Infof("discovered the cluster %s from %s", endpoints, dns) if insecure { return endpoints diff --git a/discovery/srv.go b/pkg/srv/srv.go similarity index 50% rename from discovery/srv.go rename to pkg/srv/srv.go index 782b6888f54..71a0af7956e 100644 --- a/discovery/srv.go +++ b/pkg/srv/srv.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package discovery +package srv import ( "fmt" @@ -25,14 +25,13 @@ import ( var ( // indirection for testing - lookupSRV = net.LookupSRV + lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict resolveTCPAddr = net.ResolveTCPAddr ) -// SRVGetCluster gets the cluster information via DNS discovery. -// TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap) +// GetCluster gets the cluster information via DNS discovery. // Also sees each entry as a separate instance. -func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { +func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) { tempName := int(0) tcp2ap := make(map[string]url.URL) @@ -40,8 +39,7 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { for _, url := range apurls { tcpAddr, err := resolveTCPAddr("tcp", url.Host) if err != nil { - plog.Errorf("couldn't resolve host %s during SRV discovery", url.Host) - return "", err + return nil, err } tcp2ap[tcpAddr.String()] = url } @@ -55,9 +53,9 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { for _, srv := range addrs { port := fmt.Sprintf("%d", srv.Port) host := net.JoinHostPort(srv.Target, port) - tcpAddr, err := resolveTCPAddr("tcp", host) - if err != nil { - plog.Warningf("couldn't resolve host %s during SRV discovery", host) + tcpAddr, terr := resolveTCPAddr("tcp", host) + if terr != nil { + terr = err continue } n := "" @@ -73,31 +71,69 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { shortHost := strings.TrimSuffix(srv.Target, ".") urlHost := net.JoinHostPort(shortHost, port) stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost)) - plog.Noticef("got bootstrap from DNS for %s at %s://%s", service, scheme, urlHost) if ok && url.Scheme != scheme { - plog.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) + err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) } } + if len(stringParts) == 0 { + return err + } return nil } failCount := 0 - err := updateNodeMap("etcd-server-ssl", "https") + err := updateNodeMap(service+"-ssl", "https") srvErr := make([]string, 2) if err != nil { - srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _etcd-server-ssl %s", err) + srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err) failCount++ } - err = updateNodeMap("etcd-server", "http") + err = updateNodeMap(service, "http") if err != nil { - srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _etcd-server %s", err) + srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err) failCount++ } if failCount == 2 { - plog.Warningf(srvErr[0]) - plog.Warningf(srvErr[1]) - plog.Errorf("SRV discovery failed: too many errors querying DNS SRV records") - return "", err + return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1]) + } + return stringParts, nil +} + +type SRVClients struct { + Endpoints []string + SRVs []*net.SRV +} + +// GetClient looks up the client endpoints for a service and domain. +func GetClient(service, domain string) (*SRVClients, error) { + var urls []*url.URL + var srvs []*net.SRV + + updateURLs := func(service, scheme string) error { + _, addrs, err := lookupSRV(service, "tcp", domain) + if err != nil { + return err + } + for _, srv := range addrs { + urls = append(urls, &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), + }) + } + srvs = append(srvs, addrs...) + return nil + } + + errHTTPS := updateURLs(service+"-ssl", "https") + errHTTP := updateURLs(service, "http") + + if errHTTPS != nil && errHTTP != nil { + return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) + } + + endpoints := make([]string, len(urls)) + for i := range urls { + endpoints[i] = urls[i].String() } - return strings.Join(stringParts, ","), nil + return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil } diff --git a/discovery/srv_test.go b/pkg/srv/srv_test.go similarity index 59% rename from discovery/srv_test.go rename to pkg/srv/srv_test.go index b9914a5544c..0386c9d2a09 100644 --- a/discovery/srv_test.go +++ b/pkg/srv/srv_test.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package discovery +package srv import ( "errors" "net" + "reflect" "strings" "testing" @@ -110,12 +111,90 @@ func TestSRVGetCluster(t *testing.T) { return "", nil, errors.New("Unknown service in mock") } urls := testutil.MustNewURLs(t, tt.urls) - str, err := SRVGetCluster(name, "example.com", urls) + str, err := GetCluster("etcd-server", name, "example.com", urls) if err != nil { t.Fatalf("%d: err: %#v", i, err) } - if str != tt.expected { + if strings.Join(str, ",") != tt.expected { t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected) } } } + +func TestSRVDiscover(t *testing.T) { + defer func() { lookupSRV = net.LookupSRV }() + + tests := []struct { + withSSL []*net.SRV + withoutSSL []*net.SRV + expected []string + }{ + { + []*net.SRV{}, + []*net.SRV{}, + []string{}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{}, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{ + {Target: "10.0.0.1", Port: 7001}, + }, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{ + {Target: "10.0.0.1", Port: 7001}, + }, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + }, + { + []*net.SRV{ + {Target: "a.example.com", Port: 2480}, + {Target: "b.example.com", Port: 2480}, + {Target: "c.example.com", Port: 2480}, + }, + []*net.SRV{}, + []string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"}, + }, + } + + for i, tt := range tests { + lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { + if service == "etcd-client-ssl" { + return "", tt.withSSL, nil + } + if service == "etcd-client" { + return "", tt.withoutSSL, nil + } + return "", nil, errors.New("Unknown service in mock") + } + + srvs, err := GetClient("etcd-client", "example.com") + if err != nil { + t.Fatalf("%d: err: %#v", i, err) + } + + if !reflect.DeepEqual(srvs.Endpoints, tt.expected) { + t.Errorf("#%d: endpoints = %v, want %v", i, srvs.Endpoints, tt.expected) + } + + } +}