diff --git a/android/onceper.go b/android/onceper.go index 5fefb710c..f06f428d5 100644 --- a/android/onceper.go +++ b/android/onceper.go @@ -20,8 +20,23 @@ import ( ) type OncePer struct { - values sync.Map - valuesLock sync.Mutex + values sync.Map +} + +type onceValueWaiter chan bool + +func (once *OncePer) maybeWaitFor(key OnceKey, value interface{}) interface{} { + if wait, isWaiter := value.(onceValueWaiter); isWaiter { + // The entry in the map is a placeholder waiter because something else is constructing the value + // wait until the waiter is signalled, then load the real value. + <-wait + value, _ = once.values.Load(key) + if _, isWaiter := value.(onceValueWaiter); isWaiter { + panic(fmt.Errorf("Once() waiter completed but key is still not valid")) + } + } + + return value } // Once computes a value the first time it is called with a given key per OncePer, and returns the @@ -29,21 +44,20 @@ type OncePer struct { func (once *OncePer) Once(key OnceKey, value func() interface{}) interface{} { // Fast path: check if the key is already in the map if v, ok := once.values.Load(key); ok { - return v + return once.maybeWaitFor(key, v) } - // Slow path: lock so that we don't call the value function twice concurrently - once.valuesLock.Lock() - defer once.valuesLock.Unlock() - - // Check again with the lock held - if v, ok := once.values.Load(key); ok { - return v + // Slow path: create a OnceValueWrapper and attempt to insert it + waiter := make(onceValueWaiter) + if v, loaded := once.values.LoadOrStore(key, waiter); loaded { + // Got a value, something else inserted its own waiter or a constructed value + return once.maybeWaitFor(key, v) } - // Still not in the map, call the value function and store it + // The waiter is inserted, call the value constructor, store it, and signal the waiter v := value() once.values.Store(key, v) + close(waiter) return v } diff --git a/android/onceper_test.go b/android/onceper_test.go index d2ca9ad70..f27799bec 100644 --- a/android/onceper_test.go +++ b/android/onceper_test.go @@ -133,3 +133,14 @@ func TestNewCustomOnceKey(t *testing.T) { t.Errorf(`second call to Once with the NewCustomOnceKey from equal key should return "a": %q`, b) } } + +func TestOncePerReentrant(t *testing.T) { + once := OncePer{} + key1 := NewOnceKey("key") + key2 := NewOnceKey("key") + + a := once.Once(key1, func() interface{} { return once.Once(key2, func() interface{} { return "a" }) }) + if a != "a" { + t.Errorf(`reentrant Once should return "a": %q`, a) + } +}