From 1eb3bfe68ba2366d287e8e347d59098e43ae9fd0 Mon Sep 17 00:00:00 2001 From: Brendan Morante Date: Tue, 5 May 2026 07:05:30 -0700 Subject: [PATCH] Fix CosmosClient memory leak in CosmosDbFactory.GetCosmosClient GetOrAdd was being called with the eager TValue overload of ConcurrentDictionary, constructing a new CosmosClient on every call and silently discarding it when the cache already held an entry for the key. Discarded instances hold unmanaged state and have finalizers; they accumulate until the pod is OOMKilled. Switch to the lazy Func overload so CreateCosmosClient only runs on cache miss. Add a regression test that counts CreateCosmosClient invocations via a test subclass. This requires changing the class from 'internal sealed' to 'internal' and CreateCosmosClient from 'private' to 'protected internal virtual'. Signed-off-by: Brendan Morante --- src/Scaler.Tests/CosmosDbFactoryTests.cs | 51 ++++++++++++++++++++++++ src/Scaler/Services/CosmosDbFactory.cs | 8 ++-- 2 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 src/Scaler.Tests/CosmosDbFactoryTests.cs diff --git a/src/Scaler.Tests/CosmosDbFactoryTests.cs b/src/Scaler.Tests/CosmosDbFactoryTests.cs new file mode 100644 index 0000000..b980bca --- /dev/null +++ b/src/Scaler.Tests/CosmosDbFactoryTests.cs @@ -0,0 +1,51 @@ +using System.Threading; +using Microsoft.Azure.Cosmos; +using Xunit; + +namespace Keda.CosmosDb.Scaler.Tests +{ + public class CosmosDbFactoryTests + { + private const string DummyConnection1 = "AccountEndpoint=https://example1.com:443/;AccountKey=ZHVtbXkx"; + private const string DummyConnection2 = "AccountEndpoint=https://example2.com:443/;AccountKey=ZHVtbXky"; + + [Fact] + public void GetCosmosClient_OnlyConstructsOnceForSameKey() + { + var factory = new CountingCosmosDbFactory(); + + CosmosClient c1 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + CosmosClient c2 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + CosmosClient c3 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + + Assert.Same(c1, c2); + Assert.Same(c2, c3); + Assert.Equal(1, factory.CreateCount); + } + + [Fact] + public void GetCosmosClient_ConstructsOncePerDistinctKey() + { + var factory = new CountingCosmosDbFactory(); + + factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection2, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection2, useCredentials: false, clientId: null); + + Assert.Equal(2, factory.CreateCount); + } + + private sealed class CountingCosmosDbFactory : CosmosDbFactory + { + public int CreateCount; + + protected internal override CosmosClient CreateCosmosClient( + string endpointOrConnection, bool useCredentials, string clientId) + { + Interlocked.Increment(ref CreateCount); + return base.CreateCosmosClient(endpointOrConnection, useCredentials, clientId); + } + } + } +} diff --git a/src/Scaler/Services/CosmosDbFactory.cs b/src/Scaler/Services/CosmosDbFactory.cs index 7353ca5..44539d9 100644 --- a/src/Scaler/Services/CosmosDbFactory.cs +++ b/src/Scaler/Services/CosmosDbFactory.cs @@ -5,7 +5,7 @@ namespace Keda.CosmosDb.Scaler { - internal sealed class CosmosDbFactory + internal class CosmosDbFactory { private const string _applicationName = "keda-external-azure-cosmos-db"; // As per https://docs.microsoft.com/dotnet/api/microsoft.azure.cosmos.cosmosclient, it is recommended to @@ -14,10 +14,12 @@ internal sealed class CosmosDbFactory public CosmosClient GetCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) { - return _cosmosClientCache.GetOrAdd((endpointOrConnection, clientId), CreateCosmosClient(endpointOrConnection, useCredentials, clientId)); + return _cosmosClientCache.GetOrAdd( + (endpointOrConnection, clientId), + _ => CreateCosmosClient(endpointOrConnection, useCredentials, clientId)); } - private CosmosClient CreateCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) + protected internal virtual CosmosClient CreateCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) { if (useCredentials) {