@@ -22,38 +22,85 @@ def load_model(model_path, mode):
2222
2323
2424def _ner_predict (
25- text ,
26- model ,
27- tokenizer ,
28- max_len ,
29- device
25+ text ,
26+ model ,
27+ tokenizer ,
28+ max_len ,
29+ device ,
30+ exceed_strategy
3031):
3132 model .to (device )
3233
33- encoding = tokenizer (
34- text ,
35- add_special_tokens = False ,
36- max_length = max_len ,
37- padding = "max_length" ,
38- truncation = True ,
39- return_offsets_mapping = True ,
40- return_tensors = "pt" ,
41- )
42-
43- input_ids = encoding ["input_ids" ].to (device )
44- attention_mask = encoding ["attention_mask" ].to (device )
45- token_type_ids = torch .zeros_like (input_ids ).to (device )
46- offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
47- tokens = tokenizer .convert_ids_to_tokens (input_ids .squeeze ())
34+ pred_label_ids = []
35+
36+ if exceed_strategy == "truncation" :
37+ encoding = tokenizer (
38+ text ,
39+ add_special_tokens = False ,
40+ max_length = max_len ,
41+ padding = "max_length" ,
42+ truncation = True ,
43+ return_offsets_mapping = True ,
44+ return_tensors = "pt" ,
45+ )
4846
49- with torch .no_grad ():
50- outputs = model (
51- input_ids = input_ids ,
52- attention_mask = attention_mask ,
53- token_type_ids = token_type_ids ,
47+ input_ids = encoding ["input_ids" ].to (device )
48+ attention_mask = encoding ["attention_mask" ].to (device )
49+ token_type_ids = torch .zeros_like (input_ids ).to (device )
50+ offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
51+ tokens = tokenizer .convert_ids_to_tokens (input_ids .squeeze ())
52+
53+ with torch .no_grad ():
54+ outputs = model (
55+ input_ids = input_ids ,
56+ attention_mask = attention_mask ,
57+ token_type_ids = token_type_ids ,
58+ )
59+ pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
60+
61+ else : # exceed_strategy == "sliding_window":
62+ full_encoding = tokenizer (
63+ text ,
64+ add_special_tokens = False ,
65+ return_offsets_mapping = True ,
66+ return_tensors = "pt"
5467 )
5568
56- pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
69+ input_ids = full_encoding ["input_ids" ].to (device )
70+ tokens = full_encoding .tokens ()
71+ attention_mask = full_encoding ["attention_mask" ].to (device )
72+ token_type_ids = torch .zeros_like (input_ids ).to (device )
73+ offset_mapping = full_encoding ["offset_mapping" ].squeeze ().tolist ()
74+
75+ if len (tokens ) <= max_len :
76+ with torch .no_grad ():
77+ outputs = model (
78+ input_ids = input_ids ,
79+ attention_mask = attention_mask ,
80+ token_type_ids = token_type_ids ,
81+ )
82+ pred_label_ids = outputs ["pred_label_ids" ].cpu ().numpy ()[0 ]
83+
84+ else :
85+ window_size = max_len
86+ stride = window_size // 2
87+
88+ start_token_idx = 0
89+ while True :
90+ end_token_idx = min (start_token_idx + window_size , len (tokens ))
91+ with torch .no_grad ():
92+ window_pred_label_ids = model (
93+ input_ids = input_ids [start_token_idx :end_token_idx ],
94+ attention_mask = attention_mask [start_token_idx :end_token_idx ],
95+ token_type_ids = token_type_ids [start_token_idx :end_token_idx ],
96+ )["pred_label_ids" ].cpu ().numpy ()[0 ]
97+ if end_token_idx >= len (tokens ):
98+ pred_label_ids .extend (window_pred_label_ids ) # 最后一个窗口全部保留
99+ break
100+ else :
101+ pred_label_ids .extend (window_pred_label_ids [0 : stride ]) # 只保留每个窗口的前 stride 部分
102+
103+ start_token_idx += stride
57104
58105 entities = []
59106 char_labels = ["O" ] * len (text )
@@ -64,10 +111,10 @@ def _ner_predict(
64111 offset = offset_mapping [i ]
65112 if label .startswith ("B-" ):
66113 char_labels [offset [0 ]] = label
67- char_labels [offset [0 ]+ 1 : offset [1 ]] = ["I-" + label [2 :]] * (offset [1 ] - offset [0 ] - 1 )
114+ char_labels [offset [0 ] + 1 : offset [1 ]] = ["I-" + label [2 :]] * (offset [1 ] - offset [0 ] - 1 )
68115 elif label .startswith ("I-" ):
69116 char_labels [offset [0 ]: offset [1 ]] = [label ] * (offset [1 ] - offset [0 ])
70-
117+
71118 # 从 char_labels 中推断实体
72119 i = 0
73120 while i < len (char_labels ):
@@ -76,31 +123,31 @@ def _ner_predict(
76123 start = i
77124 i += 1
78125 while i < len (char_labels ) and (
79- char_labels [i ] == f"I-{ entity_type } "
80- or (char_labels [i ] == f"O" and tokens [i ].startswith ("##" ))
126+ char_labels [i ] == f"I-{ entity_type } "
127+ or (char_labels [i ] == f"O" and tokens [i ].startswith ("##" ))
81128 ):
82129 i += 1
83130 end = i
84- entities .append ({"start" : start ,
85- "end" : end ,
86- "type" : entity_type ,
87- "text" : text [start :end ]}) # 缺少 id
131+ entities .append ({"start" : start ,
132+ "end" : end ,
133+ "type" : entity_type ,
134+ "text" : text [start :end ]}) # 缺少 id
88135 else :
89136 i += 1
90137 return entities
91138
92139
93140def _re_predict (
94- text ,
95- e1 ,
96- e2 ,
97- model ,
98- tokenizer ,
99- max_len ,
100- device
141+ text ,
142+ e1 ,
143+ e2 ,
144+ model ,
145+ tokenizer ,
146+ max_len ,
147+ device
101148):
102149 model .to (device )
103-
150+
104151 encoding = tokenizer (
105152 text ,
106153 add_special_tokens = False ,
@@ -114,10 +161,10 @@ def _re_predict(
114161 attention_mask = encoding ["attention_mask" ].to (device )
115162 token_type_ids = torch .zeros_like (input_ids ).to (device )
116163 offset_mapping = encoding ["offset_mapping" ].squeeze ().tolist ()
117-
164+
118165 e1_mask = _create_entity_mask (input_ids , offset_mapping , e1 [0 ], e1 [1 ])
119166 e2_mask = _create_entity_mask (input_ids , offset_mapping , e2 [0 ], e2 [1 ])
120-
167+
121168 with torch .no_grad ():
122169 outputs = model (
123170 input_ids = input_ids ,
@@ -126,24 +173,25 @@ def _re_predict(
126173 e1_mask = e1_mask ,
127174 e2_mask = e2_mask
128175 )
129-
176+
130177 logits = outputs ["logits" ]
131178 probs = torch .nn .functional .softmax (logits , dim = 1 ).cpu ().numpy ()[0 ]
132-
179+
133180 pred_idx = logits .argmax (dim = 1 ).item ()
134181 relation = id2relation [pred_idx ]
135182 probability = probs [pred_idx ]
136- return {"source" : text [e1 [0 ]:e1 [1 ]],
137- "target" : text [e2 [0 ]:e2 [1 ]],
138- "type" : relation ,
183+ return {"source" : text [e1 [0 ]:e1 [1 ]],
184+ "target" : text [e2 [0 ]:e2 [1 ]],
185+ "type" : relation ,
139186 "probability" : float (probability )} # 缺少 source_id 和 target_id, 多 probability
140187
141188
142189def ner_predict (
143- text : str ,
144- model_path : str = "experimental/scripts/ke/checkpoints/ner/final_model" ,
145- max_len : int = 512 ,
146- device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
190+ text : str ,
191+ model_path : str = "experimental/scripts/ke/checkpoints/ner/final_model" ,
192+ max_len : int = 512 ,
193+ device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
194+ exceed_strategy : str = "truncation"
147195) -> list [dict ]:
148196 """
149197 使用模型进行实体预测
@@ -158,16 +206,17 @@ def ner_predict(
158206 list[dict]: 预测结果列表
159207 """
160208 model , tokenizer = load_model (model_path , "ner" )
161- return _ner_predict (text , model , tokenizer , max_len , device )
162-
209+ return _ner_predict (text , model , tokenizer , max_len , device , exceed_strategy )
210+
163211
164212def re_predict (
165- text : str ,
166- e1_range : tuple [int , int ],
167- e2_range : tuple [int , int ],
168- model_path : str = "experimental/scripts/ke/checkpoints/re/final_model" ,
169- max_len : int = 512 ,
170- device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
213+ text : str ,
214+ e1_range : tuple [int , int ],
215+ e2_range : tuple [int , int ],
216+ model_path : str = "experimental/scripts/ke/checkpoints/re/final_model" ,
217+ max_len : int = 512 ,
218+ device : str = "cuda" if torch .cuda .is_available () else "cpu" ,
219+ exceed_strategy : str = "truncation"
171220) -> dict :
172221 """
173222 使用模型进行关系预测
@@ -184,4 +233,4 @@ def re_predict(
184233 dict: 预测结果
185234 """
186235 model , tokenizer = load_model (model_path , "re" )
187- return _re_predict (text , e1_range , e2_range , model , tokenizer , max_len , device )
236+ return _re_predict (text , e1_range , e2_range , model , tokenizer , max_len , device , exceed_strategy )
0 commit comments