Files
traefik/pkg/proxy/httputil/proxy_test.go
2025-12-19 11:18:04 +01:00

183 lines
4.8 KiB
Go

package httputil
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
func Test_rewriteRequestBuilder(t *testing.T) {
tests := []struct {
name string
target *url.URL
passHostHeader bool
preservePath bool
incomingURL string
expectedScheme string
expectedHost string
expectedPath string
expectedRawPath string
expectedQuery string
notAppendXFF bool
}{
{
name: "Basic proxy",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/test?param=value",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/test",
expectedQuery: "param=value",
},
{
name: "Basic proxy - notAppendXFF",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/test?param=value",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/test",
expectedQuery: "param=value",
notAppendXFF: true,
},
{
name: "HTTPS target",
target: testhelpers.MustParseURL("https://secure.example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/secure",
expectedScheme: "https",
expectedHost: "secure.example.com",
expectedPath: "/secure",
},
{
name: "PassHostHeader",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: true,
preservePath: false,
incomingURL: "http://original.host/test",
expectedScheme: "http",
expectedHost: "original.host",
expectedPath: "/test",
},
{
name: "Preserve path",
target: testhelpers.MustParseURL("http://example.com/base"),
passHostHeader: false,
preservePath: true,
incomingURL: "http://localhost/foo%2Fbar",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/base/foo/bar",
expectedRawPath: "/base/foo%2Fbar",
},
{
name: "Handle semicolons in query",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/test?param1=value1;param2=value2",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/test",
expectedQuery: "param1=value1&param2=value2",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
rewriteRequest := rewriteRequestBuilder(test.target, test.passHostHeader, test.preservePath)
ctx := t.Context()
if test.notAppendXFF {
ctx = SetNotAppendXFF(ctx)
}
reqIn := httptest.NewRequest(http.MethodGet, test.incomingURL, http.NoBody)
reqIn = reqIn.WithContext(ctx)
reqIn.Header.Add("X-Forwarded-For", "1.2.3.4")
reqIn.RemoteAddr = "127.0.0.1:1234"
reqOut := httptest.NewRequest(http.MethodGet, test.incomingURL, http.NoBody)
pr := &httputil.ProxyRequest{
In: reqIn,
Out: reqOut,
}
rewriteRequest(pr)
if test.notAppendXFF {
assert.Equal(t, "1.2.3.4", reqOut.Header.Get("X-Forwarded-For"))
} else {
// When not disabled, X-Forwarded-For should have RemoteAddr appended
assert.Equal(t, "1.2.3.4, 127.0.0.1", reqOut.Header.Get("X-Forwarded-For"))
}
assert.Equal(t, test.expectedScheme, reqOut.URL.Scheme)
assert.Equal(t, test.expectedHost, reqOut.Host)
assert.Equal(t, test.expectedPath, reqOut.URL.Path)
assert.Equal(t, test.expectedRawPath, reqOut.URL.RawPath)
assert.Equal(t, test.expectedQuery, reqOut.URL.RawQuery)
assert.Empty(t, reqOut.RequestURI)
assert.Equal(t, "HTTP/1.1", reqOut.Proto)
assert.Equal(t, 1, reqOut.ProtoMajor)
assert.Equal(t, 1, reqOut.ProtoMinor)
assert.False(t, !test.passHostHeader && reqOut.Host != reqOut.URL.Host)
})
}
}
func Test_isTLSConfigError(t *testing.T) {
testCases := []struct {
desc string
err error
expected bool
}{
{
desc: "nil",
},
{
desc: "TLS ECHRejectionError",
err: &tls.ECHRejectionError{},
},
{
desc: "TLS AlertError",
err: tls.AlertError(0),
},
{
desc: "Random error",
err: errors.New("random error"),
},
{
desc: "TLS RecordHeaderError",
err: tls.RecordHeaderError{},
expected: true,
},
{
desc: "TLS CertificateVerificationError",
err: &tls.CertificateVerificationError{},
expected: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := isTLSConfigError(test.err)
require.Equal(t, test.expected, actual)
})
}
}