向前向后算法forwardbackward algorithm.docx
《向前向后算法forwardbackward algorithm.docx》由会员分享,可在线阅读,更多相关《向前向后算法forwardbackward algorithm.docx(19页珍藏版)》请在冰豆网上搜索。
向前向后算法forwardbackwardalgorithm
向前-向后算法(forward-backwardalgorithm)
本文承接上篇博客《隐马尔可夫模型及的评估和解码问题》,用到的概念和例子都是那里面的。
学习问题
在HMM模型中,已知隐藏状态的集合S,观察值的集合O,以及一个观察序列(o1,o2,...,on),求使得该观察序列出现的可能性最大的模型参数(包括初始状态概率矩阵π,状态转移矩阵A,发射矩阵B)。
这正好就是EM算法要求解的问题:
已知一系列的观察值X,在隐含变量Y未知的情况下求最佳参数θ*,使得:
在中文词性标注里,根据为训练语料,我们观察到了一系列的词(对应EM中的X),如果每个词的词性(即隐藏状态)也是知道的,那它就不需要用EM来求模型参数θ了,因为Y是已知的,不存在隐含变量了。
当没有隐含变量时,直接用maximumlikelihood就可以把模型参数求出来。
预备知识
首先你得对下面的公式表示认同。
以下都是针对相互独立的事件,
P(A,B)=P(B|A)*P(A)
P(A,B,C)=P(C)*P(A,B|C)=P(A,C|B)*P(B)=P(B,C|A)*P(A)
P(A,B,C,D)=P(D)*P(A,B|D)*P(C|A)=P(D)*P(A,B|D)*P(C|B)
P(A,B|C)=P(D1,A,B|C)+P(D2,A,B|C) D1,D2是事件D的一个全划分
理解了上面几个式子,你也就能理解本文中出现的公式是怎么推导出来的了。
EM算法求解
我们已经知道如果隐含变量Y是已知的,那么求解模型参数直接利用MaximumLikelihood就可以了。
EM算法的基本思路是:
随机初始化一组参数θ(0),根据后验概率Pr(Y|X;θ)来更新Y的期望E(Y),然后用E(Y)代替Y求出新的模型参数θ
(1)。
如此迭代直到θ趋于稳定。
在HMM问题中,隐含变量自然就是状态变量,要求状态变量的期望值,其实就是求时刻ti观察到xi时处于状态si的概率,为了求此概率,需要用到向前变量和向后变量。
向前变量
向前变量 是假定的参数
它表示t时刻满足状态
,且t时刻之前(包括t时刻)满足给定的观测序列
的概率。
1.令初始值
2.归纳法计算
3.最后计算
复杂度
向后变量
向后变量
它表示在时刻t出现状态
,且t时刻以后的观察序列满足
的概率。
1.初始值
2.归纳计算
E-Step
定义变量
为t时刻处于状态i,t+1时刻处于状态j的概率。
定义变量
表示t时刻呈现状态i的概率。
实际上
是从其他所有状态转移到状态i的次数的期望值。
是从状态i转移出去的次数的期望值。
是从状态i转移到状态j的次数的期望值。
M-Step
是在初始时刻出现状态i的频率的期望值,
是从状态i转移到状态j的次数的期望值 除以 从状态i转移出去的次数的期望值,
是在状态j下观察到活动为k的次数的期望值 除以 从其他所有状态转移到状态j的次数的期望值,
然后用新的参数
再来计算向前变量、向后变量、
和
。
如此循环迭代,直到前后两次参数的变化量小于某个值为止。
下面给出我的java代码:
1importjava.io.BufferedReader;
2importjava.io.File;
3importjava.io.FileReader;
4importjava.io.IOException;
5importjava.util.Arrays;
6importjava.util.HashMap;
7importjava.util.LinkedList;
8importjava.util.List;
9importjava.util.Map;
10importjava.util.Map.Entry;
11
12/**
13*隐马尔可夫模型参数学习。
14*
15*@Author:
zhangchaoyang
16*@Since:
2015年4月4日
17*@Version:
1.0
18*/
19publicclassHmmLearn{
20
21privateintstateCount;//状态的个数
22privateMapobserveIndexMap=newHashMap();//观察值及其索引编号
23/**
24*通过学习得到的模型参数
25*/
26privatedouble[]stateProb;//初始状态概率矩阵
27privatedouble[][]stateTrans;//状态转移矩阵
28privatedouble[][]emission;//混淆矩阵
29
30privateListobserveSeqs=newLinkedList();//训练集中所有的观察序列
31
32/**
33*迭代终止条件
34*/
35privatefinalintITERATION_MAX=100;
36privatefinaldoubleDELTA_PI=1E-3;
37privatefinaldoubleDELTA_A=1E-2;
38privatefinaldoubleDELTA_B=1E-2;
39
40/**
41*
42*@paramstateCount
43*指定状态取值有多少种
44*@paramobserveFile
45*存储观察序列的文件,各个观察序列用空白符或换行符隔开即可
46*@throwsIOException
47*/
48publicvoidinitParam(intstateCount,StringobserveFile)
49throwsIOException{
50this.stateCount=stateCount;
51intobserveCount=0;
52BufferedReaderbr=newBufferedReader(newFileReader(newFile(
53observeFile)));
54Stringline=null;
55while((line=br.readLine())!
=null){
56String[]arr=line.split("\\s+");
57for(Stringseq:
arr){
58if(seq.length()>1){//长度为1的观察序列必须过滤掉,不然在更新stateTrans时会出现NaN的情况
59observeSeqs.add(seq);
60for(inti=0;i61Stringobserve=seq.substring(i,i+1);
62if(!
observeIndexMap.containsKey(observe)){
63observeIndexMap.put(observe,observeCount++);
64}
65}
66}
67}
68}
69br.close();
70
71stateProb=newdouble[stateCount];
72initWeightRandomly(stateProb,1E5);
73//initWeightEqually(stateProb);
74stateTrans=newdouble[stateCount][];
75for(inti=0;i76stateTrans[i]=newdouble[stateCount];
77initWeightRandomly(stateTrans[i],1E5);
78//initWeightEqually(stateTrans[i]);
79}
80emission=newdouble[stateCount][];
81for(inti=0;i82emission[i]=newdouble[observeCount];
83initWeightRandomly(emission[i],1E9);
84//initWeightEqually(emission[i]);
85}
86}
87
88/**
89*随机地初始化权重,使得各权重非负,且和为1.
90*
91*@paramarr
92*@paramprecision
93*/
94publicvoidinitWeightRandomly(double[]arr,doubleprecision){
95intlen=arr.length-1;
96int[]position=newint[len];
97for(inti=0;i98position[i]=(int)(Math.random()*precision);
99}
100Arrays.sort(position);
101intpre=0;
102for(inti=0;i103arr[i]=1.0*(position[i]-pre)/precision;
104pre=position[i];
105}
106arr[len]=1.0*(precision-pre)/precision;
107}
108
109/**
110*均等地初始化权重,使得各权重非负,且和为1.
111*
112*@paramarr
113*/
114publicvoidinitWeightEqually(double[]arr){
115intlen=arr.length;
116for(inti=0;i117arr[i]=1.0/len;
118}
119}
120
121/**
122*BaumWelch算法学习HMM的模型参数
123*/
124publicvoidbaumWelch(){
125longbegin=System.currentTimeMillis();
126intiter=0;
127while(iter++128double[]stateProb_new=newdouble[stateCount];
129double[][]stateTrans_new=newdouble[stateCount][];
130double[][]emission_new=newdouble[stateCount][];
131for(inti=0;i132