Skip to content

Commit

Permalink
format: const qualifier
Browse files Browse the repository at this point in the history
  • Loading branch information
changlan committed Jun 29, 2019
1 parent f244654 commit 4ecd470
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
51 changes: 26 additions & 25 deletions byteps/common/core_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,27 @@ void FinishOrProceed(std::shared_ptr<TensorTableEntry> task) {
size_t i = task->offset / 4;
size_t j = (task->offset + task->len) / 4 - 1;
if (task->device == CPU_DEVICE_ID) {
BPS_LOG(DEBUG) << "Sampled key=" << task->key
<< " rank=" << BytePSGlobal::GetLocalRank() << " input[0]="
<< *(reinterpret_cast<float *>(task->tensor->data()) + i)
<< "\tinput[-1]="
<< *(reinterpret_cast<float *>(task->tensor->data()) + j)
<< "\toutput[0]="
<< *(reinterpret_cast<float *>(task->output->data()) + i)
<< "\toutput[-1]="
<< *(reinterpret_cast<float *>(task->output->data()) + j)
<< "\t after stage: " << LogStrings[this_op];
BPS_LOG(DEBUG)
<< "Sampled key=" << task->key
<< " rank=" << BytePSGlobal::GetLocalRank() << " input[0]="
<< *(reinterpret_cast<const float *>(task->tensor->data()) + i)
<< "\tinput[-1]="
<< *(reinterpret_cast<const float *>(task->tensor->data()) + j)
<< "\toutput[0]="
<< *(reinterpret_cast<const float *>(task->output->data()) + i)
<< "\toutput[-1]="
<< *(reinterpret_cast<const float *>(task->output->data()) + j)
<< "\t after stage: " << LogStrings[this_op];
} else {
float i0, i1, o0, o1;
cudaMemcpy(&i0, reinterpret_cast<float *>(task->tensor->data()) + i, 4,
cudaMemcpyDeviceToHost);
cudaMemcpy(&i1, reinterpret_cast<float *>(task->tensor->data()) + j, 4,
cudaMemcpyDeviceToHost);
cudaMemcpy(&o0, reinterpret_cast<float *>(task->output->data()) + i, 4,
cudaMemcpyDeviceToHost);
cudaMemcpy(&o1, reinterpret_cast<float *>(task->output->data()) + j, 4,
cudaMemcpyDeviceToHost);
cudaMemcpy(&i0, reinterpret_cast<const float *>(task->tensor->data()) + i,
4, cudaMemcpyDeviceToHost);
cudaMemcpy(&i1, reinterpret_cast<const float *>(task->tensor->data()) + j,
4, cudaMemcpyDeviceToHost);
cudaMemcpy(&o0, reinterpret_cast<const float *>(task->output->data()) + i,
4, cudaMemcpyDeviceToHost);
cudaMemcpy(&o1, reinterpret_cast<const float *>(task->output->data()) + j,
4, cudaMemcpyDeviceToHost);
BPS_LOG(DEBUG) << "Sampled key=" << task->key
<< " rank=" << BytePSGlobal::GetLocalRank()
<< " input[0]=" << i0 << "\tinput[-1]=" << i1
Expand Down Expand Up @@ -146,9 +147,9 @@ inline void PostNcclCalls(
auto len = task->len;
auto offset = task->offset;
auto unit_len = tensor->size() / tensor->shape().num_elements();
auto p = reinterpret_cast<char *>(tensor->data()) + offset;
auto p = (char *)(tensor->data()) + offset;
if (task->device == CPU_DEVICE_ID) {
p = reinterpret_cast<char *>(task->gpu_ptr) + offset;
p = (char *)(task->gpu_ptr) + offset;
}

auto nccl_dtype = getNcclDataType(tensor->dtype());
Expand All @@ -171,7 +172,7 @@ inline void PostNcclCalls(

if (this_op == REDUCE) {
// We reduce to task->output except that it is a CPU tensor
auto out_p = reinterpret_cast<char *>(task->output->data()) + offset;
auto out_p = (char *)(task->output->data()) + offset;
if (task->device == CPU_DEVICE_ID && task->tensor == task->output) {
out_p = p;
}
Expand Down Expand Up @@ -333,9 +334,9 @@ bool RunCopyDevice2HostLoopOnce() {

auto len = task->len;
auto offset = task->offset;
auto p = reinterpret_cast<char *>(tensor->data()) + offset;
auto p = reinterpret_cast<const char *>(tensor->data()) + offset;
if (task->device == CPU_DEVICE_ID) {
p = reinterpret_cast<char *>(task->gpu_ptr) + offset;
p = reinterpret_cast<const char *>(task->gpu_ptr) + offset;
}
auto unit_len = tensor->size() / tensor->shape().num_elements();
char *cpubuff;
Expand Down Expand Up @@ -518,9 +519,9 @@ void CopyHost2Device(std::shared_ptr<byteps::common::TensorTableEntry> task) {
BPS_CHECK(cpubuff) << task->tensor_name
<< ": CPU buffer not initialized, size=" << len;

auto gpu_addr = reinterpret_cast<char *>(tensor->data()) + offset;
auto gpu_addr = (char *)(tensor->data()) + offset;
if (task->device == CPU_DEVICE_ID) {
gpu_addr = reinterpret_cast<char *>(task->gpu_ptr) + offset;
gpu_addr = (char *)(task->gpu_ptr) + offset;
}

auto unit_len = tensor->size() / tensor->shape().num_elements();
Expand Down
6 changes: 3 additions & 3 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::shared_ptr<CpuReducer> BytePSGlobal::_cpu_reducer;
uint64_t BytePSGlobal::_sample_key = std::numeric_limits<uint64_t>::max();

BytePSScheduledQueue* BytePSGlobal::GetScheduledQueue(QueueType queueType) {
return reinterpret_cast<BytePSScheduledQueue*>(_queues[queueType]);
return (BytePSScheduledQueue*)(_queues[queueType]);
}

void BytePSGlobal::CreateScheduledQueue(QueueType queueType) {
Expand Down Expand Up @@ -258,9 +258,9 @@ void BytePSGlobal::Init() {

// Create CUDA streams for GPU-CPU copies
_copy_host2device_stream =
reinterpret_case<cudaStream_t*>(std::malloc(sizeof(cudaStream_t) * 1));
reinterpret_cast<cudaStream_t*>(std::malloc(sizeof(cudaStream_t) * 1));
_copy_device2host_stream =
reinterpret_case<cudaStream_t*>(std::malloc(sizeof(cudaStream_t) * 1));
reinterpret_cast<cudaStream_t*>(std::malloc(sizeof(cudaStream_t) * 1));
CUDA_CALL(cudaStreamCreateWithFlags(_copy_host2device_stream,
cudaStreamNonBlocking));
CUDA_CALL(cudaStreamCreateWithFlags(_copy_device2host_stream,
Expand Down

0 comments on commit 4ecd470

Please sign in to comment.