Trie树实现前缀自动补全 + AC自动机实现敏感词过滤

Trie树实现前缀自动补全 + AC自动机实现敏感词过滤,第1张

文章目录
      • 背景
      • 扩展
      • AC自动机

背景

最近参与了某业务系统的开发, 需要根据城市的名字简称,找到其官方的完整名称。比如云南的大理,其实其完整的名称是大理白族自治州。可以参考官方的行政区划,点这里。

通常来说,城市的简称,都是其完整名称的前缀。所以任务就转化成了:根据前缀,在一堆字符串中,找出满足条件的字符串。

Trie树可以派上用场,只需要对全国所有城市的完整名称,建一颗Trie树即可。这种前缀补全的功能,也有其他的一些经典应用,比如在命令行下,输入一个命令的前缀,或文件名的前缀,敲下 Tab,能够进行自动补全,也是利用了Trie树这种数据结构。

Trie树的原理参考我之前的这篇文章

代码如下


import java.util.*;

/**
 * @Author yogurtzzz
 * @Date 2022/3/31 15:57
 *
 * 利用 Trie 树实现字符串的存储和快速查找
 *
 * 该类的功能是, 给定一个字符串数组 list , 给定一个字符串 s
 * 返回在 list 中所有前缀为 s 的字符串
 *
 * 简单地说, 该类的功能是做字符串前缀匹配, 或者自动补全
 *
 * 比如, 输入 "内蒙", 能够查找到完整的名称为 "内蒙古自治区"
 **/
public class TrieTree {

	private Node root = new Node('0');

	/**
	 * 添加一个字符串到 [集合] 中
	 * **/
	public void addString(String s) {
		Node cur = root;
		for (int i = 0; i < s.length(); i++) {
			char c = s.charAt(i);
			cur = cur.getOrCreateSon(c);
		}
		cur.mark = true; // 在末尾打上标记
	}

	/**
	 * 返回 [集合] 中所有以 @param prefix 为前缀的字符串
	 * **/
	public List<String> findStringFuzzy(String prefix) {
		List<String> ans = new ArrayList<>();
		char[] chars = prefix.toCharArray();
		find(root, chars, 0, ans);
		return ans;
	}

	/**
	 * 查找
	 * **/
	private void find(Node cur, char[] prefix, int index, List<String> ans) {
		int n = prefix.length;
		if (index == n && cur != null) dfsAns(cur, prefix, ans); // 前缀匹配完成, 开始搜寻答案
		else {
			if (cur == null || cur.son == null || !cur.son.containsKey(prefix[index])) return;
			find(cur.son.get(prefix[index]), prefix, index + 1, ans);
		}
	}


	/**
	 * 返回这个节点下所有可能的字符串
	 * **/
	private void dfsAns(Node node, char[] prefix, List<String> ans) {
		StringBuilder sb = new StringBuilder(new String(prefix));
		if (node.mark) ans.add(sb.toString());
		if (node.son != null) {
			for (Node s : node.son.values()) dfsFinal(s, sb, ans);
		}
	}


	private void dfsFinal(Node cur, StringBuilder sb, List<String> ans) {
		sb.append(cur.c);
		if (cur.mark) ans.add(sb.toString());
		if (cur.son != null) {
			for (Node n : cur.son.values()) dfsFinal(n, sb, ans);
		}
		sb.deleteCharAt(sb.length() - 1); // 深搜恢复现场
	}

	static class Node {

		char c;

		boolean mark; // 标记是否为终点

		Map<Character, Node> son; // 子节点

		Node(char c) {
			this.c = c;
			mark = false;
		}

		void addSon(char c) {
			if (son == null) son = new HashMap<>();
			son.put(c, new Node(c));
		}

		Node getOrCreateSon(char c) {
			boolean isSon = son != null && son.containsKey(c);
			if (!isSon) addSon(c);
			return son.get(c);
		}
	}


