@@ -69,23 +69,17 @@ def forward(self, inputs):
69
69
sent_out = self .post_linear (sent_out )
70
70
71
71
elif self .opt .lcf == 'fusion' :
72
- # cdw_sent_out = self.CDW_LSA(global_context_features,
73
- # spc_mask_vec=spc_mask_vec,
74
- # lcf_matrix=lcf_cdw_matrix,
75
- # left_lcf_matrix=left_lcf_cdw_matrix,
76
- # right_lcf_matrix=right_lcf_cdw_matrix)
77
- # cdm_sent_out = self.CDM_LSA(global_context_features,
78
- # spc_mask_vec=spc_mask_vec,
79
- # lcf_matrix=lcf_cdm_matrix,
80
- # left_lcf_matrix=left_lcf_cdm_matrix,
81
- # right_lcf_matrix=right_lcf_cdm_matrix)
82
- # sent_out = self.fusion_linear(torch.cat((global_context_features, cdw_sent_out, cdm_sent_out), -1))
83
- sent_out = self .CDW_LSA (global_context_features ,
84
- spc_mask_vec = spc_mask_vec ,
85
- lcf_matrix = lcf_cdw_matrix ,
86
- left_lcf_matrix = left_lcf_cdm_matrix ,
87
- right_lcf_matrix = right_lcf_cdm_matrix )
88
- sent_out = torch .cat ((global_context_features , sent_out ), - 1 )
72
+ cdw_sent_out = self .CDW_LSA (global_context_features ,
73
+ spc_mask_vec = spc_mask_vec ,
74
+ lcf_matrix = lcf_cdw_matrix ,
75
+ left_lcf_matrix = left_lcf_cdw_matrix ,
76
+ right_lcf_matrix = right_lcf_cdw_matrix )
77
+ cdm_sent_out = self .CDM_LSA (global_context_features ,
78
+ spc_mask_vec = spc_mask_vec ,
79
+ lcf_matrix = lcf_cdm_matrix ,
80
+ left_lcf_matrix = left_lcf_cdm_matrix ,
81
+ right_lcf_matrix = right_lcf_cdm_matrix )
82
+ sent_out = self .fusion_linear (torch .cat ((global_context_features , cdw_sent_out , cdm_sent_out ), - 1 ))
89
83
sent_out = self .post_linear (sent_out )
90
84
91
85
else :
0 commit comments