diff --git a/docs/arch/11-auth-server-storage.md b/docs/arch/11-auth-server-storage.md index 4198c16b36..c8d49674f4 100644 --- a/docs/arch/11-auth-server-storage.md +++ b/docs/arch/11-auth-server-storage.md @@ -233,3 +233,33 @@ When passing configuration across process boundaries (operator → proxy-runner) - [Redis Storage Configuration Guide](../redis-storage.md) — User-facing setup guide - [Operator Architecture](09-operator-architecture.md) — CRD and controller design - [Core Concepts](02-core-concepts.md) — Platform terminology + +## CIMD Storage Decorator + +When `authServer.cimd.enabled: true` is set, the embedded authorization server wraps its storage backend in a `CIMDStorageDecorator` before passing it to fosite. This decorator enables MCP clients to present HTTPS URLs as `client_id` values without first calling `/oauth/register`. + +### What it does + +`CIMDStorageDecorator` embeds the full `storage.Storage` interface and overrides only `GetClient`. When fosite calls `GetClient("https://vscode.dev/oauth/client-metadata.json")` during an authorization request: + +1. The decorator detects the HTTPS URL using `oauthproto.IsClientIDMetadataDocumentURL` +2. It fetches the Client ID Metadata Document from that URL via `pkg/oauthproto/cimd.FetchClientMetadataDocument` (with SSRF protection, 10 KB cap, 5-second timeout) +3. It builds a `fosite.Client` from the document fields, caches it with a configurable TTL, and returns it to fosite +4. Concurrent fetches for the same URL are deduplicated via `singleflight` + +All other `Storage` methods (`RegisterClient`, token storage, upstream token storage, etc.) delegate to the underlying backend unchanged. DCR clients (opaque string IDs) continue to work exactly as before. + +### Unwrap pattern + +`CIMDStorageDecorator` implements `Unwrap() Storage` to expose the concrete backend through the decorator layer. Two call sites in `server_impl.go` depend on this: + +- **`DCRCredentialStore` assertion** (`newServer`): The `DCRCredentialStore` interface is narrower than `Storage` and not embedded in it. The assertion `unwrapStorage(stor).(storage.DCRCredentialStore)` reaches the concrete backend through the decorator. +- **`RedisStorage` migration** (`runLegacyMigration`): A type assertion to `*storage.RedisStorage` is needed to run a one-shot data migration. Same `unwrapStorage` call. + +Both call sites use the `unwrapStorage(stor)` helper rather than asserting directly on `stor`. + +### Air-gapped environments + +When the embedded authorization server is deployed in an environment that cannot reach `https://toolhive.dev/oauth/client-metadata.json` or any public CIMD metadata URL, set `authServer.cimd.enabled: false`. Clients will fall back to DCR (`/oauth/register`) which uses only the local storage backend and requires no outbound connectivity. + +**Implementation:** `pkg/authserver/storage/cimd_decorator.go` diff --git a/pkg/authserver/server/registration/client.go b/pkg/authserver/server/registration/client.go index 4f72a8582c..d6cdcf39a5 100644 --- a/pkg/authserver/server/registration/client.go +++ b/pkg/authserver/server/registration/client.go @@ -43,12 +43,14 @@ import ( // native apps that register redirect URIs like "http://localhost/callback" and then // request authorization with dynamic ports like "http://localhost:57403/callback". type LoopbackClient struct { - *fosite.DefaultClient + *fosite.DefaultOpenIDConnectClient } -// NewLoopbackClient creates a new LoopbackClient wrapping the provided DefaultClient. -func NewLoopbackClient(client *fosite.DefaultClient) *LoopbackClient { - return &LoopbackClient{DefaultClient: client} +// NewLoopbackClient creates a new LoopbackClient wrapping the provided client. +// The wrapper preserves all OIDC fields (including TokenEndpointAuthMethod) +// while adding RFC 8252 §7.3 dynamic port matching for loopback redirect URIs. +func NewLoopbackClient(client *fosite.DefaultOpenIDConnectClient) *LoopbackClient { + return &LoopbackClient{DefaultOpenIDConnectClient: client} } // MatchRedirectURI checks if the given redirect URI matches one of the client's @@ -167,8 +169,14 @@ func New(cfg Config) (fosite.Client, error) { // Wrap public clients in LoopbackClient for RFC 8252 Section 7.3 // dynamic port matching for native app loopback redirect URIs. + // Use DefaultOpenIDConnectClient so TokenEndpointAuthMethod ("none" for + // public clients) is preserved through the LoopbackClient wrapper. if cfg.Public { - return NewLoopbackClient(defaultClient), nil + oidcClient := &fosite.DefaultOpenIDConnectClient{ + DefaultClient: defaultClient, + TokenEndpointAuthMethod: "none", + } + return NewLoopbackClient(oidcClient), nil } return defaultClient, nil diff --git a/pkg/authserver/server/registration/client_test.go b/pkg/authserver/server/registration/client_test.go index b536eb50a6..37f6280006 100644 --- a/pkg/authserver/server/registration/client_test.go +++ b/pkg/authserver/server/registration/client_test.go @@ -32,7 +32,7 @@ func TestNewLoopbackClient(t *testing.T) { Public: true, } - client := NewLoopbackClient(defaultClient) + client := NewLoopbackClient(&fosite.DefaultOpenIDConnectClient{DefaultClient: defaultClient}) assert.NotNil(t, client) assert.Equal(t, "test-client", client.GetID()) @@ -196,10 +196,12 @@ func TestLoopbackClient_MatchRedirectURI(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - client := NewLoopbackClient(&fosite.DefaultClient{ - ID: "test-client", - RedirectURIs: tt.registeredURIs, - Public: true, + client := NewLoopbackClient(&fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "test-client", + RedirectURIs: tt.registeredURIs, + Public: true, + }, }) result := client.MatchRedirectURI(tt.requestedURI) @@ -247,10 +249,12 @@ func TestLoopbackClient_GetMatchingRedirectURI(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - client := NewLoopbackClient(&fosite.DefaultClient{ - ID: "test-client", - RedirectURIs: tt.registeredURIs, - Public: true, + client := NewLoopbackClient(&fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "test-client", + RedirectURIs: tt.registeredURIs, + Public: true, + }, }) result := client.GetMatchingRedirectURI(tt.requestedURI) diff --git a/pkg/authserver/server_impl.go b/pkg/authserver/server_impl.go index 83d063adc5..0c61b7d73b 100644 --- a/pkg/authserver/server_impl.go +++ b/pkg/authserver/server_impl.go @@ -107,9 +107,10 @@ func newServer(ctx context.Context, cfg Config, stor storage.Storage, opts ...se // provably safe for the production backends; surfacing a bad backend as // a constructor error keeps misconfiguration fail-loud at boot rather // than at first DCR resolve. - dcrStore, ok := stor.(storage.DCRCredentialStore) + baseStore := unwrapStorage(stor) + dcrStore, ok := baseStore.(storage.DCRCredentialStore) if !ok { - return nil, fmt.Errorf("storage backend %T does not implement storage.DCRCredentialStore", stor) + return nil, fmt.Errorf("storage backend %T does not implement storage.DCRCredentialStore", baseStore) } slog.Debug("creating OAuth2 configuration") @@ -168,13 +169,8 @@ func newServer(ctx context.Context, cfg Config, stor storage.Storage, opts ...se // Run one-shot bulk migration of legacy data before handler construction. // TODO(migration): Remove once all deployments have upgraded past this version. - if rs, ok := stor.(*storage.RedisStorage); ok { - for i := range cfg.Upstreams { - upCfg := &cfg.Upstreams[i] - if err := rs.MigrateLegacyUpstreamData(ctx, upCfg.Name, string(upCfg.Type)); err != nil { - return nil, fmt.Errorf("legacy data migration failed for upstream %q: %w", upCfg.Name, err) - } - } + if err := runLegacyMigration(ctx, stor, cfg.Upstreams); err != nil { + return nil, err } handlerInstance, err := handlers.NewHandler(fositeProvider, authServerConfig, stor, upstreams) @@ -294,3 +290,31 @@ func createProvider(authServerConfig *oauthserver.AuthorizationServerConfig, sto compose.OAuth2PKCEFactory, // PKCE for public clients ) } + +// unwrapStorage peels off one decorator layer if the storage implements +// Unwrap(), returning the concrete backend. Both newServer (DCRCredentialStore +// assertion) and runLegacyMigration (RedisStorage type assertion) need this. +func unwrapStorage(stor storage.Storage) storage.Storage { + if unwrapper, ok := stor.(interface{ Unwrap() storage.Storage }); ok { + return unwrapper.Unwrap() + } + return stor +} + +// runLegacyMigration runs one-shot Redis data migrations before handlers are +// constructed. It is a no-op for non-Redis backends and passes through any +// decorator wrapping so the concrete type can be reached. +func runLegacyMigration(ctx context.Context, stor storage.Storage, upstreams []UpstreamConfig) error { + base := unwrapStorage(stor) + rs, ok := base.(*storage.RedisStorage) + if !ok { + return nil + } + for i := range upstreams { + upCfg := &upstreams[i] + if err := rs.MigrateLegacyUpstreamData(ctx, upCfg.Name, string(upCfg.Type)); err != nil { + return fmt.Errorf("legacy data migration failed for upstream %q: %w", upCfg.Name, err) + } + } + return nil +} diff --git a/pkg/authserver/storage/cimd_decorator.go b/pkg/authserver/storage/cimd_decorator.go new file mode 100644 index 0000000000..0a1fc2f0e3 --- /dev/null +++ b/pkg/authserver/storage/cimd_decorator.go @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package storage + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + lru "github.com/hashicorp/golang-lru/v2" + "github.com/ory/fosite" + "golang.org/x/sync/singleflight" + + "github.com/stacklok/toolhive/pkg/authserver/server/registration" + "github.com/stacklok/toolhive/pkg/oauthproto" + "github.com/stacklok/toolhive/pkg/oauthproto/cimd" +) + +// CIMDStorageDecorator wraps storage.Storage and intercepts GetClient calls +// for HTTPS client_id values, fetching and caching the corresponding Client +// ID Metadata Document instead of requiring prior DCR registration. +// +// All other Storage methods delegate to the underlying storage unchanged. +// Only GetClient is overridden. DCR clients (opaque IDs) continue to work +// exactly as before. +type CIMDStorageDecorator struct { + Storage // embed full interface — all methods delegate + sf singleflight.Group // deduplicates concurrent fetches for the same URL + cache *lru.Cache[string, *cimdCacheEntry] + ttl time.Duration +} + +type cimdCacheEntry struct { + client fosite.Client + expires time.Time +} + +// NewCIMDStorageDecorator wraps base with CIMD client lookup. When enabled=false +// it returns base unchanged (no allocation). cacheMaxSize must be >= 1; +// fallbackTTL is the fixed TTL applied to every cache entry (Cache-Control +// header parsing is not yet implemented; all entries use this value). +func NewCIMDStorageDecorator( + base Storage, + enabled bool, + cacheMaxSize int, + fallbackTTL time.Duration, +) (Storage, error) { + if !enabled { + return base, nil + } + + if cacheMaxSize < 1 { + return nil, fmt.Errorf("CIMD storage decorator cacheMaxSize must be >= 1, got %d", cacheMaxSize) + } + + c, err := lru.New[string, *cimdCacheEntry](cacheMaxSize) + if err != nil { + return nil, fmt.Errorf("failed to create CIMD LRU cache: %w", err) + } + + return &CIMDStorageDecorator{ + Storage: base, + cache: c, + ttl: fallbackTTL, + }, nil +} + +// GetClient intercepts HTTPS client_id values to resolve them via CIMD. +// Opaque DCR-issued IDs are delegated to the underlying storage unchanged. +func (d *CIMDStorageDecorator) GetClient(ctx context.Context, id string) (fosite.Client, error) { + if !oauthproto.IsClientIDMetadataDocumentURL(id) { + return d.Storage.GetClient(ctx, id) + } + return d.fetchOrCached(ctx, id) +} + +// Unwrap returns the underlying storage so that type assertions (e.g. for +// storage.DCRCredentialStore in server_impl.go) can reach the concrete type. +func (d *CIMDStorageDecorator) Unwrap() Storage { + return d.Storage +} + +func (d *CIMDStorageDecorator) fetchOrCached(ctx context.Context, id string) (fosite.Client, error) { + // Check cache first (outside singleflight to avoid holding the group lock for cache hits) + if entry, ok := d.cache.Get(id); ok && time.Now().Before(entry.expires) { + return entry.client, nil + } + + // Deduplicate concurrent fetches for the same URL. The shared fetch uses a + // context detached from the caller so that one caller cancelling does not + // abort the in-flight request for other waiters. The HTTP client inside + // FetchClientMetadataDocument enforces its own 5-second timeout. + fetchCtx := context.WithoutCancel(ctx) + result, err, _ := d.sf.Do(id, func() (interface{}, error) { + // Re-check cache inside singleflight (another goroutine may have populated it) + if entry, ok := d.cache.Get(id); ok && time.Now().Before(entry.expires) { + return entry.client, nil + } + return d.fetch(fetchCtx, id) + }) + if err != nil { + return nil, err + } + client, ok := result.(fosite.Client) + if !ok { + return nil, fmt.Errorf("CIMD singleflight returned unexpected type %T", result) + } + return client, nil +} + +func (d *CIMDStorageDecorator) fetch(ctx context.Context, id string) (fosite.Client, error) { + doc, err := cimd.FetchClientMetadataDocument(ctx, id) + if err != nil { + return nil, fmt.Errorf("%w: %w", fosite.ErrNotFound.WithHint("CIMD fetch failed"), err) + } + + // Reject documents that declare an auth method this AS does not support. + // The embedded AS only advertises "none"; accepting a doc that says + // "private_key_jwt" and then silently treating the client as public would + // mislead operators and break clients that actually try to use JWT assertions. + if m := doc.TokenEndpointAuthMethod; m != "" && m != defaultCIMDTokenEndpointAuthMethod { + return nil, fmt.Errorf("%w: CIMD document at %s claims token_endpoint_auth_method %q "+ + "but this server only supports %q", + fosite.ErrNotFound.WithHint("unsupported token_endpoint_auth_method"), + id, m, defaultCIMDTokenEndpointAuthMethod) + } + + client := buildFositeClient(doc) + + d.cache.Add(id, &cimdCacheEntry{ + client: client, + expires: time.Now().Add(d.ttl), + }) + + return client, nil +} + +// defaultCIMDGrantTypes are the OAuth 2.0 grant types applied when the CIMD +// document omits grant_types. CIMD clients are typically public native apps +// that use the authorization code flow with refresh token rotation. +var defaultCIMDGrantTypes = []string{"authorization_code", "refresh_token"} + +// defaultCIMDResponseTypes are the OAuth 2.0 response types applied when the +// CIMD document omits response_types. +var defaultCIMDResponseTypes = []string{"code"} + +// defaultCIMDTokenEndpointAuthMethod is the token endpoint authentication +// method applied when the CIMD document omits token_endpoint_auth_method. +// Documents that declare any other value are rejected by fetch() before +// buildFositeClient is called. +const defaultCIMDTokenEndpointAuthMethod = "none" + +// buildFositeClient converts a ClientMetadataDocument into a fosite.Client. +// Redirect URIs containing http://localhost are wrapped in a LoopbackClient +// so that RFC 8252 §7.3 dynamic port matching applies. +func buildFositeClient(doc *cimd.ClientMetadataDocument) fosite.Client { + grantTypes := doc.GrantTypes + if len(grantTypes) == 0 { + grantTypes = defaultCIMDGrantTypes + } + + responseTypes := doc.ResponseTypes + if len(responseTypes) == 0 { + responseTypes = defaultCIMDResponseTypes + } + + tokenEndpointAuthMethod := doc.TokenEndpointAuthMethod + if tokenEndpointAuthMethod == "" { + tokenEndpointAuthMethod = defaultCIMDTokenEndpointAuthMethod + } + + var scopes []string + if doc.Scope != "" { + scopes = strings.Fields(doc.Scope) + } + + defaultClient := &fosite.DefaultClient{ + ID: doc.ClientID, + RedirectURIs: doc.RedirectURIs, + GrantTypes: grantTypes, + ResponseTypes: responseTypes, + Scopes: scopes, + // CIMD clients don't pre-declare audience; leave empty so the AS + // applies its own audience policy rather than rejecting all values. + Audience: nil, + Public: true, + } + + openIDClient := &fosite.DefaultOpenIDConnectClient{ + DefaultClient: defaultClient, + TokenEndpointAuthMethod: tokenEndpointAuthMethod, + } + + // Wrap in LoopbackClient when any redirect URI targets localhost so that + // RFC 8252 §7.3 dynamic port matching works for native app clients. + // Pass openIDClient directly so TokenEndpointAuthMethod is preserved — + // LoopbackClient now embeds *fosite.DefaultOpenIDConnectClient. + if hasLoopbackRedirectURI(doc.RedirectURIs) { + return registration.NewLoopbackClient(openIDClient) + } + + return openIDClient +} + +// hasLoopbackRedirectURI returns true when any of the redirect URIs in the +// list targets a loopback address over HTTP. The host is parsed from each URI +// to prevent bypass via hosts like "http://localhost.evil.com/". +func hasLoopbackRedirectURI(uris []string) bool { + for _, uri := range uris { + parsed, err := url.Parse(uri) + if err != nil { + continue + } + if parsed.Scheme == "http" && oauthproto.IsLoopbackHost(parsed.Hostname()) { + return true + } + } + return false +} diff --git a/pkg/authserver/storage/cimd_decorator_test.go b/pkg/authserver/storage/cimd_decorator_test.go new file mode 100644 index 0000000000..b56fc2f256 --- /dev/null +++ b/pkg/authserver/storage/cimd_decorator_test.go @@ -0,0 +1,440 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package storage + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/oauthproto/cimd" +) + +// serveCIMDDoc starts an httptest.Server that serves a valid CIMD document at +// path. The document's client_id equals the full URL (scheme + host + path) as +// required by ValidateClientMetadataDocument. The returned server URL is the +// base (without path); append path to form the client_id. +// +// An optional pre-handler runs before the default JSON response, allowing +// tests to inject counters or delays. Pass nil to use the default behaviour. +func serveCIMDDoc(t *testing.T, path string, preHandler func()) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != path { + http.NotFound(w, r) + return + } + if preHandler != nil { + preHandler() + } + // client_id must equal the URL we are serving from. + clientID := "http://" + r.Host + r.URL.Path + doc := cimd.ClientMetadataDocument{ + ClientID: clientID, + RedirectURIs: []string{"https://example.com/callback"}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + t.Cleanup(srv.Close) + return srv +} + +// newTestBase creates a MemoryStorage suitable for use as the decorator base in tests. +func newTestBase(t *testing.T) *MemoryStorage { + t.Helper() + base := NewMemoryStorage() + t.Cleanup(func() { _ = base.Close() }) + return base +} + +// newEnabledDecorator creates a CIMDStorageDecorator wrapping base. +func newEnabledDecorator(t *testing.T, base *MemoryStorage, maxSize int, ttl time.Duration) *CIMDStorageDecorator { + t.Helper() + got, err := NewCIMDStorageDecorator(base, true, maxSize, ttl) + require.NoError(t, err) + return got.(*CIMDStorageDecorator) +} + +// cimdURL returns the CIMD URL for the given server and path. +func cimdURL(srv *httptest.Server, path string) string { + return srv.URL + path +} + +// --- Constructor tests --- + +func TestNewCIMDStorageDecorator_DisabledReturnsBase(t *testing.T) { + t.Parallel() + base := newTestBase(t) + got, err := NewCIMDStorageDecorator(base, false, 10, time.Minute) + require.NoError(t, err) + assert.Same(t, base, got, "disabled decorator must return base unchanged") +} + +func TestNewCIMDStorageDecorator_ZeroCacheSizeReturnsError(t *testing.T) { + t.Parallel() + base := newTestBase(t) + _, err := NewCIMDStorageDecorator(base, true, 0, time.Minute) + require.Error(t, err) +} + +func TestNewCIMDStorageDecorator_NegativeCacheSizeReturnsError(t *testing.T) { + t.Parallel() + base := newTestBase(t) + _, err := NewCIMDStorageDecorator(base, true, -1, time.Minute) + require.Error(t, err) +} + +func TestNewCIMDStorageDecorator_EnabledReturnsCIMDDecorator(t *testing.T) { + t.Parallel() + base := newTestBase(t) + got, err := NewCIMDStorageDecorator(base, true, 10, time.Minute) + require.NoError(t, err) + require.NotNil(t, got) + _, isCIMD := got.(*CIMDStorageDecorator) + assert.True(t, isCIMD, "enabled decorator must return a *CIMDStorageDecorator") +} + +// --- Unwrap --- + +func TestCIMDStorageDecorator_UnwrapReturnsBase(t *testing.T) { + t.Parallel() + base := newTestBase(t) + dec := newEnabledDecorator(t, base, 10, time.Minute) + assert.Same(t, base, dec.Unwrap()) +} + +// --- GetClient delegation for non-CIMD IDs --- + +func TestCIMDStorageDecorator_GetClient_OpaqueIDDelegatesToBase(t *testing.T) { + t.Parallel() + base := newTestBase(t) + ctx := context.Background() + + dc := &fosite.DefaultClient{ID: "opaque-client-id"} + require.NoError(t, base.RegisterClient(ctx, dc)) + + dec := newEnabledDecorator(t, base, 10, time.Minute) + + got, err := dec.GetClient(ctx, "opaque-client-id") + require.NoError(t, err) + assert.Equal(t, "opaque-client-id", got.GetID()) +} + +func TestCIMDStorageDecorator_GetClient_UnknownOpaqueIDReturnsError(t *testing.T) { + t.Parallel() + base := newTestBase(t) + dec := newEnabledDecorator(t, base, 10, time.Minute) + _, err := dec.GetClient(context.Background(), "unknown-opaque-id") + require.Error(t, err) +} + +// --- fetchOrCached / fetch (loopback HTTP accepted by FetchClientMetadataDocument) --- +// These tests call fetchOrCached directly (same package) using http://127.0.0.1 +// URLs, which FetchClientMetadataDocument accepts for testing purposes. + +func TestCIMDStorageDecorator_FetchOrCached_FetchesAndReturnsClient(t *testing.T) { + t.Parallel() + + var fetchCount atomic.Int32 + srv := serveCIMDDoc(t, "/metadata.json", func() { fetchCount.Add(1) }) + + id := cimdURL(srv, "/metadata.json") + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + + got, err := dec.fetchOrCached(context.Background(), id) + require.NoError(t, err) + assert.Equal(t, id, got.GetID()) + assert.Equal(t, int32(1), fetchCount.Load()) +} + +func TestCIMDStorageDecorator_FetchOrCached_CacheHitAvoidsSecondFetch(t *testing.T) { + t.Parallel() + + var fetchCount atomic.Int32 + srv := serveCIMDDoc(t, "/metadata.json", func() { fetchCount.Add(1) }) + + id := cimdURL(srv, "/metadata.json") + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + + ctx := context.Background() + _, err := dec.fetchOrCached(ctx, id) + require.NoError(t, err) + + _, err = dec.fetchOrCached(ctx, id) + require.NoError(t, err) + + assert.Equal(t, int32(1), fetchCount.Load(), "second call must be served from cache") +} + +func TestCIMDStorageDecorator_FetchOrCached_LRUEvictionForcesRefetch(t *testing.T) { + t.Parallel() + + var fetchCount atomic.Int32 + srv := serveCIMDDoc(t, "/a.json", func() { fetchCount.Add(1) }) + srv2 := serveCIMDDoc(t, "/b.json", func() { fetchCount.Add(1) }) + + id1 := cimdURL(srv, "/a.json") + id2 := cimdURL(srv2, "/b.json") + + // maxSize=1 forces eviction after the first entry. + dec := newEnabledDecorator(t, newTestBase(t), 1, time.Minute) + ctx := context.Background() + + _, err := dec.fetchOrCached(ctx, id1) + require.NoError(t, err) + + // Fetching id2 evicts id1 from the single-slot cache. + _, err = dec.fetchOrCached(ctx, id2) + require.NoError(t, err) + + // id1 must re-fetch. + _, err = dec.fetchOrCached(ctx, id1) + require.NoError(t, err) + + assert.Equal(t, int32(3), fetchCount.Load(), "id1 must be fetched twice due to LRU eviction") +} + +func TestCIMDStorageDecorator_FetchOrCached_SingleflightDeduplicatesConcurrentFetches(t *testing.T) { + t.Parallel() + + var fetchCount atomic.Int32 + // Barrier lets us hold all goroutines until they are all in-flight. + ready := make(chan struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-ready + fetchCount.Add(1) + clientID := "http://" + r.Host + r.URL.Path + doc := cimd.ClientMetadataDocument{ + ClientID: clientID, + RedirectURIs: []string{"https://example.com/callback"}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + t.Cleanup(srv.Close) + + id := cimdURL(srv, "/metadata.json") + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + + const goroutines = 5 + errs := make([]error, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + // Each goroutine signals on startBarrier immediately before calling + // fetchOrCached. Draining all signals before closing ready ensures they + // are all scheduled and about to enter sf.Do, making the singleflight + // deduplication deterministic without relying on time.Sleep. + startBarrier := make(chan struct{}, goroutines) + + for i := 0; i < goroutines; i++ { + go func(i int) { + defer wg.Done() + startBarrier <- struct{}{} + _, errs[i] = dec.fetchOrCached(context.Background(), id) + }(i) + } + + for range goroutines { + <-startBarrier + } + close(ready) + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for concurrent fetchOrCached goroutines") + } + + for i, e := range errs { + require.NoError(t, e, "goroutine %d returned an error", i) + } + assert.Equal(t, int32(1), fetchCount.Load(), "singleflight must collapse concurrent fetches into one") +} + +func TestCIMDStorageDecorator_FetchOrCached_FetchFailureReturnsNotFound(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + t.Cleanup(srv.Close) + + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + _, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") + require.Error(t, err) + assert.ErrorIs(t, err, fosite.ErrNotFound, "fetch failure must be wrapped as fosite.ErrNotFound") +} + +func TestCIMDStorageDecorator_FetchOrCached_ExpiredCacheEntryRefetches(t *testing.T) { + t.Parallel() + + var fetchCount atomic.Int32 + srv := serveCIMDDoc(t, "/metadata.json", func() { fetchCount.Add(1) }) + + id := cimdURL(srv, "/metadata.json") + dec := newEnabledDecorator(t, newTestBase(t), 10, 1*time.Millisecond) + + ctx := context.Background() + _, err := dec.fetchOrCached(ctx, id) + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + _, err = dec.fetchOrCached(ctx, id) + require.NoError(t, err) + + assert.Equal(t, int32(2), fetchCount.Load(), "expired cache entry must trigger a re-fetch") +} + +// --- GetClient with HTTPS CIMD URLs --- +// Verify that GetClient routes HTTPS client_id values through fetchOrCached by +// pre-populating the cache directly (avoiding real network). + +func TestCIMDStorageDecorator_GetClient_CIMDURLHitsCacheDirectly(t *testing.T) { + t.Parallel() + + base := newTestBase(t) + dec := newEnabledDecorator(t, base, 10, time.Minute) + + const httpsID = "https://example.com/meta.json" + fakeClient := &fosite.DefaultClient{ID: httpsID} + + // Pre-populate the cache so no real HTTP fetch is needed. + dec.cache.Add(httpsID, &cimdCacheEntry{ + client: fakeClient, + expires: time.Now().Add(time.Minute), + }) + + got, err := dec.GetClient(context.Background(), httpsID) + require.NoError(t, err) + assert.Equal(t, httpsID, got.GetID()) +} + +// --- buildFositeClient --- + +func TestBuildFositeClient_Defaults(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + } + + got := buildFositeClient(doc) + assert.Equal(t, "https://example.com/meta.json", got.GetID()) + assert.True(t, got.IsPublic()) + assert.ElementsMatch(t, []string{"authorization_code", "refresh_token"}, []string(got.GetGrantTypes())) + assert.ElementsMatch(t, []string{"code"}, []string(got.GetResponseTypes())) +} + +func TestBuildFositeClient_ExplicitGrantTypes(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + GrantTypes: []string{"authorization_code"}, + } + + got := buildFositeClient(doc) + assert.ElementsMatch(t, []string{"authorization_code"}, []string(got.GetGrantTypes())) +} + +func TestBuildFositeClient_ScopeParsing(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + Scope: "openid profile email", + } + + got := buildFositeClient(doc) + assert.ElementsMatch(t, []string{"openid", "profile", "email"}, []string(got.GetScopes())) +} + +func TestBuildFositeClient_LoopbackRedirectWrapsInLoopbackClient(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"http://localhost/callback"}, + } + + got := buildFositeClient(doc) + // LoopbackClient adds MatchRedirectURI — check the distinctive method is present. + type loopbackMatcher interface { + MatchRedirectURI(string) bool + } + _, ok := got.(loopbackMatcher) + assert.True(t, ok, "loopback redirect URI must produce a LoopbackClient") + + // TokenEndpointAuthMethod must be preserved through the LoopbackClient wrapper. + oidc, ok := got.(fosite.OpenIDConnectClient) + require.True(t, ok, "LoopbackClient must implement fosite.OpenIDConnectClient") + assert.Equal(t, "none", oidc.GetTokenEndpointAuthMethod(), + "loopback client must preserve TokenEndpointAuthMethod from the OIDC client") +} + +func TestBuildFositeClient_NonLoopbackRedirectReturnsOpenIDConnectClient(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + } + + got := buildFositeClient(doc) + _, ok := got.(*fosite.DefaultOpenIDConnectClient) + assert.True(t, ok, "non-loopback redirect URI must produce a DefaultOpenIDConnectClient") +} + +func TestBuildFositeClient_TokenEndpointAuthMethodDefault(t *testing.T) { + t.Parallel() + + doc := &cimd.ClientMetadataDocument{ + ClientID: "https://example.com/meta.json", + RedirectURIs: []string{"https://example.com/callback"}, + } + + got := buildFositeClient(doc) + if oidc, ok := got.(fosite.OpenIDConnectClient); ok { + assert.Equal(t, "none", oidc.GetTokenEndpointAuthMethod()) + } +} + +func TestFetch_RejectsUnsupportedTokenEndpointAuthMethod(t *testing.T) { + t.Parallel() + + // Serve a CIMD doc that declares a non-"none" auth method. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientID := "http://" + r.Host + r.URL.Path + doc := cimd.ClientMetadataDocument{ + ClientID: clientID, + RedirectURIs: []string{"https://example.com/callback"}, + TokenEndpointAuthMethod: "private_key_jwt", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + })) + t.Cleanup(srv.Close) + + dec := newEnabledDecorator(t, newTestBase(t), 10, time.Minute) + _, err := dec.fetchOrCached(context.Background(), srv.URL+"/meta.json") + require.Error(t, err, "fetch must fail when token_endpoint_auth_method is not \"none\"") +}