Skip to content

Commit

Permalink
Fix fdo.command FSIM not sending output, ignoring Transform, and brok…
Browse files Browse the repository at this point in the history
…en test

Signed-off-by: Ben Krieger <[email protected]>
  • Loading branch information
ben-krieger committed Oct 18, 2024
1 parent fd239a9 commit 0c2b155
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 45 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ Client options:
A dir to wget files into (FSIM disabled if empty)

Server options:
-command-date
Use fdo.command FSIM to have device run "date --utc"
-db string
SQLite database file path
-db-pass string
Expand Down
2 changes: 1 addition & 1 deletion examples/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func transferOwnership2(transport fdo.Transport, to1d *cose.Sign1[protocol.To1d,
Timeout: time.Second,
Transform: func(cmd string, args []string) (string, []string) {
return "sh", []string{"-c",
fmt.Sprintf("echo %q", strings.Join(args, " "))}
fmt.Sprintf("echo %q", strings.Join(append([]string{cmd}, args...), " "))}
},
}
}
Expand Down
13 changes: 13 additions & 0 deletions examples/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ var (
rvDelay int
printOwnerPubKey string
importVoucher string
cmdDate bool
downloads stringList
uploadDir string
uploadReqs stringList
Expand Down Expand Up @@ -94,6 +95,7 @@ func init() {
serverFlags.IntVar(&rvDelay, "rv-delay", 0, "Delay TO1 by N `seconds`")
serverFlags.StringVar(&printOwnerPubKey, "print-owner-public", "", "Print owner public key of `type` and exit")
serverFlags.StringVar(&importVoucher, "import-voucher", "", "Import a PEM encoded voucher file at `path`")
serverFlags.BoolVar(&cmdDate, "command-date", false, "Use fdo.command FSIM to have device run \"date --utc\"")
serverFlags.Var(&downloads, "download", "Use fdo.download FSIM for each `file` (flag may be used multiple times)")
serverFlags.StringVar(&uploadDir, "upload-dir", "uploads", "The directory `path` to put file uploads")
serverFlags.Var(&uploadReqs, "upload", "Use fdo.upload FSIM for each `file` (flag may be used multiple times)")
Expand Down Expand Up @@ -569,6 +571,17 @@ func ownerModules(ctx context.Context, guid protocol.GUID, info string, chain []
}
}
}

if cmdDate && slices.Contains(modules, "fdo.command") {
if !yield("fdo.command", &fsim.RunCommand{
Command: "date",
Args: []string{"--utc"},
Stdout: os.Stdout,
Stderr: os.Stderr,
}) {
return
}
}
}
}

Expand Down
34 changes: 19 additions & 15 deletions fsim/command_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package fsim