	/**
	 * 测试代码
	 * **/
	public static void main(String[] args) {
		List<String> list = Arrays.asList("海南市", "海北市", "辽宁哈哈市", "辽宁市", "辽宁哈嘻嘻市", "绵阳市", "绵花市", "绵阳咩咩市", "内蒙古自治区");
		TrieTree trie = new TrieTree();
		list.forEach(trie::addString);
		Scanner scanner = new Scanner(System.in);
		String line = null;
		while (!(line = scanner.nextLine()).equals("quit")) {
			List<String> result = trie.findStringFuzzy(line);
			result.forEach(s -> System.out.printf("%s,", s));
			System.out.println();
		}
	}
}
扩展

上面的功能完成后,我回顾了一下,实现的就是一个前缀模糊查找。(查找有固定前缀的字符串)
在 mysql 中就相当于 like 's%',但是像后缀 like '%s' 和 中缀 like '%s%' 又要怎么实现呢?

对于后缀模糊查找,容易想到的一个方法是,对所有字符串,按照字符逆序建一颗 Trie 树即可(可以理解为先把字符串反转一下,再建 Trie)

但是对于 %s% 好像就无能为力了。

查了很多的资料,发现一个比较接近的解决方案是:AC自动机。但是这个方案还是不太适用于 %s% 这种场景。

我们知道,在字符串的模式匹配中,一般有两种情形:

  • 单模式匹配:给定一个主串s和一个模式串p,查找sp出现的位置,方法有BF(暴力),KMP
  • 多模式匹配:给定一个主串s和多个模式串,查找s中出现的模式串都有哪些,经典的应用场景是敏感词过滤。

但是我们发现,这两种情形都是在一个主串中,根据一个或多个模式串,来查找主串中是否有匹配的部分,且都是精确匹配。

这和我们上面说的情形不太一样。上面的情形是:给定一堆字符串(主串),再给一个模式串(带通配符),在这一堆主串中,查找出所有满足模糊匹配条件的串。

其实是两个不太一样的问题。

由于对于 %s% 这样中缀模糊查找,仍然没找到解决方案,于是便暂时搁置,先研究一下利用AC自动机实现敏感词过滤,也挺有意思的。(有朋友说去研究一下ElasticSearch中的倒排索引的原理,手写一下,就能解决%s% 这个问题,这个留在之后再去做了)

AC自动机

对于在一个主串中查找多个模式串,一个简单粗暴的做法是,将每个模式串拿出来,单独和主串用KMP做匹配。由于KMP的时间复杂度是 O ( m + n ) O(m+n) O(m+n)

关于KMP算法的原理,参考我之前的这篇文章

关于KMP的时间复杂度如何计算,参考思否的这篇文章。

关于KMP在回溯指针j时,如何保证不漏掉正确答案?-> 用反证法证明即可,参考这篇文章

大概是,在KMP的匹配过程中,当匹配到i位置时,若下一个位置不匹配,则j最多回退i-1次。极端的例子是:T=“aaaabaaaab”,P=“aaaaa”。则在对主串进行一次遍历时(假设主串长度为m),则最多会遍历2m次(实际,上界到不了2m),而我们构造next数组时,是对模式串进行了一次同样的遍历匹配 *** 作(假设模式串长度为n),那么构造next数组时最多要2n次,加到一起,就是2m + 2n,所以复杂度是 O ( m + n ) O(m+n) O(m+n)

那么如果对所有模式串,依次与主串做一次KMP,则开销是非常大的。假设主串长度为m,模式串平均长度为n,模式串个数为N。容易算得,这种方式的时间复杂度是 O ( N × ( m + n ) ) O(N×(m+n)) O(N×(m+n))

