Skip to content

Commit

Permalink
Merge pull request #215 from drift-labs/sina/add-grpc
Browse files Browse the repository at this point in the history
Add grpc
  • Loading branch information
SinaKhalili authored Jan 2, 2025
2 parents 6145be6 + a5d3631 commit 3e0ccb9
Show file tree
Hide file tree
Showing 27 changed files with 2,638 additions and 573 deletions.
90 changes: 90 additions & 0 deletions examples/grpc_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
import os

from anchorpy.provider import Provider, Wallet
from dotenv import load_dotenv
from solana.rpc.async_api import AsyncClient
from solana.rpc.commitment import Commitment
from solders.keypair import Keypair

from driftpy.drift_client import AccountSubscriptionConfig, DriftClient
from driftpy.types import GrpcConfig

load_dotenv()

RED = "\033[91m"
GREEN = "\033[92m"
RESET = "\033[0m"

CLEAR_SCREEN = "\033c"


async def watch_drift_markets():
rpc_fqdn = os.environ.get("RPC_FQDN")
x_token = os.environ.get("X_TOKEN")
private_key = os.environ.get("PRIVATE_KEY")
rpc_url = os.environ.get("RPC_TRITON")

if not (rpc_fqdn and x_token and private_key and rpc_url):
raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set")

wallet = Wallet(Keypair.from_base58_string(private_key))
connection = AsyncClient(rpc_url)
provider = Provider(connection, wallet)

drift_client = DriftClient(
provider.connection,
provider.wallet,
"mainnet",
account_subscription=AccountSubscriptionConfig(
"grpc",
grpc_config=GrpcConfig(
endpoint=rpc_fqdn,
token=x_token,
commitment=Commitment("confirmed"),
),
),
)

await drift_client.subscribe()
print("Subscribed via gRPC. Listening for market updates...")

previous_prices = {}

while True:
print(CLEAR_SCREEN, end="")

perp_markets = drift_client.get_perp_market_accounts()

if not perp_markets:
print(f"{RED}No perp markets found (yet){RESET}")
else:
print("Drift Perp Markets (gRPC subscription)\n")
perp_markets.sort(key=lambda x: x.market_index)
for market in perp_markets[:20]:
market_index = market.market_index
last_price = market.amm.historical_oracle_data.last_oracle_price / 1e6

if market_index in previous_prices:
old_price = previous_prices[market_index]
if last_price > old_price:
color = GREEN
elif last_price < old_price:
color = RED
else:
color = RESET
else:
color = RESET

print(
f"Market Index: {market_index} | "
f"Price: {color}${last_price:.4f}{RESET}"
)

previous_prices[market_index] = last_price

await asyncio.sleep(1)


if __name__ == "__main__":
asyncio.run(watch_drift_markets())
781 changes: 408 additions & 373 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ mypy = "^1.7.0"
deprecated = "^1.2.14"
events = "^0.5"
numpy = "^1.26.2"
jito-searcher-client = "0.1.4"
# jito-searcher-client = "0.1.5"
grpcio = "1.68.1"
protobuf = "5.29.2"

[tool.poetry.dev-dependencies]
pytest = "^7.2.0"
Expand All @@ -97,6 +99,9 @@ mkdocs-material = "^8.1.8"
bump2version = "^1.0.1"
autopep8 = "^2.0.4"
mypy = "^1.7.0"
python-dotenv = "^1.0.0"
ruff = "^0.8.4"


[tool.poetry.group.dev.dependencies]
pytest = "^7.4.4"
Expand All @@ -111,7 +116,7 @@ build-backend = "poetry.core.masonry.api"
asyncio_mode = "strict"

[tool.ruff]
exclude = [".git", "__pycache__", "docs/source/conf.py", "old", "build", "dist"]
exclude = [".git", "__pycache__", "docs/source/conf.py", "old", "build", "dist", "**/geyser_codegen/**"]

[tool.ruff.pycodestyle]
max-line-length = 88
Expand Down
31 changes: 30 additions & 1 deletion src/driftpy/account_subscription_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
DemoDriftClientAccountSubscriber,
DemoUserAccountSubscriber,
)
from driftpy.accounts.grpc.account_subscriber import GrpcConfig
from driftpy.accounts.grpc.drift_client import GrpcDriftClientAccountSubscriber
from driftpy.accounts.grpc.user import GrpcUserAccountSubscriber
from driftpy.accounts.polling import (
PollingDriftClientAccountSubscriber,
PollingUserAccountSubscriber,
Expand All @@ -32,13 +35,17 @@ def default():

def __init__(
self,
account_subscription_type: Literal["polling", "websocket", "cached", "demo"],
account_subscription_type: Literal[
"polling", "websocket", "cached", "demo", "grpc"
],
bulk_account_loader: Optional[BulkAccountLoader] = None,
commitment: Commitment = Commitment("confirmed"),
grpc_config: Optional[GrpcConfig] = None,
):
self.type = account_subscription_type
self.commitment = commitment
self.bulk_account_loader = None
self.grpc_config = grpc_config

if self.type != "polling":
return
Expand Down Expand Up @@ -117,6 +124,18 @@ def get_drift_client_subscriber(
oracle_infos,
self.commitment,
)
case "grpc":
if self.grpc_config is None:
raise ValueError("A grpc config is required for grpc subscription")
return GrpcDriftClientAccountSubscriber(
program,
self.grpc_config,
perp_market_indexes,
spot_market_indexes,
cast(list[FullOracleWrapper], oracle_infos),
should_find_all_markets_and_oracles,
self.commitment,
)

def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
match self.type:
Expand All @@ -138,3 +157,13 @@ def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
)
case "demo":
return DemoUserAccountSubscriber(user_pubkey, program, self.commitment)
case "grpc":
if self.grpc_config is None:
raise ValueError("A grpc config is required for grpc subscription")
return GrpcUserAccountSubscriber(
grpc_config=self.grpc_config,
account_name="user",
account_public_key=user_pubkey,
program=program,
commitment=self.commitment,
)
155 changes: 155 additions & 0 deletions src/driftpy/accounts/grpc/account_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import asyncio
import time
from typing import Callable, Optional, TypeVar

