Skip to content

Commit

Permalink
Throttle outgoing requests by both peer and protocol id
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Jan 8, 2025
1 parent b0b414d commit a05d7b7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import static tech.pegasys.teku.networking.p2p.libp2p.LibP2PNetwork.REMOTE_OPEN_STREAMS_RATE_LIMIT;
import static tech.pegasys.teku.networking.p2p.libp2p.LibP2PNetwork.REMOTE_PARALLEL_OPEN_STREAMS_COUNT_LIMIT;
import static tech.pegasys.teku.spec.constants.NetworkConstants.MAX_CONCURRENT_REQUESTS;

import com.google.common.base.Preconditions;
import identify.pb.IdentifyOuterClass;
Expand Down Expand Up @@ -153,9 +152,7 @@ public P2PNetwork<Peer> build() {
}

protected List<? extends RpcHandler<?, ?, ?>> createRpcHandlers() {
return rpcMethods.stream()
.map(m -> new RpcHandler<>(asyncRunner, m, MAX_CONCURRENT_REQUESTS))
.toList();
return rpcMethods.stream().map(m -> new RpcHandler<>(asyncRunner, m)).toList();
}

protected LibP2PGossipNetwork createGossipNetwork() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
import io.libp2p.core.PeerId;
import io.libp2p.core.crypto.PubKey;
import io.libp2p.protocol.Identify;
import io.libp2p.protocol.IdentifyController;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.EnumUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import tech.pegasys.teku.infrastructure.async.SafeFuture;
import tech.pegasys.teku.infrastructure.async.ThrottlingTaskQueue;
import tech.pegasys.teku.networking.p2p.libp2p.rpc.RpcHandler;
import tech.pegasys.teku.networking.p2p.network.PeerAddress;
import tech.pegasys.teku.networking.p2p.peer.DisconnectReason;
Expand All @@ -41,6 +44,7 @@
import tech.pegasys.teku.networking.p2p.rpc.RpcRequestHandler;
import tech.pegasys.teku.networking.p2p.rpc.RpcResponseHandler;
import tech.pegasys.teku.networking.p2p.rpc.RpcStreamController;
import tech.pegasys.teku.spec.constants.NetworkConstants;

