Skip to content

Commit

Permalink
Allow symlinked directories to be registered.
Browse files Browse the repository at this point in the history
We already allow symlinks in the parents of the registered directory, so we
might as well allow the symlinked directory itself to be registered.

Also centralized the logic for checking that a path is indeed a directory.
This allows us to get rid of another Lstat call.
  • Loading branch information
LTLA committed Oct 20, 2024
1 parent 31c608f commit 51dc679
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 35 deletions.
16 changes: 2 additions & 14 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,6 @@ func checkVerificationCode(path string, verifier *verificationRegistry, timeout

/**********************************************************************/

func checkValidRegistrationPath(path string) error {
info, err := os.Lstat(path) // distinguish between directories and symlinks to one.
if err != nil {
return fmt.Errorf("failed to stat; %w", err)
} else if info.Mode() & fs.ModeSymlink != 0 {
return errors.New("path cannot be a symbolic link to a directory")
} else if !info.IsDir() {
return errors.New("path should be a directory")
}
return nil
}

func newRegisterStartHandler(verifier *verificationRegistry) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Body == nil {
Expand All @@ -155,7 +143,7 @@ func newRegisterStartHandler(verifier *verificationRegistry) func(http.ResponseW
return
}

err = checkValidRegistrationPath(regpath)
err = verifyDirectory(regpath)
if err != nil {
dumpErrorResponse(w, http.StatusBadRequest, err.Error())
return
Expand Down Expand Up @@ -195,7 +183,7 @@ func newRegisterFinishHandler(db *sql.DB, verifier *verificationRegistry, tokeni
return
}

err = checkValidRegistrationPath(regpath)
err = verifyDirectory(regpath)
if err != nil {
dumpHttpErrorResponse(w, err)
return
Expand Down
6 changes: 3 additions & 3 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func TestRegisterHandlers(t *testing.T) {
}
})

t.Run("register start failure", func(t *testing.T) {
t.Run("register start symlink", func(t *testing.T) {
handler := http.HandlerFunc(newRegisterStartHandler(verifier))

tmp, err := os.MkdirTemp("", "")
Expand All @@ -254,8 +254,8 @@ func TestRegisterHandlers(t *testing.T) {
req := createJsonRequest("POST", "/register/start", map[string]interface{}{ "path": to_add2 }, t)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("registration of a symlink should have failed")
if rr.Code != http.StatusAccepted {
t.Fatalf("should have succeeded")
}
})

Expand Down
18 changes: 0 additions & 18 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,8 @@ import (
"path/filepath"
"strings"
"errors"
"net/http"
)

func verifyDirectory(dir string) error {
info, err := os.Stat(dir)
if errors.Is(err, os.ErrNotExist) {
return newHttpError(http.StatusNotFound, fmt.Errorf("path %q does not exist", dir))
}

if err != nil {
return fmt.Errorf("failed to check %q; %w", dir, err)
}

if !info.IsDir() {
return newHttpError(http.StatusBadRequest, fmt.Errorf("%q is not a directory", dir))
}

return nil
}

func listFiles(dir string, recursive bool) ([]string, error) {
err := verifyDirectory(dir)
if err != nil {
Expand Down
24 changes: 24 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package main

import (
"fmt"
"os"
"net/http"
"errors"
)

type httpError struct {
Status int
Reason error
Expand All @@ -16,3 +23,20 @@ func (r *httpError) Unwrap() error {
func newHttpError(status int, reason error) *httpError {
return &httpError{ Status: status, Reason: reason }
}

func verifyDirectory(dir string) error {
info, err := os.Stat(dir)
if errors.Is(err, os.ErrNotExist) {
return newHttpError(http.StatusNotFound, fmt.Errorf("path %q does not exist", dir))
}

if err != nil {
return fmt.Errorf("failed to check %q; %w", dir, err)
}

if !info.IsDir() {
return newHttpError(http.StatusBadRequest, fmt.Errorf("%q is not a directory", dir))
}

return nil
}

0 comments on commit 51dc679

Please sign in to comment.