Skip to content

Commit

Permalink
[FIX] new user (#769)
Browse files Browse the repository at this point in the history
* attempt to fix adding a new user when posting

* style

* fix tests
  • Loading branch information
jdkent authored Jun 14, 2024
1 parent d540f69 commit 7df1e11
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
11 changes: 7 additions & 4 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
raiseload,
selectinload,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.exc import SQLAlchemyError, IntegrityError
from sqlalchemy import func
from webargs.flaskparser import parser
from webargs import fields
Expand Down Expand Up @@ -242,9 +242,12 @@ def update_or_create(cls, data, id=None, user=None, record=None, flush=True):
current_user = user or get_current_user()
if not current_user:
current_user = create_user()

db.session.add(current_user)
db.session.commit()
try:
db.session.add(current_user)
db.session.commit()
except (SQLAlchemyError, IntegrityError):
db.session.rollback()
current_user = User.query.filter_by(external_id=context["user"]).first()

id = id or data.get("id", None) # want to handle case of {"id": "asdfasf"}

Expand Down
6 changes: 6 additions & 0 deletions store/neurostore/tests/api/test_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,9 @@ def test_studies_flat(auth_client, ingest_neurosynth, session):

assert "analyses" not in flat_resp.json()["results"][0]
assert "analyses" in reg_resp.json()["results"][0]


def test_create_study_new_user(new_user_client, mock_auth0_auth, session):

study_resp = new_user_client.post("/api/studies/", data={"name": "test"})
assert study_resp.status_code == 200
50 changes: 41 additions & 9 deletions store/neurostore/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
Entity,
)
from auth0.v3.authentication import GetToken
from auth0.v3.authentication.users import Users
from unittest.mock import patch


import shortuuid
import vcr

Expand Down Expand Up @@ -81,6 +85,9 @@ def mock_decode_token(token):
algorithm="HS256",
):
return {"sub": os.environ.get("COMPOSE_AUTH0_CLIENT_ID") + "@clients"}
# new user not in the database
elif token == encode({"sub": "newuser-id"}, "789", algorithm="HS256"):
return {"sub": "newuser-id"}


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -210,6 +217,14 @@ def session(db):
connection.close()


@pytest.fixture(scope="session")
def mock_auth0_auth():
with patch.object(
Users, "userinfo", return_value={"name": "newuser", "nickname": "new user"}
):
yield


"""
Data population fixtures
"""
Expand All @@ -221,6 +236,12 @@ def auth_client(auth_clients):
return auth_clients[0]


@pytest.fixture(scope="function")
def new_user_client(auth_clients):
"""Return authorized client wrapper for new user"""
return next(c for c in auth_clients if c.username == "newuser-id")


@pytest.fixture(scope="function")
def auth_clients(mock_add_users, app):
"""Return authorized client wrapper"""
Expand Down Expand Up @@ -260,25 +281,34 @@ def mock_add_users(app, db, session, mock_auth):
algorithm="HS256",
),
},
{
"name": "newuser",
"password": "newpassword",
"access_token": encode({"sub": "newuser-id"}, "789", algorithm="HS256"),
},
]

tokens = {}
for u in users:
token_info = mock_decode_token(u["access_token"])
user = User(
name=u["name"],
external_id=token_info["sub"],
)
if User.query.filter_by(external_id=token_info["sub"]).first() is None:
db.session.add(user)
db.session.commit()

tokens[u["name"]] = {
"token": u["access_token"],
"external_id": token_info["sub"],
"id": User.query.filter_by(external_id=token_info["sub"]).first().id,
}

if u["name"] != "newuser":
user = User(
name=u["name"],
external_id=token_info["sub"],
)
if User.query.filter_by(external_id=token_info["sub"]).first() is None:
db.session.add(user)
db.session.commit()

tokens[u["name"]]["id"] = (
User.query.filter_by(external_id=token_info["sub"]).first().id
)

yield tokens


Expand Down Expand Up @@ -397,6 +427,8 @@ def user_data(session, mock_add_users):
)
public_studies = []
for user_info in mock_add_users.values():
if user_info["external_id"] == "newuser-id":
continue
user = User.query.filter_by(id=user_info["id"]).first()
for level in ["group", "meta"]:
entity = Entity(level=level)
Expand Down

0 comments on commit 7df1e11

Please sign in to comment.