diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index f22b8a0743f1..2bb96fa99d67 100755 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -877,7 +877,11 @@ std::vector ds_transformer_backward(int layer_id, seq_len = g_output.size(1); layer->SetSeqLength(seq_len); } - + auto options = torch::TensorOptions() + .dtype(g_output.options().dtype()) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(true); auto workspace = torch::empty({get_workspace_size(bsz, seq_len, layer->GetHiddenSize(), @@ -885,7 +889,7 @@ std::vector ds_transformer_backward(int layer_id, layer->GetNumHeads(), layer->IsTrainingMode(), layer->GeluCheckpoint())}, - grad_output.options()); + options); Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); auto grad_input = torch::empty_like(input);