diff --git a/cmd/crproxy/cluster/auth/auth.go b/cmd/crproxy/cluster/auth/auth.go index 96985fa..c40c079 100644 --- a/cmd/crproxy/cluster/auth/auth.go +++ b/cmd/crproxy/cluster/auth/auth.go @@ -1,20 +1,25 @@ package auth import ( - "bytes" "context" - "encoding/json" + "crypto/rsa" + "database/sql" "fmt" "log/slog" "net/http" "net/url" "os" "sync/atomic" + "time" + + _ "github.com/go-sql-driver/mysql" "github.com/daocloud/crproxy/internal/pki" "github.com/daocloud/crproxy/internal/server" + "github.com/daocloud/crproxy/manager" "github.com/daocloud/crproxy/signing" "github.com/daocloud/crproxy/token" + "github.com/emicklei/go-restful/v3" "github.com/gorilla/handlers" "github.com/spf13/cobra" ) @@ -34,10 +39,11 @@ type flagpole struct { AllowAnonymous bool AnonymousRateLimitPerSecond uint64 + AnonymousNoAllowlist bool BlobsURLs []string - WebhookURL string + DBURL string } func NewCommand() *cobra.Command { @@ -66,40 +72,68 @@ func NewCommand() *cobra.Command { cmd.Flags().StringToStringVar(&flags.SimpleAuthUserpass, "simple-auth-userpass", flags.SimpleAuthUserpass, "Simple auth userpass") cmd.Flags().BoolVar(&flags.AllowAnonymous, "allow-anonymous", flags.AllowAnonymous, "Allow anonymous") + cmd.Flags().Uint64Var(&flags.AnonymousRateLimitPerSecond, "anonymous-rate-limit-per-second", flags.AnonymousRateLimitPerSecond, "Rate limit for anonymous users per second") cmd.Flags().StringSliceVar(&flags.BlobsURLs, "blobs-url", flags.BlobsURLs, "Blobs urls") - cmd.Flags().StringVar(&flags.WebhookURL, "webhook-url", flags.WebhookURL, "Webhook url") + cmd.Flags().StringVar(&flags.DBURL, "db-url", flags.DBURL, "Database URL") return cmd } func runE(ctx context.Context, flags *flagpole) error { - mux := http.NewServeMux() - logger := slog.New(slog.NewJSONHandler(os.Stderr, nil)) - privateKeyData, err := os.ReadFile(flags.TokenPrivateKeyFile) - if err != nil { - logger.Error("failed to ReadFile", "file", flags.TokenPrivateKeyFile, "error", err) - os.Exit(1) - } - privateKey, err := pki.DecodePrivateKey(privateKeyData) - if err != nil { - logger.Error("failed to DecodePrivateKey", "file", flags.TokenPrivateKeyFile, "error", err) - os.Exit(1) - } + var privateKey *rsa.PrivateKey + var err error + if flags.TokenPrivateKeyFile != "" { + privateKeyData, err := os.ReadFile(flags.TokenPrivateKeyFile) + if err != nil { + return fmt.Errorf("failed to read token private key file: %w", err) + } + privateKey, err = pki.DecodePrivateKey(privateKeyData) + if err != nil { + return fmt.Errorf("failed to decode private key: %w", err) + } + if flags.TokenPublicKeyFile != "" { + publicKeyData, err := pki.EncodePublicKey(&privateKey.PublicKey) + if err != nil { + return fmt.Errorf("failed to encode public key: %w", err) + } + + err = os.WriteFile(flags.TokenPublicKeyFile, publicKeyData, 0644) + if err != nil { + return fmt.Errorf("failed to write token public key file: %w", err) + } + } - if flags.TokenPublicKeyFile != "" { - publicKeyData, err := pki.EncodePublicKey(&privateKey.PublicKey) + } else { + privateKey, err = pki.GenerateKey() if err != nil { - return fmt.Errorf("failed to encode public key: %w", err) + return fmt.Errorf("failed to generate private key: %w", err) } + } + + container := restful.NewContainer() - err = os.WriteFile(flags.TokenPublicKeyFile, publicKeyData, 0644) + var mgr *manager.Manager + if flags.DBURL != "" { + dburl := flags.DBURL + db, err := sql.Open("mysql", dburl) if err != nil { - return fmt.Errorf("failed to write token public key file: %w", err) + return fmt.Errorf("failed to connect to database: %w", err) + } + defer db.Close() + + if err = db.Ping(); err != nil { + return fmt.Errorf("failed to ping database: %w", err) } + + mgr = manager.NewManager(privateKey, db, 1*time.Minute) + + mgr.Register(container) + + mgr.InitTable(ctx) } getHosts := getBlobsURLs(flags.BlobsURLs) @@ -116,55 +150,37 @@ func runE(ctx context.Context, flags *flagpole) error { } return t.Attribute, true } - if flags.SimpleAuthUserpass == nil { - return token.Attribute{}, false - } - pass, ok := flags.SimpleAuthUserpass[userinfo.Username()] - if !ok { - return token.Attribute{}, false - } - upass, ok := userinfo.Password() - if !ok { - return token.Attribute{}, false - } - if upass != pass { - return token.Attribute{}, false - } - if flags.WebhookURL != "" { - body, _ := json.Marshal(t) - wu, err := url.Parse(flags.WebhookURL) - if err != nil { - logger.Error("failed to parse webhook url", "url", flags.WebhookURL, "error", err) - return token.Attribute{}, false + var has bool + if flags.SimpleAuthUserpass != nil { + pass, ok := flags.SimpleAuthUserpass[userinfo.Username()] + if ok { + upass, ok := userinfo.Password() + if !ok { + return token.Attribute{}, false + } + if upass != pass { + return token.Attribute{}, false + } + t.NoRateLimit = true + t.NoAllowlist = true + t.NoBlock = true + t.AllowTagsList = true + has = true } - wu.User = userinfo + } - resp, err := http.Post(wu.String(), "application/json", bytes.NewBuffer(body)) - if err != nil { - logger.Error("failed to post webhook", "url", flags.WebhookURL, "error", err) + if !has { + if mgr == nil { return token.Attribute{}, false } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - err = json.NewDecoder(resp.Body).Decode(&t) - if err != nil { - logger.Error("failed to decode webhook response", "url", flags.WebhookURL, "error", err) - return token.Attribute{}, false - } - case http.StatusForbidden: - return token.Attribute{}, false - default: - logger.Error("failed to post webhook", "url", flags.WebhookURL, "status", resp.StatusCode) + attr, err := mgr.GetToken(r.Context(), userinfo, t) + if err != nil { + logger.Info("Failed to retrieve token", "user", userinfo, "err", err) return token.Attribute{}, false } - } else { - t.NoRateLimit = true - t.NoAllowlist = true - t.NoBlock = true - t.AllowTagsList = true + t.Attribute = attr } if !t.Block { @@ -177,9 +193,9 @@ func runE(ctx context.Context, flags *flagpole) error { } gen := token.NewGenerator(token.NewEncoder(signing.NewSigner(privateKey)), authFunc, logger) - mux.Handle("/auth/token", gen) + container.Handle("/auth/token", gen) - var handler http.Handler = mux + var handler http.Handler = container handler = handlers.LoggingHandler(os.Stderr, handler) if flags.Behind { handler = handlers.ProxyHeaders(handler) diff --git a/go.mod b/go.mod index 2425d2c..60c5fb1 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,9 @@ require ( github.com/denverdino/aliyungo v0.0.0 github.com/distribution/reference v0.6.0 github.com/docker/distribution v2.8.2+incompatible + github.com/emicklei/go-restful-openapi/v2 v2.11.0 + github.com/emicklei/go-restful/v3 v3.12.1 + github.com/go-sql-driver/mysql v1.8.1 github.com/google/go-containerregistry v0.20.2 github.com/gorilla/handlers v1.5.2 github.com/huaweicloud/huaweicloud-sdk-go-obs v3.24.6+incompatible @@ -15,6 +18,7 @@ require ( github.com/wzshiming/geario v0.0.0-20240308093553-a996e3817533 github.com/wzshiming/hostmatcher v0.0.3 github.com/wzshiming/httpseek v0.1.0 + github.com/wzshiming/swaggerui v0.0.0-20241218081300-1c57a69746ef golang.org/x/crypto v0.28.0 ) @@ -25,6 +29,7 @@ replace ( require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/azure-sdk-for-go v56.3.0+incompatible // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/Azure/go-autorest/autorest v0.11.24 // indirect @@ -43,6 +48,10 @@ require ( github.com/docker/go-metrics v0.0.1 // indirect github.com/docker/libtrust v0.0.0-20150114040149-fa567046d9b1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/jsonreference v0.20.0 // indirect + github.com/go-openapi/spec v0.20.9 // indirect + github.com/go-openapi/swag v0.19.15 // indirect github.com/gofrs/uuid v4.0.0+incompatible // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect @@ -54,7 +63,8 @@ require ( github.com/gorilla/mux v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/mailru/easyjson v0.7.6 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/opencontainers/image-spec v1.1.0-rc3 // indirect @@ -80,4 +90,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/grpc v1.65.0 // indirect google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index b156672..d443a49 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/Azure/azure-sdk-for-go v56.3.0+incompatible h1:DmhwMrUIvpeoTDiWRDtNHqelNUd3Og8JCkrLHQK795c= github.com/Azure/azure-sdk-for-go v56.3.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= @@ -61,6 +63,11 @@ github.com/docker/go-metrics v0.0.1 h1:AgB/0SvBxihN0X8OR4SjsblXkbMvalQ8cjmtKQ2rQ github.com/docker/go-metrics v0.0.1/go.mod h1:cG1hvH2utMXtqgqqYE9plW6lDxS3/5ayHzueweSI3Vw= github.com/docker/libtrust v0.0.0-20150114040149-fa567046d9b1 h1:ZClxb8laGDf5arXfYcAtECDFgAgHklGI8CxgjHnXKJ4= github.com/docker/libtrust v0.0.0-20150114040149-fa567046d9b1/go.mod h1:cyGadeNEkKy96OOhEzfZl+yxihPEzKnqJwvfuSUqbZE= +github.com/emicklei/go-restful-openapi/v2 v2.11.0 h1:Ur+yGxoOH/7KRmcj/UoMFqC3VeNc9VOe+/XidumxTvk= +github.com/emicklei/go-restful-openapi/v2 v2.11.0/go.mod h1:4CTuOXHFg3jkvCpnXN+Wkw5prVUnP8hIACssJTYorWo= +github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU= +github.com/emicklei/go-restful/v3 v3.12.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -73,6 +80,18 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonreference v0.20.0 h1:MYlu0sBgChmCfJxxUKZ8g1cPWFOB37YSZqewK7OKeyA= +github.com/go-openapi/jsonreference v0.20.0/go.mod h1:Ag74Ico3lPc+zR+qjn4XBUmXymS4zJbYVCZmcgkasdo= +github.com/go-openapi/spec v0.20.9 h1:xnlYNQAwKd2VQRRfwTEI0DcK+2cbuvI/0c7jx3gA8/8= +github.com/go-openapi/spec v0.20.9/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= +github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -136,16 +155,26 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= +github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= @@ -153,10 +182,12 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0-rc3 h1:fzg1mXZFj8YdPeNkRXMg+zb88BFV0Ys52cJydRwBkb8= @@ -206,6 +237,7 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -226,6 +258,8 @@ github.com/wzshiming/hostmatcher v0.0.3 h1:+JYAq6vUZXDEQ1Ipfdc/D7HmaIMngcc71fton github.com/wzshiming/hostmatcher v0.0.3/go.mod h1:F04RIvIWEvOIrIKOlQlMuR8vQMKAVf2YhpU6l31Wwz4= github.com/wzshiming/httpseek v0.1.0 h1:lEgL7EBELT/VV9UaTp+m3kw5Pe1KOUdY+IPnKkag6tI= github.com/wzshiming/httpseek v0.1.0/go.mod h1:YoZhlLIwNjTBDXIT8NpK5zRjOgZouRXPaBfjVXdqMMs= +github.com/wzshiming/swaggerui v0.0.0-20241218081300-1c57a69746ef h1:pco7HMr+x2elW3GSIkaQLa+pIAZVDMjdDYjSfW6Ydj4= +github.com/wzshiming/swaggerui v0.0.0-20241218081300-1c57a69746ef/go.mod h1:vHxbopIWgVhM4ItpXF43QkaGhmsXn7Q8Z0FjzHwLfcI= github.com/wzshiming/trie v0.3.1 h1:YpuoqmEQFJiW0mns/mM6Qk4kdWrXc8kc28/KR1vn0m8= github.com/wzshiming/trie v0.3.1/go.mod h1:c9thxXTh4KcGkejt4sUsO4c5GUmWpxeWzOJ7AZJaI+8= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -353,6 +387,8 @@ google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6h google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -362,6 +398,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= diff --git a/internal/slices/slices.go b/internal/slices/slices.go new file mode 100644 index 0000000..3ee3474 --- /dev/null +++ b/internal/slices/slices.go @@ -0,0 +1,10 @@ +package slices + +// Map returns a new slice containing the results of applying the given function +func Map[S ~[]T, T any, O any](s S, f func(T) O) []O { + out := make([]O, len(s)) + for i := range s { + out[i] = f(s[i]) + } + return out +} diff --git a/manager/controller/error.go b/manager/controller/error.go new file mode 100644 index 0000000..daf5469 --- /dev/null +++ b/manager/controller/error.go @@ -0,0 +1,6 @@ +package controller + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` +} diff --git a/manager/controller/jwt.go b/manager/controller/jwt.go new file mode 100644 index 0000000..840baf0 --- /dev/null +++ b/manager/controller/jwt.go @@ -0,0 +1,59 @@ +package controller + +import ( + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/daocloud/crproxy/signing" + "github.com/emicklei/go-restful/v3" +) + +type Session struct { + UserID int64 `json:"user_id"` +} + +func validJWT(key *rsa.PrivateKey, authHeader string) (Session, error) { + if !strings.HasPrefix(authHeader, "Bearer ") { + return Session{}, errors.New("invalid token") + } + + jwtToken := strings.Split(authHeader, " ") + if len(jwtToken) != 2 { + return Session{}, errors.New("invalid token format") + } + + data, err := signing.NewVerifier(&key.PublicKey).Verify(jwtToken[1]) + if err != nil { + return Session{}, fmt.Errorf("failed to decode signature: %w", err) + } + + var session Session + err = json.Unmarshal(data, &session) + if err != nil { + return Session{}, fmt.Errorf("failed to unmarshal session: %w", err) + } + return session, nil +} + +func generateJWT(key *rsa.PrivateKey, session Session) (string, error) { + data, err := json.Marshal(session) + if err != nil { + return "", err + } + + return signing.NewSigner(key).Sign(data) +} + +func unauthorizedResponse(resp *restful.Response) { + resp.AddHeader("WWW-Authenticate", `Bearer realm="/users/login"`) + resp.WriteHeader(http.StatusUnauthorized) +} + +func getSession(key *rsa.PrivateKey, req *restful.Request) (Session, error) { + authHeader := req.HeaderParameter("Authorization") + return validJWT(key, authHeader) +} diff --git a/manager/controller/token.go b/manager/controller/token.go new file mode 100644 index 0000000..d04fb74 --- /dev/null +++ b/manager/controller/token.go @@ -0,0 +1,175 @@ +package controller + +import ( + "crypto/rsa" + "net/http" + "strconv" + + "github.com/daocloud/crproxy/internal/slices" + "github.com/daocloud/crproxy/manager/model" + "github.com/daocloud/crproxy/manager/service" + "github.com/emicklei/go-restful/v3" +) + +type TokenRequest struct { + Account string `json:"account"` + Password string `json:"password"` + Data string `json:"data"` +} + +type TokenResponse struct { + TokenID int64 `json:"token_id"` +} + +type TokenDetailResponse struct { + TokenID int64 `json:"token_id"` + Account string `json:"account"` + Data string `json:"data"` +} + +type TokenController struct { + key *rsa.PrivateKey + tokenService *service.TokenService +} + +func NewTokenController(key *rsa.PrivateKey, tokenService *service.TokenService) *TokenController { + return &TokenController{key: key, tokenService: tokenService} +} + +func (tc *TokenController) RegisterRoutes(ws *restful.WebService) { + ws.Route(ws.POST("/tokens").To(tc.Create). + Doc("Create a new token for a user."). + Operation("createToken"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Reads(TokenRequest{}). + Writes(TokenResponse{}). + Returns(http.StatusCreated, "Token created successfully.", TokenResponse{}). + Returns(http.StatusBadRequest, "Invalid request format. Ensure that the token data is provided and is valid.", Error{})) + + ws.Route(ws.GET("/tokens").To(tc.List). + Doc("Retrieve all tokens by user."). + Operation("listToken"). + Produces(restful.MIME_JSON). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Writes([]TokenDetailResponse{}). + Returns(http.StatusOK, "Tokens found.", []TokenDetailResponse{}). + Returns(http.StatusNotFound, "No tokens found for the user.", Error{})) + + ws.Route(ws.GET("/tokens/{id}").To(tc.Get). + Doc("Retrieve a token by its ID."). + Operation("getToken"). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Produces(restful.MIME_JSON). + Param(ws.PathParameter("id", "Token ID").DataType("int64")). + Writes(TokenDetailResponse{}). + Returns(http.StatusOK, "Token found.", TokenDetailResponse{}). + Returns(http.StatusNotFound, "Token not found.", Error{})) + + ws.Route(ws.DELETE("/tokens/{id}").To(tc.Delete). + Doc("Delete a token by its ID."). + Operation("Token"). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Produces(restful.MIME_JSON). + Param(ws.PathParameter("id", "Token ID").DataType("int64")). + Returns(http.StatusNoContent, "Token deleted successfully.", nil). + Returns(http.StatusNotFound, "Token not found.", Error{})) +} + +func (tc *TokenController) Create(req *restful.Request, resp *restful.Response) { + var tokenRequest TokenRequest + if err := req.ReadEntity(&tokenRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "TokenRequestError", Message: "Failed to read token request: " + err.Error()}) + return + } + + session, err := getSession(tc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + tokenID, err := tc.tokenService.Create(req.Request.Context(), model.Token{ + UserID: session.UserID, + Account: tokenRequest.Account, + Password: tokenRequest.Password, + Data: tokenRequest.Data, + }) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "TokenCreationError", Message: "Failed to create token: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusCreated, TokenResponse{TokenID: tokenID}) +} + +func (tc *TokenController) List(req *restful.Request, resp *restful.Response) { + session, err := getSession(tc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + tokens, err := tc.tokenService.GetByUserID(req.Request.Context(), session.UserID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "TokensNotFoundError", Message: "No tokens found for the user: " + err.Error()}) + return + } + + resp.WriteEntity(slices.Map(tokens, func(token model.Token) TokenDetailResponse { + return TokenDetailResponse{ + TokenID: token.TokenID, + Account: token.Account, + Data: token.Data, + } + })) +} + +func (tc *TokenController) Get(req *restful.Request, resp *restful.Response) { + session, err := getSession(tc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + tokenIDStr := req.PathParameter("id") + tokenID, err := strconv.ParseInt(tokenIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidTokenIDError", Message: "Invalid token ID: " + err.Error()}) + return + } + + token, err := tc.tokenService.Get(req.Request.Context(), tokenID, session.UserID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "TokenNotFoundError", Message: "Token not found: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusOK, TokenDetailResponse{ + TokenID: token.TokenID, + Account: token.Account, + Data: token.Data, + }) +} + +func (tc *TokenController) Delete(req *restful.Request, resp *restful.Response) { + session, err := getSession(tc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + tokenIDStr := req.PathParameter("id") + tokenID, err := strconv.ParseInt(tokenIDStr, 10, 64) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "InvalidTokenIDError", Message: "Invalid token ID: " + err.Error()}) + return + } + + if err := tc.tokenService.Delete(req.Request.Context(), tokenID, session.UserID); err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "TokenNotFoundError", Message: "Token not found: " + err.Error()}) + return + } + resp.WriteHeader(http.StatusNoContent) +} diff --git a/manager/controller/user.go b/manager/controller/user.go new file mode 100644 index 0000000..1d4e35d --- /dev/null +++ b/manager/controller/user.go @@ -0,0 +1,173 @@ +package controller + +import ( + "crypto/rsa" + "net/http" + + "github.com/daocloud/crproxy/manager/service" + "github.com/emicklei/go-restful/v3" +) + +type UserRequest struct { + Nickname string `json:"nickname,omitempty"` + Account string `json:"account"` + Password string `json:"password"` +} + +type UserDetailResponse struct { + UserID int64 `json:"user_id"` + Nickname string `json:"nickname"` +} + +type UserLoginRequest struct { + Account string `json:"account"` + Password string `json:"password"` +} + +type UserLoginResponse struct { + Token string `json:"token"` +} + +type UpdateNicknameRequest struct { + Nickname string `json:"nickname"` +} + +type UserController struct { + key *rsa.PrivateKey + userService *service.UserService +} + +func NewUserController(key *rsa.PrivateKey, userService *service.UserService) *UserController { + return &UserController{key: key, userService: userService} +} + +func (uc *UserController) RegisterRoutes(ws *restful.WebService) { + ws.Route(ws.POST("/users").To(uc.Create). + Doc("Create a new user with account and password."). + Operation("createUser"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Reads(UserRequest{}). + Writes(UserDetailResponse{}). + Returns(http.StatusCreated, "User created successfully. Returns the created user's ID and nickname.", UserDetailResponse{}). + Returns(http.StatusBadRequest, "Invalid request format. Ensure that the nickname, account, and password are provided and are valid.", Error{})) + + ws.Route(ws.POST("/users/login").To(uc.GetUserLogin). + Doc("Retrieve a token by login account."). + Operation("userLogin"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Reads(UserLoginRequest{}). + Writes(UserLoginResponse{}). + Returns(http.StatusOK, "Token retrieved successfully.", UserLoginResponse{}). + Returns(http.StatusUnauthorized, "Invalid account or password.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format. Ensure that the account and password are provided.", Error{})) + + ws.Route(ws.GET("/users").To(uc.Get). + Doc("Retrieve a user."). + Operation("getUser"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Writes(UserDetailResponse{}). + Returns(http.StatusOK, "User found. Returns the user's ID and nickname.", UserDetailResponse{}). + Returns(http.StatusUnauthorized, "Unauthorized access. Please provide a valid token.", Error{}). + Returns(http.StatusNotFound, "User with the specified ID does not exist. Please check the ID and try again.", Error{})) + + ws.Route(ws.PUT("/users/nickname").To(uc.UpdateNickname). + Doc("Update the nickname of an existing user identified by their unique ID."). + Operation("updateNickname"). + Produces(restful.MIME_JSON). + Consumes(restful.MIME_JSON). + Param(ws.HeaderParameter("Authorization", "Bearer ")). + Reads(UpdateNicknameRequest{}). + Writes(UserDetailResponse{}). + Returns(http.StatusOK, "Nickname updated successfully. Returns the updated user's ID and new nickname.", UserDetailResponse{}). + Returns(http.StatusUnauthorized, "Unauthorized access. Please provide a valid token.", Error{}). + Returns(http.StatusNotFound, "User with the specified ID does not exist. Please check the ID and try again.", Error{}). + Returns(http.StatusBadRequest, "Invalid request format. Ensure that the new nickname is provided and is valid.", Error{})) +} + +func (uc *UserController) Create(req *restful.Request, resp *restful.Response) { + var userRequest UserRequest + if err := req.ReadEntity(&userRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "UserRequestError", Message: "Failed to read user request: " + err.Error()}) + return + } + + userID, err := uc.userService.Create(req.Request.Context(), userRequest.Nickname, userRequest.Account, userRequest.Password) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "UserCreationError", Message: "Failed to create user: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusCreated, UserDetailResponse{UserID: userID, Nickname: userRequest.Nickname}) +} + +func (uc *UserController) GetUserLogin(req *restful.Request, resp *restful.Response) { + var userRequest UserRequest + if err := req.ReadEntity(&userRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "UserRequestError", Message: "Failed to read login request: " + err.Error()}) + return + } + + login, err := uc.userService.GetLoginByAccount(req.Request.Context(), userRequest.Account) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusForbidden, Error{Code: "LoginNotFoundError", Message: "Login not found for the specified account: " + err.Error()}) + return + } + + if login.Password != userRequest.Password { + resp.WriteHeaderAndEntity(http.StatusForbidden, Error{Code: "InvalidCredentialsError", Message: "Invalid account or password"}) + return + } + + token, err := generateJWT(uc.key, Session{ + UserID: login.UserID, + }) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "TokenGenerationError", Message: "Failed to generate token: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusOK, UserLoginResponse{ + Token: token, + }) +} + +func (uc *UserController) Get(req *restful.Request, resp *restful.Response) { + session, err := getSession(uc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + user, err := uc.userService.GetByID(req.Request.Context(), session.UserID) + if err != nil { + resp.WriteHeaderAndEntity(http.StatusNotFound, Error{Code: "UserNotFoundError", Message: "User with the specified ID does not exist: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusOK, UserDetailResponse{UserID: user.UserID, Nickname: user.Nickname}) +} + +func (uc *UserController) UpdateNickname(req *restful.Request, resp *restful.Response) { + session, err := getSession(uc.key, req) + if err != nil { + unauthorizedResponse(resp) + return + } + + var updateRequest UpdateNicknameRequest + if err := req.ReadEntity(&updateRequest); err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, Error{Code: "NicknameUpdateError", Message: "Failed to read nickname update request: " + err.Error()}) + return + } + + if err := uc.userService.UpdateNickname(req.Request.Context(), session.UserID, updateRequest.Nickname); err != nil { + resp.WriteHeaderAndEntity(http.StatusInternalServerError, Error{Code: "NicknameUpdateError", Message: "Failed to update nickname: " + err.Error()}) + return + } + + resp.WriteHeaderAndEntity(http.StatusOK, UserDetailResponse{UserID: session.UserID, Nickname: updateRequest.Nickname}) +} diff --git a/manager/dao/db.go b/manager/dao/db.go new file mode 100644 index 0000000..4386622 --- /dev/null +++ b/manager/dao/db.go @@ -0,0 +1,37 @@ +package dao + +import ( + "context" + "database/sql" +) + +type dbCtxKey struct{} + +// contextKey is a key type for storing database context values. +var contextKey = dbCtxKey{} + +// WithDB returns a new context with the given database connection. +func WithDB(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, contextKey, db) +} + +// GetDB retrieves the database connection from the context. +func GetDB(ctx context.Context) DB { + db := ctx.Value(contextKey) + if db == nil { + return nil + } + d, _ := db.(DB) + return d +} + +var ( + _ DB = (*sql.Tx)(nil) + _ DB = (*sql.DB)(nil) +) + +type DB interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} diff --git a/manager/dao/login.go b/manager/dao/login.go new file mode 100644 index 0000000..5e1c630 --- /dev/null +++ b/manager/dao/login.go @@ -0,0 +1,97 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + + "github.com/daocloud/crproxy/manager/model" +) + +type Login struct{} + +func NewLogin() *Login { + return &Login{} +} + +const loginTableSQL = ` +CREATE TABLE IF NOT EXISTS logins ( + id SERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + type VARCHAR(50) NOT NULL, + account VARCHAR(255) NOT NULL, + password VARCHAR(255) NOT NULL, + create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + delete_at TIMESTAMP +); +` + +func (l *Login) InitTable(ctx context.Context) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, loginTableSQL) + if err != nil { + return fmt.Errorf("failed to create logins table: %w", err) + } + return nil +} + +const createLoginSQL = ` +INSERT INTO logins (user_id, type, account, password) VALUES (?, ?, ?, ?) +` + +func (l *Login) Create(ctx context.Context, login model.Login) (int64, error) { + db := GetDB(ctx) + result, err := db.ExecContext(ctx, createLoginSQL, login.UserID, login.Type, login.Account, login.Password) + if err != nil { + return 0, fmt.Errorf("failed to create login: %w", err) + } + + return result.LastInsertId() +} + +const getLoginByIDSQL = ` +SELECT id, user_id, type, account, password FROM logins WHERE id = ? AND delete_at IS NULL +` + +func (l *Login) GetByID(ctx context.Context, id int64) (model.Login, error) { + db := GetDB(ctx) + var login model.Login + err := db.QueryRowContext(ctx, getLoginByIDSQL, id).Scan(&login.LoginID, &login.UserID, &login.Type, &login.Account, &login.Password) + if err != nil { + if err == sql.ErrNoRows { + return model.Login{}, fmt.Errorf("login not found: %w", err) + } + return model.Login{}, fmt.Errorf("failed to get login: %w", err) + } + return login, nil +} + +const getLoginByAccountSQL = ` +SELECT id, user_id, type, account, password FROM logins WHERE account = ? AND delete_at IS NULL +` + +func (l *Login) GetByAccount(ctx context.Context, account string) (model.Login, error) { + db := GetDB(ctx) + var login model.Login + err := db.QueryRowContext(ctx, getLoginByAccountSQL, account).Scan(&login.LoginID, &login.UserID, &login.Type, &login.Account, &login.Password) + if err != nil { + if err == sql.ErrNoRows { + return model.Login{}, fmt.Errorf("login not found for account %s: %w", account, err) + } + return model.Login{}, fmt.Errorf("failed to get login by account: %w", err) + } + return login, nil +} + +const deleteLoginByID = ` +UPDATE logins SET delete_at = NOW() WHERE id = ? +` + +func (l *Login) DeleteByID(ctx context.Context, id int64) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, deleteLoginByID, id) + if err != nil { + return fmt.Errorf("failed to delete login: %w", err) + } + return nil +} diff --git a/manager/dao/token.go b/manager/dao/token.go new file mode 100644 index 0000000..abbcebe --- /dev/null +++ b/manager/dao/token.go @@ -0,0 +1,125 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + + "github.com/daocloud/crproxy/manager/model" +) + +type Token struct{} + +func NewToken() *Token { + return &Token{} +} + +const tokenTableSQL = ` +CREATE TABLE IF NOT EXISTS tokens ( + id SERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + account VARCHAR(255) NOT NULL, + password VARCHAR(255) NOT NULL, + data JSON NOT NULL, + create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + delete_at TIMESTAMP +); +` + +func (t *Token) InitTable(ctx context.Context) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, tokenTableSQL) + if err != nil { + return fmt.Errorf("failed to create tokens table: %w", err) + } + return nil +} + +const createTokenSQL = ` +INSERT INTO tokens (user_id, account, password, data) VALUES (?, ?, ?, ?) +` + +func (t *Token) Create(ctx context.Context, token model.Token) (int64, error) { + db := GetDB(ctx) + result, err := db.ExecContext(ctx, createTokenSQL, token.UserID, token.Account, token.Password, token.Data) + if err != nil { + return 0, fmt.Errorf("failed to create token: %w", err) + } + + return result.LastInsertId() +} + +const getTokenByIDSQL = ` +SELECT id, user_id, account, data FROM tokens WHERE id = ? AND user_id = ? AND delete_at IS NULL +` + +func (t *Token) GetByID(ctx context.Context, tokenID, userID int64) (model.Token, error) { + db := GetDB(ctx) + var token model.Token + err := db.QueryRowContext(ctx, getTokenByIDSQL, tokenID, userID).Scan(&token.TokenID, &token.UserID, &token.Account, &token.Data) + if err != nil { + if err == sql.ErrNoRows { + return model.Token{}, fmt.Errorf("token not found: %w", err) + } + return model.Token{}, fmt.Errorf("failed to get token: %w", err) + } + return token, nil +} + +const getTokensByUserIDSQL = ` +SELECT id, user_id, account, data FROM tokens WHERE user_id = ? AND delete_at IS NULL +` + +func (t *Token) GetByUserID(ctx context.Context, userID int64) ([]model.Token, error) { + db := GetDB(ctx) + rows, err := db.QueryContext(ctx, getTokensByUserIDSQL, userID) + if err != nil { + return nil, fmt.Errorf("failed to get tokens by user ID: %w", err) + } + defer rows.Close() + + var tokens []model.Token + for rows.Next() { + var token model.Token + if err := rows.Scan(&token.TokenID, &token.UserID, &token.Account, &token.Data); err != nil { + return nil, fmt.Errorf("failed to scan token: %w", err) + } + tokens = append(tokens, token) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error occurred during rows iteration: %w", err) + } + + return tokens, nil +} + +const getTokenSQL = ` +SELECT id, user_id, account, data FROM tokens WHERE account = ? AND password = ? AND delete_at IS NULL +` + +func (t *Token) GetByAccount(ctx context.Context, account, password string) (model.Token, error) { + db := GetDB(ctx) + var token model.Token + err := db.QueryRowContext(ctx, getTokenSQL, account, password).Scan(&token.TokenID, &token.UserID, &token.Account, &token.Data) + if err != nil { + if err == sql.ErrNoRows { + return model.Token{}, fmt.Errorf("token not found: %w", err) + } + return model.Token{}, fmt.Errorf("failed to get token: %w", err) + } + return token, nil +} + +const deleteTokenByIDSQL = ` +UPDATE tokens SET delete_at = NOW(), password = NULL WHERE id = ? AND user_id = ? AND delete_at IS NULL +` + +func (t *Token) DeleteByID(ctx context.Context, tokenID, userID int64) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, deleteTokenByIDSQL, tokenID, userID) + if err != nil { + return fmt.Errorf("failed to delete token: %w", err) + } + return nil +} diff --git a/manager/dao/user.go b/manager/dao/user.go new file mode 100644 index 0000000..77d0b86 --- /dev/null +++ b/manager/dao/user.go @@ -0,0 +1,79 @@ +package dao + +import ( + "context" + "database/sql" + "fmt" + + "github.com/daocloud/crproxy/manager/model" +) + +type User struct{} + +func NewUser() *User { + return &User{} +} + +const userTableSQL = ` +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + nickname VARCHAR(255) NOT NULL, + create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + delete_at TIMESTAMP +); +` + +func (c *User) InitTable(ctx context.Context) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, userTableSQL) + if err != nil { + return fmt.Errorf("failed to create users table: %w", err) + } + return nil +} + +const createSQL = ` +INSERT INTO users (nickname) VALUES (?) +` + +func (c *User) Create(ctx context.Context, u model.User) (int64, error) { + db := GetDB(ctx) + result, err := db.ExecContext(ctx, createSQL, u.Nickname) + if err != nil { + return 0, fmt.Errorf("failed to create user: %w", err) + } + + return result.LastInsertId() +} + +const getUserSQL = ` +SELECT id, nickname FROM users WHERE id = ? +` + +func (c *User) GetByID(ctx context.Context, id int64) (model.User, error) { + db := GetDB(ctx) + var u model.User + + err := db.QueryRowContext(ctx, getUserSQL, id).Scan(&u.UserID, &u.Nickname) + if err != nil { + if err == sql.ErrNoRows { + return model.User{}, fmt.Errorf("user not found: %w", err) + } + return model.User{}, fmt.Errorf("failed to get user: %w", err) + } + return u, nil +} + +const updateNicknameSQL = ` +UPDATE users SET nickname = ? WHERE id = ? +` + +func (c *User) UpdateNickname(ctx context.Context, id int64, nickname string) error { + db := GetDB(ctx) + _, err := db.ExecContext(ctx, updateNicknameSQL, nickname, id) + if err != nil { + return fmt.Errorf("failed to update nickname: %w", err) + } + return nil +} diff --git a/manager/manager.go b/manager/manager.go new file mode 100644 index 0000000..02dacea --- /dev/null +++ b/manager/manager.go @@ -0,0 +1,128 @@ +package manager + +import ( + "context" + "crypto/rsa" + "database/sql" + "encoding/json" + "net/http" + "net/url" + "sync" + "time" + + "github.com/daocloud/crproxy/manager/controller" + "github.com/daocloud/crproxy/manager/dao" + "github.com/daocloud/crproxy/manager/service" + "github.com/daocloud/crproxy/token" + restfulspec "github.com/emicklei/go-restful-openapi/v2" + "github.com/emicklei/go-restful/v3" + "github.com/wzshiming/swaggerui" +) + +type Manager struct { + key *rsa.PrivateKey + db *sql.DB + + UserDAO *dao.User + LoginDAO *dao.Login + TokenDAO *dao.Token + + UserService *service.UserService + UserController *controller.UserController + TokenService *service.TokenService + TokenController *controller.TokenController + + tokenCache map[string]tokenTTL + cacheMutex sync.RWMutex + cacheTTL time.Duration +} + +func NewManager(key *rsa.PrivateKey, db *sql.DB, cacheTTL time.Duration) *Manager { + m := &Manager{ + key: key, + db: db, + tokenCache: map[string]tokenTTL{}, + cacheTTL: cacheTTL, + } + return m +} + +func (m *Manager) InitTable(ctx context.Context) { + ctx = dao.WithDB(ctx, m.db) + m.UserDAO.InitTable(ctx) + m.LoginDAO.InitTable(ctx) + m.TokenDAO.InitTable(ctx) +} + +func (m *Manager) Register(container *restful.Container) { + m.UserDAO = dao.NewUser() + m.LoginDAO = dao.NewLogin() + m.TokenDAO = dao.NewToken() + + m.UserService = service.NewUserService(m.db, m.UserDAO, m.LoginDAO) + m.UserController = controller.NewUserController(m.key, m.UserService) + m.TokenService = service.NewTokenService(m.db, m.TokenDAO) + m.TokenController = controller.NewTokenController(m.key, m.TokenService) + + ws := new(restful.WebService) + m.UserController.RegisterRoutes(ws) + m.TokenController.RegisterRoutes(ws) + + container.Add(ws) + + config := restfulspec.Config{ + WebServices: []*restful.WebService{ws}, + APIPath: "/swagger.json", + } + + container.Add(restfulspec.NewOpenAPIService(config)) + + container.Handle("/swaggerui/", http.FileServerFS(swaggerui.FS)) +} + +func (m *Manager) GetToken(ctx context.Context, userinfo *url.Userinfo, t *token.Token) (token.Attribute, error) { + pwd, _ := userinfo.Password() + username := userinfo.Username() + + // Check cache first + m.cacheMutex.RLock() + cached, found := m.tokenCache[username] + m.cacheMutex.RUnlock() + + if found && time.Since(cached.last) < m.cacheTTL { + return cached.attr, cached.err + } + + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + cached, found = m.tokenCache[username] + if found && time.Since(cached.last) < m.cacheTTL { + return cached.attr, cached.err + } + + tt, err := m.TokenService.GetByAccount(ctx, username, pwd) + if err != nil { + m.tokenCache[username] = tokenTTL{err: err, last: time.Now()} + return token.Attribute{}, err + } + + var attr token.Attribute + err = json.Unmarshal([]byte(tt.Data), &attr) + if err != nil { + m.tokenCache[username] = tokenTTL{err: err, last: time.Now()} + return token.Attribute{}, err + } + + attr.UserID = tt.UserID + attr.TokenID = tt.TokenID + + m.tokenCache[username] = tokenTTL{attr: attr, last: time.Now()} + + return attr, nil +} + +type tokenTTL struct { + err error + attr token.Attribute + last time.Time +} diff --git a/manager/model/login.go b/manager/model/login.go new file mode 100644 index 0000000..d08688e --- /dev/null +++ b/manager/model/login.go @@ -0,0 +1,10 @@ +package model + +type Login struct { + LoginID int64 + UserID int64 + + Type string + Account string + Password string +} diff --git a/manager/model/token.go b/manager/model/token.go new file mode 100644 index 0000000..c3d48ee --- /dev/null +++ b/manager/model/token.go @@ -0,0 +1,10 @@ +package model + +type Token struct { + TokenID int64 + UserID int64 + + Account string + Password string + Data string +} diff --git a/manager/model/user.go b/manager/model/user.go new file mode 100644 index 0000000..c26860e --- /dev/null +++ b/manager/model/user.go @@ -0,0 +1,6 @@ +package model + +type User struct { + UserID int64 + Nickname string +} diff --git a/manager/service/token.go b/manager/service/token.go new file mode 100644 index 0000000..f0fdf1a --- /dev/null +++ b/manager/service/token.go @@ -0,0 +1,46 @@ +package service + +import ( + "context" + "database/sql" + + "github.com/daocloud/crproxy/manager/dao" + "github.com/daocloud/crproxy/manager/model" +) + +type TokenService struct { + db *sql.DB + tokenDao *dao.Token +} + +func NewTokenService(db *sql.DB, tokenDao *dao.Token) *TokenService { + return &TokenService{ + db: db, + tokenDao: tokenDao, + } +} + +func (s *TokenService) Create(ctx context.Context, token model.Token) (int64, error) { + ctx = dao.WithDB(ctx, s.db) + return s.tokenDao.Create(ctx, token) +} + +func (s *TokenService) GetByAccount(ctx context.Context, account, password string) (model.Token, error) { + ctx = dao.WithDB(ctx, s.db) + return s.tokenDao.GetByAccount(ctx, account, password) +} + +func (s *TokenService) Get(ctx context.Context, tokenID, userID int64) (model.Token, error) { + ctx = dao.WithDB(ctx, s.db) + return s.tokenDao.GetByID(ctx, tokenID, userID) +} + +func (s *TokenService) Delete(ctx context.Context, tokenID, userID int64) error { + ctx = dao.WithDB(ctx, s.db) + return s.tokenDao.DeleteByID(ctx, tokenID, userID) +} + +func (s *TokenService) GetByUserID(ctx context.Context, userID int64) ([]model.Token, error) { + ctx = dao.WithDB(ctx, s.db) + return s.tokenDao.GetByUserID(ctx, userID) +} diff --git a/manager/service/user.go b/manager/service/user.go new file mode 100644 index 0000000..1982d2d --- /dev/null +++ b/manager/service/user.go @@ -0,0 +1,83 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/daocloud/crproxy/manager/dao" + "github.com/daocloud/crproxy/manager/model" +) + +type UserService struct { + db *sql.DB + userDao *dao.User + loginDao *dao.Login +} + +func NewUserService(db *sql.DB, userDao *dao.User, loginDao *dao.Login) *UserService { + return &UserService{ + db: db, + userDao: userDao, + loginDao: loginDao, + } +} + +func (s *UserService) Create(ctx context.Context, nickname, account, password string) (int64, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + + ctx = dao.WithDB(ctx, tx) + + _, err = s.loginDao.GetByAccount(ctx, account) + if err == nil { + return 0, fmt.Errorf("account already exists") + } + + if !errors.Is(err, sql.ErrNoRows) { + return 0, fmt.Errorf("failed to check account: %w", err) + } + + user := model.User{Nickname: nickname} + userID, err := s.userDao.Create(ctx, user) + if err != nil { + tx.Rollback() + return 0, err + } + + login := model.Login{ + UserID: userID, + Account: account, + Password: password, + } + _, err = s.loginDao.Create(ctx, login) + if err != nil { + tx.Rollback() + return 0, err + } + + err = tx.Commit() + if err != nil { + return 0, err + } + + return userID, nil +} + +func (s *UserService) GetByID(ctx context.Context, id int64) (model.User, error) { + ctx = dao.WithDB(ctx, s.db) + return s.userDao.GetByID(ctx, id) +} + +func (s *UserService) GetLoginByAccount(ctx context.Context, account string) (model.Login, error) { + ctx = dao.WithDB(ctx, s.db) + return s.loginDao.GetByAccount(ctx, account) +} + +func (s *UserService) UpdateNickname(ctx context.Context, id int64, nickname string) error { + ctx = dao.WithDB(ctx, s.db) + return s.userDao.UpdateNickname(ctx, id, nickname) +} diff --git a/token/encoding.go b/token/encoding.go index f61baa2..60f2d4d 100644 --- a/token/encoding.go +++ b/token/encoding.go @@ -41,6 +41,9 @@ type Token struct { } type Attribute struct { + UserID int64 `json:"user_id,omitempty"` + TokenID int64 `json:"token_id,omitempty"` + NoRateLimit bool `json:"no_rate_limit,omitempty"` RateLimitPerSecond uint64 `json:"rate_limit_per_second,omitempty"`