public class LibP2PPeer implements Peer {
private static final Logger LOG = LogManager.getLogger();
Expand All @@ -56,6 +60,10 @@ public class LibP2PPeer implements Peer {
private volatile PeerClientType peerClientType = PeerClientType.UNKNOWN;
private volatile Optional<String> maybeAgentString = Optional.empty();

// used for throttling outgoing requests by protocol id
private final Map<List<String>, ThrottlingTaskQueue> outgoingRequestsQueuesByRpcMethod =
new ConcurrentHashMap<>();

private volatile Optional<DisconnectReason> disconnectReason = Optional.empty();
private volatile boolean disconnectLocallyInitiated = false;
private volatile DisconnectRequestHandler disconnectRequestHandler =
Expand Down Expand Up @@ -109,10 +117,6 @@ private PeerClientType getPeerTypeFromAgentString(final String agentVersion) {
return EnumUtils.getEnumIgnoreCase(PeerClientType.class, agent, PeerClientType.UNKNOWN);
}

public Optional<String> getMaybeAgentString() {
return maybeAgentString;
}

public PubKey getPubKey() {
return pubKey;
}
Expand Down Expand Up @@ -161,7 +165,7 @@ private SafeFuture<IdentifyOuterClass.Identify> getIdentify() {
.muxerSession()
.createStream(new Identify())
.getController()
.thenCompose(controller -> controller.id()))
.thenCompose(IdentifyController::id))
.exceptionallyCompose(
error -> {
LOG.debug("Failed to get peer identity", error);
Expand Down Expand Up @@ -208,14 +212,18 @@ SafeFuture<RpcStreamController<TOutgoingHandler>> sendRequest(
final TRequest request,
final RespHandler responseHandler) {
@SuppressWarnings("unchecked")
RpcHandler<TOutgoingHandler, TRequest, RespHandler> rpcHandler =
final RpcHandler<TOutgoingHandler, TRequest, RespHandler> rpcHandler =
(RpcHandler<TOutgoingHandler, TRequest, RespHandler>) rpcHandlers.get(rpcMethod);
if (rpcHandler == null) {
throw new IllegalArgumentException(
"Unknown rpc method invoked: " + String.join(",", rpcMethod.getIds()));
}

return rpcHandler.sendRequest(connection, request, responseHandler);
return outgoingRequestsQueuesByRpcMethod
.computeIfAbsent(
rpcMethod.getIds(),
__ -> ThrottlingTaskQueue.create(NetworkConstants.MAX_CONCURRENT_REQUESTS))
.queueTask(() -> rpcHandler.sendRequest(connection, request, responseHandler));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import tech.pegasys.teku.infrastructure.async.AsyncRunner;
import tech.pegasys.teku.infrastructure.async.SafeFuture;
import tech.pegasys.teku.infrastructure.async.SafeFuture.Interruptor;
import tech.pegasys.teku.infrastructure.async.ThrottlingTaskQueue;
import tech.pegasys.teku.infrastructure.exceptions.ExceptionUtil;
import tech.pegasys.teku.networking.p2p.libp2p.LibP2PNodeId;
import tech.pegasys.teku.networking.p2p.libp2p.rpc.RpcHandler.Controller;
Expand All @@ -63,15 +62,12 @@ public class RpcHandler<

private final AsyncRunner asyncRunner;
private final RpcMethod<TOutgoingHandler, TRequest, TRespHandler> rpcMethod;
private final ThrottlingTaskQueue concurrentRequestsQueue;

public RpcHandler(
final AsyncRunner asyncRunner,
final RpcMethod<TOutgoingHandler, TRequest, TRespHandler> rpcMethod,
final int maxConcurrentRequests) {
final RpcMethod<TOutgoingHandler, TRequest, TRespHandler> rpcMethod) {
this.asyncRunner = asyncRunner;
this.rpcMethod = rpcMethod;
concurrentRequestsQueue = ThrottlingTaskQueue.create(maxConcurrentRequests);
}

public RpcMethod<TOutgoingHandler, TRequest, TRespHandler> getRpcMethod() {
Expand All @@ -80,13 +76,6 @@ public RpcMethod<TOutgoingHandler, TRequest, TRespHandler> getRpcMethod() {

public SafeFuture<RpcStreamController<TOutgoingHandler>> sendRequest(
final Connection connection, final TRequest request, final TRespHandler responseHandler) {
return concurrentRequestsQueue.queueTask(
() -> sendRequestInternal(connection, request, responseHandler));
}

public SafeFuture<RpcStreamController<TOutgoingHandler>> sendRequestInternal(
final Connection connection, final TRequest request, final TRespHandler responseHandler) {

final Bytes initialPayload;
try {
initialPayload = rpcMethod.encodeRequest(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.libp2p.core.mux.StreamMuxer.Session;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import kotlin.Unit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -53,9 +52,8 @@ public class RpcHandlerTest {

StubAsyncRunner asyncRunner = new StubAsyncRunner();
RpcMethod<RpcRequestHandler, Object, RpcResponseHandler<?>> rpcMethod = mock(RpcMethod.class);
int maxConcurrentRequests = 2;
RpcHandler<RpcRequestHandler, Object, RpcResponseHandler<?>> rpcHandler =
new RpcHandler<>(asyncRunner, rpcMethod, maxConcurrentRequests);
new RpcHandler<>(asyncRunner, rpcMethod);

Connection connection = mock(Connection.class);
Session session = mock(Session.class);
Expand Down Expand Up @@ -249,39 +247,6 @@ void sendRequest_interruptBeforeInitialPayloadWritten(
verify(stream).close();
}

@Test
@SuppressWarnings("FutureReturnValueIgnored")
void requestIsThrottledIfQueueIsFull() {
// fill the queue
IntStream.range(0, maxConcurrentRequests)
.forEach(__ -> rpcHandler.sendRequest(connection, request, responseHandler));

final StreamPromise<Controller<RpcRequestHandler>> streamPromise1 =
new StreamPromise<>(new CompletableFuture<>(), new CompletableFuture<>());
when(session.createStream((ProtocolBinding<Controller<RpcRequestHandler>>) any()))
.thenReturn(streamPromise1);
final Stream stream1 = mock(Stream.class);
streamPromise1.getStream().complete(stream1);
streamPromise1.getController().complete(controller);
final CompletableFuture<String> protocolIdFuture1 = new CompletableFuture<>();
when(stream1.getProtocol()).thenReturn(protocolIdFuture1);
protocolIdFuture1.complete("test");

final SafeFuture<RpcStreamController<RpcRequestHandler>> throttledResult =
rpcHandler.sendRequest(connection, request, responseHandler);

assertThat(throttledResult).isNotDone();

// empty the queue
streamPromise.getStream().complete(stream);
streamPromise.getController().complete(controller);
stream.getProtocol().complete("test");
writeFuture.complete(null);

// throttled request should have completed now
assertThat(throttledResult).isCompleted();
}

@SuppressWarnings("UnnecessaryAsync")
private Class<? extends Exception> executeInterrupts(
final boolean closeStream, final boolean exceedTimeout) {
Expand Down

0 comments on commit a05d7b7

Please sign in to comment.