Skip to content

Commit

Permalink
add direct database access client
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed May 21, 2024
1 parent c9b3198 commit 6cbba72
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 13 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,3 @@ print(c2.text)
You can set any attribute of the `Client`, except `{"address", "room", "sio"}` and it will be shared.
The data must be JSON-serializable to be shared.


# TODOs
- direct database access
- frozen client
60 changes: 53 additions & 7 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def start_server():
else:
raise TimeoutError("Server did not start in time")

yield f"http://localhost:{port}"
yield f"http://localhost:{port}", None

thread.kill()

Expand Down Expand Up @@ -65,14 +65,14 @@ def start_server():
else:
raise TimeoutError("Server did not start in time")

yield f"http://localhost:{port}"
yield f"http://localhost:{port}", db_path

thread.kill()


@pytest.mark.parametrize("server", ["eventlet_memory_server", "eventlet_sql_server"])
def test_example(server, request):
eventlet_server = request.getfixturevalue(server)
eventlet_server, _ = request.getfixturevalue(server)
c1 = znsocket.client.Client(eventlet_server, room="tmp")
c2 = znsocket.client.Client(eventlet_server, room="tmp")
c3 = znsocket.client.Client(eventlet_server)
Expand All @@ -95,15 +95,15 @@ def test_example(server, request):

@pytest.mark.parametrize("server", ["eventlet_memory_server", "eventlet_sql_server"])
def test_attribute_error(server, request):
eventlet_server = request.getfixturevalue(server)
eventlet_server, _ = request.getfixturevalue(server)
c1 = znsocket.client.Client(eventlet_server, room="tmp")
with pytest.raises(AttributeError):
_ = c1.non_existent_attribute


@pytest.mark.parametrize("server", ["eventlet_memory_server", "eventlet_sql_server"])
def test_multiple_attributes(server, request):
eventlet_server = request.getfixturevalue(server)
eventlet_server, _ = request.getfixturevalue(server)
c1 = znsocket.client.Client(eventlet_server, room="tmp")
c2 = znsocket.client.Client(eventlet_server, room="tmp")

Expand All @@ -127,7 +127,7 @@ def test_multiple_attributes(server, request):

@pytest.mark.parametrize("server", ["eventlet_memory_server", "eventlet_sql_server"])
def test_frozen_client(server, request):
eventlet_server = request.getfixturevalue(server)
eventlet_server, _ = request.getfixturevalue(server)
client = znsocket.client.FrozenClient(eventlet_server, room="tmp")

client.a = "1"
Expand All @@ -140,7 +140,7 @@ def test_frozen_client(server, request):

@pytest.mark.parametrize("server", ["eventlet_memory_server", "eventlet_sql_server"])
def test_frozen_client_pull(server, request):
eventlet_server = request.getfixturevalue(server)
eventlet_server, _ = request.getfixturevalue(server)
client = znsocket.client.Client(eventlet_server, room="tmp")
client.a = "1"
client.b = "2"
Expand Down Expand Up @@ -170,3 +170,49 @@ def test_frozen_client_pull(server, request):

assert frozen_client.a == "3"
assert frozen_client.b == "4"


@pytest.mark.parametrize("server", ["eventlet_sql_server"])
def test_db_client(server, request):
eventlet_server, db_path = request.getfixturevalue(server)

db_client = znsocket.client.DBClient(
db=SqlDatabase(engine=f"sqlite:///{db_path}"), room="tmp"
)
db_client.a = "1"
db_client.b = "2"

assert db_client.a == "1"
assert db_client.b == "2"

db_client.a = "3"
db_client.b = "4"

assert db_client.a == "3"
assert db_client.b == "4"


@pytest.mark.parametrize("server", ["eventlet_sql_server"])
def test_db_client_shared(server, request):
eventlet_server, db_path = request.getfixturevalue(server)

db_client = znsocket.client.DBClient(
db=SqlDatabase(engine=f"sqlite:///{db_path}"), room="tmp"
)
client = znsocket.client.Client(eventlet_server, room="tmp")

client.a = "1"
client.b = "2"

eventlet.sleep(0.1)

assert db_client.a == "1"
assert db_client.b == "2"

db_client.a = "3"
db_client.b = "4"

eventlet.sleep(0.1)

assert client.a == "3"
assert client.b == "4"
4 changes: 2 additions & 2 deletions znsocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .client import Client, FrozenClient
from .client import Client, DBClient, FrozenClient

__all__ = ["Client", "FrozenClient"]
__all__ = ["Client", "FrozenClient", "DBClient"]
28 changes: 28 additions & 0 deletions znsocket/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import dataclasses
import uuid
from typing import Any

import socketio

from znsocket.db.base import Database


@dataclasses.dataclass
class Client:
Expand Down Expand Up @@ -90,3 +93,28 @@ def __getattribute__(self, name: str) -> Any:
return self._data[name]
else:
return super().__getattribute__(name)


@dataclasses.dataclass
class DBClient:
db: Database
sid: str = None
room: str = None

def __post_init__(self):
self.sid = uuid.uuid4().hex
self.db.join_room(self.sid, self.room)

def __setattr__(self, name: str, value: Any) -> None:
if name not in [x.name for x in dataclasses.fields(self)]:
self.db.set_room_storage(self.sid, name, value)
else:
super().__setattr__(name, value)

def __getattribute__(self, name: str) -> Any:
if name.startswith("_"):
return super().__getattribute__(name)
if name not in [x.name for x in dataclasses.fields(self)]:
return self.db.get_room_storage(self.sid, name)
else:
return super().__getattribute__(name)

0 comments on commit 6cbba72

Please sign in to comment.