Skip to content

Commit

Permalink
Reject payloads over the threshold set by server (nats-io#378)
Browse files Browse the repository at this point in the history
* Reject payloads over the threshold set by server

* format and test fix
  • Loading branch information
mtmk authored Feb 7, 2024
1 parent e16ac1a commit 4ab46d2
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/NATS.Client.Core/Commands/CommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal sealed class CommandWriter : IAsyncDisposable
private const int MaxSendSize = 16384;

private readonly ILogger<CommandWriter> _logger;
private readonly NatsConnection _connection;
private readonly ObjectPool _pool;
private readonly int _arrayPoolInitialSize;
private readonly object _lock = new();
Expand All @@ -42,9 +43,10 @@ internal sealed class CommandWriter : IAsyncDisposable
private CancellationTokenSource? _ctsReader;
private volatile bool _disposed;

public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing, TimeSpan? overrideCommandTimeout = default)
public CommandWriter(NatsConnection connection, ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing, TimeSpan? overrideCommandTimeout = default)
{
_logger = opts.LoggerFactory.CreateLogger<CommandWriter>();
_connection = connection;
_pool = pool;

// Derive ArrayPool rent size from buffer size to
Expand Down Expand Up @@ -245,6 +247,12 @@ public ValueTask PublishAsync<T>(string subject, T? value, NatsHeaders? headers,
if (value != null)
serializer.Serialize(payloadBuffer, value);

var size = payloadBuffer.WrittenMemory.Length + (headersBuffer?.WrittenMemory.Length ?? 0);
if (_connection.ServerInfo is { } info && size > info.MaxPayload)
{
ThrowOnMaxPayload(size, info.MaxPayload);
}

return PublishLockedAsync(subject, replyTo, payloadBuffer, headersBuffer, cancellationToken);
}

Expand Down Expand Up @@ -309,6 +317,9 @@ public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken
// only used for internal testing
internal bool TestStallFlush() => _channelLock.Writer.TryWrite(1);

[MethodImpl(MethodImplOptions.NoInlining)]
private static void ThrowOnMaxPayload(int size, int max) => throw new NatsException($"Payload size {size} exceeds server's maximum payload size {max}");

private static async Task ReaderLoopAsync(ILogger<CommandWriter> logger, ISocketConnection connection, PipeReader pipeReader, Channel<int> channelSize, CancellationToken cancellationToken)
{
try
Expand Down
4 changes: 2 additions & 2 deletions src/NATS.Client.Core/Commands/PriorityCommandWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ internal sealed class PriorityCommandWriter : IAsyncDisposable
{
private int _disposed;

public PriorityCommandWriter(ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing)
public PriorityCommandWriter(NatsConnection connection, ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action<PingCommand> enqueuePing)
{
CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan);
CommandWriter = new CommandWriter(connection, pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan);
CommandWriter.Reset(socketConnection);
}

Expand Down
6 changes: 3 additions & 3 deletions src/NATS.Client.Core/NatsConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public partial class NatsConnection : INatsConnection
public Func<(string Host, int Port), ValueTask<(string Host, int Port)>>? OnConnectingAsync;

internal readonly ConnectionStatsCounter Counter; // allow to call from external sources
internal ServerInfo? WritableServerInfo;
internal volatile ServerInfo? WritableServerInfo;
internal bool IsDisposed;

#pragma warning restore SA1401
Expand Down Expand Up @@ -79,7 +79,7 @@ public NatsConnection(NatsOpts opts)
_cancellationTimerPool = new CancellationTimerPool(_pool, _disposedCancellationTokenSource.Token);
_name = opts.Name;
Counter = new ConnectionStatsCounter();
CommandWriter = new CommandWriter(_pool, Opts, Counter, EnqueuePing);
CommandWriter = new CommandWriter(this, _pool, Opts, Counter, EnqueuePing);
InboxPrefix = NewInbox(opts.InboxPrefix);
SubscriptionManager = new SubscriptionManager(this, InboxPrefix);
_logger = opts.LoggerFactory.CreateLogger<NatsConnection>();
Expand Down Expand Up @@ -431,7 +431,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect)
// Authentication
_userCredentials?.Authenticate(_clientOpts, WritableServerInfo);

await using (var priorityCommandWriter = new PriorityCommandWriter(_pool, _socket!, Opts, Counter, EnqueuePing))
await using (var priorityCommandWriter = new PriorityCommandWriter(this, _pool, _socket!, Opts, Counter, EnqueuePing))
{
// add CONNECT and PING command to priority lane
await priorityCommandWriter.CommandWriter.ConnectAsync(_clientOpts, CancellationToken.None).ConfigureAwait(false);
Expand Down
74 changes: 74 additions & 0 deletions tests/NATS.Client.Core.Tests/ProtocolTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,80 @@ public async Task Protocol_parser_under_load(int size)
counts.Count.Should().BeGreaterOrEqualTo(3);
}

[Fact]
public async Task Proactively_reject_payloads_over_the_threshold_set_by_server()
{
await using var server = NatsServer.Start();
await using var nats = server.CreateClientConnection();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));

