diff --git a/src/Scaler.Tests/CosmosDbFactoryTests.cs b/src/Scaler.Tests/CosmosDbFactoryTests.cs new file mode 100644 index 0000000..b980bca --- /dev/null +++ b/src/Scaler.Tests/CosmosDbFactoryTests.cs @@ -0,0 +1,51 @@ +using System.Threading; +using Microsoft.Azure.Cosmos; +using Xunit; + +namespace Keda.CosmosDb.Scaler.Tests +{ + public class CosmosDbFactoryTests + { + private const string DummyConnection1 = "AccountEndpoint=https://example1.com:443/;AccountKey=ZHVtbXkx"; + private const string DummyConnection2 = "AccountEndpoint=https://example2.com:443/;AccountKey=ZHVtbXky"; + + [Fact] + public void GetCosmosClient_OnlyConstructsOnceForSameKey() + { + var factory = new CountingCosmosDbFactory(); + + CosmosClient c1 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + CosmosClient c2 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + CosmosClient c3 = factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + + Assert.Same(c1, c2); + Assert.Same(c2, c3); + Assert.Equal(1, factory.CreateCount); + } + + [Fact] + public void GetCosmosClient_ConstructsOncePerDistinctKey() + { + var factory = new CountingCosmosDbFactory(); + + factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection2, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection1, useCredentials: false, clientId: null); + factory.GetCosmosClient(DummyConnection2, useCredentials: false, clientId: null); + + Assert.Equal(2, factory.CreateCount); + } + + private sealed class CountingCosmosDbFactory : CosmosDbFactory + { + public int CreateCount; + + protected internal override CosmosClient CreateCosmosClient( + string endpointOrConnection, bool useCredentials, string clientId) + { + Interlocked.Increment(ref CreateCount); + return base.CreateCosmosClient(endpointOrConnection, useCredentials, clientId); + } + } + } +} diff --git a/src/Scaler/Services/CosmosDbFactory.cs b/src/Scaler/Services/CosmosDbFactory.cs index 7353ca5..44539d9 100644 --- a/src/Scaler/Services/CosmosDbFactory.cs +++ b/src/Scaler/Services/CosmosDbFactory.cs @@ -5,7 +5,7 @@ namespace Keda.CosmosDb.Scaler { - internal sealed class CosmosDbFactory + internal class CosmosDbFactory { private const string _applicationName = "keda-external-azure-cosmos-db"; // As per https://docs.microsoft.com/dotnet/api/microsoft.azure.cosmos.cosmosclient, it is recommended to @@ -14,10 +14,12 @@ internal sealed class CosmosDbFactory public CosmosClient GetCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) { - return _cosmosClientCache.GetOrAdd((endpointOrConnection, clientId), CreateCosmosClient(endpointOrConnection, useCredentials, clientId)); + return _cosmosClientCache.GetOrAdd( + (endpointOrConnection, clientId), + _ => CreateCosmosClient(endpointOrConnection, useCredentials, clientId)); } - private CosmosClient CreateCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) + protected internal virtual CosmosClient CreateCosmosClient(string endpointOrConnection, bool useCredentials, string clientId) { if (useCredentials) {