首页 > 解决方案 > 我可以在 Go 中使用特定值锁定吗?

问题描述

在回答另一个问题时,我编写了一个sync.Map用于缓存 API 请求的小结构。

type PostManager struct {
    sync.Map
}

func (pc PostManager) Fetch(id int) Post {
    post, ok := pc.Load(id)
    if ok {
        fmt.Printf("Using cached post %v\n", id)
        return post.(Post)
    }
    fmt.Printf("Fetching post %v\n", id)
    post = pc.fetchPost(id)
    pc.Store(id, post)

    return post.(Post)
}

不幸的是,如果两个 goroutine 同时获取同一个未缓存的 Post,它们都会发出请求。

var postManager PostManager

wg.Add(3)

var firstPost Post
var secondPost Post
var secondPostAgain Post

go func() {
    // Fetches and caches 1
    firstPost = postManager.Fetch(1)
    defer wg.Done()
}()

go func() {
    // Fetches and caches 2
    secondPost = postManager.Fetch(2)
    defer wg.Done()
}()

go func() {
    // Also fetches and caches 2
    secondPostAgain = postManager.Fetch(2)
    defer wg.Done()
}()

wg.Wait()

我需要确保当同时提取相同 ID 时,只允许一个实际发出请求。另一个必须等​​待并将使用缓存的 Post。但也不要锁定不同 ID 的提取。

在上面的示例中,我希望只有一个调用pc.fetchPost(1)andpc.fetchPost(2)并且它们应该是同时的。

链接到完整代码

标签: gomutex

解决方案


如果提取已经在进行中,看起来可以使用第二张地图等待。

type PostManager struct {
    sync.Map
    q sync.Map
}

func (pc *PostManager) Fetch(id int) Post {
    post, ok := pc.Load(id)
    if ok {
        fmt.Printf("Using cached post %v\n", id)
        return post.(Post)
    }
    fmt.Printf("Fetching post %v\n", id)
    if c, loaded := pc.q.LoadOrStore(id, make(chan struct{})); !loaded {
        post = pc.fetchPost(id)
        pc.Store(id, post)
        close(c.(chan struct{}))
    } else {
        <-c.(chan struct{})
        post,_ = pc.Load(id)
    }
    return post.(Post)
}

或者,更复杂一点,使用相同的地图;-)

func (pc *PostManager) Fetch(id int) Post {
    p, ok := pc.Load(id)

    if !ok {
        fmt.Printf("Fetching post %v\n", id)
        if p, ok = pc.LoadOrStore(id, make(chan struct{})); !ok {
            fetched = pc.fetchPost(id)
            pc.Store(id, fetched)
            close(p.(chan struct{}))
            return fetched
        }
    }

    if cached, ok := p.(Post); ok {
        fmt.Printf("Using cached post %v\n", id)
        return cached
    }

    fmt.Printf("Wating for cached post %v\n", id)
    <-p.(chan struct{})
    return pc.Fetch(id)
}

推荐阅读