Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mmrath committed May 30, 2020
1 parent c269bfa commit 28dd33a
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 3 deletions.
9 changes: 9 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"time"
)

Expand Down Expand Up @@ -92,6 +93,14 @@ type DbMap struct {
tables []*TableMap
logger GorpLogger
logPrefix string
lock sync.RWMutex

Cache Cache
}

type Cache interface {
Load(key interface{}) (value interface{}, ok bool)
Store(key, value interface{})
}

func (m *DbMap) WithContext(ctx context.Context) SqlExecutor {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/mmrath/gorp

go 1.13
go 1.14

require (
github.com/go-sql-driver/mysql v1.4.1
Expand Down
42 changes: 40 additions & 2 deletions gorp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"reflect"
"regexp"
"strings"
"sync"
"time"
)

Expand Down Expand Up @@ -247,7 +248,37 @@ func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect
}), args
}

type fieldCacheKey struct {
t reflect.Type
name string
cols string
}

type fieldCacheEntry struct {
mapping [][]int
err error
}

func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) {
var ck fieldCacheKey
var err error
if m.Cache != nil {
ck.t = t
ck.cols = strings.Join(cols, ",")

rv, ok := m.Cache.Load(ck)
if ok {
entry := rv.(*fieldCacheEntry)
return entry.mapping, entry.err
}
} else {
m.lock.Lock()
if m.Cache == nil {
m.Cache = &sync.Map{}
}
m.lock.Unlock()
}

colToFieldIndex := make([][]int, len(cols))

// check if type t is a mapped table - if so we'll
Expand Down Expand Up @@ -289,13 +320,20 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error
missingColNames = append(missingColNames, colName)
}
}

if len(missingColNames) > 0 {
return colToFieldIndex, &NoFieldInTypeError{
err = &NoFieldInTypeError{
TypeName: t.Name(),
MissingColNames: missingColNames,
}
}
return colToFieldIndex, nil
entry := &fieldCacheEntry{
mapping: colToFieldIndex,
err: err,
}
m.Cache.Store(ck, entry)

return colToFieldIndex, err
}

// toSliceType returns the element type of the given object, if the object is a
Expand Down
137 changes: 137 additions & 0 deletions gorp_mapping_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//Original: https://github.com/paulquerna-okta/gorp/blob/1fb48e4c1f26abac8d69b7b2cca5aeafde0c3532/mapping_test.go

package gorp

import (
"reflect"
"sync"
"testing"
"time"
)

type testUser struct {
ID uint64 `db:"id"`
Username string `db:"user_name"`
HashedPassword []byte `db:"hashed_password"`
EMail string `db:"email"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

type testCoolUser struct {
testUser
IsCool bool `db:"is_cool"`
BestFriends []string `db:"best_friends"`
}

func BenchmarkColumnToFieldIndex(b *testing.B) {
structType := reflect.TypeOf(testUser{})
dbmap := &DbMap{Cache: &sync.Map{}}
b.ResetTimer()
for n := 0; n < b.N; n++ {
_, err := columnToFieldIndex(dbmap,
structType,
[]string{
"user_name",
"email",
"created_at",
"updated_at",
"id",
})
if err != nil {
panic(err)
}
}
}

func TestColumnToFieldIndexBasic(t *testing.T) {
structType := reflect.TypeOf(testUser{})
dbmap := &DbMap{}
cols, err := columnToFieldIndex(dbmap,
structType,
[]string{
"email",
})
if err != nil {
t.Fatal(err)
}
if len(cols) != 1 {
t.Fatal("cols should have 1 result", cols)
}
if cols[0][0] != 3 {
t.Fatal("cols[0][0] should map to email field in testUser", cols)
}
}

func TestColumnToFieldIndexSome(t *testing.T) {
structType := reflect.TypeOf(testUser{})
dbmap := &DbMap{}
cols, err := columnToFieldIndex(dbmap,
structType,
[]string{
"id",
"email",
"created_at",
})
if err != nil {
t.Fatal(err)
}
if len(cols) != 3 {
t.Fatal("cols should have 3 results", cols)
}
if cols[0][0] != 0 {
t.Fatal("cols[0][0] should map to id field in testUser", cols)
}
if cols[1][0] != 3 {
t.Fatal("cols[1][0] should map to email field in testUser", cols)
}
if cols[2][0] != 4 {
t.Fatal("cols[2][0] should map to created_at field in testUser", cols)
}
}

func TestColumnToFieldIndexEmbedded(t *testing.T) {
structType := reflect.TypeOf(testCoolUser{})
dbmap := &DbMap{}
cols, err := columnToFieldIndex(dbmap,
structType,
[]string{
"id",
"email",
"is_cool",
})
if err != nil {
t.Fatal(err)
}
if len(cols) != 3 {
t.Fatal("cols should have 3 results", cols)
}
if cols[0][0] != 0 && cols[0][1] != 0 {
t.Fatal("cols[0][0] should map to id field in testCoolUser", cols)
}
if cols[1][0] != 0 && cols[1][1] != 3 {
t.Fatal("cols[1][0] should map to email field in testCoolUser", cols)
}
if cols[2][0] != 1 {
t.Fatal("cols[2][0] should map to is_cool field in testCoolUser", cols)
}
}

func TestColumnToFieldIndexEmbeddedFriends(t *testing.T) {
structType := reflect.TypeOf(testCoolUser{})
dbmap := &DbMap{}
cols, err := columnToFieldIndex(dbmap,
structType,
[]string{
"best_friends",
})
if err != nil {
t.Fatal(err)
}
if len(cols) != 1 {
t.Fatal("cols should have 1 results", cols)
}
if cols[0][0] != 2 {
t.Fatal("cols[0][0] should map to BestFriends field in testCoolUser", cols)
}
}

0 comments on commit 28dd33a

Please sign in to comment.