Skip to content

Commit

Permalink
Merge branch 'main' into stringallocations
Browse files Browse the repository at this point in the history
  • Loading branch information
lukebakken committed May 15, 2024
2 parents 4e1156b + bf9a35a commit 67b7f88
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 77 deletions.
8 changes: 5 additions & 3 deletions projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ public void Dispose()

private Dictionary<string, string> buildRequestParameters()
{
var dict = new Dictionary<string, string>(_additionalRequestParameters);
dict.Add(CLIENT_ID, _clientId);
dict.Add(CLIENT_SECRET, _clientSecret);
var dict = new Dictionary<string, string>(_additionalRequestParameters)
{
{ CLIENT_ID, _clientId },
{ CLIENT_SECRET, _clientSecret }
};
if (_scope != null && _scope.Length > 0)
{
dict.Add(SCOPE, _scope);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private void DoDeleteAutoDeleteExchange(string exchangeName)

bool AnyBindingsOnExchange(string exchange)
{
foreach (var recordedBinding in _recordedBindings)
foreach (RecordedBinding recordedBinding in _recordedBindings)
{
if (recordedBinding.Source == exchange)
{
Expand Down Expand Up @@ -400,15 +400,15 @@ await _recordedEntitiesSemaphore.WaitAsync()

private void DoDeleteRecordedConsumer(string consumerTag)
{
if (_recordedConsumers.Remove(consumerTag, out var recordedConsumer))
if (_recordedConsumers.Remove(consumerTag, out RecordedConsumer recordedConsumer))
{
DeleteAutoDeleteQueue(recordedConsumer.Queue);
}
}

private void DeleteAutoDeleteQueue(string queue)
{
if (_recordedQueues.TryGetValue(queue, out var recordedQueue) && recordedQueue.AutoDelete)
if (_recordedQueues.TryGetValue(queue, out RecordedQueue recordedQueue) && recordedQueue.AutoDelete)
{
// last consumer on this connection is gone, remove recorded queue if it is auto-deleted.
if (!AnyConsumersOnQueue(queue))
Expand All @@ -420,7 +420,7 @@ private void DeleteAutoDeleteQueue(string queue)

private bool AnyConsumersOnQueue(string queue)
{
foreach (var pair in _recordedConsumers)
foreach (KeyValuePair<string, RecordedConsumer> pair in _recordedConsumers)
{
if (pair.Value.Queue == queue)
{
Expand Down
8 changes: 4 additions & 4 deletions projects/RabbitMQ.Client/client/impl/ChannelBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ private void OnChannelShutdown(ShutdownEventArgs reason)
if (_confirmsTaskCompletionSources?.Count > 0)
{
var exception = new AlreadyClosedException(reason);
foreach (var confirmsTaskCompletionSource in _confirmsTaskCompletionSources)
foreach (TaskCompletionSource<bool> confirmsTaskCompletionSource in _confirmsTaskCompletionSources)
{
confirmsTaskCompletionSource.TrySetException(exception);
}
Expand Down Expand Up @@ -635,7 +635,7 @@ protected void HandleAckNack(ulong deliveryTag, bool multiple, bool isNack)
if (_pendingDeliveryTags.Count == 0 && _confirmsTaskCompletionSources.Count > 0)
{
// Done, mark tasks
foreach (var confirmsTaskCompletionSource in _confirmsTaskCompletionSources)
foreach (TaskCompletionSource<bool> confirmsTaskCompletionSource in _confirmsTaskCompletionSources)
{
confirmsTaskCompletionSource.TrySetResult(_onlyAcksReceived);
}
Expand Down Expand Up @@ -754,7 +754,7 @@ protected async Task<bool> HandleChannelCloseOkAsync(IncomingCommand cmd, Cancel
*/
FinishClose();

if (_continuationQueue.TryPeek<ChannelCloseAsyncRpcContinuation>(out var k))
if (_continuationQueue.TryPeek<ChannelCloseAsyncRpcContinuation>(out ChannelCloseAsyncRpcContinuation k))
{
_continuationQueue.Next();
await k.HandleCommandAsync(cmd)
Expand Down Expand Up @@ -1905,7 +1905,7 @@ private static BasicProperties PopulateActivityAndPropagateTraceId<TProperties>(
props = new BasicProperties();
}

var headers = props.Headers ?? new Dictionary<string, object>();
IDictionary<string, object> headers = props.Headers ?? new Dictionary<string, object>();

// Inject the ActivityContext into the message headers to propagate trace context to the receiving service.
DistributedContextPropagator.Current.Inject(sendActivity, headers, InjectTraceContextIntoBasicProperties);
Expand Down
6 changes: 5 additions & 1 deletion projects/RabbitMQ.Client/client/impl/Connection.Commands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,13 @@ await UpdateSecretAsync(_config.CredentialsProvider.Password, "Token refresh", c
private IAuthMechanismFactory GetAuthMechanismFactory(string supportedMechanismNames)
{
// Our list is in order of preference, the server one is not.
foreach (var factory in _config.AuthMechanisms)
foreach (IAuthMechanismFactory factory in _config.AuthMechanisms)
{
#if NET6_0_OR_GREATER
if (supportedMechanismNames.Contains(factory.Name, StringComparison.OrdinalIgnoreCase))
#else
if (supportedMechanismNames.IndexOf(factory.Name, StringComparison.OrdinalIgnoreCase) >= 0)
#endif
{
return factory;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ protected void AddConsumer(IBasicConsumer consumer, string tag)
{
lock (_consumers)
{
var tagBytes = Encoding.UTF8.GetBytes(tag);
byte[] tagBytes = Encoding.UTF8.GetBytes(tag);
_consumers[tagBytes] = (consumer, tag);
}
}
Expand All @@ -39,7 +39,7 @@ protected void AddConsumer(IBasicConsumer consumer, string tag)
return consumerPair;
}

#if !NETSTANDARD
#if NET6_0_OR_GREATER
string consumerTag = Encoding.UTF8.GetString(tag.Span);
#else
string consumerTag;
Expand All @@ -60,18 +60,29 @@ public IBasicConsumer GetAndRemoveConsumer(string tag)
{
lock (_consumers)
{
var utf8 = Encoding.UTF8;
var pool = ArrayPool<byte>.Shared;
var buf = pool.Rent(utf8.GetMaxByteCount(tag.Length));
#if NETSTANDARD
int count = utf8.GetBytes(tag, 0, tag.Length, buf, 0);
ArrayPool<byte> pool = ArrayPool<byte>.Shared;
byte[]? buf = null;
try
{
buf = pool.Rent(Encoding.UTF8.GetMaxByteCount(tag.Length));
#if NET6_0_OR_GREATER
int count = Encoding.UTF8.GetBytes(tag, buf);
#else
int count = utf8.GetBytes(tag, buf);
int count = Encoding.UTF8.GetBytes(tag, 0, tag.Length, buf, 0);
#endif
var memory = buf.AsMemory(0, count);
var result = _consumers.Remove(memory, out var consumerPair) ? consumerPair.consumer : GetDefaultOrFallbackConsumer();
pool.Return(buf);
return result;
Memory<byte> memory = buf.AsMemory(0, count);
IBasicConsumer result = _consumers.Remove(memory,
out (IBasicConsumer consumer, string consumerTag) consumerPair) ?
consumerPair.consumer : GetDefaultOrFallbackConsumer();
return result;
}
finally
{
if (buf != null)
{
pool.Return(buf);
}
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions projects/RabbitMQ.Client/client/impl/WireFormatting.Read.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static object ReadFieldValue(ReadOnlySpan<byte> span, out int bytesRead)
switch ((char)span[0])
{
case 'S':
bytesRead = 1 + ReadLongstr(span.Slice(1), out var bytes);
bytesRead = 1 + ReadLongstr(span.Slice(1), out byte[] bytes);
return bytes;
case 't':
bytesRead = 2;
Expand All @@ -96,11 +96,11 @@ public static object ReadFieldValue(ReadOnlySpan<byte> span, out int bytesRead)
// Moved out of outer switch to have a shorter main method (improves performance)
static object ReadFieldValueSlow(ReadOnlySpan<byte> span, out int bytesRead)
{
var slice = span.Slice(1);
ReadOnlySpan<byte> slice = span.Slice(1);
switch ((char)span[0])
{
case 'F':
bytesRead = 1 + ReadDictionary(slice, out var dictionary);
bytesRead = 1 + ReadDictionary(slice, out Dictionary<string, object> dictionary);
return dictionary;
case 'A':
IList arrayResult = ReadArray(slice, out int arrayBytesRead);
Expand Down Expand Up @@ -134,10 +134,10 @@ static object ReadFieldValueSlow(ReadOnlySpan<byte> span, out int bytesRead)
bytesRead = 3;
return NetworkOrderDeserializer.ReadUInt16(slice);
case 'T':
bytesRead = 1 + ReadTimestamp(slice, out var timestamp);
bytesRead = 1 + ReadTimestamp(slice, out AmqpTimestamp timestamp);
return timestamp;
case 'x':
bytesRead = 1 + ReadLongstr(slice, out var binaryTableResult);
bytesRead = 1 + ReadLongstr(slice, out byte[] binaryTableResult);
return new BinaryTableValue(binaryTableResult);
default:
bytesRead = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ public static int WriteShort(ref byte destination, ushort val)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int WriteShortstr(ref byte destination, ReadOnlySpan<byte> value)
{
var length = value.Length;
int length = value.Length;
if (length <= byte.MaxValue)
{
destination = (byte)length;
Expand Down
8 changes: 4 additions & 4 deletions projects/Test/Integration/TestAsyncConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ public async Task TestBasicRoundtripConcurrent()
{
QueueDeclareOk q = await _channel.QueueDeclareAsync();

string publish1 = GetUniqueString(1024);
string publish1 = GetUniqueString(512);
byte[] body = _encoding.GetBytes(publish1);
await _channel.BasicPublishAsync("", q.QueueName, body);

string publish2 = GetUniqueString(1024);
string publish2 = GetUniqueString(512);
body = _encoding.GetBytes(publish2);
await _channel.BasicPublishAsync("", q.QueueName, body);

Expand Down Expand Up @@ -141,9 +141,9 @@ public async Task TestBasicRoundtripConcurrentManyMessages()
const int publish_total = 4096;
string queueName = GenerateQueueName();

string publish1 = GetUniqueString(32768);
string publish1 = GetUniqueString(512);
byte[] body1 = _encoding.GetBytes(publish1);
string publish2 = GetUniqueString(32768);
string publish2 = GetUniqueString(512);
byte[] body2 = _encoding.GetBytes(publish2);

var publish1SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand Down
78 changes: 35 additions & 43 deletions projects/Test/Unit/TestTimerBasedCredentialRefresher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
//---------------------------------------------------------------------------

using System;
using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client;
using Xunit;
Expand Down Expand Up @@ -129,60 +130,51 @@ public void TestDoNotRegisterWhenHasNoExpiry()
[Fact]
public async Task TestRefreshToken()
{
var cbtcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
bool? callbackArg = null;
var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1));
Task cb(bool arg)
var tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)))
{
callbackArg = arg;
cbtcs.SetResult(true);
return Task.CompletedTask;
}

try
{
_refresher.Register(credentialsProvider, cb);

await cbtcs.Task.WaitAsync(TimeSpan.FromSeconds(5));
Assert.True(await cbtcs.Task);

Assert.True(credentialsProvider.RefreshCalled);
Assert.True(callbackArg);
}
finally
{
Assert.True(_refresher.Unregister(credentialsProvider));
using (CancellationTokenRegistration ctr = cts.Token.Register(() => tcs.TrySetCanceled()))
{
var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1));

Task cb(bool arg)
{
tcs.SetResult(arg);
return Task.CompletedTask;
}

_refresher.Register(credentialsProvider, cb);
Assert.True(await tcs.Task);
Assert.True(credentialsProvider.RefreshCalled);
Assert.True(_refresher.Unregister(credentialsProvider));
}
}
}

[Fact]
public async Task TestRefreshTokenFailed()
{
var cbtcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
bool? callbackArg = null;
var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1));
Task cb(bool arg)
var tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)))
{
callbackArg = arg;
cbtcs.SetResult(true);
return Task.CompletedTask;
}
using (CancellationTokenRegistration ctr = cts.Token.Register(() => tcs.TrySetCanceled()))
{
var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1));

var ex = new Exception();
credentialsProvider.PasswordThrows(ex);
Task cb(bool arg)
{
tcs.SetResult(arg);
return Task.CompletedTask;
}

try
{
_refresher.Register(credentialsProvider, cb);
await cbtcs.Task.WaitAsync(TimeSpan.FromSeconds(5));
Assert.True(await cbtcs.Task);
var ex = new Exception();
credentialsProvider.PasswordThrows(ex);

Assert.True(credentialsProvider.RefreshCalled);
Assert.False(callbackArg);
}
finally
{
Assert.True(_refresher.Unregister(credentialsProvider));
_refresher.Register(credentialsProvider, cb);
Assert.False(await tcs.Task);
Assert.True(credentialsProvider.RefreshCalled);
Assert.True(_refresher.Unregister(credentialsProvider));
}
}
}
}
Expand Down

0 comments on commit 67b7f88

Please sign in to comment.