Skip to content

Commit

Permalink
Add more information when returning 403 (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Yushkovskiy authored Aug 15, 2019
1 parent 1f584d4 commit 768ceff
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 12 deletions.
38 changes: 27 additions & 11 deletions platform_api/handlers/jobs_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import json
import logging
from dataclasses import dataclass, replace
from pathlib import PurePath
from typing import Any, Dict, List, Optional, Sequence, Set

import aiohttp.web
import trafaret as t
from aiohttp_security import check_authorized, check_permission
from aiohttp_security import check_authorized
from multidict import MultiDictProxy
from neuro_auth_client import AuthClient, Permission
from neuro_auth_client.client import ClientSubTreeViewRoot
Expand Down Expand Up @@ -254,8 +255,7 @@ async def create_job(self, request: aiohttp.web.Request) -> aiohttp.web.Response
permissions = infer_permissions_from_container(
user, container, cluster_config.registry
)
logger.info("Checking whether %r has %r", user, permissions)
await check_permission(request, permissions[0].action, permissions)
await self._check_permissions(request, user, permissions)

name = request_payload.get("name")
description = request_payload.get("description")
Expand All @@ -282,8 +282,7 @@ async def handle_get(self, request: aiohttp.web.Request) -> aiohttp.web.Response
job = await self._jobs_service.get_job(job_id)

permission = Permission(uri=str(job.to_uri()), action="read")
logger.info("Checking whether %r has %r", user, permission)
await check_permission(request, permission.action, [permission])
await self._check_permissions(request, user, [permission])

cluster_name = self._jobs_service.get_cluster_name(job)
response_payload = convert_job_to_job_response(job, cluster_name)
Expand Down Expand Up @@ -355,8 +354,7 @@ async def handle_delete(
job = await self._jobs_service.get_job(job_id)

permission = Permission(uri=str(job.to_uri()), action="write")
logger.info("Checking whether %r has %r", user, permission)
await check_permission(request, permission.action, [permission])
await self._check_permissions(request, user, [permission])

await self._jobs_service.delete_job(job_id)
raise aiohttp.web.HTTPNoContent()
Expand All @@ -369,8 +367,7 @@ async def stream_log(
job = await self._jobs_service.get_job(job_id)

permission = Permission(uri=str(job.to_uri()), action="read")
logger.info("Checking whether %r has %r", user, permission)
await check_permission(request, permission.action, [permission])
await self._check_permissions(request, user, [permission])

log_reader = await self._jobs_service.get_job_log_reader(job_id)
# TODO: expose. make configurable
Expand Down Expand Up @@ -401,8 +398,7 @@ async def stream_top(
job = await self._jobs_service.get_job(job_id)

permission = Permission(uri=str(job.to_uri()), action="read")
logger.info("Checking whether %r has %r", user, permission)
await check_permission(request, permission.action, [permission])
await self._check_permissions(request, user, [permission])

logger.info("Websocket connection starting")
ws = aiohttp.web.WebSocketResponse()
Expand Down Expand Up @@ -457,6 +453,26 @@ def _convert_job_stats_to_ws_message(self, job_stats: JobStats) -> Dict[str, Any
message["gpu_memory"] = job_stats.gpu_memory
return message

async def _check_permissions(
self, request: aiohttp.web.Request, user: User, permissions: List[Permission]
) -> None:
await check_authorized(request)
assert permissions, "empty permission set to check"
logger.info("Checking whether %r has %r", user, permissions)
missing = await self._auth_client.get_missing_permissions(
user.name, permissions
)
if missing:
error_details = {
"resources": [self._permission_to_primitive(p) for p in missing]
}
raise aiohttp.web.HTTPForbidden(
text=json.dumps(error_details), content_type="application/json"
)

def _permission_to_primitive(self, perm: Permission) -> Dict[str, str]:
return {"uri": perm.uri, "action": perm.action}


class JobFilterException(ValueError):
pass
Expand Down
4 changes: 4 additions & 0 deletions platform_api/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field

from aiohttp.web import HTTPUnauthorized, Request
Expand All @@ -10,6 +11,9 @@
)


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class User:
name: str
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"dataclasses==0.6", # backport from 3.7 stdlib
"iso8601==0.1.12",
"trafaret==1.2.0",
"neuro_auth_client==1.0.6",
"neuro_auth_client==1.0.7",
# Circle CI fails on the latest cryptography version
# because the server has too old OpenSSL version
"cryptography==2.7",
Expand Down
94 changes: 94 additions & 0 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiohttp.web
import multidict
import pytest
from aiohttp import WSServerHandshakeError
from aiohttp.web import (
HTTPAccepted,
HTTPBadRequest,
Expand Down Expand Up @@ -276,6 +277,8 @@ async def test_forbidden_storage_uri(
url, headers=regular_user.headers, json=payload
) as response:
assert response.status == HTTPForbidden.status_code, await response.text()
data = await response.json()
assert data == {"resources": [{"action": "write", "uri": "storage:"}]}

@pytest.mark.asyncio
async def test_forbidden_image(
Expand All @@ -298,6 +301,10 @@ async def test_forbidden_image(
url, headers=regular_user.headers, json=payload
) as response:
assert response.status == HTTPForbidden.status_code, await response.text()
data = await response.json()
assert data == {
"resources": [{"action": "read", "uri": "image://anotheruser/image"}]
}

@pytest.mark.asyncio
async def test_allowed_image(
Expand Down Expand Up @@ -1143,6 +1150,10 @@ async def test_get_shared_job(

async with client.get(url, headers=follower.headers) as response:
assert response.status == HTTPForbidden.status_code
data = await response.json()
assert data == {
"resources": [{"action": "read", "uri": f"job://{owner.name}/{job_id}"}]
}

permission = Permission(uri=f"job://{owner.name}/{job_id}", action="read")
await auth_client.grant_user_permissions(
Expand Down Expand Up @@ -1429,6 +1440,35 @@ async def test_delete_job(
assert jobs[0]["status"] == "succeeded"
assert jobs[0]["id"] == job_id

@pytest.mark.asyncio
async def test_delete_job_forbidden(
self,
api: ApiConfig,
client: aiohttp.ClientSession,
job_submit: Dict[str, Any],
jobs_client: JobsClient,
regular_user_factory: Callable[..., Awaitable[_User]],
regular_user: _User,
) -> None:
url = api.jobs_base_url
async with client.post(
url, headers=regular_user.headers, json=job_submit
) as response:
assert response.status == HTTPAccepted.status_code, await response.text()
result = await response.json()
job_id = result["id"]

url = api.generate_job_url(job_id)
another_user = await regular_user_factory()
async with client.delete(url, headers=another_user.headers) as response:
assert response.status == HTTPForbidden.status_code, await response.text()
result = await response.json()
assert result == {
"resources": [
{"action": "write", "uri": f"job://{regular_user.name}/{job_id}"}
]
}

@pytest.mark.asyncio
async def test_delete_already_deleted(
self,
Expand Down Expand Up @@ -1494,6 +1534,35 @@ async def test_job_log(
expected_payload = "\n".join(str(i) for i in range(1, 6)) + "\n"
assert actual_payload == expected_payload.encode()

@pytest.mark.asyncio
async def test_job_log_forbidden(
self,
api: ApiConfig,
client: aiohttp.ClientSession,
job_submit: Dict[str, Any],
jobs_client: JobsClient,
regular_user_factory: Callable[..., Awaitable[_User]],
regular_user: _User,
) -> None:
url = api.jobs_base_url
async with client.post(
url, headers=regular_user.headers, json=job_submit
) as response:
assert response.status == HTTPAccepted.status_code
result = await response.json()
job_id = result["id"]

url = api.jobs_base_url + f"/{job_id}/log"
another_user = await regular_user_factory()
async with client.get(url, headers=another_user.headers) as response:
assert response.status == HTTPForbidden.status_code
result = await response.json()
assert result == {
"resources": [
{"action": "read", "uri": f"job://{regular_user.name}/{job_id}"}
]
}

@pytest.mark.asyncio
async def test_create_validation_failure(
self, api: ApiConfig, client: aiohttp.ClientSession, regular_user: _User
Expand Down Expand Up @@ -1881,3 +1950,28 @@ async def test_job_top_close_when_job_succeeded(
assert job["status"] == "succeeded"

await jobs_client.delete_job(job_id=job_id)

@pytest.mark.asyncio
async def test_job_top_forbidden(
self,
api: ApiConfig,
client: aiohttp.ClientSession,
jobs_client: JobsClient,
job_submit: Dict[str, Any],
regular_user_factory: Callable[..., Awaitable[_User]],
regular_user: _User,
) -> None:

url = api.jobs_base_url
async with client.post(
url, headers=regular_user.headers, json=job_submit
) as response:
assert response.status == HTTPAccepted.status_code, await response.text()
result = await response.json()
job_id = result["id"]

url = api.jobs_base_url + f"/{job_id}/top"
another_user = await regular_user_factory()
with pytest.raises(WSServerHandshakeError, match="403"):
async with client.ws_connect(url, headers=another_user.headers):
pass

0 comments on commit 768ceff

Please sign in to comment.