这样肯定是不行的,那我们来看看用朴素的 Trie 树可以怎么做?我们先对多个模式串,建一颗 Trie 树,然后从主串s的第一个位置开始,在 Trie 树中进行查找,查找完毕后(匹配到或者没匹配到),则把主串中的指针i往后移一位,从第二个位置开始查找,再从第三个位置进行查找。如此以来就能找到s中包含的全部模式串。假设模式串的平均长度为n,主串长度为m,最坏情况下,在主串的每一个位置都要进行n次匹配,则时间复杂度是 O ( m × n ) O(m × n) O(m×n)

这种效率也无法达到我们的需求。

此时轮到AC自动机登场了。

多模式匹配中,朴素的 Trie 和 AC自动机的关系;就和单模式匹配中,BF法 和 KMP 的关系一样。

AC自动机仅仅是在 Trie 树上运用了 KMP 的思想,增加了一个类似 next 数组的东西,叫做 fail 指针。

整个匹配的过程,主串只会被扫描一次,而不会不断回退,当在主串的某个位置匹配失败后,根据 Trie 树上该节点的父节点的 fail 指针,找到 Trie 树上下一个要匹配的节点,继续进行匹配即可。

可以这样简单的理解fail指针,假设Trie树上的一个节点p,其到根节点root,构成的字符串为abc,节点pfail指针,指向节点q,而节点qroot节点,构成的字符串是bc,这两个节点代表的字符串,存在一个公共的前后缀(这和KMP中next数组的含义一样)。
在AC自动机匹配时,假设主串为abcd,则在匹配第四个位置d时,Trie树上是到节点pc,此时匹配失败,此时查看pc节点的父节点pfail指针为q,则继续查看q的子节点中有没有字符为d的,发现有,则成功匹配到模式串bcd

(图片来源极客时间)

由于每个节点的fail指针,都一定依赖于其上层的节点,那么我们构造fail指针的时候,需要从根节点往下,一层一层构造,所以需要使用层序遍历BFS。根节点rootfail指针为null,第一层节点的fail指针为root。因为第一个位置不存在公共前后缀,需要从头开始匹配。(因为公共前后缀的长度一定要小于当前长度,才能构成公共前后缀,至少从第二个位置开始,才可能存在公共前后缀)。
比如a,是没有公共前后缀的,aa的公共前后缀长度为1。

我们用BFS,每次处理一个节点,并为这个节点的所有子节点,填充fail指针。

假设当前节点为p,它有一个子节点s,那么填充sfail指针的过程如下:

  • p.fail,看这个节点的子节点中,有没有和s节点字符相同的,若有,假设为s',则填充s.fail = s’
  • p.fail的子节点中,没有和s字符相同的节点,则更新p = p.fail,继续上面的判断,直到p = null
  • p = null,说明在root节点的子节点中,都没有发现和s节点字符相同的节点(才会在p = p.fail时将p更新为null),那么在s节点就不存在公共前后缀,需要从头匹配,所以此时s.fail = root

代码如下:

import java.util.*;

/**
 * @Author yogurtzzz
 * @Date 2022/4/26 10:30
 *
 * AC 自动机
 **/
public class AcMachine {

	private AcNode root;

	/**
	 * 用一个敏感词集合构建一个 AC 自动机
	 * **/
	public AcMachine(Set<String> words) {
		root = new AcNode('0');
		buildTrieTree(words); // 先把 Trie 树建起来
		fillFailPointer(); // 再填充每个节点的 fail 指针
	}

