diff --git a/app.go b/app.go index c4dce9c530..cac4417a3a 100644 --- a/app.go +++ b/app.go @@ -31,11 +31,13 @@ const ( // App represents an app type App struct { - broker *Broker - client *http.Client - mode AppMode // mode - route string // route - addr string // upstream address http://host:port + broker *Broker + client *http.Client + mode AppMode // mode + route string // route + addr string // upstream address http://host:port + keyID string // access key ID + keySecret string // access key secret } // Boot represents the initial message sent when a client connects to an app @@ -53,13 +55,15 @@ func toAppMode(mode string) AppMode { return unicastMode } -func newApp(broker *Broker, mode, route, addr string) *App { +func newApp(broker *Broker, mode, route, addr, keyID, keySecret string) *App { return &App{ broker, &http.Client{}, // TODO tune keep-alive and idle timeout toAppMode(mode), route, addr, + keyID, + keySecret, } } @@ -76,6 +80,8 @@ func (app *App) send(clientID string, session *Session, data []byte) error { return fmt.Errorf("failed creating request: %v", err) } + req.SetBasicAuth(app.keyID, app.keySecret) + req.Header.Set("Content-Type", "application/json; charset=utf-8") req.Header.Set("Wave-Client-ID", clientID) if session.subject != anon { @@ -87,9 +93,12 @@ func (app *App) send(clientID string, session *Session, data []byte) error { resp, err := app.client.Do(req) if err != nil { - return fmt.Errorf("failed sending request: %v", err) + return fmt.Errorf("request failed: %v", err) } defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("request failed: %s", http.StatusText(resp.StatusCode)) + } if _, err := readWithLimit(resp.Body, 0); err != nil { // apps always return empty plain-text responses. return fmt.Errorf("failed reading response: %v", err) } diff --git a/broker.go b/broker.go index 2f3822ca2b..db87982101 100644 --- a/broker.go +++ b/broker.go @@ -83,8 +83,8 @@ func newBroker(site *Site) *Broker { } } -func (b *Broker) addApp(mode, route, addr string) { - s := newApp(b, mode, route, addr) +func (b *Broker) addApp(mode, route, addr, keyID, keySecret string) { + s := newApp(b, mode, route, addr, keyID, keySecret) b.appsMux.Lock() b.apps[route] = s diff --git a/protocol.go b/protocol.go index 168f331d44..13263675af 100644 --- a/protocol.go +++ b/protocol.go @@ -88,9 +88,11 @@ type AppRequest struct { // RegisterApp represents a request to register an app. type RegisterApp struct { - Mode string `json:"mode"` - Route string `json:"route"` - Address string `json:"address"` + Mode string `json:"mode"` + Route string `json:"route"` + Address string `json:"address"` + KeyID string `json:"key_id"` + KeySecret string `json:"key_secret"` } // UnregisterApp represents a request to unregister an app. diff --git a/protocol.md b/protocol.md index fac6228373..e85347cc1a 100644 --- a/protocol.md +++ b/protocol.md @@ -34,8 +34,10 @@ Relevant environment variables: - `WAVE_ADDRESS`: The `protocol://ip:port` of the Wave server as visible from the app server. - `WAVE_APP_ADDRESS`: The `protocol://ip:port` of the app server as visible from the Wave server. - `WAVE_APP_MODE`: The sync mode of the app, one of `unicast`, `multicast` or `broadcast`. -- `WAVE_ACCESS_KEY_ID`: The API access key ID, typically a 20-character cryptographically random string. -- `WAVE_ACCESS_KEY_SECRET`: The API access key secret, typically a 40-character cryptographically random string. +- `WAVE_ACCESS_KEY_ID`: The Wave server API access key ID, typically a cryptographically random string. +- `WAVE_ACCESS_KEY_SECRET`: The Wave server API access key secret, typically a cryptographically random string. +- `WAVE_APP_ACCESS_KEY_ID`: The app server API access key ID, typically a cryptographically random string. +- `WAVE_APP_ACCESS_KEY_SECRET`: The app server API access key secret, typically a cryptographically random string. ### Startup @@ -46,16 +48,23 @@ On app launch, the app registers itself with the Wave server by sending a `POST` "register_app": { "mode": "$WAVE_APP_MODE", "address": "$WAVE_APP_ADDRESS" + "key_id": "$WAVE_APP_ACCESS_KEY_ID", + "key_secret": "$WAVE_APP_ACCESS_KEY_SECRET", "route": "/foo", } } ``` +The `key_id` and `key_secret` are automatically generated at startup if `$WAVE_APP_ACCESS_KEY_ID` or `$WAVE_APP_ACCESS_KEY_SECRET` are empty. + ### Accepting requests The Wave server now starts forwarding browser requests from the Wave server's `/foo` to the app server's `/`. Consequently, the app framework requires exactly one HTTP handler, listening to `POST` requests at `/`. -Before parsing the HTTP request and handing over control to the app, the body of the HTTP request is captured and a plain-text empty-string response is sent to the Wave server. The Wave server ignores responses. +On receiving a request, the app server: +1. Verifies if the credentials in the request's basic-authentication header match `$WAVE_APP_ACCESS_KEY_ID` and `$WAVE_APP_ACCESS_KEY_SECRET`. +2. Captures the headers and body of the HTTP request. +3. Responds with a plain-text empty-string (200 status code). Note that the Wave server ignores responses. ### Processing requests diff --git a/py/h2o_wave/core.py b/py/h2o_wave/core.py index d53c3e4d53..5ba119f64f 100644 --- a/py/h2o_wave/core.py +++ b/py/h2o_wave/core.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import secrets import warnings import logging import os @@ -47,6 +48,8 @@ def __init__(self): self.hub_address = _get_env('ADDRESS', 'http://127.0.0.1:10101') self.hub_access_key_id: str = _get_env('ACCESS_KEY_ID', 'access_key_id') self.hub_access_key_secret: str = _get_env('ACCESS_KEY_SECRET', 'access_key_secret') + self.app_access_key_id: str = _get_env('APP_ACCESS_KEY_ID', None) or secrets.token_urlsafe(16) + self.app_access_key_secret: str = _get_env('APP_ACCESS_KEY_SECRET', None) or secrets.token_urlsafe(16) _config = _Config() diff --git a/py/h2o_wave/server.py b/py/h2o_wave/server.py index 04837d02fe..68c7ff1837 100644 --- a/py/h2o_wave/server.py +++ b/py/h2o_wave/server.py @@ -26,6 +26,8 @@ import warnings import pickle import traceback +import base64 +import binascii from typing import Dict, Tuple, Callable, Any, Awaitable, Optional from urllib.parse import urlparse @@ -38,8 +40,8 @@ from starlette.responses import PlainTextResponse from starlette.background import BackgroundTask -from .core import Expando, expando_to_dict, _config, marshal, unmarshal, _content_type_json, AsyncSite, _get_env, \ - UNICAST, MULTICAST +from .core import Expando, expando_to_dict, _config, marshal, _content_type_json, AsyncSite, _get_env, UNICAST, \ + MULTICAST from .ui import markdown_card logger = logging.getLogger(__name__) @@ -237,7 +239,14 @@ def __init__(self, route: str, handle: HandleAsync, mode=None, on_startup: Optio async def _register(self): app_address = _get_env('APP_ADDRESS', _config.app_address) logger.debug(f'Registering app at {app_address} ...') - await self._wave.call('register_app', mode=self._mode, route=self._route, address=app_address) + await self._wave.call( + 'register_app', + mode=self._mode, + route=self._route, + address=app_address, + key_id=_config.app_access_key_id, + key_secret=_config.app_access_key_secret, + ) logger.debug('Register: success!') async def _unregister(self): @@ -246,6 +255,21 @@ async def _unregister(self): logger.debug('Unregister: success!') async def _receive(self, req: Request): + basic_auth = req.headers.get("Authorization") + if basic_auth is None: + return PlainTextResponse(content='Unauthorized', status_code=401) + try: + scheme, credentials = basic_auth.split() + if scheme.lower() != 'basic': + return PlainTextResponse(content='Unauthorized', status_code=401) + decoded = base64.b64decode(credentials).decode("ascii") + except (ValueError, UnicodeDecodeError, binascii.Error) as exc: + return PlainTextResponse(content='Unauthorized', status_code=401) + + key_id, _, key_secret = decoded.partition(":") + if key_id != _config.app_access_key_id or key_secret != _config.app_access_key_secret: + return PlainTextResponse(content='Unauthorized', status_code=401) + client_id = req.headers.get('Wave-Client-ID') subject = req.headers.get('Wave-Subject-ID') username = req.headers.get('Wave-Username') @@ -253,6 +277,7 @@ async def _receive(self, req: Request): refresh_token = req.headers.get('Wave-Refresh-Token') auth = Auth(username, subject, access_token, refresh_token) args = await req.json() + return PlainTextResponse('', background=BackgroundTask(self._process, client_id, auth, args)) async def _process(self, client_id: str, auth: Auth, args: dict): diff --git a/web_server.go b/web_server.go index 6f0314a101..ff00dcaad7 100644 --- a/web_server.go +++ b/web_server.go @@ -123,7 +123,7 @@ func (s *WebServer) post(w http.ResponseWriter, r *http.Request) { } if req.RegisterApp != nil { q := req.RegisterApp - s.broker.addApp(q.Mode, q.Route, q.Address) + s.broker.addApp(q.Mode, q.Route, q.Address, q.KeyID, q.KeySecret) } else if req.UnregisterApp != nil { q := req.UnregisterApp s.broker.dropApp(q.Route)