import grpc.aio
from anchorpy.program.core import Program
from solana.rpc.commitment import Commitment
from solders.pubkey import Pubkey

from driftpy.accounts.grpc.geyser_codegen import geyser_pb2, geyser_pb2_grpc
from driftpy.accounts.types import DataAndSlot
from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.types import GrpcConfig

T = TypeVar("T")


class TritonAuthMetadataPlugin(grpc.AuthMetadataPlugin):
def __init__(self, x_token: str):
self.x_token = x_token

def __call__(
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
):
metadata = (("x-token", self.x_token),)
callback(metadata, None)


class GrpcAccountSubscriber(WebsocketAccountSubscriber[T]):
def __init__(
self,
grpc_config: GrpcConfig,
account_name: str,
program: Program,
account_public_key: Pubkey,
commitment: Commitment = Commitment("confirmed"),
decode: Optional[Callable[[bytes], T]] = None,
initial_data: Optional[DataAndSlot[T]] = None,
):
super().__init__(account_public_key, program, commitment, decode, initial_data)
self.client = self._create_grpc_client(grpc_config)
self.stream = None
self.listener_id = None
self.account_name = account_name
self.decode = (
decode if decode is not None else self.program.coder.accounts.decode
)

def _create_grpc_client(self, config: GrpcConfig) -> geyser_pb2_grpc.GeyserStub:
auth = TritonAuthMetadataPlugin(config.token)
ssl_creds = grpc.ssl_channel_credentials()
call_creds = grpc.metadata_call_credentials(auth)
combined_creds = grpc.composite_channel_credentials(ssl_creds, call_creds)

channel = grpc.aio.secure_channel(config.endpoint, credentials=combined_creds)
return geyser_pb2_grpc.GeyserStub(channel)

async def subscribe(self) -> Optional[asyncio.Task[None]]:
if self.listener_id is not None:
return

self.task = asyncio.create_task(self._subscribe_grpc())
return self.task

async def _subscribe_grpc(self):
"""Internal method to handle the gRPC subscription"""
if self.data_and_slot is None:
await self.fetch()

try:
request_iterator = self._create_subscribe_request()
self.stream = self.client.Subscribe(request_iterator)
await self.stream.wait_for_connection()

self.listener_id = 1

async for update in self.stream:
try:
if update.HasField("ping") or update.HasField("pong"):
continue

if not update.HasField("account"):
print(f"No account for {self.account_name}")
continue

slot = int(update.account.slot)
account_info = {
"owner": Pubkey.from_bytes(update.account.account.owner),
"lamports": int(update.account.account.lamports),
"data": bytes(update.account.account.data),
"executable": update.account.account.executable,
"rent_epoch": int(update.account.account.rent_epoch),
}

if not account_info["data"]:
print(f"No data for {self.account_name}")
continue

decoded_data = (
self.decode(account_info["data"])
if self.decode
else account_info
)
self.update_data(DataAndSlot(slot, decoded_data))

except Exception as e:
print(f"Error processing account data for {self.account_name}: {e}")
break

except Exception as e:
print(f"Error in gRPC subscription for {self.account_name}: {e}")
if self.stream:
self.stream.cancel()
self.listener_id = None
raise e

async def _create_subscribe_request(self):
request = geyser_pb2.SubscribeRequest()
account_filter = geyser_pb2.SubscribeRequestFilterAccounts()
account_filter.account.append(str(self.pubkey))
account_filter.nonempty_txn_signature = True
request.accounts["account_monitor"].CopyFrom(account_filter)

request.commitment = geyser_pb2.CommitmentLevel.CONFIRMED
if self.commitment == Commitment("finalized"):
request.commitment = geyser_pb2.CommitmentLevel.FINALIZED
if self.commitment == Commitment("processed"):
request.commitment = geyser_pb2.CommitmentLevel.PROCESSED

yield request

while True:
await asyncio.sleep(30)
ping_request = geyser_pb2.SubscribeRequest()
ping_request.ping.id = int(time.time())
yield ping_request

async def unsubscribe(self) -> None:
if self.listener_id is not None:
try:
if self.stream:
self.stream.cancel()
self.listener_id = None
except Exception as e:
print(f"Error unsubscribing from account {self.account_name}: {e}")
raise e

def update_data(self, new_data: Optional[DataAndSlot[T]]):
if new_data is None:
return

if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot:
self.data_and_slot = new_data
Loading

0 comments on commit 3e0ccb9

Please sign in to comment.