Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion core/lowering/register_trt_placeholder_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ RegisterOperators trt_placeholder_ops_reg({
[](Stack& stack) {
auto attn_mask = pop(stack).to<at::Tensor>();
if (attn_mask.scalar_type() == at::kBool) {
attn_mask = attn_mask;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is unnecessary and is very confusing.

attn_mask.masked_fill_(attn_mask.logical_not(), -std::numeric_limits<float>::infinity());
}
return attn_mask;
Expand Down
1 change: 1 addition & 0 deletions cpp/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ TensorFormat::TensorFormat(at::MemoryFormat t) {
switch (t) {
case at::MemoryFormat::ChannelsLast:
value = TensorFormat::kChannelsLast;
break;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one fixes an oversight of the implementation.

case at::MemoryFormat::Contiguous:
default:
value = TensorFormat::kContiguous;
Expand Down
Loading