Fix verifyServerCertMatchesURI function behavior

Co-authored-by: Mathis Urien <contact.lbf38@gmail.com>
This commit is contained in:
Kevin Pollet
2026-01-28 12:04:05 +01:00
committed by GitHub
parent 121dfa6060
commit 3678d2332f
4 changed files with 93 additions and 39 deletions
+7 -4
View File
@@ -10,8 +10,9 @@ import (
// connectCert holds our certificates as a client of the Consul Connect protocol. // connectCert holds our certificates as a client of the Consul Connect protocol.
type connectCert struct { type connectCert struct {
root []string trustDomain string
leaf keyPair root []string
leaf keyPair
} }
func (c *connectCert) getRoot() []types.FileOrContent { func (c *connectCert) getRoot() []types.FileOrContent {
@@ -52,7 +53,8 @@ func (c *connectCert) equals(other *connectCert) bool {
} }
func (c *connectCert) serversTransport(item itemData) *dynamic.ServersTransport { func (c *connectCert) serversTransport(item itemData) *dynamic.ServersTransport {
spiffeID := fmt.Sprintf("spiffe:///ns/%s/dc/%s/svc/%s", spiffeID := fmt.Sprintf("spiffe://%s/ns/%s/dc/%s/svc/%s",
c.trustDomain,
item.Namespace, item.Namespace,
item.Datacenter, item.Datacenter,
item.Name, item.Name,
@@ -72,7 +74,8 @@ func (c *connectCert) serversTransport(item itemData) *dynamic.ServersTransport
} }
func (c *connectCert) tcpServersTransport(item itemData) *dynamic.TCPServersTransport { func (c *connectCert) tcpServersTransport(item itemData) *dynamic.TCPServersTransport {
spiffeID := fmt.Sprintf("spiffe:///ns/%s/dc/%s/svc/%s", spiffeID := fmt.Sprintf("spiffe://%s/ns/%s/dc/%s/svc/%s",
c.trustDomain,
item.Namespace, item.Namespace,
item.Datacenter, item.Datacenter,
item.Name, item.Name,
+16 -10
View File
@@ -465,7 +465,7 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error {
} }
leafWatcher.HybridHandler = leafWatcherHandler(ctx, leafChan) leafWatcher.HybridHandler = leafWatcherHandler(ctx, leafChan)
rootsChan := make(chan []string) rootsChan := make(chan caRootList)
rootsWatcher, err := watch.Parse(map[string]any{ rootsWatcher, err := watch.Parse(map[string]any{
"type": "connect_roots", "type": "connect_roots",
}) })
@@ -497,9 +497,9 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error {
}() }()
var ( var (
certInfo *connectCert certInfo *connectCert
leafCerts keyPair leafCert keyPair
rootCerts []string caRoots caRootList
) )
for { for {
@@ -510,13 +510,14 @@ func (p *Provider) watchConnectTLS(ctx context.Context) error {
case err := <-errChan: case err := <-errChan:
return fmt.Errorf("leaf or roots watcher terminated: %w", err) return fmt.Errorf("leaf or roots watcher terminated: %w", err)
case rootCerts = <-rootsChan: case caRoots = <-rootsChan:
case leafCerts = <-leafChan: case leafCert = <-leafChan:
} }
newCertInfo := &connectCert{ newCertInfo := &connectCert{
root: rootCerts, trustDomain: caRoots.trustDomain,
leaf: leafCerts, root: caRoots.roots,
leaf: leafCert,
} }
if newCertInfo.isReady() && !newCertInfo.equals(certInfo) { if newCertInfo.isReady() && !newCertInfo.equals(certInfo) {
log.Ctx(ctx).Debug().Msgf("Updating connect certs for service %s", p.ServiceName) log.Ctx(ctx).Debug().Msgf("Updating connect certs for service %s", p.ServiceName)
@@ -546,7 +547,12 @@ func (p *Provider) includesHealthStatus(status string) bool {
return false return false
} }
func rootsWatchHandler(ctx context.Context, dest chan<- []string) func(watch.BlockingParamVal, any) { type caRootList struct {
trustDomain string
roots []string
}
func rootsWatchHandler(ctx context.Context, dest chan<- caRootList) func(watch.BlockingParamVal, any) {
return func(_ watch.BlockingParamVal, raw any) { return func(_ watch.BlockingParamVal, raw any) {
if raw == nil { if raw == nil {
log.Ctx(ctx).Error().Msg("Root certificate watcher called with nil") log.Ctx(ctx).Error().Msg("Root certificate watcher called with nil")
@@ -566,7 +572,7 @@ func rootsWatchHandler(ctx context.Context, dest chan<- []string) func(watch.Blo
select { select {
case <-ctx.Done(): case <-ctx.Done():
case dest <- roots: case dest <- caRootList{trustDomain: v.TrustDomain, roots: roots}:
} }
} }
} }
+6 -25
View File
@@ -5,7 +5,6 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"net/url"
"os" "os"
"strings" "strings"
@@ -160,37 +159,19 @@ func VerifyPeerCertificate(uri string, cfg *tls.Config, rawCerts [][]byte) error
return nil return nil
} }
// verifyServerCertMatchesURI is used on tls connections dialed to a server // verifyServerCertMatchesURI verifies that the given certificate contains the specified URI in its SANs.
// to ensure that the certificate it presented has the correct URI.
func verifyServerCertMatchesURI(uri string, cert *x509.Certificate) error { func verifyServerCertMatchesURI(uri string, cert *x509.Certificate) error {
if cert == nil { if cert == nil {
return errors.New("peer certificate mismatch: no peer certificate presented") return errors.New("peer certificate mismatch: no peer certificate presented")
} }
// Our certs will only ever have a single URI for now so only check that for _, certURI := range cert.URIs {
if len(cert.URIs) < 1 { if strings.EqualFold(certURI.String(), uri) {
return errors.New("peer certificate mismatch: peer certificate invalid") return nil
}
} }
gotURI := cert.URIs[0] return fmt.Errorf("peer certificate mismatch: no SAN URI in peer certificate matches %s", uri)
// Override the hostname since we rely on x509 constraints to limit ability to spoof the trust domain if needed
// (i.e. because a root is shared with other PKI or Consul clusters).
// This allows for seamless migrations between trust domains.
expectURI := &url.URL{}
id, err := url.Parse(uri)
if err != nil {
return fmt.Errorf("%q is not a valid URI", uri)
}
*expectURI = *id
expectURI.Host = gotURI.Host
if strings.EqualFold(gotURI.String(), expectURI.String()) {
return nil
}
return fmt.Errorf("peer certificate mismatch got %s, want %s", gotURI, uri)
} }
// verifyChain performs standard TLS verification without enforcing remote hostname matching. // verifyChain performs standard TLS verification without enforcing remote hostname matching.
+64
View File
@@ -0,0 +1,64 @@
package tls
import (
"crypto/x509"
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func Test_verifyServerCertMatchesURI(t *testing.T) {
tests := []struct {
desc string
uri string
cert *x509.Certificate
expErr require.ErrorAssertionFunc
}{
{
desc: "returns error when certificate is nil",
uri: "spiffe://foo.com",
expErr: require.Error,
},
{
desc: "returns error when certificate has no URIs",
uri: "spiffe://foo.com",
cert: &x509.Certificate{URIs: nil},
expErr: require.Error,
},
{
desc: "returns error when no URI matches",
uri: "spiffe://foo.com",
cert: &x509.Certificate{URIs: []*url.URL{
{Scheme: "spiffe", Host: "other.org"},
}},
expErr: require.Error,
},
{
desc: "returns nil when URI matches",
uri: "spiffe://foo.com",
cert: &x509.Certificate{URIs: []*url.URL{
{Scheme: "spiffe", Host: "foo.com"},
}},
expErr: require.NoError,
},
{
desc: "returns nil when one of the URI matches",
uri: "spiffe://foo.com",
cert: &x509.Certificate{URIs: []*url.URL{
{Scheme: "spiffe", Host: "example.org"},
{Scheme: "spiffe", Host: "foo.com"},
}},
expErr: require.NoError,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
err := verifyServerCertMatchesURI(test.uri, test.cert)
test.expErr(t, err)
})
}
}