diff --git a/rapi/router.go b/rapi/router.go new file mode 100644 index 0000000..67ca215 --- /dev/null +++ b/rapi/router.go @@ -0,0 +1,39 @@ +package rapi + +import ( + "net/http" + + "github.com/go-chi/chi" +) + +type Router interface { + http.Handler + GetChiRouter() chi.Router + Route(pattern string, fn func(r Router)) Router + SetAuthMiddleware(middlewares ...func(http.Handler) http.Handler) + SetOptAuthMiddleware(middlewares ...func(http.Handler) http.Handler) + Use(middlewares ...func(http.Handler) http.Handler) + With(middlewares ...func(http.Handler) http.Handler) Router + Auth() Router + OptAuth() Router + Connect(pattern string, re RouterElement) + Delete(pattern string, re RouterElement) + Get(pattern string, re RouterElement) + Head(pattern string, re RouterElement) + Options(pattern string, re RouterElement) + Patch(pattern string, re RouterElement) + Post(pattern string, re RouterElement) + Put(pattern string, re RouterElement) + Trace(pattern string, re RouterElement) + // router のエンドポイントと input, output の型定義を出力する + GetRouterDefinition() ([]*RouterDefinition, map[string]*TypeStructure) +} + +type RouterDefinition struct { + InputTypeStructure *TypeStructure `json:"input_type_structure"` + OutputTypeStructure *TypeStructure `json:"output_type_structure"` + FullPathName string `json:"full_path_name"` + CurrentPathName string `json:"current_path_name"` + Method string `json:"method"` + WithAuth bool `json:"with_auth"` +} diff --git a/rapi/router_element.go b/rapi/router_element.go new file mode 100644 index 0000000..8752daf --- /dev/null +++ b/rapi/router_element.go @@ -0,0 +1,30 @@ +package rapi + +import ( + "context" + "net/http" +) + +type RouterElement interface { + GetHandleFunc() http.HandlerFunc + GetEmptyInput() any + GetEmptyOutput() any +} + +type HandlerMethod[I any] interface { + RouterElement + // 共通のリクエストパラメーター受け取り処理をセット + SetInputFunc(func(ctx context.Context, r *http.Request, param any) error) + // 共通のバリデーション処理をセット + SetValidateFunc(func(ctx context.Context, param any) error) + // エラーをレンダリングする処理をセット + SetHandleErrorFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error)) + // レスポンスをレンダリングする処理をセット + SetRenderFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request, output any)) + // エラーをレンダリングする直前にエラーを書き換える処理をセット + BeforeHandleError(func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) error) + // 共通リクエストパラメーター受け取り処理の後に必要な処理があればセット + AfterInput(func(ctx context.Context, r *http.Request, param *I) error) + // 共通のバリデーション処理の後に必要な処理があればセット + AfterValidate(func(ctx context.Context, param *I) error) +} diff --git a/rapi/router_element_impl.go b/rapi/router_element_impl.go new file mode 100644 index 0000000..41d101b --- /dev/null +++ b/rapi/router_element_impl.go @@ -0,0 +1,120 @@ +package rapi + +import ( + "context" + "net/http" +) + +func NewHandlerMethod[I, O any](f func(ctx context.Context, param *I) (*O, error)) HandlerMethod[I] { + return &handlerMethod[I, O]{ + ServiceFunc: f, + } +} + +type handlerMethod[I, O any] struct { + InputFunc func(ctx context.Context, r *http.Request, param any) error + AfterInputFunc func(ctx context.Context, r *http.Request, param *I) error + ValidateFunc func(ctx context.Context, param any) error + AfterValidateFunc func(ctx context.Context, param *I) error + ServiceFunc func(ctx context.Context, param *I) (*O, error) + BeforeHandleErrorFunc func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) error + HandleErrorFunc func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) + RenderFunc func(ctx context.Context, w http.ResponseWriter, r *http.Request, output any) +} + +// --- handlerMethod implements --- + +func (h *handlerMethod[I, O]) SetInputFunc(f func(ctx context.Context, r *http.Request, param any) error) { + h.InputFunc = f +} + +func (h *handlerMethod[I, O]) SetValidateFunc(f func(ctx context.Context, param any) error) { + h.ValidateFunc = f +} + +func (h *handlerMethod[I, O]) SetRenderFunc(f func(ctx context.Context, w http.ResponseWriter, r *http.Request, output any)) { + h.RenderFunc = f +} + +func (h *handlerMethod[I, O]) SetHandleErrorFunc(f func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error)) { + h.HandleErrorFunc = f +} + +func (h *handlerMethod[I, O]) BeforeHandleError(f func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) error) { + h.BeforeHandleErrorFunc = f +} + +func (h *handlerMethod[I, O]) AfterInput(f func(ctx context.Context, r *http.Request, param *I) error) { + h.AfterInputFunc = f +} + +func (h *handlerMethod[I, O]) AfterValidate(f func(ctx context.Context, param *I) error) { + h.AfterValidateFunc = f +} + +// --- RouterElement implements --- + +func (h *handlerMethod[I, O]) handleError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + if h.BeforeHandleErrorFunc != nil { + err = h.BeforeHandleErrorFunc(ctx, w, r, err) + } + if h.HandleErrorFunc != nil { + h.HandleErrorFunc(ctx, w, r, err) + } +} + +func (h *handlerMethod[I, O]) GetHandleFunc() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var param I + if h.InputFunc != nil { + if err := h.InputFunc(ctx, r, ¶m); err != nil { + h.handleError(ctx, w, r, err) + return + } + } + if h.AfterInputFunc != nil { + if err := h.AfterInputFunc(ctx, r, ¶m); err != nil { + h.handleError(ctx, w, r, err) + return + } + } + + if h.ValidateFunc != nil { + if err := h.ValidateFunc(ctx, ¶m); err != nil { + h.handleError(ctx, w, r, err) + return + } + } + + if h.AfterValidateFunc != nil { + if err := h.AfterValidateFunc(ctx, ¶m); err != nil { + h.handleError(ctx, w, r, err) + return + } + } + var output *O + var err error + if h.ServiceFunc != nil { + output, err = h.ServiceFunc(ctx, ¶m) + if err != nil { + h.handleError(ctx, w, r, err) + return + } + } + + if h.RenderFunc == nil { + panic("RenderFunc is required") + } + h.RenderFunc(ctx, w, r, output) + } +} + +func (h *handlerMethod[I, O]) GetEmptyInput() any { + return *new(I) +} + +func (h *handlerMethod[I, O]) GetEmptyOutput() any { + return *new(O) +} diff --git a/rapi/router_impl.go b/rapi/router_impl.go new file mode 100644 index 0000000..733e54e --- /dev/null +++ b/rapi/router_impl.go @@ -0,0 +1,196 @@ +package rapi + +import ( + "net/http" + + "github.com/go-chi/chi" +) + +func NewRouter() Router { + r := &router{ + chiRouter: chi.NewRouter(), + children: []*router{}, + authMiddlewares: chi.Middlewares{}, + optAuthMiddlewares: chi.Middlewares{}, + } + r.root = r + return r +} + +type router struct { + method string + path string + root *router + parent *router + withAuth bool + chiRouter chi.Router + element RouterElement + children []*router + authMiddlewares chi.Middlewares + optAuthMiddlewares chi.Middlewares +} + +func (r *router) sub() *router { + subRouter := &router{ + children: []*router{}, + root: r.root, + parent: r, + } + r.children = append(r.children, subRouter) + return subRouter +} + +func (r *router) handle(method string, pattern string, re RouterElement) { + subRouter := r.sub() + subRouter.chiRouter = r.chiRouter + subRouter.element = re + subRouter.path = pattern + subRouter.method = method + + switch method { + case http.MethodConnect: + subRouter.chiRouter.Connect(pattern, re.GetHandleFunc()) + case http.MethodDelete: + subRouter.chiRouter.Delete(pattern, re.GetHandleFunc()) + case http.MethodGet: + subRouter.chiRouter.Get(pattern, re.GetHandleFunc()) + case http.MethodHead: + subRouter.chiRouter.Head(pattern, re.GetHandleFunc()) + case http.MethodOptions: + subRouter.chiRouter.Options(pattern, re.GetHandleFunc()) + case http.MethodPatch: + subRouter.chiRouter.Patch(pattern, re.GetHandleFunc()) + case http.MethodPost: + subRouter.chiRouter.Post(pattern, re.GetHandleFunc()) + case http.MethodPut: + subRouter.chiRouter.Put(pattern, re.GetHandleFunc()) + case http.MethodTrace: + subRouter.chiRouter.Trace(pattern, re.GetHandleFunc()) + } +} + +func (r *router) GetChiRouter() chi.Router { + return r.chiRouter +} + +func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.chiRouter.ServeHTTP(w, req) +} + +func (r *router) Route(pattern string, fn func(r Router)) Router { + subRouter := r.sub() + subRouter.path = pattern + + if fn != nil { + r.chiRouter.Route(pattern, func(chiRouter chi.Router) { + subRouter.chiRouter = chiRouter + fn(subRouter) + }) + } else { + subRouter.chiRouter = r.chiRouter.Route(pattern, nil) + } + + return subRouter +} + +func (r *router) Use(middlewares ...func(http.Handler) http.Handler) { + r.chiRouter.Use(middlewares...) +} + +func (r *router) SetAuthMiddleware(middlewares ...func(http.Handler) http.Handler) { + r.root.authMiddlewares = append(r.authMiddlewares, middlewares...) +} + +func (r *router) SetOptAuthMiddleware(middlewares ...func(http.Handler) http.Handler) { + r.root.optAuthMiddlewares = append(r.optAuthMiddlewares, middlewares...) +} + +func (r *router) With(middlewares ...func(http.Handler) http.Handler) Router { + return r.with(middlewares...) +} + +func (r *router) with(middlewares ...func(http.Handler) http.Handler) *router { + subRouter := r.sub() + subRouter.chiRouter = r.chiRouter.With(middlewares...) + return subRouter +} + +func (r *router) Auth() Router { + subRouter := r.with(r.root.authMiddlewares...) + subRouter.withAuth = true + return subRouter +} + +func (r *router) OptAuth() Router { + subRouter := r.with(r.root.optAuthMiddlewares...) + subRouter.withAuth = true + return subRouter +} + +func (r *router) Connect(pattern string, re RouterElement) { + r.handle(http.MethodConnect, pattern, re) +} + +func (r *router) Delete(pattern string, re RouterElement) { + r.handle(http.MethodDelete, pattern, re) +} + +func (r *router) Get(pattern string, re RouterElement) { + r.handle(http.MethodGet, pattern, re) +} + +func (r *router) Head(pattern string, re RouterElement) { + r.handle(http.MethodHead, pattern, re) +} + +func (r *router) Options(pattern string, re RouterElement) { + r.handle(http.MethodOptions, pattern, re) +} + +func (r *router) Patch(pattern string, re RouterElement) { + r.handle(http.MethodPatch, pattern, re) +} + +func (r *router) Post(pattern string, re RouterElement) { + r.handle(http.MethodPost, pattern, re) +} + +func (r *router) Put(pattern string, re RouterElement) { + r.handle(http.MethodPut, pattern, re) +} + +func (r *router) Trace(pattern string, re RouterElement) { + r.handle(http.MethodTrace, pattern, re) +} + +// router のエンドポイントと input, output の型定義を出力する +func (r *router) GetRouterDefinition() ([]*RouterDefinition, map[string]*TypeStructure) { + ts := NewTypeScanner() + ts.DisableStructField() + ts.AddStructTagName("json", "form") + + routerDefinitions := []*RouterDefinition{} + + // 再帰で全てのRouter定義をappend + var appendRouterDefinition func(r *router, parentPath string) + appendRouterDefinition = func(r *router, parentPath string) { + if r.element != nil { + routerDefinition := &RouterDefinition{ + FullPathName: parentPath + r.path, + CurrentPathName: r.path, + Method: r.method, + WithAuth: r.withAuth, + InputTypeStructure: ts.Scan(r.element.GetEmptyInput()), + OutputTypeStructure: ts.Scan(r.element.GetEmptyOutput()), + } + routerDefinitions = append(routerDefinitions, routerDefinition) + } + + for _, child := range r.children { + appendRouterDefinition(child, parentPath+r.path) + } + } + + appendRouterDefinition(r.root, "") + return routerDefinitions, ts.Export() +} diff --git a/rapi/type_scanner.go b/rapi/type_scanner.go new file mode 100644 index 0000000..b672c4c --- /dev/null +++ b/rapi/type_scanner.go @@ -0,0 +1,45 @@ +package rapi + +// 型情報を読み取り、整理して、 JSON 化可能な形で出力する interface +type TypeScanner interface { + // any の型情報を出力する + Scan(value any) *TypeStructure + // 今まで scan したすべての型情報を slice で出力する + Export() map[string]*TypeStructure + ScanUnion(values []any) *UnionStructure + ExportUnion() map[string]*UnionStructure + EnableStructField() TypeScanner + DisableStructField() TypeScanner + AddStructTagName(tagName ...string) TypeScanner +} + +// 一つの型情報 +type TypeStructure struct { + Name string `json:"name"` + GoTypeName string `json:"go_type_name"` + Kind string `json:"kind"` + // map の key の型情報。map じゃない場合は nil + KeyType *TypeStructure `json:"key_type,omitempty"` + // map, slice, array の要素の型情報。それ以外は nil + ElemType *TypeStructure `json:"elem_type,omitempty"` + // struct の field の型情報。それ以外は nil + Fields map[string]*TypeStructure `json:"fields,omitempty"` +} + +type UnionStructure struct { + Name string `json:"name"` + GoTypeName string `json:"go_type_name"` + Kind string `json:"kind"` + Values []any `json:"values"` +} + +const ( + TypeKindString = "string" + TypeKindInt = "int" + TypeKindFloat = "float" + TypeKindBool = "bool" + TypeKindArray = "array" + TypeKindMap = "map" + TypeKindStruct = "struct" + TypeKindAny = "any" +) diff --git a/rapi/type_scanner_impl.go b/rapi/type_scanner_impl.go new file mode 100644 index 0000000..37662e7 --- /dev/null +++ b/rapi/type_scanner_impl.go @@ -0,0 +1,231 @@ +package rapi + +import ( + "fmt" + "reflect" + "strings" +) + +type typeScanner struct { + types map[string]*TypeStructure + unions map[string]*UnionStructure + structFieldEnabled bool + structTagNames []string + unnamedCount int +} + +func NewTypeScanner() TypeScanner { + return &typeScanner{ + types: map[string]*TypeStructure{}, + unions: map[string]*UnionStructure{}, + structFieldEnabled: true, + structTagNames: []string{}, + } +} + +func (t *typeScanner) EnableStructField() TypeScanner { + t.structFieldEnabled = true + return t +} + +func (t *typeScanner) DisableStructField() TypeScanner { + t.structFieldEnabled = false + return t +} + +func (t *typeScanner) AddStructTagName(tagName ...string) TypeScanner { + t.structTagNames = append(t.structTagNames, tagName...) + return t +} + +func (t *typeScanner) Scan(value any) *TypeStructure { + return t.scan(reflect.TypeOf(value), false) +} + +func (ts *TypeStructure) getFieldsRemovedStruct() *TypeStructure { + copied := &TypeStructure{} + copied.Name = ts.Name + copied.GoTypeName = ts.GoTypeName + copied.Kind = ts.Kind + return copied +} + +func (t *typeScanner) scan(rt reflect.Type, ignoreField bool) *TypeStructure { + if rt == nil { + return nil + } + + // pointer の場合は pointer じゃなくなるまで Elem + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + + var typeKind string + + switch rt.Kind() { + case reflect.String: + typeKind = TypeKindString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + typeKind = TypeKindInt + case reflect.Float32, reflect.Float64: + typeKind = TypeKindFloat + case reflect.Bool: + typeKind = TypeKindBool + case reflect.Map: + typeKind = TypeKindMap + case reflect.Slice, reflect.Array, reflect.Chan: + typeKind = TypeKindArray + case reflect.Struct: + typeKind = TypeKindStruct + case reflect.Interface: + typeKind = TypeKindAny + default: + panic(fmt.Sprintf("Invalid type: %s. (%s)", rt.Kind().String(), rt.String())) + } + + ts := &TypeStructure{ + Kind: typeKind, + } + + switch rt.Kind() { + // map + case reflect.Map: + ts.Name = rt.Name() + ts.GoTypeName = rt.String() + ts.KeyType = t.scan(rt.Key(), true) + ts.ElemType = t.scan(rt.Elem(), true) + + // array + case reflect.Slice, reflect.Array, reflect.Chan: + ts.Name = rt.Name() + ts.GoTypeName = rt.String() + ts.ElemType = t.scan(rt.Elem(), true) + + // struct + case reflect.Struct: + name := rt.Name() + + if name == "" { + name = "__unnamed__." + string(rune(t.unnamedCount)) + t.unnamedCount++ + } else { + name = rt.String() + } + + if v, ok := t.types[name]; ok { + if ignoreField { + return v.getFieldsRemovedStruct() + } + return v + } + + ts.Name = name + ts.GoTypeName = rt.String() + ts.Fields = map[string]*TypeStructure{} + + t.types[name] = ts + for i := 0; i < rt.NumField(); i++ { + keyName := "" + field := rt.Field(i) + + // embedded field の場合は、自身のフィールドとして処理する + if field.Anonymous { + fieldTs := t.scan(field.Type, false) + if fieldTs != nil { + for k, v := range fieldTs.Fields { + ts.Fields[k] = v + } + } + continue + } + + for _, tagName := range t.structTagNames { + tagValue := field.Tag.Get(tagName) + tagValue = strings.Split(tagValue, ",")[0] + if tagValue != "-" { + keyName = tagValue + } + if keyName != "" { + break + } + } + + if keyName == "" { + // struct tag による命名がなく、 structFieldEnabled が false の場合は、そのフィールドは存在しないものとして扱う + if !t.structFieldEnabled { + continue + } + // tag による命名がない場合は、フィールド名をそのまま使用する + keyName = field.Name + } + + fieldTs := t.scan(field.Type, true) + if fieldTs != nil { + ts.Fields[keyName] = fieldTs + } + } + + // json 化するときに再帰的に参照され続けないように fields を削除 + if ignoreField { + return ts.getFieldsRemovedStruct() + } + + // primitive or other + default: + ts.Name = rt.Name() + ts.GoTypeName = rt.String() + } + + return ts +} + +func (t *typeScanner) Export() map[string]*TypeStructure { + // types をコピーして返す + types := map[string]*TypeStructure{} + for k, v := range t.types { + types[k] = v + } + return types +} + +func (t *typeScanner) ScanUnion(values []any) *UnionStructure { + if len(values) == 0 { + return nil + } + rt := reflect.TypeOf(values[0]) + typeName := rt.Name() + if typeName == "" { + return nil + } + + var kind string + switch rt.Kind() { + case reflect.String: + kind = TypeKindString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + kind = TypeKindInt + case reflect.Float32, reflect.Float64: + kind = TypeKindFloat + default: + panic(fmt.Sprintf("Invalid type: %s. (%s)", rt.Kind().String(), rt.String())) + } + + us := &UnionStructure{ + Name: typeName, + GoTypeName: rt.String(), + Kind: kind, + Values: values, + } + + t.unions[us.GoTypeName] = us + return us +} + +func (t *typeScanner) ExportUnion() map[string]*UnionStructure { + // unions をコピーして返す + unions := map[string]*UnionStructure{} + for k, v := range t.unions { + unions[k] = v + } + return unions +} diff --git a/rapi/util.go b/rapi/util.go new file mode 100644 index 0000000..ca8e037 --- /dev/null +++ b/rapi/util.go @@ -0,0 +1,17 @@ +package rapi + +import ( + "net/http" + "reflect" + + "github.com/go-chi/chi" + "github.com/rabee-inc/go-pkg/util" +) + +func FillURLParam(r *http.Request, param any) { + _ = util.EachTaggedFields(param, "url", func(tagValue string, reflectParam reflect.Value, fieldNum int) error { + urlParam := chi.URLParam(r, tagValue) + reflectParam.Field(fieldNum).SetString(urlParam) + return nil + }) +} diff --git a/util/util.go b/util/util.go index 9b91583..20ec4e6 100644 --- a/util/util.go +++ b/util/util.go @@ -1,5 +1,7 @@ package util +import "reflect" + // AssignIfNotNil ... src が nil でない場合に dest に代入する。代入前の値を含むpointerを返す func AssignIfNotNil[T any](dest *T, src *T) *T { if src != nil { @@ -9,3 +11,21 @@ func AssignIfNotNil[T any](dest *T, src *T) *T { } return dest } + +// EachTaggedFields ... 指定のタグをがついてるフィールドをループする +func EachTaggedFields(param any, tagName string, callback func(tagValue string, reflectParam reflect.Value, fieldNum int) error) error { + val := reflect.Indirect(reflect.ValueOf(param)) + for i := 0; i < val.NumField(); i++ { + typeField := val.Type().Field(i) + + tag := typeField.Tag.Get(tagName) + if tag == "" { + continue + } + + if err := callback(tag, val, i); err != nil { + return err + } + } + return nil +}