From 3678d2332f485dc446a51d8c59f13284f677f4a3 Mon Sep 17 00:00:00 2001 From: Kevin Pollet Date: Wed, 28 Jan 2026 12:04:05 +0100 Subject: [PATCH] Fix verifyServerCertMatchesURI function behavior Co-authored-by: Mathis Urien --- pkg/provider/consulcatalog/connect_tls.go | 11 ++-- pkg/provider/consulcatalog/consul_catalog.go | 26 +++++--- pkg/tls/certificate.go | 31 ++-------- pkg/tls/certificate_test.go | 64 ++++++++++++++++++++ 4 files changed, 93 insertions(+), 39 deletions(-) create mode 100644 pkg/tls/certificate_test.go diff --git a/pkg/provider/consulcatalog/connect_tls.go b/pkg/provider/consulcatalog/connect_tls.go index b73acb1a1..16456e51a 100644 --- a/pkg/provider/consulcatalog/connect_tls.go +++ b/pkg/provider/consulcatalog/connect_tls.go @@ -10,8 +10,9 @@ import ( // connectCert holds our certificates as a client of the Consul Connect protocol. type connectCert struct { - root []string - leaf keyPair + trustDomain string + root []string + leaf keyPair } func (c *connectCert) getRoot() []types.FileOrContent { @@ -52,7 +53,8 @@ func (c *connectCert) equals(other *connectCert) bool { } func (c *connectCert) serversTransport(item itemData) *dynamic.ServersTransport { - spiffeID := fmt.Sprintf("spiffe:///ns/%s/dc/%s/svc/%s", + spiffeID := fmt.Sprintf("spiffe://%s/ns/%s/dc/%s/svc/%s", + c.trustDomain, item.Namespace, item.Datacenter, item.Name, @@ -72,7 +74,8 @@ func (c *connectCert) serversTransport(item itemData) *dynamic.ServersTransport } func (c *connectCert) tcpServersTransport(item itemData) *dynamic.TCPServersTransport { - spiffeID := fmt.Sprintf("spiffe:///ns/%s/dc/%s/svc/%s", + spiffeID := fmt.Sprintf("spiffe://%s/ns/%s/dc/%s/svc/%s", + c.trustDomain, item.Namespace, item.Datacenter, item.Name, diff --git a/pkg/provider/consulcatalog/consul_catalog.go b/pkg/provider/consulcatalog/consul_catalog.go index 9da53ce26..0519db395 100644 --- a/pkg/provider/consulcatalog/consul_catalog.go +++ b/pkg/provider/consulcatalog/consul_catalog.go @@ -465,7 +465,7 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error { } leafWatcher.HybridHandler = leafWatcherHandler(ctx, leafChan) - rootsChan := make(chan []string) + rootsChan := make(chan caRootList) rootsWatcher, err := watch.Parse(map[string]any{ "type": "connect_roots", }) @@ -497,9 +497,9 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error { }() var ( - certInfo *connectCert - leafCerts keyPair - rootCerts []string + certInfo *connectCert + leafCert keyPair + caRoots caRootList ) for { @@ -510,13 +510,14 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error { case err := <-errChan: return fmt.Errorf("leaf or roots watcher terminated: %w", err) - case rootCerts = <-rootsChan: - case leafCerts = <-leafChan: + case caRoots = <-rootsChan: + case leafCert = <-leafChan: } newCertInfo := &connectCert{ - root: rootCerts, - leaf: leafCerts, + trustDomain: caRoots.trustDomain, + root: caRoots.roots, + leaf: leafCert, } if newCertInfo.isReady() && !newCertInfo.equals(certInfo) { log.Ctx(ctx).Debug().Msgf("Updating connect certs for service %s", p.ServiceName) @@ -546,7 +547,12 @@ func (p *Provider) includesHealthStatus(status string) bool { return false } -func rootsWatchHandler(ctx context.Context, dest chan<- []string) func(watch.BlockingParamVal, any) { +type caRootList struct { + trustDomain string + roots []string +} + +func rootsWatchHandler(ctx context.Context, dest chan<- caRootList) func(watch.BlockingParamVal, any) { return func(_ watch.BlockingParamVal, raw any) { if raw == nil { log.Ctx(ctx).Error().Msg("Root certificate watcher called with nil") @@ -566,7 +572,7 @@ func rootsWatchHandler(ctx context.Context, dest chan<- []string) func(watch.Blo select { case <-ctx.Done(): - case dest <- roots: + case dest <- caRootList{trustDomain: v.TrustDomain, roots: roots}: } } } diff --git a/pkg/tls/certificate.go b/pkg/tls/certificate.go index f99796783..da1325a52 100644 --- a/pkg/tls/certificate.go +++ b/pkg/tls/certificate.go @@ -5,7 +5,6 @@ import ( "crypto/x509" "errors" "fmt" - "net/url" "os" "strings" @@ -160,37 +159,19 @@ func VerifyPeerCertificate(uri string, cfg *tls.Config, rawCerts [][]byte) error return nil } -// verifyServerCertMatchesURI is used on tls connections dialed to a server -// to ensure that the certificate it presented has the correct URI. +// verifyServerCertMatchesURI verifies that the given certificate contains the specified URI in its SANs. func verifyServerCertMatchesURI(uri string, cert *x509.Certificate) error { if cert == nil { return errors.New("peer certificate mismatch: no peer certificate presented") } - // Our certs will only ever have a single URI for now so only check that - if len(cert.URIs) < 1 { - return errors.New("peer certificate mismatch: peer certificate invalid") + for _, certURI := range cert.URIs { + if strings.EqualFold(certURI.String(), uri) { + return nil + } } - gotURI := cert.URIs[0] - - // Override the hostname since we rely on x509 constraints to limit ability to spoof the trust domain if needed - // (i.e. because a root is shared with other PKI or Consul clusters). - // This allows for seamless migrations between trust domains. - - expectURI := &url.URL{} - id, err := url.Parse(uri) - if err != nil { - return fmt.Errorf("%q is not a valid URI", uri) - } - *expectURI = *id - expectURI.Host = gotURI.Host - - if strings.EqualFold(gotURI.String(), expectURI.String()) { - return nil - } - - return fmt.Errorf("peer certificate mismatch got %s, want %s", gotURI, uri) + return fmt.Errorf("peer certificate mismatch: no SAN URI in peer certificate matches %s", uri) } // verifyChain performs standard TLS verification without enforcing remote hostname matching. diff --git a/pkg/tls/certificate_test.go b/pkg/tls/certificate_test.go new file mode 100644 index 000000000..c97c26531 --- /dev/null +++ b/pkg/tls/certificate_test.go @@ -0,0 +1,64 @@ +package tls + +import ( + "crypto/x509" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_verifyServerCertMatchesURI(t *testing.T) { + tests := []struct { + desc string + uri string + cert *x509.Certificate + expErr require.ErrorAssertionFunc + }{ + { + desc: "returns error when certificate is nil", + uri: "spiffe://foo.com", + expErr: require.Error, + }, + { + desc: "returns error when certificate has no URIs", + uri: "spiffe://foo.com", + cert: &x509.Certificate{URIs: nil}, + expErr: require.Error, + }, + { + desc: "returns error when no URI matches", + uri: "spiffe://foo.com", + cert: &x509.Certificate{URIs: []*url.URL{ + {Scheme: "spiffe", Host: "other.org"}, + }}, + expErr: require.Error, + }, + { + desc: "returns nil when URI matches", + uri: "spiffe://foo.com", + cert: &x509.Certificate{URIs: []*url.URL{ + {Scheme: "spiffe", Host: "foo.com"}, + }}, + expErr: require.NoError, + }, + { + desc: "returns nil when one of the URI matches", + uri: "spiffe://foo.com", + cert: &x509.Certificate{URIs: []*url.URL{ + {Scheme: "spiffe", Host: "example.org"}, + {Scheme: "spiffe", Host: "foo.com"}, + }}, + expErr: require.NoError, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + err := verifyServerCertMatchesURI(test.uri, test.cert) + test.expErr(t, err) + }) + } +}