var sync = 0;
var count = 0;
var signal1 = new WaitSignal<NatsMsg<byte[]>>();
var signal2 = new WaitSignal<NatsMsg<byte[]>>();
var subTask = Task.Run(
async () =>
{
await foreach (var m in nats.SubscribeAsync<byte[]>("foo.*", cancellationToken: cts.Token))
{
if (m.Subject == "foo.sync")
{
Interlocked.Exchange(ref sync, 1);
continue;
}

Interlocked.Increment(ref count);

if (m.Subject == "foo.signal1")
{
signal1.Pulse(m);
}
else if (m.Subject == "foo.signal2")
{
signal2.Pulse(m);
}
else if (m.Subject == "foo.end")
{
break;
}
}
},
cancellationToken: cts.Token);

await Retry.Until(
reason: "subscription is active",
condition: () => Volatile.Read(ref sync) == 1,
action: async () => await nats.PublishAsync("foo.sync", cancellationToken: cts.Token),
retryDelay: TimeSpan.FromSeconds(.3));
{
var payload = new byte[nats.ServerInfo!.MaxPayload];
await nats.PublishAsync("foo.signal1", payload, cancellationToken: cts.Token);
var msg1 = await signal1;
Assert.Equal(payload.Length, msg1.Data!.Length);
}

{
var payload = new byte[nats.ServerInfo!.MaxPayload + 1];
var exception = await Assert.ThrowsAsync<NatsException>(async () =>
await nats.PublishAsync("foo.none", payload, cancellationToken: cts.Token));
Assert.Matches(@"Payload size \d+ exceeds server's maximum payload size \d+", exception.Message);
}

{
var payload = new byte[123];
await nats.PublishAsync("foo.signal2", payload, cancellationToken: cts.Token);
var msg1 = await signal2;
Assert.Equal(payload.Length, msg1.Data!.Length);
}

await nats.PublishAsync("foo.end", cancellationToken: cts.Token);

await subTask;

Assert.Equal(3, Volatile.Read(ref count));
}

private sealed class NatsSubReconnectTest : NatsSubBase
{
private readonly Action<int> _callback;
Expand Down
2 changes: 1 addition & 1 deletion tests/NATS.Client.TestUtilities/MockServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public MockServer(
var stream = client.GetStream();

var sw = new StreamWriter(stream, Encoding.ASCII);
await sw.WriteAsync("INFO {}\r\n");
await sw.WriteAsync("INFO {\"max_payload\":1048576}\r\n");
await sw.FlushAsync();

var sr = new StreamReader(stream, Encoding.ASCII);
Expand Down

0 comments on commit 4ab46d2

Please sign in to comment.