From 774191567bd4af66f967362538751359823cecb4 Mon Sep 17 00:00:00 2001 From: Ben Krieger Date: Thu, 17 Oct 2024 14:56:15 -0400 Subject: [PATCH] Fix fdo.command FSIM not sending output, ignoring Transform, and broken test Signed-off-by: Ben Krieger --- README.md | 2 ++ examples/cmd/client.go | 2 +- examples/cmd/server.go | 13 +++++++++++++ fsim/command_device.go | 34 +++++++++++++++++++--------------- fsim/command_owner.go | 19 +++++++------------ fsim/fsim_test.go | 22 +++++----------------- 6 files changed, 47 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 00434b2..38bf582 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,8 @@ Client options: A dir to wget files into (FSIM disabled if empty) Server options: + -command-date + Use fdo.command FSIM to have device run "date --utc" -db string SQLite database file path -db-pass string diff --git a/examples/cmd/client.go b/examples/cmd/client.go index 424b200..c2c2a03 100644 --- a/examples/cmd/client.go +++ b/examples/cmd/client.go @@ -433,7 +433,7 @@ func transferOwnership2(transport fdo.Transport, to1d *cose.Sign1[protocol.To1d, Timeout: time.Second, Transform: func(cmd string, args []string) (string, []string) { return "sh", []string{"-c", - fmt.Sprintf("echo %q", strings.Join(args, " "))} + fmt.Sprintf("echo %q", strings.Join(append([]string{cmd}, args...), " "))} }, } } diff --git a/examples/cmd/server.go b/examples/cmd/server.go index d1696cd..2ea308d 100644 --- a/examples/cmd/server.go +++ b/examples/cmd/server.go @@ -61,6 +61,7 @@ var ( rvDelay int printOwnerPubKey string importVoucher string + cmdDate bool downloads stringList uploadDir string uploadReqs stringList @@ -94,6 +95,7 @@ func init() { serverFlags.IntVar(&rvDelay, "rv-delay", 0, "Delay TO1 by N `seconds`") serverFlags.StringVar(&printOwnerPubKey, "print-owner-public", "", "Print owner public key of `type` and exit") serverFlags.StringVar(&importVoucher, "import-voucher", "", "Import a PEM encoded voucher file at `path`") + serverFlags.BoolVar(&cmdDate, "command-date", false, "Use fdo.command FSIM to have device run \"date --utc\"") serverFlags.Var(&downloads, "download", "Use fdo.download FSIM for each `file` (flag may be used multiple times)") serverFlags.StringVar(&uploadDir, "upload-dir", "uploads", "The directory `path` to put file uploads") serverFlags.Var(&uploadReqs, "upload", "Use fdo.upload FSIM for each `file` (flag may be used multiple times)") @@ -569,6 +571,17 @@ func ownerModules(ctx context.Context, guid protocol.GUID, info string, chain [] } } } + + if cmdDate && slices.Contains(modules, "fdo.command") { + if !yield("fdo.command", &fsim.RunCommand{ + Command: "date", + Args: []string{"--utc"}, + Stdout: os.Stdout, + Stderr: os.Stderr, + }) { + return + } + } } } diff --git a/fsim/command_device.go b/fsim/command_device.go index 28e0623..e7e1aeb 100644 --- a/fsim/command_device.go +++ b/fsim/command_device.go @@ -4,7 +4,6 @@ package fsim import ( - "bufio" "bytes" "context" "fmt" @@ -41,8 +40,8 @@ type Command struct { // Internal state cmd *exec.Cmd - out *bufio.Reader - err *bufio.Reader + out *bytes.Buffer + err *bytes.Buffer errc chan error } @@ -68,6 +67,7 @@ func (c *Command) Receive(ctx context.Context, messageName string, messageBody i func (c *Command) receive(ctx context.Context, messageName string, messageBody io.Reader) error { switch messageName { case "command": + c.reset() return cbor.NewDecoder(messageBody).Decode(&c.arg0) case "args": @@ -130,12 +130,12 @@ func (c *Command) execute(ctx context.Context) error { if c.stdout { var buf bytes.Buffer c.cmd.Stdout = &buf - c.out = bufio.NewReader(&buf) + c.out = &buf } if c.stderr { var buf bytes.Buffer c.cmd.Stderr = &buf - c.err = bufio.NewReader(&buf) + c.err = &buf } if debugEnabled() { slog.Debug("fdo.command", "args", c.cmd.Args) @@ -160,6 +160,11 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri return nil } + // Check exited before writing any output to avoid race conditions where + // output is lost if process exits between writing stdout/stderr and the + // exited check + exited := c.cmd.ProcessState != nil + // Send any data on the stdout/stderr pipes if c.stdout { if err := cborEncodeBuffer(respond("stdout"), c.out); err != nil { @@ -173,7 +178,7 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri } // Continue if process is still running - if c.cmd.ProcessState == nil { + if !exited { return nil } @@ -191,14 +196,14 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri return cbor.NewEncoder(respond("exitcode")).Encode(code) } -func cborEncodeBuffer(w io.Writer, r *bufio.Reader) error { - n := r.Buffered() +func cborEncodeBuffer(w io.Writer, r *bytes.Buffer) error { + n := r.Len() if n == 0 { return nil } - buf, err := r.Peek(n) - if err != nil { + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { return fmt.Errorf("error reading from buffer: %w", err) } @@ -206,10 +211,6 @@ func cborEncodeBuffer(w io.Writer, r *bufio.Reader) error { return fmt.Errorf("error sending buffer: %w", err) } - if _, err := r.Discard(n); err != nil { - return fmt.Errorf("error reading from buffer: %w", err) - } - return nil } @@ -220,5 +221,8 @@ func (c *Command) reset() { } _ = c.cmd.Process.Kill() } - *c = Command{Timeout: c.Timeout} + *c = Command{ + Timeout: c.Timeout, + Transform: c.Transform, + } } diff --git a/fsim/command_owner.go b/fsim/command_owner.go index 4aa3a9e..1d09e77 100644 --- a/fsim/command_owner.go +++ b/fsim/command_owner.go @@ -31,11 +31,11 @@ type RunCommand struct { // If set, stdout will be requested from the device and written to this // writer - Stdout io.WriteCloser + Stdout io.Writer // If set, stderr will be requested from the device and written to this // writer - Stderr io.WriteCloser + Stderr io.Writer // If set, the exit code will be sent on this channel. It should be // buffered with a size of 1. @@ -57,6 +57,7 @@ var _ serviceinfo.OwnerModule = (*RunCommand)(nil) func (c *RunCommand) HandleInfo(ctx context.Context, messageName string, messageBody io.Reader) error { if err := c.handleInfo(ctx, messageName, messageBody); err != nil { c.cleanup() + return err } return nil } @@ -74,27 +75,27 @@ func (c *RunCommand) handleInfo(ctx context.Context, messageName string, message return nil case "stdout": - var buf cbor.Bstr[[]byte] + var buf []byte if err := cbor.NewDecoder(messageBody).Decode(&buf); err != nil { return fmt.Errorf("error decoding message %q: %w", messageName, err) } if c.Stdout == nil { return fmt.Errorf("stdout received but not requested") } - if _, err := c.Stdout.Write(buf.Val); err != nil { + if _, err := c.Stdout.Write(buf); err != nil { return fmt.Errorf("error writing stdout: %w", err) } return nil case "stderr": - var buf cbor.Bstr[[]byte] + var buf []byte if err := cbor.NewDecoder(messageBody).Decode(&buf); err != nil { return fmt.Errorf("error decoding message %q: %w", messageName, err) } if c.Stderr == nil { return fmt.Errorf("stderr received but not requested") } - if _, err := c.Stderr.Write(buf.Val); err != nil { + if _, err := c.Stderr.Write(buf); err != nil { return fmt.Errorf("error writing stderr: %w", err) } return nil @@ -220,12 +221,6 @@ func (c *RunCommand) sendArgsAndExecute(producer *serviceinfo.Producer) (moreInf } func (c *RunCommand) cleanup() { - if c.Stdout != nil { - _ = c.Stdout.Close() - } - if c.Stderr != nil { - _ = c.Stderr.Close() - } if c.ExitChan != nil { close(c.ExitChan) } diff --git a/fsim/fsim_test.go b/fsim/fsim_test.go index bc208cb..602feb9 100644 --- a/fsim/fsim_test.go +++ b/fsim/fsim_test.go @@ -263,22 +263,10 @@ func TestClientWithCommandModule(t *testing.T) { run := runData{exitChan: make(chan int, 1)} if !yield("fdo.command", &fsim.RunCommand{ - Command: "date", - Args: []string{"--utc"}, - Stdout: struct { - io.Writer - io.Closer - }{ - Writer: &run.outbuf, - Closer: io.NopCloser(nil), - }, - Stderr: struct { - io.Writer - io.Closer - }{ - Writer: &run.errbuf, - Closer: io.NopCloser(nil), - }, + Command: "date", + Args: []string{"--utc"}, + Stdout: &run.outbuf, + Stderr: &run.errbuf, ExitChan: run.exitChan, }) { return @@ -300,7 +288,7 @@ func TestClientWithCommandModule(t *testing.T) { default: t.Error("expected exit code on channel") } - if !strings.Contains(" UTC ", run.outbuf.String()) { + if !strings.Contains(run.outbuf.String(), " UTC ") { t.Errorf("expected stdout to include UTC, got\n%s", run.outbuf.String()) } if run.errbuf.Len() > 0 {