Skip to content

Commit

Permalink
Reimplement SQLite upsert, loadOrStore, and remove without transactions
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Krieger <[email protected]>
  • Loading branch information
ben-krieger committed Oct 31, 2024
1 parent dca3304 commit 27edd33
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 84 deletions.
52 changes: 38 additions & 14 deletions fdotest/server_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,17 @@ func RunServerStateSuite(t *testing.T, state AllServerState) { //nolint:gocyclo
if _, err := rand.Read(guid[:]); err != nil {
t.Fatal(err)
}
ov := &fdo.Voucher{Header: *cbor.NewBstr(fdo.VoucherHeader{GUID: guid})}
expectOV := &fdo.Voucher{Header: *cbor.NewBstr(fdo.VoucherHeader{
GUID: guid,
ManufacturerKey: protocol.PublicKey{
Type: protocol.Rsa2048RestrKeyType,
Encoding: protocol.X509KeyEnc,
Body: cbor.RawBytes{0x40},
},
})}
dnsAddr := "owner.fidoalliance.org"
fakeHash := sha256.Sum256([]byte("fake blob"))
expect := cose.Sign1[protocol.To1d, []byte]{
expectBlob := cose.Sign1[protocol.To1d, []byte]{
Payload: cbor.NewByteWrap(protocol.To1d{
RV: []protocol.RvTO2Addr{
{
Expand All @@ -515,27 +522,32 @@ func RunServerStateSuite(t *testing.T, state AllServerState) { //nolint:gocyclo
if err != nil {
t.Fatal(err)
}
if err := expect.Sign(testKey, nil, nil, nil); err != nil {
if err := expectBlob.Sign(testKey, nil, nil, nil); err != nil {
t.Fatal(err)
}

// Shadow state to limit testable functions
var state fdo.RendezvousBlobPersistentState = state

// Store and retrieve rendezvous blob
// Store and retrieve rendezvous blob 2 times to test upsert
if _, _, err := state.RVBlob(context.TODO(), guid); !errors.Is(err, fdo.ErrNotFound) {
t.Fatalf("expected ErrNotFound, got %v", err)
}
exp := time.Now().Add(time.Hour)
if err := state.SetRVBlob(context.TODO(), ov, &expect, exp); err != nil {
t.Fatal(err)
}
got, _, err := state.RVBlob(context.TODO(), guid)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*got, expect) {
t.Fatalf("expected %+v, got %+v", expect, got)
for range 2 {
exp := time.Now().Add(time.Hour)
if err := state.SetRVBlob(context.TODO(), expectOV, &expectBlob, exp); err != nil {
t.Fatal(err)
}
gotBlob, gotOV, err := state.RVBlob(context.TODO(), guid)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*gotBlob, expectBlob) {
t.Fatalf("expected blob %+v, got %+v", expectBlob, gotBlob)
}
if !gotOV.Header.Val.Equal(&expectOV.Header.Val) {
t.Fatalf("expected blob %+v, got %+v", expectOV.Header.Val, gotOV.Header.Val)
}
}
})

Expand Down Expand Up @@ -588,6 +600,18 @@ func RunServerStateSuite(t *testing.T, state AllServerState) { //nolint:gocyclo
if _, err := state.Voucher(context.TODO(), newGUID); err != nil {
t.Fatal(err)
}

// Remove voucher
removed, err := state.RemoveVoucher(context.TODO(), newGUID)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(removed, ov) {
t.Errorf("removed voucher should match replaced %+v, got %+v", ov, removed)
}
if _, err := state.RemoveVoucher(context.TODO(), newGUID); !errors.Is(err, fdo.ErrNotFound) {
t.Fatalf("removed voucher GUID should return not found, got error %v", err)
}
})

t.Run("OwnerKeyPersistentState", func(t *testing.T) {
Expand Down
144 changes: 74 additions & 70 deletions sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,32 +218,20 @@ func (db *DB) NewToken(ctx context.Context, protocol protocol.Protocol) (string,
}

func (db *DB) loadOrStoreSecret(ctx context.Context) ([]byte, error) {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error starting transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()

var readSecret []byte
if err := query(db.debugCtx(ctx), tx, "secrets", []string{"secret"}, map[string]any{"type": "hmac"}, &readSecret); err != nil && !errors.Is(err, fdo.ErrNotFound) {
return nil, fmt.Errorf("error reading hmac secret: %w", err)
}
if len(readSecret) > 0 {
return readSecret, nil
}

// Insert new secret
var secret [64]byte
if _, err := rand.Read(secret[:]); err != nil {
// Insert (or ignore) a new HMAC secret
secret := make([]byte, 64)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
if err := insert(db.debugCtx(ctx), tx, "secrets", map[string]any{"type": "hmac", "secret": secret[:]}, nil); err != nil {
if err := db.insertOrIgnore(ctx, "secrets", map[string]any{"type": "hmac", "secret": secret}); err != nil {
return nil, fmt.Errorf("error writing hmac secret: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, err

// Read secret
if err := db.query(ctx, "secrets", []string{"secret"}, map[string]any{"type": "hmac"}, &secret); err != nil {
return nil, fmt.Errorf("error reading hmac secret: %w", err)
}
return secret[:], nil
return secret, nil
}

type contextKey struct{}
Expand Down Expand Up @@ -308,20 +296,7 @@ func (db *DB) sessionID(ctx context.Context) ([]byte, bool) {
}

func (db *DB) insert(ctx context.Context, table string, kvs, upsertWhere map[string]any) error {
if len(upsertWhere) == 0 {
return insert(ctx, db.db, table, kvs, upsertWhere)
}

tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()

if err := insert(ctx, tx, table, kvs, upsertWhere); err != nil {
return err
}
return tx.Commit()
return insert(db.debugCtx(ctx), db.db, table, kvs, upsertWhere)
}

func (db *DB) insertOrIgnore(ctx context.Context, table string, kvs map[string]any) error {
Expand All @@ -336,17 +311,25 @@ func (db *DB) query(ctx context.Context, table string, columns []string, where m
return query(db.debugCtx(ctx), db.db, table, columns, where, into...)
}

// Allows using *sql.DB or *sql.Tx
type execer interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}

// Allows using *sql.DB or *sql.Tx
type querier interface {
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

// Allows using *sql.DB or *sql.Tx
type queryexecer interface {
querier
execer
}

func insert(ctx context.Context, db execer, table string, kvs, upsertWhere map[string]any) error {
var orIgnore string
if upsertWhere != nil {
if upsertWhere != nil && len(upsertWhere) == 0 {
orIgnore = "OR IGNORE "
}

Expand All @@ -357,22 +340,41 @@ func insert(ctx context.Context, db execer, table string, kvs, upsertWhere map[s
}
markers := slices.Repeat([]string{"?"}, len(columns))

var upsert string
if len(upsertWhere) > 0 {
upsertWhereKeys := slices.Collect(maps.Keys(upsertWhere))

var updates []string
for _, key := range columns {
if slices.Contains(upsertWhereKeys, key) {
continue
}
updates = append(updates, fmt.Sprintf("`%s` = excluded.`%s`", key, key))
}

whereClauses := make([]string, len(upsertWhereKeys))
for i, key := range upsertWhereKeys {
whereClauses[i] = fmt.Sprintf("`%s` = ?", key)
args = append(args, upsertWhere[key])
}

upsert = " ON CONFLICT DO UPDATE SET "
upsert += strings.Join(updates, " AND ")
upsert += " WHERE "
upsert += strings.Join(whereClauses, " AND ")
}

query := fmt.Sprintf(
"INSERT %sINTO %s (%s) VALUES (%s)",
"INSERT %sINTO %s (%s) VALUES (%s)%s",
orIgnore,
table,
"`"+strings.Join(columns, "`, `")+"`",
strings.Join(markers, ", "),
upsert,
)
debug(ctx, "sqlite: %s\n%+v", query, kvs)
if _, err := db.ExecContext(ctx, query, args...); err != nil {
return err
}

if len(upsertWhere) > 0 {
return update(ctx, db, table, kvs, upsertWhere)
}
return nil
debug(ctx, "sqlite: %s\n%+v", query, args)
_, err := db.ExecContext(ctx, query, args...)
return err
}

func update(ctx context.Context, db execer, table string, kvs, where map[string]any) error {
Expand Down Expand Up @@ -440,7 +442,7 @@ func query(ctx context.Context, db querier, table string, columns []string, wher
return nil
}

func remove(ctx context.Context, db execer, table string, where map[string]any) error {
func remove(ctx context.Context, db queryexecer, table string, where map[string]any, returning map[string]any) error {
whereKeys := slices.Collect(maps.Keys(where))
clauses := make([]string, len(whereKeys))
for i, key := range whereKeys {
Expand All @@ -451,12 +453,33 @@ func remove(ctx context.Context, db execer, table string, where map[string]any)
whereVals[i] = where[key]
}

var returningQuery string
returningArgs := make([]any, len(returning))
if len(returning) > 0 {
returningKeys := slices.Collect(maps.Keys(returning))
returningQuery = " RETURNING `" + strings.Join(returningKeys, "`, `") + "`"
for i, key := range returningKeys {
returningArgs[i] = returning[key]
}
}

query := fmt.Sprintf(
`DELETE FROM %s WHERE %s`,
`DELETE FROM %s WHERE %s%s`,
table,
strings.Join(clauses, " AND "),
returningQuery,
)
debug(ctx, "sqlite: %s\n%+v", query, where)
debug(ctx, "sqlite: %s\n%+v", query, whereVals)

if returningQuery != "" {
row := db.QueryRowContext(ctx, query, whereVals...)
if err := row.Scan(returningArgs...); errors.Is(err, sql.ErrNoRows) {
return fdo.ErrNotFound
} else if err != nil {
return err
}
return nil
}

result, err := db.ExecContext(ctx, query, whereVals...)
if err != nil {
Expand Down Expand Up @@ -829,19 +852,9 @@ func (db *DB) ReplaceVoucher(ctx context.Context, guid protocol.GUID, ov *fdo.Vo

// RemoveVoucher untracks a voucher, deleting it, and returns it for extension.
func (db *DB) RemoveVoucher(ctx context.Context, guid protocol.GUID) (*fdo.Voucher, error) {
ctx = db.debugCtx(ctx)

tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error starting transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()

var data []byte
if err := query(ctx, tx, "owner_vouchers", []string{"cbor"},
map[string]any{"guid": guid[:]},
&data,
); err != nil {
if err := remove(db.debugCtx(ctx), db.db, "owner_vouchers",
map[string]any{"guid": guid[:]}, map[string]any{"cbor": &data}); err != nil {
return nil, err
}
if data == nil {
Expand All @@ -852,15 +865,6 @@ func (db *DB) RemoveVoucher(ctx context.Context, guid protocol.GUID) (*fdo.Vouch
if err := cbor.Unmarshal(data, &ov); err != nil {
return nil, fmt.Errorf("error unmarshaling ownership voucher: %w", err)
}

if err := remove(ctx, tx, "owner_vouchers", map[string]any{"guid": guid[:]}); err != nil {
return nil, err
}

if err := tx.Commit(); err != nil {
return nil, err
}

return &ov, nil
}

Expand Down

0 comments on commit 27edd33

Please sign in to comment.