import (
"bufio"
"bytes"
"context"
"fmt"
Expand Down Expand Up @@ -41,8 +40,8 @@ type Command struct {

// Internal state
cmd *exec.Cmd
out *bufio.Reader
err *bufio.Reader
out *bytes.Buffer
err *bytes.Buffer
errc chan error
}

Expand All @@ -68,6 +67,7 @@ func (c *Command) Receive(ctx context.Context, messageName string, messageBody i
func (c *Command) receive(ctx context.Context, messageName string, messageBody io.Reader) error {
switch messageName {
case "command":
c.reset()
return cbor.NewDecoder(messageBody).Decode(&c.arg0)

case "args":
Expand Down Expand Up @@ -130,12 +130,12 @@ func (c *Command) execute(ctx context.Context) error {
if c.stdout {
var buf bytes.Buffer
c.cmd.Stdout = &buf
c.out = bufio.NewReader(&buf)
c.out = &buf
}
if c.stderr {
var buf bytes.Buffer
c.cmd.Stderr = &buf
c.err = bufio.NewReader(&buf)
c.err = &buf
}
if debugEnabled() {
slog.Debug("fdo.command", "args", c.cmd.Args)
Expand All @@ -160,6 +160,11 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri
return nil
}

// Check exited before writing any output to avoid race conditions where
// output is lost if process exits between writing stdout/stderr and the
// exited check
exited := c.cmd.ProcessState != nil

// Send any data on the stdout/stderr pipes
if c.stdout {
if err := cborEncodeBuffer(respond("stdout"), c.out); err != nil {
Expand All @@ -173,7 +178,7 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri
}

// Continue if process is still running
if c.cmd.ProcessState == nil {
if !exited {
return nil
}

Expand All @@ -191,25 +196,21 @@ func (c *Command) Yield(ctx context.Context, respond func(message string) io.Wri
return cbor.NewEncoder(respond("exitcode")).Encode(code)
}

func cborEncodeBuffer(w io.Writer, r *bufio.Reader) error {
n := r.Buffered()
func cborEncodeBuffer(w io.Writer, r *bytes.Buffer) error {
n := r.Len()
if n == 0 {
return nil
}

buf, err := r.Peek(n)
if err != nil {
buf := make([]byte, n)
if _, err := r.Read(buf); err != nil {
return fmt.Errorf("error reading from buffer: %w", err)
}

if err := cbor.NewEncoder(w).Encode(buf); err != nil {
return fmt.Errorf("error sending buffer: %w", err)
}

if _, err := r.Discard(n); err != nil {
return fmt.Errorf("error reading from buffer: %w", err)
}

return nil
}

Expand All @@ -220,5 +221,8 @@ func (c *Command) reset() {
}
_ = c.cmd.Process.Kill()
}
*c = Command{Timeout: c.Timeout}
*c = Command{
Timeout: c.Timeout,
Transform: c.Transform,
}
}
19 changes: 7 additions & 12 deletions fsim/command_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ type RunCommand struct {

// If set, stdout will be requested from the device and written to this
// writer
Stdout io.WriteCloser
Stdout io.Writer

// If set, stderr will be requested from the device and written to this
// writer
Stderr io.WriteCloser
Stderr io.Writer

// If set, the exit code will be sent on this channel. It should be
// buffered with a size of 1.
Expand All @@ -57,6 +57,7 @@ var _ serviceinfo.OwnerModule = (*RunCommand)(nil)
func (c *RunCommand) HandleInfo(ctx context.Context, messageName string, messageBody io.Reader) error {
if err := c.handleInfo(ctx, messageName, messageBody); err != nil {
c.cleanup()
return err
}
return nil
}
Expand All @@ -74,27 +75,27 @@ func (c *RunCommand) handleInfo(ctx context.Context, messageName string, message
return nil

case "stdout":
var buf cbor.Bstr[[]byte]
var buf []byte
if err := cbor.NewDecoder(messageBody).Decode(&buf); err != nil {
return fmt.Errorf("error decoding message %q: %w", messageName, err)
}
if c.Stdout == nil {
return fmt.Errorf("stdout received but not requested")
}
if _, err := c.Stdout.Write(buf.Val); err != nil {
if _, err := c.Stdout.Write(buf); err != nil {
return fmt.Errorf("error writing stdout: %w", err)
}
return nil

case "stderr":
var buf cbor.Bstr[[]byte]
var buf []byte
if err := cbor.NewDecoder(messageBody).Decode(&buf); err != nil {
return fmt.Errorf("error decoding message %q: %w", messageName, err)
}
if c.Stderr == nil {
return fmt.Errorf("stderr received but not requested")
}
if _, err := c.Stderr.Write(buf.Val); err != nil {
if _, err := c.Stderr.Write(buf); err != nil {
return fmt.Errorf("error writing stderr: %w", err)
}
return nil
Expand Down Expand Up @@ -220,12 +221,6 @@ func (c *RunCommand) sendArgsAndExecute(producer *serviceinfo.Producer) (moreInf
}

func (c *RunCommand) cleanup() {
if c.Stdout != nil {
_ = c.Stdout.Close()
}
if c.Stderr != nil {
_ = c.Stderr.Close()
}
if c.ExitChan != nil {
close(c.ExitChan)
}
Expand Down
22 changes: 5 additions & 17 deletions fsim/fsim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,10 @@ func TestClientWithCommandModule(t *testing.T) {
run := runData{exitChan: make(chan int, 1)}

if !yield("fdo.command", &fsim.RunCommand{
Command: "date",
Args: []string{"--utc"},
Stdout: struct {
io.Writer
io.Closer
}{
Writer: &run.outbuf,
Closer: io.NopCloser(nil),
},
Stderr: struct {
io.Writer
io.Closer
}{
Writer: &run.errbuf,
Closer: io.NopCloser(nil),
},
Command: "date",
Args: []string{"--utc"},
Stdout: &run.outbuf,
Stderr: &run.errbuf,
ExitChan: run.exitChan,
}) {
return
Expand All @@ -300,7 +288,7 @@ func TestClientWithCommandModule(t *testing.T) {
default:
t.Error("expected exit code on channel")
}
if !strings.Contains(" UTC ", run.outbuf.String()) {
if !strings.Contains(run.outbuf.String(), " UTC ") {
t.Errorf("expected stdout to include UTC, got\n%s", run.outbuf.String())
}
if run.errbuf.Len() > 0 {
Expand Down

0 comments on commit 0c2b155

Please sign in to comment.