viterbi.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import operator
  2. MIN_FLOAT=-3.14e100
  3. def get_top_states(t_state_v,K=4):
  4. items = t_state_v.items()
  5. topK= sorted(items,key=operator.itemgetter(1),reverse=True)[:K]
  6. return [x[0] for x in topK]
  7. def viterbi(obs, states, start_p, trans_p, emit_p):
  8. V = [{}] #tabular
  9. mem_path = [{}]
  10. all_states = trans_p.keys()
  11. for y in states.get(obs[0],all_states): #init
  12. V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT)
  13. mem_path[0][y] = ''
  14. for t in range(1,len(obs)):
  15. V.append({})
  16. mem_path.append({})
  17. prev_states = get_top_states(V[t-1])
  18. prev_states =[ x for x in mem_path[t-1].keys() if len(trans_p[x])>0 ]
  19. prev_states_expect_next = set( (y for x in prev_states for y in trans_p[x].keys() ) )
  20. obs_states = states.get(obs[t],all_states)
  21. obs_states = set(obs_states) & set(prev_states_expect_next)
  22. if len(obs_states)==0: obs_states = all_states
  23. for y in obs_states:
  24. (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_FLOAT) + emit_p[y].get(obs[t],MIN_FLOAT) ,y0) for y0 in prev_states])
  25. V[t][y] =prob
  26. mem_path[t][y] = state
  27. last = [(V[-1][y], y) for y in mem_path[-1].keys() ]
  28. #if len(last)==0:
  29. #print obs
  30. (prob, state) = max(last)
  31. route = [None] * len(obs)
  32. i = len(obs)-1
  33. while i>=0:
  34. route[i] = state
  35. state = mem_path[i][state]
  36. i-=1
  37. return (prob, route)