-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #215 from drift-labs/sina/add-grpc
Add grpc
- Loading branch information
Showing
27 changed files
with
2,638 additions
and
573 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import asyncio | ||
import os | ||
|
||
from anchorpy.provider import Provider, Wallet | ||
from dotenv import load_dotenv | ||
from solana.rpc.async_api import AsyncClient | ||
from solana.rpc.commitment import Commitment | ||
from solders.keypair import Keypair | ||
|
||
from driftpy.drift_client import AccountSubscriptionConfig, DriftClient | ||
from driftpy.types import GrpcConfig | ||
|
||
load_dotenv() | ||
|
||
RED = "\033[91m" | ||
GREEN = "\033[92m" | ||
RESET = "\033[0m" | ||
|
||
CLEAR_SCREEN = "\033c" | ||
|
||
|
||
async def watch_drift_markets(): | ||
rpc_fqdn = os.environ.get("RPC_FQDN") | ||
x_token = os.environ.get("X_TOKEN") | ||
private_key = os.environ.get("PRIVATE_KEY") | ||
rpc_url = os.environ.get("RPC_TRITON") | ||
|
||
if not (rpc_fqdn and x_token and private_key and rpc_url): | ||
raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set") | ||
|
||
wallet = Wallet(Keypair.from_base58_string(private_key)) | ||
connection = AsyncClient(rpc_url) | ||
provider = Provider(connection, wallet) | ||
|
||
drift_client = DriftClient( | ||
provider.connection, | ||
provider.wallet, | ||
"mainnet", | ||
account_subscription=AccountSubscriptionConfig( | ||
"grpc", | ||
grpc_config=GrpcConfig( | ||
endpoint=rpc_fqdn, | ||
token=x_token, | ||
commitment=Commitment("confirmed"), | ||
), | ||
), | ||
) | ||
|
||
await drift_client.subscribe() | ||
print("Subscribed via gRPC. Listening for market updates...") | ||
|
||
previous_prices = {} | ||
|
||
while True: | ||
print(CLEAR_SCREEN, end="") | ||
|
||
perp_markets = drift_client.get_perp_market_accounts() | ||
|
||
if not perp_markets: | ||
print(f"{RED}No perp markets found (yet){RESET}") | ||
else: | ||
print("Drift Perp Markets (gRPC subscription)\n") | ||
perp_markets.sort(key=lambda x: x.market_index) | ||
for market in perp_markets[:20]: | ||
market_index = market.market_index | ||
last_price = market.amm.historical_oracle_data.last_oracle_price / 1e6 | ||
|
||
if market_index in previous_prices: | ||
old_price = previous_prices[market_index] | ||
if last_price > old_price: | ||
color = GREEN | ||
elif last_price < old_price: | ||
color = RED | ||
else: | ||
color = RESET | ||
else: | ||
color = RESET | ||
|
||
print( | ||
f"Market Index: {market_index} | " | ||
f"Price: {color}${last_price:.4f}{RESET}" | ||
) | ||
|
||
previous_prices[market_index] = last_price | ||
|
||
await asyncio.sleep(1) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(watch_drift_markets()) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import asyncio | ||
import time | ||
from typing import Callable, Optional, TypeVar | ||
|
||
import grpc.aio | ||
from anchorpy.program.core import Program | ||
from solana.rpc.commitment import Commitment | ||
from solders.pubkey import Pubkey | ||
|
||
from driftpy.accounts.grpc.geyser_codegen import geyser_pb2, geyser_pb2_grpc | ||
from driftpy.accounts.types import DataAndSlot | ||
from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber | ||
from driftpy.types import GrpcConfig | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class TritonAuthMetadataPlugin(grpc.AuthMetadataPlugin): | ||
def __init__(self, x_token: str): | ||
self.x_token = x_token | ||
|
||
def __call__( | ||
self, | ||
context: grpc.AuthMetadataContext, | ||
callback: grpc.AuthMetadataPluginCallback, | ||
): | ||
metadata = (("x-token", self.x_token),) | ||
callback(metadata, None) | ||
|
||
|
||
class GrpcAccountSubscriber(WebsocketAccountSubscriber[T]): | ||
def __init__( | ||
self, | ||
grpc_config: GrpcConfig, | ||
account_name: str, | ||
program: Program, | ||
account_public_key: Pubkey, | ||
commitment: Commitment = Commitment("confirmed"), | ||
decode: Optional[Callable[[bytes], T]] = None, | ||
initial_data: Optional[DataAndSlot[T]] = None, | ||
): | ||
super().__init__(account_public_key, program, commitment, decode, initial_data) | ||
self.client = self._create_grpc_client(grpc_config) | ||
self.stream = None | ||
self.listener_id = None | ||
self.account_name = account_name | ||
self.decode = ( | ||
decode if decode is not None else self.program.coder.accounts.decode | ||
) | ||
|
||
def _create_grpc_client(self, config: GrpcConfig) -> geyser_pb2_grpc.GeyserStub: | ||
auth = TritonAuthMetadataPlugin(config.token) | ||
ssl_creds = grpc.ssl_channel_credentials() | ||
call_creds = grpc.metadata_call_credentials(auth) | ||
combined_creds = grpc.composite_channel_credentials(ssl_creds, call_creds) | ||
|
||
channel = grpc.aio.secure_channel(config.endpoint, credentials=combined_creds) | ||
return geyser_pb2_grpc.GeyserStub(channel) | ||
|
||
async def subscribe(self) -> Optional[asyncio.Task[None]]: | ||
if self.listener_id is not None: | ||
return | ||
|
||
self.task = asyncio.create_task(self._subscribe_grpc()) | ||
return self.task | ||
|
||
async def _subscribe_grpc(self): | ||
"""Internal method to handle the gRPC subscription""" | ||
if self.data_and_slot is None: | ||
await self.fetch() | ||
|
||
try: | ||
request_iterator = self._create_subscribe_request() | ||
self.stream = self.client.Subscribe(request_iterator) | ||
await self.stream.wait_for_connection() | ||
|
||
self.listener_id = 1 | ||
|
||
async for update in self.stream: | ||
try: | ||
if update.HasField("ping") or update.HasField("pong"): | ||
continue | ||
|
||
if not update.HasField("account"): | ||
print(f"No account for {self.account_name}") | ||
continue | ||
|
||
slot = int(update.account.slot) | ||
account_info = { | ||
"owner": Pubkey.from_bytes(update.account.account.owner), | ||
"lamports": int(update.account.account.lamports), | ||
"data": bytes(update.account.account.data), | ||
"executable": update.account.account.executable, | ||
"rent_epoch": int(update.account.account.rent_epoch), | ||
} | ||
|
||
if not account_info["data"]: | ||
print(f"No data for {self.account_name}") | ||
continue | ||
|
||
decoded_data = ( | ||
self.decode(account_info["data"]) | ||
if self.decode | ||
else account_info | ||
) | ||
self.update_data(DataAndSlot(slot, decoded_data)) | ||
|
||
except Exception as e: | ||
print(f"Error processing account data for {self.account_name}: {e}") | ||
break | ||
|
||
except Exception as e: | ||
print(f"Error in gRPC subscription for {self.account_name}: {e}") | ||
if self.stream: | ||
self.stream.cancel() | ||
self.listener_id = None | ||
raise e | ||
|
||
async def _create_subscribe_request(self): | ||
request = geyser_pb2.SubscribeRequest() | ||
account_filter = geyser_pb2.SubscribeRequestFilterAccounts() | ||
account_filter.account.append(str(self.pubkey)) | ||
account_filter.nonempty_txn_signature = True | ||
request.accounts["account_monitor"].CopyFrom(account_filter) | ||
|
||
request.commitment = geyser_pb2.CommitmentLevel.CONFIRMED | ||
if self.commitment == Commitment("finalized"): | ||
request.commitment = geyser_pb2.CommitmentLevel.FINALIZED | ||
if self.commitment == Commitment("processed"): | ||
request.commitment = geyser_pb2.CommitmentLevel.PROCESSED | ||
|
||
yield request | ||
|
||
while True: | ||
await asyncio.sleep(30) | ||
ping_request = geyser_pb2.SubscribeRequest() | ||
ping_request.ping.id = int(time.time()) | ||
yield ping_request | ||
|
||
async def unsubscribe(self) -> None: | ||
if self.listener_id is not None: | ||
try: | ||
if self.stream: | ||
self.stream.cancel() | ||
self.listener_id = None | ||
except Exception as e: | ||
print(f"Error unsubscribing from account {self.account_name}: {e}") | ||
raise e | ||
|
||
def update_data(self, new_data: Optional[DataAndSlot[T]]): | ||
if new_data is None: | ||
return | ||
|
||
if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot: | ||
self.data_and_slot = new_data |
Oops, something went wrong.