Skip to content

Commit

Permalink
Unify executor and futures logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Neil Booth committed Jan 24, 2017
1 parent 9b5cb10 commit cb01609
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 89 deletions.
18 changes: 8 additions & 10 deletions server/block_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ class BlockProcessor(server.db.DB):
Coordinate backing up in case of chain reorganisations.
'''

def __init__(self, env, daemon):
def __init__(self, env, controller, daemon):
super().__init__(env)
self.daemon = daemon
self.controller = controller

# These are our state as we move ahead of DB state
self.fs_height = self.db_height
Expand Down Expand Up @@ -190,6 +191,7 @@ def on_prefetcher_first_caught_up(self):

async def main_loop(self):
'''Main loop for block processing.'''
self.controller.ensure_future(self.prefetcher.main_loop())
await self.prefetcher.reset_height()

while True:
Expand All @@ -205,16 +207,11 @@ def shutdown(self, executor):
self.logger.info('flushing state to DB for a clean shutdown...')
self.flush(True)

async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, partial(func, *args, **kwargs))

async def first_caught_up(self):
'''Called when first caught up to daemon after starting.'''
# Flush everything with updated first_sync->False state.
self.first_sync = False
await self.executor(self.flush, True)
await self.controller.run_in_executor(self.flush, True)
if self.utxo_db.for_sync:
self.logger.info('{} synced to height {:,d}'
.format(VERSION, self.height))
Expand All @@ -240,7 +237,8 @@ async def check_and_advance_blocks(self, blocks, first):

if hprevs == chain:
start = time.time()
await self.executor(self.advance_blocks, blocks, headers)
await self.controller.run_in_executor(self.advance_blocks,
blocks, headers)
if not self.first_sync:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'
Expand Down Expand Up @@ -277,14 +275,14 @@ async def reorg_chain(self, count=None):
self.logger.info('chain reorg detected')
else:
self.logger.info('faking a reorg of {:,d} blocks'.format(count))
await self.executor(self.flush, True)
await self.controller.run_in_executor(self.flush, True)

hashes = await self.reorg_hashes(count)
# Reverse and convert to hex strings.
hashes = [hash_to_str(hash) for hash in reversed(hashes)]
for hex_hashes in chunks(hashes, 50):
blocks = await self.daemon.raw_blocks(hex_hashes)
await self.executor(self.backup_blocks, blocks)
await self.controller.run_in_executor(self.backup_blocks, blocks)
await self.prefetcher.reset_height()

async def reorg_hashes(self, count):
Expand Down
99 changes: 58 additions & 41 deletions server/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import ssl
import time
import traceback
from bisect import bisect_left
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -49,9 +50,9 @@ def __init__(self, env):
self.start_time = time.time()
self.coin = env.coin
self.daemon = Daemon(env.coin.daemon_urls(env.daemon_url))
self.bp = BlockProcessor(env, self.daemon)
self.mempool = MemPool(self.bp)
self.peers = PeerManager(env)
self.bp = BlockProcessor(env, self, self.daemon)
self.mempool = MemPool(self.bp, self)
self.peers = PeerManager(env, self)
self.env = env
self.servers = {}
# Map of session to the key of its list in self.groups
Expand All @@ -63,6 +64,7 @@ def __init__(self, env):
self.max_sessions = env.max_sessions
self.low_watermark = self.max_sessions * 19 // 20
self.max_subs = env.max_subs
self.futures = set()
# Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0
self.next_stale_check = 0
Expand Down Expand Up @@ -199,43 +201,59 @@ async def serve_requests(self):
if session.items:
self.enqueue_session(session)

def initiate_shutdown(self):
'''Call this function to start the shutdown process.'''
self.shutdown_event.set()
async def run_in_executor(self, func, *args):
'''Wait whilst running func in the executor.'''
return await self.loop.run_in_executor(None, func, *args)

def schedule_executor(self, func, *args):
'''Schedule running func in the executor, return a task.'''
return self.ensure_future(self.run_in_executor(func, *args))

def ensure_future(self, coro):
'''Schedule the coro to be run.'''
future = asyncio.ensure_future(coro)
future.add_done_callback(self.on_future_done)
self.futures.add(future)
return future

def on_future_done(self, future):
'''Collect the result of a future after removing it from our set.'''
self.futures.remove(future)
try:
future.result()
except asyncio.CancelledError:
pass
except Exception:
self.log_error(traceback.format_exc())

async def wait_for_bp_catchup(self):
'''Called when the block processor catches up.'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
self.ensure_future(self.peers.main_loop())
self.ensure_future(self.start_servers())
self.ensure_future(self.mempool.main_loop())
self.ensure_future(self.enqueue_delayed_sessions())
self.ensure_future(self.notify())
for n in range(4):
self.ensure_future(self.serve_requests())

async def main_loop(self):
'''Controller main loop.'''
def add_future(coro):
futures.append(asyncio.ensure_future(coro))

