diff --git a/handlers.go b/handlers.go index dcde69a..efb8647 100644 --- a/handlers.go +++ b/handlers.go @@ -123,18 +123,6 @@ func checkVerificationCode(path string, verifier *verificationRegistry, timeout /**********************************************************************/ -func checkValidRegistrationPath(path string) error { - info, err := os.Lstat(path) // distinguish between directories and symlinks to one. - if err != nil { - return fmt.Errorf("failed to stat; %w", err) - } else if info.Mode() & fs.ModeSymlink != 0 { - return errors.New("path cannot be a symbolic link to a directory") - } else if !info.IsDir() { - return errors.New("path should be a directory") - } - return nil -} - func newRegisterStartHandler(verifier *verificationRegistry) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if r.Body == nil { @@ -155,7 +143,7 @@ func newRegisterStartHandler(verifier *verificationRegistry) func(http.ResponseW return } - err = checkValidRegistrationPath(regpath) + err = verifyDirectory(regpath) if err != nil { dumpErrorResponse(w, http.StatusBadRequest, err.Error()) return @@ -195,7 +183,7 @@ func newRegisterFinishHandler(db *sql.DB, verifier *verificationRegistry, tokeni return } - err = checkValidRegistrationPath(regpath) + err = verifyDirectory(regpath) if err != nil { dumpHttpErrorResponse(w, err) return diff --git a/handlers_test.go b/handlers_test.go index f508a81..e047c5e 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -237,7 +237,7 @@ func TestRegisterHandlers(t *testing.T) { } }) - t.Run("register start failure", func(t *testing.T) { + t.Run("register start symlink", func(t *testing.T) { handler := http.HandlerFunc(newRegisterStartHandler(verifier)) tmp, err := os.MkdirTemp("", "") @@ -254,8 +254,8 @@ func TestRegisterHandlers(t *testing.T) { req := createJsonRequest("POST", "/register/start", map[string]interface{}{ "path": to_add2 }, t) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) - if rr.Code != http.StatusBadRequest { - t.Fatalf("registration of a symlink should have failed") + if rr.Code != http.StatusAccepted { + t.Fatalf("should have succeeded") } }) diff --git a/list.go b/list.go index f2a7c4f..426df12 100644 --- a/list.go +++ b/list.go @@ -7,26 +7,8 @@ import ( "path/filepath" "strings" "errors" - "net/http" ) -func verifyDirectory(dir string) error { - info, err := os.Stat(dir) - if errors.Is(err, os.ErrNotExist) { - return newHttpError(http.StatusNotFound, fmt.Errorf("path %q does not exist", dir)) - } - - if err != nil { - return fmt.Errorf("failed to check %q; %w", dir, err) - } - - if !info.IsDir() { - return newHttpError(http.StatusBadRequest, fmt.Errorf("%q is not a directory", dir)) - } - - return nil -} - func listFiles(dir string, recursive bool) ([]string, error) { err := verifyDirectory(dir) if err != nil { diff --git a/utils.go b/utils.go index 5c52b38..3dd387f 100644 --- a/utils.go +++ b/utils.go @@ -1,5 +1,12 @@ package main +import ( + "fmt" + "os" + "net/http" + "errors" +) + type httpError struct { Status int Reason error @@ -16,3 +23,20 @@ func (r *httpError) Unwrap() error { func newHttpError(status int, reason error) *httpError { return &httpError{ Status: status, Reason: reason } } + +func verifyDirectory(dir string) error { + info, err := os.Stat(dir) + if errors.Is(err, os.ErrNotExist) { + return newHttpError(http.StatusNotFound, fmt.Errorf("path %q does not exist", dir)) + } + + if err != nil { + return fmt.Errorf("failed to check %q; %w", dir, err) + } + + if !info.IsDir() { + return newHttpError(http.StatusBadRequest, fmt.Errorf("%q is not a directory", dir)) + } + + return nil +}