Skip to content

Commit

Permalink
improve ghost inputs by not allowing to queue multiple events server …
Browse files Browse the repository at this point in the history
…side
  • Loading branch information
misko committed Oct 14, 2024
1 parent ed4765f commit d2c5964
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions hudes/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class Client:
request_idx: int = 0
sent_batch: int = -1
request_full_val: bool = False
active_inference: bool = False
active_inference: int = 0
active_request_idx: int = -1
mesh_grid_size: int = -1
mesh_grids: int = 0
Expand Down Expand Up @@ -198,7 +198,7 @@ async def inference_runner_clients(mad, client_runner_q, inference_q, stop):
client = active_clients[client_id]

# client still waiting for response just skip
if client.active_inference:
if client.active_inference > 0:
continue

client.active_request_idx = client.request_idx
Expand All @@ -222,33 +222,39 @@ async def inference_runner_clients(mad, client_runner_q, inference_q, stop):
client.sent_batch = client.batch_idx
client.force_update = True

if client.sgd > 0 and not client.active_inference:
client.active_inference = True
if client.active_inference == 0 and client.sgd > 0:
client.active_inference += 1
inference_q.put(("sgd", copy.copy(client)))
client.sgd = 0

if client.force_update or len(client.next_step) > 0:
if client.mesh_grids > 0: # if we have grids, step size changes mesh
client.force_update = True

if client.force_update or (
len(client.next_step) > 0 and client.active_inference == 0
):

client.current_step = client.next_step
client.next_step = {}

client.active_inference = True

if client.mesh_grids > 0:
logging.debug(
f"inference_runner_clients: req mesh {client.force_update}"
)
client.active_inference += 1
inference_q.put(("mesh", copy.copy(client)))
else:
logging.debug(
f"inference_runner_clients: req train {client.force_update}"
)
client.active_inference += 1
inference_q.put(("train", copy.copy(client)))
client.force_update = False

if client.request_full_val:
# send weight vector for inference
logging.debug(f"inference_runner_clients: req inference")
client.active_inference += 1
inference_q.put(
("val", copy.copy(client)),
)
Expand All @@ -259,7 +265,10 @@ async def inference_result_sender(results_q, stop):

logging.info(f"inference_result_sender: started")
while True:
msg = await asyncio.to_thread(results_q.get)
if not results_q.empty():
msg = results_q.get()
else:
msg = await asyncio.to_thread(results_q.get)

if stop is not None and stop.done():
logging.info(f"inference_result_sender: returning")
Expand All @@ -268,6 +277,8 @@ async def inference_result_sender(results_q, stop):
client_id, train_or_val, res = msg
client = active_clients[client_id]

client.active_inference -= 1 # allow next thing to run
# asyncio.sleep(0.00001)
try:
if train_or_val in ("train", "mesh", "sgd"):
# TODO need to be ok with getting errors here
Expand Down Expand Up @@ -317,7 +328,7 @@ async def inference_result_sender(results_q, stop):
websockets.exceptions.ConnectionClosedError,
) as e:
pass
client.active_inference = False
assert client.active_inference >= 0


async def wait_for_stop(inference_q, results_q, stop, client_runner_q):
Expand Down Expand Up @@ -370,7 +381,7 @@ async def process_client(websocket, client_runner_q):
batch_idx=0,
websocket=websocket,
request_idx=0,
active_inference=False,
active_inference=0,
sent_batch=-1,
)
active_clients[current_client] = client
Expand Down Expand Up @@ -428,9 +439,6 @@ async def process_client(websocket, client_runner_q):
client.total_sgd_steps += msg.sgd_steps
client.request_idx = msg.request_idx

if client.mesh_grids > 0: # if we have grids, step size changes mesh
client.force_update = True

else:
logging.warning("received invalid type from client")

Expand Down

0 comments on commit d2c5964

Please sign in to comment.