KMP 算法原理解析,公式推导,代码实现
KMP 算法原理解析,公式推导,代码实现
一、 简介
KMP 算法是一个解决模式串在文本串中是否出现过的问题,如果出现过则返回最早出现的位置。
KMP 的时间复杂度是O ( n + m ) O(n + m)O(n+m)一趟扫描即可。文本串无需回溯。
KMP 方法采用最长公共前后缀和的方法避免了文本串的回溯,降低了时间复杂度。
二、 最长公共前后缀 && 部分匹配表 next
数组
1. 字符串的前缀:
字符串的前缀是指不包含最后一个字符的所有以第一个字符(索引为0)开头的连续子串。
比如"ABABA" 的前缀为:
A
、
AB
、
ABA
、
ABAB
2. 字符串的后缀
字符串的后缀是指不包含第一个字符的所有以最后一个字符结尾的连续子串
比如"ABABA" 的后缀为:
BABA
、
ABA
、
BA
、
A
3. 公共前后缀:
字符串的公共前后缀是指前缀和后缀取交集后的字符串
比如"ABABA" 的公共前后缀为:
A
、
ABA
4. 最长公共前后缀
所有公共前后缀中长度最长的那个,上述的示例中就是
ABA
5. 部分匹配表 next
next
是一个数组,长度与给定的字符串相同,next[i]
表示字符串区间[0, i]
的子串的最长公共前后缀的长度值。
以字符串ABABA
为例:
next[0]
:由于子串A
没有公共前后缀,因此next[0] = 0
,
next[1]
:子串AB
的前缀为A
,后缀为B
,交集为空,因此next[1] = 0
next[2]
: 子串ABA
的前缀为A
,AB
,后缀为BA
,A
, 交集为A
, 因此next[2] = 1
nexr[3]
:子串ABAB
的前缀为A
,AB
,ABA
,后缀为BAB
,AB
,B
,交集为AB
,因此next[3] = 2
next[4]
:子串ABABA
的公共前后缀为A
,ABA
,取长度最大ABA
,因此next[4] = 3
6. next
数组的作用
KMP
算法可以避免文本串回溯,就是通过next
数组实现的。有了next
数组,在模式串与文本串失配的时候,不需要回溯文本串的指针i
,只需要回溯匹配串的指针j
即可。
比如模式串为ABABA
, 文本串为ABCBABABA
当进行匹配时,指针会在位置2
处发生失配。
如果是暴力算法,在发生这种情况后,指针i
会回缩到i = 1
,指针j
会回溯到j = 0
。两个指针都需要回溯,总的时间复杂度就变成了O ( n 2 ) O(n^2)O(n2)
而
KMP
算法可以不用回溯i
,只需要将j
指针回溯到j = next[j - 1] (j > 0)
即可。(当j == 0
时,直接右移一步i
指针即可)。
next[2 - 1] = 0
,因此将j
回溯到0
即可。
j
回溯到0后继续比较,这次不太幸运,在j = 0
处就失配了,在模式串头部失配时,只需向右滑动i
即可,i
滑动到3处时还是失配,继续向右滑,滑倒i = 4
处开始与模式串相等,继续比较。
从i = 4
,j = 0
处比较,这次可以顺利比较到末尾。最后j
的值超出模式串的索引范围,比较停止,也是匹配成功的停机条件。
**结论:next
数组的作用就是当发生失配时,告诉指针j
应该回溯到那个位置。而且应该回溯的位置就是j = next[j - 1] (j > 0)
三、 回溯的原理
该算法是利用已经得到的"部分匹配"的结果将模式串向右"滑动"尽可能远的一段距离后,继续进行比较。
1. 从全图视角说起
首先我们要开个全图视角:
当指针在
i = 2
,j = 2
的地方失配时,根据暴力法,i
要回溯到i = 1
, 字符为B
的位置上,但是其实我们已经知道他是B
,而模式串的首字母是’A’ 肯定匹配不上,所以可以直接将模式串向后滑,i
指针不需要回溯。后面
i = 2
,j = 0
处发生失配时,i
更不需要回溯,而是直接右移即可。
结论:其实在文本匹配的时候,文本串的指针根本不用回溯,只需要回溯模式串即可!
注:这里只举个了简单的例子,读者可以自行去观察其他例子是否和结论一样。
2. 公式推导
那么如何确定每次回溯到什么位置呢?也即是当文本串在i
处与模式串在j
出发生失配时,i
应该和模式串中那个位置的字符继续进行比较。又是怎么跟next
数组扯上关系的呢?
假设这个位置为k (k < j)
。因此,模式串区间[0, k - 1]
的子串肯定与文本串中[i - k + 1, i - 1]
的子串相等。且这个k
是最大的那个,也就是在区间[k + 1, j]
的上不会有位置在满足这个等式(贪心的思想,保证模式串能滑的更远):
p 0 p 2 . . . p k − 1 = s i − k + 1 s i − k + 2 . . . s i − 1
其中p i
表示模式串的字符,s i
表示主串的字符。
除了这个等式,我们还可以建立一个等式,那就是在发生失配的时候,失配位置之前的串肯定也是相等的,所以这个等式就是:
p j − k + 1 p j − k + 2 . . . p j − 1 = s i − k + 1 s i − k + 2 . . . s i − 1
注:为了让两个等式中的子串长度相等,第二个等式中只从失配的前一位处往前截了k
位。没有截到模式串的第一位。
通过上述两个等式就可以建立第三个等式:
p 0 p 2 . . . p k − 1 = p j − k + 1 p j − k + 2 . . . p j − 1
3. 谜底揭晓,大吃一惊
这第三个等式一出,大家可以惊奇的发现,其实k
这个位置的确定,跟文本串完全没有关系,只与模式串本身有关!
而这等式的前半部分,不就是模式串[0, j - 1]
区间子串的一个前缀串吗?这等式的后半部分不就是模式串[0, j - 1]
区间子串 的一个后缀串吗?他俩要相等。不就是要找到[0, j - 1]
区间子串的公共前后缀吗?
而且要保证k
是最大的,因此不就是要找到最长的公共前后缀吗?这个k
不也就是这个最长公共前后缀的长度吗?因此next[j - 1]
表示的不仅是[0, j - 1]
区间的最大公共前后缀的长度,也是模式串在j
处失配时要回溯的位置!
数学真的是太妙了!
四、代码实现
1. next
数组实现
要想知道在失配时模式串要回溯的位置,就需要先知道模式串的next
数组。
求解方式:递推+双指针
/**
* 求每个s[i] 位置上的最长公共前后缀的长度记录在 next[i] 中
* @param patternStr
* @return
*/
public static int[] kmpNext(String patternStr){
int n = patternStr.length();
int[] next = new int[n];
// 第一个字符没有公共前后缀,因此 next[0] = 0;
next[0] = 0;
// 记录当前最长公共前后缀长度,在原理推导中可知也是新的前缀要与后缀比较的索引位置
int k = 0;
// 从第二个位置开始遍历求解,也表示要比较的后缀的索引位置。
int j = 1;
while (j < n){
// 比较前后缀,以求next[1]为例, 新的前缀比老的前缀多个p[0],新的后缀比老的后缀多个p[1],
// 所以应该比较 p[0] 跟 p[1] 是否相等。如果相等那么新的最长公共前后缀就比原来长1.
if(patternStr.charAt(k) == patternStr.charAt(j)){
next[j] = ++ k;
j ++;
}else {
// 如果不相等,而且之前的长度是0的话,那么这次新的长度就也是0
if(k == 0){
next[j] = k;
j ++;
}else {
// 如果不是0,那么可以缩小比较前缀的范围,因为在 k这个位置不相同,那么可能在 k 之前的位置上相同
k = next[k - 1];
}
}
}
return next;
}
2. KMP 算法实现
求得了next
数组,那么实现
KMP
算法就很简单,只需要在失配的时候回溯j
即可。
/**
* 匹配到一个就立即停止
* @param textStr 文本串
* @param patternStr 模式串
* @param next
* @return
*/
public static int kmp(String textStr, String patternStr, int[] next){
int n = textStr.length(), m = patternStr.length(), i = 0, j = 0;
while (i < n && j < m){
if(textStr.charAt(i) == patternStr.charAt(j)){
i ++;
j ++;
}else {
if(j == 0){
i ++;
}else {
j = next[j - 1];
}
}
}
// 匹配成立条件:j 越界
return j == m ? i - j : - 1;
}
/**
* 匹配所有可能的串
* @param textStr
* @param patternStr
* @param next
* @return
*/
public static List<Integer> kmpAll(String textStr, String patternStr, int[] next){
int n = textStr.length(), m = patternStr.length(), i = 0, j = 0;
List<Integer> res = new ArrayList<>();
while (i < n && j < m){
if(textStr.charAt(i) == patternStr.charAt(j)){
i ++;
j ++;
}else {
if(j == 0){
i ++;
}else {
j = next[j - 1];
}
}
// 匹配到一个字符串
if(j == m){
res.add(i - j);
j = next[j - 1];
}
}
return res;
}