Blame view

tests/test_keywords_query.py 3.73 KB
ceaf6d03   tangwang   召回限定:must条件补充主干词命...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  import hanlp
  from typing import List, Tuple, Dict, Any
  
  class KeywordExtractor:
      """
      基于 HanLP 的名词关键词提取器
      """
      def __init__(self):
          # 加载带位置信息的分词模型(细粒度)
          self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF)
          self.tok.config.output_spans = True   # 启用位置输出
          
          # 加载词性标注模型
          self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL)
          
      def extract_keywords(self, query: str) -> str:
          """
          从查询中提取关键词(名词,长度  2
          
          Args:
              query: 输入文本
              
          Returns:
              拼接后的关键词字符串,非连续词之间自动插入空格
          """
          query = query.strip()
          # 分词结果带位置:[[word, start, end], ...]
          tok_result_with_position = self.tok(query)
          tok_result = [x[0] for x in tok_result_with_position]
          
          # 词性标注
          pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result)))
          
          # 需要忽略的词
          ignore_keywords = ['玩具']
          
          keywords = []
          last_end_pos = 0
          
          for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position):
              if len(word) >= 2 and postag.startswith('N'):
                  if word in ignore_keywords:
                      continue
                  # 如果当前词与上一个词在原文中不连续,插入空格
                  if start_pos != last_end_pos and keywords:
                      keywords.append(" ")
                  keywords.append(word)
                  last_end_pos = end_pos
              # 可选:打印调试信息
              # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}')
          
          return "".join(keywords).strip()
  
  
  # 测试代码
  if __name__ == "__main__":
      extractor = KeywordExtractor()
      
      test_queries = [
          # 中文(保留 9 个代表性查询)
          "2.4G遥控大蛇",
          "充气的篮球",
          "遥控 塑料 飞船 汽车 ",
          "亚克力相框",
          "8寸 搪胶蘑菇钉",
          "7寸娃娃",
          "太空沙套装",
          "脚蹬工程车",
          "捏捏乐钥匙扣",
          
          # 英文(新增)
          "plastic toy car",
          "remote control helicopter",
          "inflatable beach ball",
          "music keychain",
          "sand play set",
          # 常见商品搜索
          "plastic dinosaur toy",
          "wireless bluetooth speaker",
          "4K action camera",
          "stainless steel water bottle",
          "baby stroller with cup holder",
          
          # 疑问式 / 自然语言
          "what is the best smartphone under 500 dollars",
          "how to clean a laptop screen",
          "where can I buy organic coffee beans",
          
          # 含数字、特殊字符
          "USB-C to HDMI adapter 4K",
          "LED strip lights 16.4ft",
          "Nintendo Switch OLED model",
          "iPhone 15 Pro Max case",
          
          # 简短词组
          "gaming mouse",
          "mechanical keyboard",
          "wireless earbuds",
          
          # 长尾词
          "rechargeable AA batteries with charger",
          "foldable picnic blanket waterproof",
          
          # 商品属性组合
          "women's running shoes size 8",
          "men's cotton t-shirt crew neck",
  
              
          # 其他语种(保留原样,用于多语言测试)
          "свет USB с пультом дистанционного управления красочные",  # 俄语
      ]
      
      for q in test_queries:
          keywords = extractor.extract_keywords(q)
          print(f"{q:30} => {keywords}")