首页 > 解决方案 > 为什么马尔可夫链会重演?

问题描述

我正在尝试使用 Python 中的类来模拟马尔可夫链。这是我的代码:

import random
...
class Chain:
  def __init__(self, probabilities, start):
    self.probs = probabilities
    self.start = start
    self.names = list(self.probs.keys())

  def __iter__(self):
    self.pos = self.start
    return self

  def __next__(self):
    self.random_num = random.randrange(100)

    prob_l = self.probs[self.pos]
    for ind, prob in enumerate(prob_l):
      self.prob_sum += prob
      if self.random_num < self.prob_sum:
        exclude_names = self.names[:ind] + self.names[ind + 1 :]
        self.prob_sum = 0
        self.pos = exclude_names[ind]
        return self.pos
    return self.pos


chain = Chain({"A": [50, 25], "B": [50, 25], "C": [50, 50]}, "A")
chain_iter = iter(chain)
for k in range(100):
    print(next(chain_iter))

它按预期工作,但有时会重复字母 C。由于字典中有两个 50,它应该有 50/50 的机会去 A 或 B。它不应该重复。

标签: pythonclass

解决方案


您在哪里看到重复的 C,因为当前节点的索引没有被正确计算。

这是带有注释的更新代码:

class Chain:
  def __init__(self, probabilities, start):  # only called once
    self.probs = probabilities
    self.start = start
    self.names = list(self.probs.keys())
    
  def __iter__(self):  # only called once
    return self
    
  def __next__(self):  # each iteration
    import random
    self.pos = self.start  # next step
    self.random_num = random.randrange(100) # choice percentile must be in here
    i = self.names.index(self.pos) # get index of this node in big list
    prob_l = self.probs[self.pos]  # get probs
    self.prob_sum = 0  # start prob scan
    for ind, prob in enumerate(prob_l): # probs of going to another node
      self.prob_sum += prob  # until 100%
      if self.random_num < self.prob_sum:  # passed percentile, go to another node
        exclude_names = self.names[:i] + self.names[i + 1:]  # big list without this node
        self.start = exclude_names[ind]  # for next iteration
        break   # found percentile in probs
    return self.pos # add current pos to chain
    
chain = Chain({"A": [50, 25], "B": [50, 25], "C": [50, 50]}, "A")
chain_iter = iter(chain)
for k in range(100):
   print(next(chain_iter), end=" ")

输出(包装)

A B A A B B C B A B A C A B C B A C B A B C B B C 
B A B C B A C A B A A B C B B A B B A B A C B C B 
B C A B A C A C B A C A B A B A C A B B C B A B A 
C A C A C B A A C B A A B B B A A A C A A B A B B

推荐阅读