From 2958508fbd8680cf21e095e8d63584b1eec8943a Mon Sep 17 00:00:00 2001 From: Prithvi Prabhu Date: Fri, 16 Apr 2021 15:06:35 -0700 Subject: [PATCH] feat: Simplify server -> app request forwarding #745 --- app.go | 32 +++++++++++++++++------- auth.go | 8 ++++-- client.go | 44 ++------------------------------- py/h2o_wave/server.py | 57 ++++++++++++------------------------------- 4 files changed, 47 insertions(+), 94 deletions(-) diff --git a/app.go b/app.go index 8d6d1a22a9..c4dce9c530 100644 --- a/app.go +++ b/app.go @@ -16,6 +16,7 @@ package wave import ( "bytes" + "fmt" "net/http" ) @@ -62,22 +63,35 @@ func newApp(broker *Broker, mode, route, addr string) *App { } } -func (app *App) forward(data []byte) { - if !app.send(data) { +func (app *App) forward(clientID string, session *Session, data []byte) { + if err := app.send(clientID, session, data); err != nil { + echo(Log{"t": "app", "route": app.route, "host": app.addr, "error": err.Error()}) app.broker.dropApp(app.route) } } -func (app *App) send(data []byte) bool { - resp, err := app.client.Post(app.addr, "text/plain; charset=utf-8", bytes.NewReader(data)) +func (app *App) send(clientID string, session *Session, data []byte) error { + req, err := http.NewRequest("POST", app.addr, bytes.NewReader(data)) if err != nil { - echo(Log{"t": "app", "route": app.route, "host": app.addr, "error": err.Error()}) - return false + return fmt.Errorf("failed creating request: %v", err) + } + + req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Wave-Client-ID", clientID) + if session.subject != anon { + req.Header.Set("Wave-Subject-ID", session.subject) + req.Header.Set("Wave-Username", session.username) + req.Header.Set("Wave-Access-Token", session.token.AccessToken) + req.Header.Set("Wave-Refresh-Token", session.token.RefreshToken) + } + + resp, err := app.client.Do(req) + if err != nil { + return fmt.Errorf("failed sending request: %v", err) } defer resp.Body.Close() if _, err := readWithLimit(resp.Body, 0); err != nil { // apps always return empty plain-text responses. - echo(Log{"t": "app", "route": app.route, "host": app.addr, "error": err.Error()}) - return false + return fmt.Errorf("failed reading response: %v", err) } - return true + return nil } diff --git a/auth.go b/auth.go index 3af94c45db..b214c8cce4 100644 --- a/auth.go +++ b/auth.go @@ -60,9 +60,13 @@ type Session struct { token *oauth2.Token } +const ( + anon = "anon" +) + var anonymous = &Session{ - subject: "anonymous", - username: "anonymous", + subject: anon, + username: anon, token: &oauth2.Token{}, } diff --git a/client.go b/client.go index fa3c5cc57f..2b58bc89bb 100644 --- a/client.go +++ b/client.go @@ -15,7 +15,6 @@ package wave import ( - "bytes" "context" "encoding/json" "time" @@ -117,7 +116,7 @@ func (c *Client) listen() { echo(Log{"t": "query", "client": c.addr, "route": m.addr, "error": "service unavailable"}) continue } - app.forward(c.format(m.data)) + app.forward(c.id, c.session, m.data) case watchMsgT: c.subscribe(m.addr) // subscribe even if page is currently NA @@ -136,7 +135,7 @@ func (c *Client) listen() { } } - app.forward(c.format(boot)) + app.forward(c.id, c.session, boot) continue } @@ -214,42 +213,3 @@ func (c *Client) flush() { func (c *Client) quit() { close(c.data) } - -var ( - usernameHeader = []byte("u:") - subjectHeader = []byte("s:") - clientIDHeader = []byte("c:") - accessTokenHeader = []byte("a:") - refreshTokenHeader = []byte("r:") - queryBodySep = []byte("\n\n") -) - -func (c *Client) format(data []byte) []byte { - var buf bytes.Buffer - - s := c.session - - buf.Write(usernameHeader) - buf.WriteString(s.username) - buf.WriteByte('\n') - - buf.Write(subjectHeader) - buf.WriteString(s.subject) - buf.WriteByte('\n') - - buf.Write(clientIDHeader) - buf.WriteString(c.id) - buf.WriteByte('\n') - - buf.Write(accessTokenHeader) - buf.WriteString(s.token.AccessToken) - buf.WriteByte('\n') - - buf.Write(refreshTokenHeader) - buf.WriteString(s.token.AccessToken) - buf.Write(queryBodySep) - - buf.Write(data) - - return buf.Bytes() -} diff --git a/py/h2o_wave/server.py b/py/h2o_wave/server.py index c40053343b..04837d02fe 100644 --- a/py/h2o_wave/server.py +++ b/py/h2o_wave/server.py @@ -246,29 +246,33 @@ async def _unregister(self): logger.debug('Unregister: success!') async def _receive(self, req: Request): - b = await req.body() - return PlainTextResponse('', background=BackgroundTask(self._process, b.decode('utf-8'))) - - async def _process(self, query: str): - username, subject, client_id, access_token, refresh_token, args = _parse_query(query) - logger.debug(f'user: {username}, client: {client_id}') + client_id = req.headers.get('Wave-Client-ID') + subject = req.headers.get('Wave-Subject-ID') + username = req.headers.get('Wave-Username') + access_token = req.headers.get('Wave-Access-Token') + refresh_token = req.headers.get('Wave-Refresh-Token') + auth = Auth(username, subject, access_token, refresh_token) + args = await req.json() + return PlainTextResponse('', background=BackgroundTask(self._process, client_id, auth, args)) + + async def _process(self, client_id: str, auth: Auth, args: dict): + logger.debug(f'user: {auth.username}, client: {client_id}') logger.debug(args) app_state, user_state, client_state = self._state - args_state: dict = unmarshal(args) - events_state: Optional[dict] = args_state.get('', None) + events_state: Optional[dict] = args.get('', None) if isinstance(events_state, dict): events_state = {k: Expando(v) for k, v in events_state.items()} - del args_state[''] + del args[''] q = Q( site=self._site, mode=self._mode, - auth=Auth(username, subject, access_token, refresh_token), + auth=auth, client_id=client_id, route=self._route, app_state=app_state, - user_state=_session_for(user_state, username), + user_state=_session_for(user_state, auth.subject), client_state=_session_for(client_state, client_id), - args=Expando(args_state), + args=Expando(args), events=Expando(events_state), ) # noinspection PyBroadException,PyPep8 @@ -356,35 +360,6 @@ def _save_state(state: WebAppState): pickle.dump(checkpoint, p) -def _parse_query(query: str) -> Tuple[str, str, str, str, str, str]: - username = '' - subject = '' - client_id = '' - access_token = '' - refresh_token = '' - - # format: - # u:username\ns:subject\nc:client_id\na:access_token\nr:refresh_token\n\nbody - - head, body = query.split('\n\n', maxsplit=1) - for line in head.splitlines(): - kv = line.split(':', maxsplit=1) - if len(kv) == 2: - k, v = kv - if k == 'u': - username = v - elif k == 's': - subject = v - elif k == 'c': - client_id = v - elif k == 'a': - access_token = v - elif k == 'r': - refresh_token = v - - return username, subject, client_id, access_token, refresh_token, body - - class _Main: def __init__(self, app: Optional[_App] = None): self._app: Optional[_App] = app