	/**
	 * 进行敏感词过滤
	 * @param s 原字符串
	 * @return 脱敏后的字符串   
	 * **/
	public String filterWithSensitiveWords(String s) {
		char[] cs = s.toCharArray();
		List<Integer> begins = new ArrayList<>();
		List<Integer> lens = new ArrayList<>();

		// 开始查找
		AcNode cur = root;
		for (int j = 0; j < cs.length; j++) {
			char c = cs[j];
			if (cur.hasChild(c)) {
				cur = cur.getOrCreateChild(c);
				if (cur.len != -1) {
					// 该节点为结束节点
					lens.add(cur.len);
					begins.add(j - cur.len + 1);
				}
			} else {
				while (cur.fail != null) {
					cur = cur.fail;
					if (cur.hasChild(c)) {
						cur = cur.getOrCreateChild(c);
						if (cur.len != -1) {
							lens.add(cur.len);
							begins.add(j - cur.len + 1);
						}
						break;
					}
				}
			}
		}

		// 查找出所有敏感词出现的起始位置和长度后, 对原字符串进行敏感词屏蔽
		StringBuilder sb = new StringBuilder();
		int i = 0;
		for (int j = 0; j < begins.size(); j++) {
			int begin = begins.get(j);
			int len = lens.get(j);
			for (; i < begin; i++) sb.append(cs[i]);
			for (; i < begin + len; i++) sb.append('*');
		}
		while (i < cs.length) {
			sb.append(cs[i]);
			i++;
		}
		return sb.toString();
	}

	private void buildTrieTree(Set<String> words) {
		words.forEach(this::addWord);
	}

	private void addWord(String word) {
		AcNode p = root;
		char[] cs = word.toCharArray();
		for (char c : cs) {
			p = p.getOrCreateChild(c);
		}
		p.len = cs.length;
	}

	private void fillFailPointer() {
		// BFS 层序遍历填充 fail 指针
		Queue<AcNode> q = new LinkedList<>();
		q.offer(root);
		while (!q.isEmpty()) {
			AcNode x = q.poll();
			if (x.children == null) continue;
			// 处理当前节点的子节点的fail指针
			x.children.values().forEach(node -> {
				q.offer(node); // 加入队列
				if (x == root) node.fail = root;
				else {
					AcNode last = x.fail;
					while (last != null) {
						if (last.hasChild(node.c)) {
							node.fail = last.children.get(node.c);
							break;
						} else {
							last = last.fail;
						}
					}
					if (last == null) node.fail = root;
				}
			});
		}
	}

	private static class AcNode {

		private char c;

		private int len = -1; // 如果是结尾节点, 记录长度

		private Map<Character, AcNode> children;

		private AcNode fail;

		AcNode(char c) {
			this.c = c;
		}

		AcNode getOrCreateChild(char c) {
			if (children == null) children = new HashMap<>();
			if (children.containsKey(c)) return children.get(c);
			AcNode newNode = new AcNode(c);
			children.put(c, newNode);
			return newNode;
		}

		boolean hasChild(char c) {
			return children != null && children.containsKey(c);
		}
	}

	public static void main(String[] args) {
		testEn();
		System.out.println();
		testZh();
	}

	private static void testEn() {
		Set<String> wordSet = new HashSet<>(Arrays.asList("she", "bleed", "dog", "doggy", "hurt"));
		AcMachine acMachine = new AcMachine(wordSet);
		String s = "A girl is bitten by a doggy, and she is badly hurt, bleeding along the road";
		System.out.printf("sensitive words : ");
		for (String x : wordSet) System.out.printf("%s ", x);
		System.out.println();
		System.out.println(s);
		String filteredString = acMachine.filterWithSensitiveWords(s);
		System.out.println(filteredString);
	}

	private static void testZh() {
		Set<String> wordSet = new HashSet<>(Arrays.asList("血", "政府", "暴力"));
		AcMachine acMachine = new AcMachine(wordSet);
		System.out.printf("sensitive words : ");
		for (String x : wordSet) System.out.printf("%s ", x);
		System.out.println();
		String s = "美国政府呼吁民众不要暴力, 因为暴力会造成流血";
		System.out.println(s);
		System.out.println(acMachine.filterWithSensitiveWords(s));
	}
}

关于AC自动机的原理,参考极客时间的这篇文章,以及这篇文章

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/737742.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-28
下一篇 2022-04-28

发表评论

登录后才能评论

评论列表(0条)

保存