Skip to content

Commit 2f15704

Browse files
Hang Qufacebook-github-bot
authored andcommitted
Update embedding_forward_quantized_cpu_template.cpp to use initialized output memory instead of uninitialized (#5054)
Summary: X-link: facebookresearch/FBGEMM#2064 We observe, if the memory of output is uninitialized, the output may be garbage. This is because certain memory is untouched. The proposed fix is a quick workaround, but it will be more efficient to directly fill the untouched memory with zero. Reviewed By: sryap Differential Revision: D85447298
1 parent 51210b8 commit 2f15704

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
209209
if (o_dtype == SparseType::INT8) {
210210
total_adjusted_D += T * kINT8QparamsBytes;
211211
}
212-
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
212+
if (!output_is_int4 && !output_is_int8) {
213+
output = at::zeros({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
214+
} else {
215+
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
216+
}
213217
{% else %}
214218
constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
215219
constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements
@@ -219,8 +223,11 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
219223
} else if (o_dtype == SparseType::INT4) {
220224
adjusted_D += kINT4QparamsElems;
221225
}
222-
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
223-
226+
if (!output_is_int4 && !output_is_int8) {
227+
output = at::zeros({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
228+
} else {
229+
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
230+
}
224231
{% endif %}
225232

226233

0 commit comments

Comments
 (0)