diff --git a/encode.go b/encode.go index 792bde5..0535808 100644 --- a/encode.go +++ b/encode.go @@ -487,3 +487,40 @@ func isAVEqual(a, b *dynamodb.AttributeValue) bool { } return false } + +// isNil returns true if v is considered nil +// this is used to determine if an attribute should be set or removed +func isNil(v interface{}) bool { + if v == nil || v == "" { + return true + } + + // consider v nil if it's a special encoder defined on a value type, but v is a pointer + rv := reflect.ValueOf(v) + switch v.(type) { + case Marshaler: + if rv.Kind() == reflect.Ptr && rv.IsNil() { + if _, ok := rv.Type().Elem().MethodByName("MarshalDynamo"); ok { + return true + } + } + case dynamodbattribute.Marshaler: + if rv.Kind() == reflect.Ptr && rv.IsNil() { + if _, ok := rv.Type().Elem().MethodByName("MarshalDynamoDBAttributeValue"); ok { + return true + } + } + case encoding.TextMarshaler: + if rv.Kind() == reflect.Ptr && rv.IsNil() { + if _, ok := rv.Type().Elem().MethodByName("MarshalText"); ok { + return true + } + } + default: + // e.g. (*int)(nil) + return rv.Kind() == reflect.Ptr && rv.IsNil() + } + + // non-pointers or special encoders with a pointer receiver + return false +} diff --git a/encoding_test.go b/encoding_test.go index 4665bbf..1715e7a 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -167,6 +167,17 @@ var itemEncodingTests = []struct { "A": &dynamodb.AttributeValue{S: aws.String("hello")}, }, }, + { + name: "pointer (value receiver TextMarshaler)", + in: &struct { + A *textMarshaler + }{ + A: new(textMarshaler), + }, + out: map[string]*dynamodb.AttributeValue{ + "A": &dynamodb.AttributeValue{S: aws.String("false")}, + }, + }, { name: "rename", in: struct { @@ -390,9 +401,31 @@ func (tm *textMarshaler) UnmarshalText(text []byte) error { return nil } +type ptrTextMarshaler bool + +func (tm *ptrTextMarshaler) MarshalText() ([]byte, error) { + if tm == nil { + return []byte("null"), nil + } + if *tm { + return []byte("true"), nil + } + return []byte("false"), nil +} + +func (tm *ptrTextMarshaler) UnmarshalText(text []byte) error { + if string(text) == "null" { + return nil + } + *tm = string(text) == "true" + return nil +} + var ( _ Marshaler = new(customMarshaler) _ Unmarshaler = new(customMarshaler) _ encoding.TextMarshaler = new(textMarshaler) _ encoding.TextUnmarshaler = new(textMarshaler) + _ encoding.TextMarshaler = new(ptrTextMarshaler) + _ encoding.TextUnmarshaler = new(ptrTextMarshaler) ) diff --git a/update.go b/update.go index 1af41d2..d2fbfeb 100644 --- a/update.go +++ b/update.go @@ -59,9 +59,13 @@ func (u *Update) Range(name string, value interface{}) *Update { } // Set changes path to the given value. +// If value is an empty string or nil, path will be removed instead. // Paths that are reserved words are automatically escaped. // Use single quotes to escape complex values like 'User'.'Count'. func (u *Update) Set(path string, value interface{}) *Update { + if isNil(value) { + return u.Remove(path) + } path, err := u.escape(path) u.setError(err) expr, err := u.subExpr("🝕 = ?", path, value) diff --git a/update_test.go b/update_test.go index 6e1ed1f..a2d5235 100644 --- a/update_test.go +++ b/update_test.go @@ -73,3 +73,49 @@ func TestUpdate(t *testing.T) { t.Error("expected ConditionalCheckFailedException, not", err) } } + +func TestUpdateNil(t *testing.T) { + if testDB == nil { + t.Skip(offlineSkipMsg) + } + table := testDB.Table(testTable) + + // first, add an item to make sure there is at least one + item := widget{ + UserID: 4242, + Time: time.Now().UTC(), + Msg: "delete me", + Meta: map[string]string{ + "abc": "123", + }, + Count: 100, + } + err := table.Put(item).Run() + if err != nil { + t.Error("unexpected error:", err) + t.FailNow() + } + + // update Msg with 'nil', which should delete it + var result widget + err = table.Update("UserID", item.UserID).Range("Time", item.Time). + Set("Msg", (*textMarshaler)(nil)). + Set("Meta.'abc'", nil). + Set("Meta.'ok'", (*ptrTextMarshaler)(nil)). + Set("Count", (*int)(nil)). + Value(&result) + if err != nil { + t.Error("unexpected error:", err) + } + expected := widget{ + UserID: item.UserID, + Time: item.Time, + Msg: "", + Meta: map[string]string{ + "ok": "null", + }, + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("bad result. %+v ≠ %+v", result, expected) + } +}