async def await_bp_catchup():
'''Wait for the block processor to catch up.
Then start the servers and the peer manager.
'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
add_future(self.peers.main_loop())
add_future(self.start_servers())
add_future(self.mempool.main_loop())
add_future(self.enqueue_delayed_sessions())
add_future(self.notify())
for n in range(4):
add_future(self.serve_requests())

futures = []
add_future(self.bp.main_loop())
add_future(self.bp.prefetcher.main_loop())
add_future(await_bp_catchup())

# Perform a clean shutdown when this event is signalled.
await self.shutdown_event.wait()
self.ensure_future(self.bp.main_loop())
self.ensure_future(self.wait_for_bp_catchup())

# Shut down cleanly after waiting for shutdown to be signalled
await self.shutdown_event.wait()
self.logger.info('shutting down')
await self.shutdown(futures)
await self.shutdown()
self.logger.info('shutdown complete')

async def shutdown(self, futures):
def initiate_shutdown(self):
'''Call this function to start the shutdown process.'''
self.shutdown_event.set()

async def shutdown(self):
'''Perform the shutdown sequence.'''
self.state = self.SHUTTING_DOWN

Expand All @@ -244,13 +262,13 @@ async def shutdown(self, futures):
for session in self.sessions:
self.close_session(session)

# Cancel the futures
for future in futures:
# Cancel pending futures
for future in self.futures:
future.cancel()

# Wait for all futures to finish
while any(not future.done() for future in futures):
await asyncio.sleep(1)
while not all (future.done() for future in self.futures):
await asyncio.sleep(0.1)

# Finally shut down the block processor and executor
self.bp.shutdown(self.executor)
Expand Down Expand Up @@ -694,8 +712,7 @@ def job():
limit = self.env.max_send // 97
return list(self.bp.get_history(hashX, limit=limit))

loop = asyncio.get_event_loop()
history = await loop.run_in_executor(None, job)
history = await self.run_in_executor(job)
self.history_cache[hashX] = history
return history

Expand Down Expand Up @@ -725,8 +742,8 @@ async def get_utxos(self, hashX):
'''Get UTXOs asynchronously to reduce latency.'''
def job():
return list(self.bp.get_utxos(hashX, limit=None))
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, job)

return await self.run_in_executor(job)

def get_chunk(self, index):
'''Return header chunk as hex. Index is a non-negative integer.'''
Expand Down
9 changes: 4 additions & 5 deletions server/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ class MemPool(util.LoggedClass):
A pair is a (hashX, value) tuple. tx hashes are hex strings.
'''

def __init__(self, bp):
def __init__(self, bp, controller):
super().__init__()
self.daemon = bp.daemon
self.controller = controller
self.coin = bp.coin
self.db = bp
self.touched = bp.touched
Expand Down Expand Up @@ -139,7 +140,6 @@ async def main_loop(self):
break

def async_process_some(self, unfetched, limit):
loop = asyncio.get_event_loop()
pending = []
txs = self.txs

Expand All @@ -162,9 +162,8 @@ async def process(unprocessed):
deferred = pending
pending = []

def job():
return self.process_raw_txs(raw_txs, deferred)
result, deferred = await loop.run_in_executor(None, job)
result, deferred = await self.controller.run_in_executor \
(self.process_raw_txs, raw_txs, deferred)

pending.extend(deferred)
hashXs = self.hashXs
Expand Down
37 changes: 5 additions & 32 deletions server/peers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

'''Peer management.'''

import asyncio
import socket
import traceback
from collections import namedtuple
from functools import partial

import lib.util as util
from server.irc import IRC
Expand All @@ -30,12 +27,11 @@ class PeerManager(util.LoggedClass):
VERSION = '1.0'
DEFAULT_PORTS = {'t': 50001, 's': 50002}

def __init__(self, env):
def __init__(self, env, controller):
super().__init__()
self.env = env
self.loop = asyncio.get_event_loop()
self.controller = controller
self.irc = IRC(env, self)
self.futures = set()
self.identities = []
# Keyed by nick
self.irc_peers = {}
Expand All @@ -51,10 +47,6 @@ def __init__(self, env):
env.report_ssl_port_tor,
'_tor'))

async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
await self.loop.run_in_executor(None, partial(func, *args, **kwargs))

@classmethod
def real_name(cls, identity):
'''Real name as used on IRC.'''
Expand All @@ -70,38 +62,19 @@ def port_text(letter, port):
ssl = port_text('s', identity.ssl_port)
return '{} v{}{}{}'.format(identity.host, cls.VERSION, tcp, ssl)

def ensure_future(self, coro):
'''Convert a coro into a future and add it to our pending list
to be waited for.'''
self.futures.add(asyncio.ensure_future(coro))

def start_irc(self):
'''Start up the IRC connections if enabled.'''
if self.env.irc:
name_pairs = [(self.real_name(identity), identity.nick_suffix)
for identity in self.identities]
self.ensure_future(self.irc.start(name_pairs))
self.controller.ensure_future(self.irc.start(name_pairs))
else:
self.logger.info('IRC is disabled')

async def main_loop(self):
'''Start and then enter the main loop.'''
'''Main loop. No loop for now.'''
self.start_irc()

try:
while True:
await asyncio.sleep(10)
done = [future for future in self.futures if future.done()]
self.futures.difference_update(done)
for future in done:
try:
future.result()
except:
self.log_error(traceback.format_exc())
finally:
for future in self.futures:
future.cancel()

def dns_lookup_peer(self, nick, hostname, details):
try:
ip_addr = None
Expand All @@ -119,7 +92,7 @@ def dns_lookup_peer(self, nick, hostname, details):

def add_irc_peer(self, *args):
'''Schedule DNS lookup of peer.'''
self.ensure_future(self.executor(self.dns_lookup_peer, *args))
self.controller.schedule_executor(self.dns_lookup_peer, *args)

def remove_irc_peer(self, nick):
'''Remove a peer from our IRC peers map.'''
Expand Down
1 change: 0 additions & 1 deletion server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
'''Classes for local RPC server and remote client TCP/SSL servers.'''


import asyncio
import time
import traceback
from functools import partial
Expand Down

0 comments on commit cb01609

Please sign in to comment.