本文会涉及 KMP 算法的大体思想和例题应用。具体的工作原理,可以查阅参考资料。

算法思想

KMP算法用于字符串的匹配,相比暴力求解的 O(mn)O(mn) 来说可以优化为,O(m+n)O(m+n)

算法主要思想是,使用之前已有的匹配信息来简化计算复杂度。整个算法分为两步:

  1. 对于给定的模式字符串p,计算next数组(next[i] 表示前 p[:i] 个元素中前缀和后缀相同的最大长度。)

我们可以动态规划的求解这个数组:对于 i 来说我们如果知道 next[i-1],就可以用 k=next[i1]k=\text{next}[i-1] 的坐标 kk 的数值来判断是否和 p[i] 相等,否则继续跳转到 next[k1]\text{next}[k-1]。我们的数组是从 0 开始计数的。

  1. 根据 next 数组,参照上述思路,在匹配串 s 和模式串 p 上进行匹配。

模板

判断 needle 是否为 haystack 子串并返回第一个匹配的下标的模板:

def KMP(haystack:str, needle:str)->int:
    """
    haystack为匹配串,needle为模式串
    如果needle存在于haystack中,返回第一个匹配的下标,否则返回-1。
    """
    n, m = len(haystack), len(needle)

    nxt = [0] * m
    j = 0
    for i in range(1, m):
        while j > 0 and needle[i] != needle[j]:
            j = nxt[j-1]
        if needle[i] == needle[j]:
            j += 1
        nxt[i] = j

    pos = 0
    for i in range(0, n):
        while pos >0 and needle[pos] != haystack[i]:
            pos = nxt[pos-1]
        
        if needle[pos] == haystack[i]:
            pos += 1
        if pos == m:
            return i - m + 1
    return -1

例题

参考代码
class Solution:
    def isFlipedString(self, s1: str, s2: str) -> bool:
        if not s1 and not s2:
            return True
        
        if not s1 or not s2:
            return False
        
        n, m = len(s1), len(s2)
        if n != m:
            return False
        
        target = s1 + s1

        # 构建nxt数组
        nxt = [0] * n
        j = 0
        for i in range(1, n):
            while j > 0 and s2[j] != s2[i]:
                j = nxt[j-1]
            if s2[j] == s2[i]:
                j += 1
            nxt[i] = j
        
        # 进行查询
        pos = 0
        for i in range(0, 2*n):
            while pos > 0 and s2[pos] != target[i]:
                pos = nxt[pos-1]
            if s2[pos] == target[i]:
                pos += 1
            if pos == n:
                return True

        return False
参考代码
class Solution:
    def strStr(self, haystack: str, needle: str) -> int:
        n, m = len(haystack), len(needle)

        nxt = [0] * m
        j = 0
        for i in range(1, m):
            while j > 0 and needle[i] != needle[j]:
                j = nxt[j-1]
            if needle[i] == needle[j]:
                j += 1
            nxt[i] = j

        pos = 0
        for i in range(0, n):
            while pos >0 and needle[pos] != haystack[i]:
                pos = nxt[pos-1]
            
            if needle[pos] == haystack[i]:
                pos += 1
            if pos == m:
                return i - m + 1

        return -1
参考代码
class Solution:
    def maxRepeating(self, sequence: str, word: str) -> int:
        # 构建nxt数组
        m, n = len(sequence), len(word)
        nxt = [0] * n 
        for i in range(1, n):
            j = nxt[i-1]
            while j > 0 and word[j] != word[i]:
                j = nxt[j-1]
            if word[j] == word[i]:
                j += 1
            nxt[i] = j
        
        ans = [0] * m
        j = 0
        for i in range(m):
            while j > 0 and word[j] != sequence[i]:
                j = nxt[j-1]
            if word[j] == sequence[i]:
                j += 1
            if j == n:
                ans[i] = (0 if i == n - 1 else ans[i-n]) + 1
                j = nxt[j-1]
        return max(ans)

参考