diff --git a/grumpy-runtime-src/runtime/dict.go b/grumpy-runtime-src/runtime/dict.go index 32eeca43..f97dd542 100644 --- a/grumpy-runtime-src/runtime/dict.go +++ b/grumpy-runtime-src/runtime/dict.go @@ -28,7 +28,13 @@ var ( dictItemIteratorType = newBasisType("dictionary-itemiterator", reflect.TypeOf(dictItemIterator{}), toDictItemIteratorUnsafe, ObjectType) dictKeyIteratorType = newBasisType("dictionary-keyiterator", reflect.TypeOf(dictKeyIterator{}), toDictKeyIteratorUnsafe, ObjectType) dictValueIteratorType = newBasisType("dictionary-valueiterator", reflect.TypeOf(dictValueIterator{}), toDictValueIteratorUnsafe, ObjectType) - deletedEntry = &dictEntry{} + + // Not as a real object = just a memory address. This isn't a pointer + // since the only usecase for this is to have a unique memory address. + // By having it a value type, the compiler is able to make all + // `&deletedEntry` effectively constant (avoids a memory read if this + // had a pointer type). + deletedEntry Object ) const ( @@ -47,25 +53,16 @@ type dictEntry struct { value *Object } +func (d dictEntry) isEmpty() bool { return d.key == nil } +func (d dictEntry) isDeleted() bool { return d.key == &deletedEntry } +func (d dictEntry) isSet() bool { return !d.isEmpty() && !d.isDeleted() } + // dictTable is the hash table underlying Dict. -type dictTable struct { - // used is the number of slots in the entries table that contain values. - used int32 - // fill is the number of slots that are used or once were used but have - // since been cleared. Thus used <= fill <= len(entries). - fill int - // entries is a slice of immutable dict entries. Although elements in - // the slice will be modified to point to different dictEntry objects - // as the dictionary is updated, the slice itself (i.e. location in - // memory and size) will not change for the lifetime of a dictTable. - // When the table is no longer large enough to hold a dict's contents, - // a new dictTable will be created. - entries []*dictEntry -} +type dictTable []dictEntry // newDictTable allocates a table where at least minCapacity entries can be // accommodated. minCapacity must be <= maxDictSize. -func newDictTable(minCapacity int) *dictTable { +func newDictTable(minCapacity int) dictTable { // This takes the given capacity and sets all bits less than the highest bit. // Adding 1 to that value causes the number to become a multiple of 2 again. // The minDictSize is mixed in to make sure the resulting value is at least @@ -77,168 +74,160 @@ func newDictTable(minCapacity int) *dictTable { numEntries |= numEntries >> 4 numEntries |= numEntries >> 8 numEntries |= numEntries >> 16 - return &dictTable{entries: make([]*dictEntry, numEntries+1)} -} - -// loadEntry atomically loads the i'th entry in t and returns it. -func (t *dictTable) loadEntry(i int) *dictEntry { - p := (*unsafe.Pointer)(unsafe.Pointer(&t.entries[i])) - return (*dictEntry)(atomic.LoadPointer(p)) -} - -// storeEntry atomically sets the i'th entry in t to entry. -func (t *dictTable) storeEntry(i int, entry *dictEntry) { - p := (*unsafe.Pointer)(unsafe.Pointer(&t.entries[i])) - atomic.StorePointer(p, unsafe.Pointer(entry)) -} - -func (t *dictTable) loadUsed() int { - return int(atomic.LoadInt32(&t.used)) -} - -func (t *dictTable) incUsed(n int) { - atomic.AddInt32(&t.used, int32(n)) + return make(dictTable, numEntries+1) } // insertAbsentEntry adds the populated entry to t assuming that the key // specified in entry is absent from t. Since the key is absent, no key // comparisons are necessary to perform the insert. -func (t *dictTable) insertAbsentEntry(entry *dictEntry) { - mask := uint(len(t.entries) - 1) - i := uint(entry.hash) & mask - perturb := uint(entry.hash) - index := i +func (t dictTable) insertAbsentEntry(entry dictEntry) { + mask := uint(len(t) - 1) + i, perturb := uint(entry.hash)&mask, uint(entry.hash) // The key we're trying to insert is known to be absent from the dict - // so probe for the first nil entry. - for ; t.entries[index] != nil; index = i & mask { + // so probe for the first empty entry. +top: + index := i & mask + if !t[index].isEmpty() { i, perturb = dictNextIndex(i, perturb) + // We avoid a `for` loop so this method can be inlined and save + // +1ns/call to insertAbsentEntry (which adds up since this is + // called a lot). + goto top } - t.entries[index] = entry - t.incUsed(1) - t.fill++ + t[index] = entry } -// lookupEntry returns the index and entry in t with the given hash and key. -// Elements in the table are updated with immutable entries atomically and -// lookupEntry loads them atomically. So it is not necessary to lock the dict -// to do entry lookups in a consistent way. -func (t *dictTable) lookupEntry(f *Frame, hash int, key *Object) (int, *dictEntry, *BaseException) { - mask := uint(len(t.entries) - 1) +// lookupEntry returns the index and whether the given hash and key exist in +// the table. Calls to this either should be performed on the read(only) table +// or on the write table while it is locked. +func (t dictTable) lookupEntry(f *Frame, hash int, key *Object) (int, bool, *BaseException) { + mask := uint(len(t) - 1) i, perturb := uint(hash)&mask, uint(hash) // free is the first slot that's available. We don't immediately use it // because it has been previously used and therefore an exact match may // be found further on. free := -1 - var freeEntry *dictEntry - index := int(i & mask) - entry := t.loadEntry(index) for { - if entry == nil { + index := int(i & mask) + switch entry := t[index]; entry.key { + case key: + return index, true, nil + + case nil: if free != -1 { index = free - // Store the entry instead of fetching by index - // later since it may have changed by then. - entry = freeEntry } - break - } - if entry == deletedEntry { + return index, false, nil + + case &deletedEntry: if free == -1 { free = index } - } else if entry.hash == hash { - o, raised := Eq(f, entry.key, key) - if raised != nil { - return -1, nil, raised - } - eq, raised := IsTrue(f, o) - if raised != nil { - return -1, nil, raised - } - if eq { - break + + default: + if entry.hash == hash { + o, raised := Eq(f, entry.key, key) + if raised != nil { + return index, false, raised + } + if eq, raised := IsTrue(f, o); raised != nil || eq { + return index, eq, raised + } } } i, perturb = dictNextIndex(i, perturb) - index = int(i & mask) - entry = t.loadEntry(index) } - return index, entry, nil } -// writeEntry replaces t's entry at the given index with entry. If writing -// entry would cause t's fill ratio to grow too large then a new table is +// writeEntry replaces d's entry at the given index with entry. If writing +// entry would cause d's fill ratio to grow too large then a new table is // created, the entry is instead inserted there and that table is returned. t // remains unchanged. When a sufficiently sized table cannot be created, false // will be returned for the second value, otherwise true will be returned. -func (t *dictTable) writeEntry(f *Frame, index int, entry *dictEntry) (*dictTable, bool) { - if t.entries[index] == deletedEntry { - t.storeEntry(index, entry) - t.incUsed(1) - return nil, true - } - if t.entries[index] != nil { - t.storeEntry(index, entry) - return nil, true - } - if (t.fill+1)*3 <= len(t.entries)*2 { - // New entry does not necessitate growing the table. - t.storeEntry(index, entry) - t.incUsed(1) - t.fill++ - return nil, true +func (d *Dict) writeEntry(f *Frame, index int, entry dictEntry, overwrite bool) (prevEntry dictEntry, ok bool) { + prevEntry, d.write[index] = d.write[index], entry + ok = true + + var usedDelta int32 + if entry.isSet() { + if !overwrite { + return + } + usedDelta++ + } + + if prevEntry.isEmpty() { + d.fill++ + } else if prevEntry.isSet() { + usedDelta-- } + + used := atomic.AddInt32(&d.used, usedDelta) + if int(d.fill)*3 <= len(d.write)*2 { + // Write entry does not necessitate growing the table. + return + } + // Grow the table. var n int - if t.used <= 50000 { - n = int(t.used * 4) - } else if t.used <= maxDictSize/2 { - n = int(t.used * 2) + if used <= 50000 { + n = int(used) * 4 + } else if used <= maxDictSize/2 { + n = int(used) * 2 } else { - return nil, false + ok = false + return } + newTable := newDictTable(n) - for _, oldEntry := range t.entries { - if oldEntry != nil && oldEntry != deletedEntry { + for _, oldEntry := range d.write { + if oldEntry.isSet() { newTable.insertAbsentEntry(oldEntry) } } - newTable.insertAbsentEntry(entry) - return newTable, true + d.fill = used + d.write = newTable + return } // dictEntryIterator is used to iterate over the entries in a dictTable in an // arbitrary order. type dictEntryIterator struct { - index int64 - table *dictTable -} - -// newDictEntryIterator creates a dictEntryIterator object for d. It assumes -// that d.mutex is held by the caller. -func newDictEntryIterator(d *Dict) dictEntryIterator { - return dictEntryIterator{table: d.loadTable()} -} - -// next advances this iterator to the next occupied entry and returns it. The -// second return value is true if the dict changed since iteration began, false -// otherwise. -func (iter *dictEntryIterator) next() *dictEntry { - numEntries := len(iter.table.entries) - var entry *dictEntry - for entry == nil { - // 64bit atomic ops need to be 8 byte aligned. This compile time check - // verifies alignment by creating a negative constant for an unsigned type. - // See sync/atomic docs for details. - const blank = -(unsafe.Offsetof(iter.index) % 8) - index := int(atomic.AddInt64(&iter.index, 1)) - 1 + index int32 + table dictTable +} + +// newDictEntryIterator creates a dictEntryIterator object for d. +func newDictEntryIterator(f *Frame, d *Dict) (iter dictEntryIterator) { + if rtable := d.loadReadTable(); rtable != nil { + iter.table = *rtable + } else { + d.mutex.Lock(f) + iter.table = d.write + if iter.table == nil { + iter.table = *d.loadReadTable() + } else { + // Promote to prevent unlocked mutations to the + // dictTable we are going to iterate over. + d.promoteWriteToRead() + } + d.mutex.Unlock(f) + } + return iter +} + +// next advances this iterator to the next occupied entry and returns it. +func (iter *dictEntryIterator) next() (entry dictEntry) { + numEntries := len(iter.table) + for !entry.isSet() { + index := int(atomic.AddInt32(&iter.index, 1)) - 1 if index >= numEntries { + // Clear so we don't return a deleted entry and users can just use + // `isEmpty` for speed. + entry = dictEntry{} break } - entry = iter.table.loadEntry(index) - if entry == deletedEntry { - entry = nil - } + entry = iter.table[index] } return entry } @@ -263,10 +252,25 @@ func (g *dictVersionGuard) check() bool { // thread safe. type Dict struct { Object - table *dictTable + read *dictTable + + // used is the number of slots in the entries table where + // slot.value!=nil. + used int32 + // We use a recursive mutex for synchronization because the hash and // key comparison operations may re-enter DelItem/SetItem. mutex recursiveMutex + write dictTable + + // fill is the number of slots where slot.key != nil. + // Thus used <= fill <= len(entries). + fill int32 + + // The number of reads hitting the write table - helps gauge when the + // write table should be promoted to the read table. + misses int32 + // version is incremented whenever the Dict is modified. See: // https://www.python.org/dev/peps/pep-0509/ version int64 @@ -274,35 +278,53 @@ type Dict struct { // NewDict returns an empty Dict. func NewDict() *Dict { - return &Dict{Object: Object{typ: DictType}, table: newDictTable(0)} + return &Dict{ + Object: Object{typ: DictType}, + // We start ready to write so populating is fast(er). + write: newDictTable(0), + } } func newStringDict(items map[string]*Object) *Dict { if len(items) > maxDictSize/2 { panic(fmt.Sprintf("dictionary too big: %d", len(items))) } - n := len(items) * 2 - table := newDictTable(n) + table := newDictTable(len(items) * 2) for key, value := range items { - table.insertAbsentEntry(&dictEntry{hashString(key), NewStr(key).ToObject(), value}) + table.insertAbsentEntry(dictEntry{hashString(key), NewStr(key).ToObject(), value}) + } + return &Dict{ + Object: Object{typ: DictType}, + read: &table, + used: int32(len(items)), + fill: int32(len(items)), } - return &Dict{Object: Object{typ: DictType}, table: table} } func toDictUnsafe(o *Object) *Dict { return (*Dict)(o.toPointer()) } -// loadTable atomically loads and returns d's underlying dictTable. -func (d *Dict) loadTable() *dictTable { - p := (*unsafe.Pointer)(unsafe.Pointer(&d.table)) - return (*dictTable)(atomic.LoadPointer(p)) +// unsafeReadTablePointer returns `&d.read` as an unsafe pointer. +func (d *Dict) unsafeReadTablePointer() *unsafe.Pointer { + return (*unsafe.Pointer)(unsafe.Pointer(&d.read)) +} + +// loadReadTable atomically fetches the read table. If nil, the read table +// isn't available and a fallback to the write table should be tried. +func (d *Dict) loadReadTable() *dictTable { + return (*dictTable)(atomic.LoadPointer(d.unsafeReadTablePointer())) } -// storeTable atomically updates d's underlying dictTable to the one given. -func (d *Dict) storeTable(table *dictTable) { - p := (*unsafe.Pointer)(unsafe.Pointer(&d.table)) - atomic.StorePointer(p, unsafe.Pointer(table)) +// promoteWriteToRead promotes the write table to the read table. The mutex +// needs to be held for this operation. +func (d *Dict) promoteWriteToRead() (table dictTable) { + table, d.write = d.write, nil + // We must use a pointer to a local variable to prevent setting a + // pointer to d.write (which would be bad). + atomic.StorePointer(d.unsafeReadTablePointer(), unsafe.Pointer(&table)) + d.misses = 0 + return table } // loadVersion atomically loads and returns d's version. @@ -323,14 +345,41 @@ func (d *Dict) incVersion() { atomic.AddInt64(&d.version, 1) } +// populateWriteTable makes sure that d.write is populated with the dict's +// table, possibly copying it from the read table. Lock must be held. +func (d *Dict) populateWriteTable() dictTable { + if d.write == nil { + // Copy the read-only table so we can do modifications. + oldTable := *d.loadReadTable() + if d.used == d.fill { + // No deletion markers - use builtin copy for speed. + d.write = make(dictTable, len(oldTable)) + copy(d.write, oldTable) + } else { + // Deletion markers - take the time to clean them out. + d.write = newDictTable(int(d.used)) + for _, oldEntry := range oldTable { + if oldEntry.isSet() { + d.write.insertAbsentEntry(oldEntry) + } + } + } + // NOTE: d.read remains set until later. This allows reads to + // happen while d.write is edited. Once we are ready to + // publish, d.read must be cleared. + d.fill = d.used + d.misses = 0 + } else if d.misses > 0 { + d.misses-- + } + return d.write +} + // DelItem removes the entry associated with key from d. It returns true if an // item was removed, or false if it did not exist in d. func (d *Dict) DelItem(f *Frame, key *Object) (bool, *BaseException) { originValue, raised := d.putItem(f, key, nil, true) - if raised != nil { - return false, raised - } - return originValue != nil, nil + return originValue != nil, raised } // DelItemString removes the entry associated with key from d. It returns true @@ -346,14 +395,43 @@ func (d *Dict) GetItem(f *Frame, key *Object) (*Object, *BaseException) { if raised != nil { return nil, raised } - _, entry, raised := d.loadTable().lookupEntry(f, hash.Value(), key) - if raised != nil { - return nil, raised + + var table dictTable + if rtable := d.loadReadTable(); rtable != nil { + table = *rtable } - if entry != nil && entry != deletedEntry { - return entry.value, nil + +top: + if table != nil { + index, exists, raised := table.lookupEntry(f, hash.Value(), key) + if raised != nil || !exists { + return nil, raised + } + return table[index].value, nil } - return nil, nil + + d.mutex.Lock(f) + d.misses++ + table = d.write + if table == nil { + table = *d.loadReadTable() + d.mutex.Unlock(f) + goto top + } else if d.misses > d.used { + table = d.promoteWriteToRead() + d.mutex.Unlock(f) + goto top + } + + index, exists, raised := table.lookupEntry(f, hash.Value(), key) + // TODO: If the table changes during lookup, do we retry (like in + // putItem)? + var value *Object + if exists && raised == nil { + value = table[index].value + } + d.mutex.Unlock(f) + return value, raised } // GetItemString looks up key in d, returning the associated value or nil if @@ -370,22 +448,34 @@ func (d *Dict) Pop(f *Frame, key *Object) (*Object, *BaseException) { // Keys returns a list containing all the keys in d. func (d *Dict) Keys(f *Frame) *List { - d.mutex.Lock(f) - keys := make([]*Object, d.Len()) - i := 0 - for _, entry := range d.table.entries { - if entry != nil && entry != deletedEntry { - keys[i] = entry.key - i++ + var table dictTable + if rtable := d.loadReadTable(); rtable != nil { + table = *rtable + } else { + d.mutex.Lock(f) + d.misses++ + table = d.write + if table == nil { + table = *d.loadReadTable() + } else if d.misses > d.used { + d.promoteWriteToRead() + d.mutex.Unlock(f) + } else { + defer d.mutex.Unlock(f) + } + } + keys := make([]*Object, 0, d.Len()) + for _, entry := range table { + if entry.isSet() { + keys = append(keys, entry.key) } } - d.mutex.Unlock(f) return NewList(keys...) } // Len returns the number of entries in d. func (d *Dict) Len() int { - return d.loadTable().loadUsed() + return int(atomic.LoadInt32(&d.used)) } // putItem associates value with key in d, returning the old associated value if @@ -395,40 +485,43 @@ func (d *Dict) putItem(f *Frame, key, value *Object, overwrite bool) (*Object, * if raised != nil { return nil, raised } + hashValue := hash.Value() + + entryKey := key + if value == nil { + entryKey = &deletedEntry + } + + var originValue *Object d.mutex.Lock(f) - t := d.table v := d.version - index, entry, raised := t.lookupEntry(f, hash.Value(), key) - var originValue *Object + +top: + table := d.populateWriteTable() + index, _, raised := table.lookupEntry(f, hashValue, key) if raised == nil { if v != d.version { // Dictionary was recursively modified. Blow up instead // of trying to recover. raised = f.RaiseType(RuntimeErrorType, "dictionary changed during write") - } else { - if value == nil { - // Going to delete the entry. - if entry != nil && entry != deletedEntry { - d.table.storeEntry(index, deletedEntry) - d.table.incUsed(-1) - d.incVersion() - } - } else if overwrite || entry == nil { - newEntry := &dictEntry{hash.Value(), key, value} - if newTable, ok := t.writeEntry(f, index, newEntry); ok { - if newTable != nil { - d.storeTable(newTable) - } - d.incVersion() - } else { - raised = f.RaiseType(OverflowErrorType, errResultTooLarge) - } - } - if entry != nil && entry != deletedEntry { - originValue = entry.value + } else if &d.write[0] != &table[0] { + goto top // Entry lookup caused tables to shift. Try again. + } else if prevEntry, ok := d.writeEntry(f, index, dictEntry{hashValue, entryKey, value}, overwrite); ok { + originValue = prevEntry.value + if value != nil || originValue != nil { + d.incVersion() } + } else { + raised = f.RaiseType(OverflowErrorType, errResultTooLarge) } } + + if d.read != nil { + // Time to "publish" the write table. Must use atomic for write + // since other goroutines might be reading concurrently. + atomic.StorePointer(d.unsafeReadTablePointer(), nil) + } + d.mutex.Unlock(f) return originValue, raised } @@ -454,11 +547,9 @@ func (d *Dict) Update(f *Frame, o *Object) (raised *BaseException) { var iter *Object if o.isInstance(DictType) { d2 := toDictUnsafe(o) - d2.mutex.Lock(f) // Concurrent modifications to d2 will cause Update to raise // "dictionary changed during iteration". - iter = newDictItemIterator(d2).ToObject() - d2.mutex.Unlock(f) + iter = newDictItemIterator(f, d2).ToObject() } else { iter, raised = Iter(f, o) } @@ -483,21 +574,30 @@ func dictsAreEqual(f *Frame, d1, d2 *Dict) (bool, *BaseException) { if d1 == d2 { return true, nil } - // Do not hold both locks at the same time to avoid deadlock. - d1.mutex.Lock(f) - iter := newDictEntryIterator(d1) - g1 := newDictVersionGuard(d1) + + // NOTE: The length, iterator, and version may not be consistent. This + // is actually OK. If the length is changing concurrently to this call, + // then the programmer hasn't bothered to implement proper locking in + // their code and in reality they don't know which statement is + // happening before and which is happening after (mutator vs. eq). + // Additionally, it shouldn't matter + // that the version is potentially one higher (mutation in flight with + // initial setup) for the same reason - they can't define an ordering. + // + // Put another way, if the operation is "atomic" and doesn't bleed back + // into Python, then this should be too. If it isn't, they should have + // a lock. + iter := newDictEntryIterator(f, d1) len1 := d1.Len() - d1.mutex.Unlock(f) - d2.mutex.Lock(f) - g2 := newDictVersionGuard(d1) + g1 := newDictVersionGuard(d1) + len2 := d2.Len() - d2.mutex.Unlock(f) + g2 := newDictVersionGuard(d2) if len1 != len2 { return false, nil } result := true - for entry := iter.next(); entry != nil && result; entry = iter.next() { + for entry := iter.next(); !entry.isEmpty() && result; entry = iter.next() { if v, raised := d2.GetItem(f, entry.key); raised != nil { return false, raised } else if v == nil { @@ -524,9 +624,15 @@ func dictClear(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { return nil, raised } d := toDictUnsafe(args[0]) + d.mutex.Lock(f) - d.table = newDictTable(0) + // Start ready to write... + d.write = newDictTable(0) + atomic.StoreInt32(&d.used, 0) d.incVersion() + d.fill = 0 + + atomic.StorePointer(d.unsafeReadTablePointer(), nil) d.mutex.Unlock(f) return None, nil } @@ -599,9 +705,7 @@ func dictItems(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) { return nil, raised } d := toDictUnsafe(args[0]) - d.mutex.Lock(f) - iter := newDictItemIterator(d).ToObject() - d.mutex.Unlock(f) + iter := newDictItemIterator(f, d).ToObject() return ListType.Call(f, Args{iter}, nil) } @@ -610,9 +714,7 @@ func dictIterItems(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException) return nil, raised } d := toDictUnsafe(args[0]) - d.mutex.Lock(f) - iter := newDictItemIterator(d).ToObject() - d.mutex.Unlock(f) + iter := newDictItemIterator(f, d).ToObject() return iter, nil } @@ -628,9 +730,7 @@ func dictIterValues(f *Frame, args Args, kwargs KWArgs) (*Object, *BaseException return nil, raised } d := toDictUnsafe(args[0]) - d.mutex.Lock(f) - iter := newDictValueIterator(d).ToObject() - d.mutex.Unlock(f) + iter := newDictValueIterator(f, d).ToObject() return iter, nil } @@ -677,9 +777,7 @@ func dictInit(f *Frame, o *Object, args Args, kwargs KWArgs) (*Object, *BaseExce func dictIter(f *Frame, o *Object) (*Object, *BaseException) { d := toDictUnsafe(o) - d.mutex.Lock(f) - iter := newDictKeyIterator(d).ToObject() - d.mutex.Unlock(f) + iter := newDictKeyIterator(f, d).ToObject() return iter, nil } @@ -701,8 +799,12 @@ func dictNE(f *Frame, v, w *Object) (*Object, *BaseException) { } func dictNew(f *Frame, t *Type, _ Args, _ KWArgs) (*Object, *BaseException) { + if t == DictType { + return NewDict().ToObject(), nil + } d := toDictUnsafe(newObject(t)) - d.table = &dictTable{entries: make([]*dictEntry, minDictSize, minDictSize)} + table := newDictTable(0) + d.read = &table return d.ToObject(), nil } @@ -734,7 +836,7 @@ func dictPopItem(f *Frame, args Args, _ KWArgs) (item *Object, raised *BaseExcep } d := toDictUnsafe(args[0]) d.mutex.Lock(f) - iter := newDictEntryIterator(d) + iter := newDictEntryIterator(f, d) entry := iter.next() if entry == nil { raised = f.RaiseType(KeyErrorType, "popitem(): dictionary is empty") @@ -754,15 +856,14 @@ func dictRepr(f *Frame, o *Object) (*Object, *BaseException) { return NewStr("{...}").ToObject(), nil } defer f.reprLeave(d.ToObject()) - // Lock d so that we get a consistent view of it. Otherwise we may - // return a state that d was never actually in. - d.mutex.Lock(f) - defer d.mutex.Unlock(f) + + // Grab a snapshot of our current state: + iter := newDictEntryIterator(f, d) + var buf bytes.Buffer buf.WriteString("{") - iter := newDictEntryIterator(d) i := 0 - for entry := iter.next(); entry != nil; entry = iter.next() { + for entry := iter.next(); !entry.isEmpty(); entry = iter.next() { if i > 0 { buf.WriteString(", ") } @@ -885,12 +986,11 @@ type dictItemIterator struct { guard dictVersionGuard } -// newDictItemIterator creates a dictItemIterator object for d. It assumes that -// d.mutex is held by the caller. -func newDictItemIterator(d *Dict) *dictItemIterator { +// newDictItemIterator creates a dictItemIterator object for d. +func newDictItemIterator(f *Frame, d *Dict) *dictItemIterator { return &dictItemIterator{ Object: Object{typ: dictItemIteratorType}, - iter: newDictEntryIterator(d), + iter: newDictEntryIterator(f, d), guard: newDictVersionGuard(d), } } @@ -928,12 +1028,11 @@ type dictKeyIterator struct { guard dictVersionGuard } -// newDictKeyIterator creates a dictKeyIterator object for d. It assumes that -// d.mutex is held by the caller. -func newDictKeyIterator(d *Dict) *dictKeyIterator { +// newDictKeyIterator creates a dictKeyIterator object for d. +func newDictKeyIterator(f *Frame, d *Dict) *dictKeyIterator { return &dictKeyIterator{ Object: Object{typ: dictKeyIteratorType}, - iter: newDictEntryIterator(d), + iter: newDictEntryIterator(f, d), guard: newDictVersionGuard(d), } } @@ -953,10 +1052,7 @@ func dictKeyIteratorIter(f *Frame, o *Object) (*Object, *BaseException) { func dictKeyIteratorNext(f *Frame, o *Object) (*Object, *BaseException) { iter := toDictKeyIteratorUnsafe(o) entry, raised := dictIteratorNext(f, &iter.iter, &iter.guard) - if raised != nil { - return nil, raised - } - return entry.key, nil + return entry.key, raised } func initDictKeyIteratorType(map[string]*Object) { @@ -971,12 +1067,11 @@ type dictValueIterator struct { guard dictVersionGuard } -// newDictValueIterator creates a dictValueIterator object for d. It assumes -// that d.mutex is held by the caller. -func newDictValueIterator(d *Dict) *dictValueIterator { +// newDictValueIterator creates a dictValueIterator object for d. +func newDictValueIterator(f *Frame, d *Dict) *dictValueIterator { return &dictValueIterator{ Object: Object{typ: dictValueIteratorType}, - iter: newDictEntryIterator(d), + iter: newDictEntryIterator(f, d), guard: newDictVersionGuard(d), } } @@ -996,10 +1091,7 @@ func dictValueIteratorIter(f *Frame, o *Object) (*Object, *BaseException) { func dictValueIteratorNext(f *Frame, o *Object) (*Object, *BaseException) { iter := toDictValueIteratorUnsafe(o) entry, raised := dictIteratorNext(f, &iter.iter, &iter.guard) - if raised != nil { - return nil, raised - } - return entry.value, nil + return entry.value, raised } func initDictValueIteratorType(map[string]*Object) { @@ -1020,18 +1112,17 @@ func dictNextIndex(i, perturb uint) (uint, uint) { return (i << 2) + i + perturb + 1, perturb >> 5 } -func dictIteratorNext(f *Frame, iter *dictEntryIterator, guard *dictVersionGuard) (*dictEntry, *BaseException) { +func dictIteratorNext(f *Frame, iter *dictEntryIterator, guard *dictVersionGuard) (entry dictEntry, raises *BaseException) { // NOTE: The behavior here diverges from CPython where an iterator that // is exhausted will always return StopIteration regardless whether the // underlying dict is subsequently modified. In Grumpy, an iterator for // a dict that has been modified will always raise RuntimeError even if // the iterator was exhausted before the modification. - entry := iter.next() + entry = iter.next() if !guard.check() { - return nil, f.RaiseType(RuntimeErrorType, "dictionary changed during iteration") - } - if entry == nil { - return nil, f.Raise(StopIterationType.ToObject(), nil, nil) + raises = f.RaiseType(RuntimeErrorType, "dictionary changed during iteration") + } else if entry.isEmpty() { + raises = f.Raise(StopIterationType.ToObject(), nil, nil) } - return entry, nil + return } diff --git a/grumpy-runtime-src/runtime/dict_test.go b/grumpy-runtime-src/runtime/dict_test.go index a6545c4d..20942e7a 100644 --- a/grumpy-runtime-src/runtime/dict_test.go +++ b/grumpy-runtime-src/runtime/dict_test.go @@ -100,6 +100,7 @@ func TestDictDelItem(t *testing.T) { {args: wrapArgs(testDict, "a"), want: newTestDict("b", 2, "c", 3).ToObject()}, {args: wrapArgs(testDict, "c"), want: newTestDict("b", 2).ToObject()}, {args: wrapArgs(testDict, "a"), wantExc: mustCreateException(KeyErrorType, "a")}, + {args: wrapArgs(testDict, "b"), want: NewDict().ToObject()}, {args: wrapArgs(NewDict(), NewList()), wantExc: mustCreateException(TypeErrorType, "unhashable type: 'list'")}, } for _, cas := range cases { @@ -411,7 +412,8 @@ func TestDictHasKey(t *testing.T) { } func TestDictItemIteratorIter(t *testing.T) { - iter := &newDictItemIterator(NewDict()).Object + f := NewRootFrame() + iter := newDictItemIterator(f, NewDict()).ToObject() cas := &invokeTestCase{args: wrapArgs(iter), want: iter} if err := runInvokeMethodTestCase(dictItemIteratorType, "__iter__", cas); err != "" { t.Error(err) @@ -537,7 +539,8 @@ func TestDictItems(t *testing.T) { } func TestDictKeyIteratorIter(t *testing.T) { - iter := &newDictKeyIterator(NewDict()).Object + f := NewRootFrame() + iter := newDictKeyIterator(f, NewDict()).ToObject() cas := &invokeTestCase{args: wrapArgs(iter), want: iter} if err := runInvokeMethodTestCase(dictKeyIteratorType, "__iter__", cas); err != "" { t.Error(err) @@ -561,9 +564,23 @@ func TestDictKeyIterModified(t *testing.T) { } func TestDictKeys(t *testing.T) { + bigDict := newTestDict( + "abc", "def", + "ghi", "jkl", + "mno", "pqr", + "stu", "vwx", + "yzA", "BCD", + "EFG", "HIJ", + "KLM", "OPQ", + "RST", "UVW", + "XYZ", "123", + "456", "789", + "10!", "@#$", + "%^&", "*()") cases := []invokeTestCase{ {args: wrapArgs(NewDict()), want: NewList().ToObject()}, {args: wrapArgs(newTestDict("foo", None, 42, None)), want: newTestList(42, "foo").ToObject()}, + {args: wrapArgs(bigDict), want: newTestList("abc", "yzA", "KLM", "XYZ", "10!", "456", "stu", "%^&", "mno", "RST", "ghi", "EFG").ToObject()}, } for _, cas := range cases { if err := runInvokeMethodTestCase(DictType, "keys", &cas); err != "" { @@ -912,7 +929,9 @@ func newTestDict(elems ...interface{}) *Dict { for i := 0; i < numItems; i++ { k := mustNotRaise(WrapNative(f, reflect.ValueOf(elems[i*2]))) v := mustNotRaise(WrapNative(f, reflect.ValueOf(elems[i*2+1]))) - d.SetItem(f, k, v) + if raised := d.SetItem(f, k, v); raised != nil { + panic(raised) + } } return d } diff --git a/grumpy-runtime-src/runtime/set.go b/grumpy-runtime-src/runtime/set.go index 54aa2914..afb6c5c5 100644 --- a/grumpy-runtime-src/runtime/set.go +++ b/grumpy-runtime-src/runtime/set.go @@ -82,10 +82,7 @@ func toSetUnsafe(o *Object) *Set { // Add inserts key into s. If key already exists then does nothing. func (s *Set) Add(f *Frame, key *Object) (bool, *BaseException) { origin, raised := s.dict.putItem(f, key, None, true) - if raised != nil { - return false, raised - } - return origin == nil, nil + return origin == nil, raised } // Contains returns true if key exists in s. @@ -181,10 +178,7 @@ func setIsSuperset(f *Frame, args Args, _ KWArgs) (*Object, *BaseException) { func setIter(f *Frame, o *Object) (*Object, *BaseException) { s := toSetUnsafe(o) - s.dict.mutex.Lock(f) - iter := &newDictKeyIterator(s.dict).Object - s.dict.mutex.Unlock(f) - return iter, nil + return newDictKeyIterator(f, s.dict).ToObject(), nil } func setLE(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -311,10 +305,7 @@ func frozenSetIsSuperset(f *Frame, args Args, _ KWArgs) (*Object, *BaseException func frozenSetIter(f *Frame, o *Object) (*Object, *BaseException) { s := toFrozenSetUnsafe(o) - s.dict.mutex.Lock(f) - iter := &newDictKeyIterator(s.dict).Object - s.dict.mutex.Unlock(f) - return iter, nil + return newDictKeyIterator(f, s.dict).ToObject(), nil } func frozenSetLE(f *Frame, v, w *Object) (*Object, *BaseException) { @@ -388,15 +379,15 @@ func setCompare(f *Frame, op compareOp, v *setBase, w *Object) (*Object, *BaseEx op = op.swapped() v, s2 = s2, v } - v.dict.mutex.Lock(f) - iter := newDictEntryIterator(v.dict) - g1 := newDictVersionGuard(v.dict) + + // NOTE: See comment in dictsAreEqual for why an inconsistent view here is actually ok. + iter := newDictEntryIterator(f, v.dict) len1 := v.dict.Len() - v.dict.mutex.Unlock(f) - s2.dict.mutex.Lock(f) - g2 := newDictVersionGuard(s2.dict) + g1 := newDictVersionGuard(v.dict) + len2 := s2.dict.Len() - s2.dict.mutex.Unlock(f) + g2 := newDictVersionGuard(s2.dict) + result := (op != compareOpNE) switch op { case compareOpLT: @@ -412,7 +403,7 @@ func setCompare(f *Frame, op compareOp, v *setBase, w *Object) (*Object, *BaseEx return GetBool(!result).ToObject(), nil } } - for entry := iter.next(); entry != nil; entry = iter.next() { + for entry := iter.next(); !entry.isEmpty(); entry = iter.next() { contains, raised := s2.contains(f, entry.key) if raised != nil { return nil, raised