From d9a1cbcd73441942e9a347b26875c126aeee630c Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Tue, 21 Nov 2023 12:58:26 +0800 Subject: [PATCH] [feat][AutoParallel] Visualize flow parallel timing diagram in static graph mode (#58313) * merge from openvino master * add InterpreterRunTime() to record interpreter's run time * add profiler helper static to produce json file * add color map and support perfetto format * recover codes * control include env for gpu_timer.h * fix logic for profiler_helper_static.py * fix build error * fix build error * recover thirdparty * add flag control: not support new ir now * set auto_parallel_profiler flag to false * fix * add auto_parallel_profiler as command parameter * fix value name * support gettimeofday for win env * fix win build error * fix win build error * use job_type_to_id * Fixed repeatedly timing the same stream * add step line for timeline * add step timeline and fix logic when job overlap * update time record logic * fix bug when start profile start from none zero step * fix note * remove FLAGS_auto_parallel_profiler * use run config instead FLAGS_auto_parallelxx * fix color map logic * fix color map logic * fix bug when log step does not start from 0 * fix * fix * don't use set_enable_auto_parallel_profiler * fix bug * disable auto_parallel_profiler when not open flag by command line * fix bug * remove resettime * fix build bug * fix * remove set enable * fix build error * fix build error * fix build error * fix ci error * fix * fix run error * fix * fix * fix calculate_stream_timer logic * remove fluid head * fix build error * set default value for enable_job_schedule_profiler --- .../new_executor/interpreter_base_impl.h | 6 +- .../framework/new_executor/interpretercore.cc | 9 +- .../framework/new_executor/interpretercore.h | 5 +- .../framework/new_executor/pir_interpreter.cc | 8 +- .../framework/new_executor/pir_interpreter.h | 8 +- .../new_executor/program_interpreter.cc | 53 ++++- .../new_executor/program_interpreter.h | 18 +- .../new_executor/standalone_executor.cc | 38 +++- .../new_executor/standalone_executor.h | 4 +- paddle/fluid/pybind/pybind.cc | 6 +- paddle/phi/kernels/autotune/gpu_timer.h | 111 ++++++++++ python/paddle/base/executor.py | 16 +- .../distributed/auto_parallel/constants.py | 1 + .../auto_parallel/static/engine.py | 7 + .../static/profiler_helper_static.py | 198 ++++++++++++++++++ 15 files changed, 467 insertions(+), 21 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/static/profiler_helper_static.py diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index bea61cdeeec84..d23246a784cc8 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -71,7 +71,9 @@ class InterpreterBaseImpl { bool need_fetch = true) = 0; virtual paddle::framework::FetchList Run( - const std::vector& feed_names, bool need_fetch = true) = 0; + const std::vector& feed_names, + bool need_fetch = true, + bool enable_job_schedule_profiler = false) = 0; virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0; @@ -104,6 +106,8 @@ class InterpreterBaseImpl { std::vector* op_func_nodes) = 0; virtual bool IsStaticBuild() const = 0; + + virtual std::tuple InterpreterRunTime() = 0; }; inline void SetDeviceId(const platform::Place& place) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index d7efd510535e8..4f929a709345a 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -71,8 +71,9 @@ FetchList InterpreterCore::Run( } FetchList InterpreterCore::Run(const std::vector& feed_names, - bool need_fetch) { - return impl_->Run(feed_names, need_fetch); + bool need_fetch, + bool enable_job_schedule_profiler) { + return impl_->Run(feed_names, need_fetch, enable_job_schedule_profiler); } void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr src) { @@ -130,5 +131,9 @@ void InterpreterCore::Build( bool InterpreterCore::IsStaticBuild() const { return impl_->IsStaticBuild(); } +std::tuple InterpreterCore::InterpreterRunTime() { + return impl_->InterpreterRunTime(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 022bc0c06f5b2..77ad1b8cbc361 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -51,7 +51,8 @@ class InterpreterCore { bool need_fetch = true); paddle::framework::FetchList Run(const std::vector& feed_names, - bool need_fetch = true); + bool need_fetch = true, + bool enable_job_schedule_profiler = false); void ShareWorkQueueFrom(std::shared_ptr src); @@ -80,6 +81,8 @@ class InterpreterCore { bool IsStaticBuild() const; + std::tuple InterpreterRunTime(); + private: DISABLE_COPY_AND_ASSIGN(InterpreterCore); diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 5d9eacaa077e0..fd4f171eac739 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -284,6 +284,11 @@ void PirInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { << ") to InterpreterCore(" << this << ")"; } +std::tuple PirInterpreter::InterpreterRunTime() { + PADDLE_THROW(platform::errors::Unimplemented( + "PirInterpreter::InterpreterRunTime is not implemented.")); +} + const interpreter::PirDependencyBuilder& PirInterpreter::GetPirDependencyBuilder() const { return ir_dependency_builder_; @@ -1188,7 +1193,8 @@ paddle::framework::FetchList PirInterpreter::Run( } FetchList PirInterpreter::Run(const std::vector& feed_names, - bool need_fetch) { + bool need_fetch, + bool enable_job_schedule_profiler) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h index e75817f5e9393..586a750cbb08e 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -54,13 +54,17 @@ class PirInterpreter : public InterpreterBaseImpl { const std::vector& feed_tensors, bool need_fetch = true) override; - paddle::framework::FetchList Run(const std::vector& feed_names, - bool need_fetch = true) override; + paddle::framework::FetchList Run( + const std::vector& feed_names, + bool need_fetch = true, + bool enable_job_schedule_profiler = false) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; void ShareBuildResultsFrom(const InterpreterBaseImpl& src) override; + std::tuple InterpreterRunTime() override; + std::shared_ptr> GetDependencyCount() const override; bool IsSharedResultsBuild() const override; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 2978e1bf81c41..04d116a53a525 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -54,7 +54,8 @@ ProgramInterpreter::ProgramInterpreter(const platform::Place& place, block_(block), stream_analyzer_(place), execution_config_(execution_config), - var_scope_(scope) { + var_scope_(scope), + enable_job_schedule_profiler_(false) { VLOG(4) << "ProgramInterpreter(): " << this << " on " << place_; exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); @@ -90,6 +91,10 @@ ProgramInterpreter::ProgramInterpreter(const platform::Place& place, }; PrepareForCUDAGraphCapture(); + +#if defined(PADDLE_WITH_CUDA) + calculate_stream_timer_ = std::make_unique(place); +#endif } ProgramInterpreter::~ProgramInterpreter() { @@ -126,6 +131,7 @@ void ProgramInterpreter::RunImpl() { async_work_queue_ = GetWorkQueue(); ExecuteInstructionList(vec_instruction_); } + #ifdef PADDLE_WITH_CUSTOM_DEVICE if (platform::is_custom_place(place_)) { platform::DeviceContextPool::Instance().Get(place_)->Wait(); @@ -134,7 +140,10 @@ void ProgramInterpreter::RunImpl() { } FetchList ProgramInterpreter::Run(const std::vector& feed_names, - bool need_fetch) { + bool need_fetch, + bool enable_job_schedule_profiler) { + enable_job_schedule_profiler_ = enable_job_schedule_profiler; + std::vector op_func_nodes; Build(feed_names, &op_func_nodes); @@ -633,6 +642,15 @@ void ProgramInterpreter::ClearLoDTensorArrayInLocalScope() { } } +std::tuple ProgramInterpreter::InterpreterRunTime() { + double start_time = 0, end_time = 0; +#if defined(PADDLE_WITH_CUDA) + start_time = calculate_stream_timer_->StartTime(); + end_time = calculate_stream_timer_->EndTime(); +#endif + return std::make_tuple(start_time, end_time); +} + void ProgramInterpreter::Convert( std::vector* op_func_nodes) { auto& vec_meta_info = var_scope_.MutableVecMetaInfo(); @@ -1040,6 +1058,15 @@ void ProgramInterpreter::RunInstruction(const Instruction& instr_node) { try { instr_node.WaitEvent(place_); +#if defined(PADDLE_WITH_CUDA) + if (enable_job_schedule_profiler_) { + if (!calculate_stream_timer_->IsStarted() && + !interpreter::IsCommunicationOp(instr_node)) { + VLOG(3) << "Start calculated stream timer from op: " << op->Type(); + calculate_stream_timer_->Start(); + } + } +#endif if (!instr_node.IsArtificial()) { RunOperator(instr_node); @@ -1094,6 +1121,17 @@ void ProgramInterpreter::ExecuteInstructionList( exception_holder_.Clear(); + if (enable_job_schedule_profiler_) { + for (int i = vec_instr.size() - 1; i >= 0; --i) { + auto& instr_node = vec_instr[i]; + if (!interpreter::IsCommunicationOp(instr_node)) { + VLOG(3) << "Last calculated op type: " << instr_node.OpBase()->Type(); + last_calculate_instr_id_ = i; + break; + } + } + } + for (size_t i = 0; i < dependecy_count_->size(); ++i) { if ((*dependecy_count_)[i] == 0) { // NOTE(zhiqiu): hot fix for jit input var @@ -1205,6 +1243,17 @@ void ProgramInterpreter::RunInstructionAsync(size_t instr_id) { RunInstruction(instr_node); +#if defined(PADDLE_WITH_CUDA) + if (enable_job_schedule_profiler_) { + if (instr_id == last_calculate_instr_id_ && + calculate_stream_timer_->IsStarted()) { + VLOG(3) << "Stop calculated stream timer from op: " + << instr_node.OpBase()->Type(); + calculate_stream_timer_->Stop(); + } + } +#endif + if (UNLIKELY(exception_holder_.IsCaught())) { VLOG(4) << "Exception caught"; if (exception_notifier_ != nullptr) { diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index 9c4b8f9bf1c9b..5b8b4dbb36a81 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -16,6 +16,10 @@ #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" +#if defined(PADDLE_WITH_CUDA) +#include "paddle/phi/kernels/autotune/gpu_timer.h" +#endif + namespace paddle { namespace framework { @@ -46,8 +50,10 @@ class ProgramInterpreter : public InterpreterBaseImpl { const std::vector& feed_tensors, bool need_fetch = true) override; - paddle::framework::FetchList Run(const std::vector& feed_names, - bool need_fetch = true) override; + paddle::framework::FetchList Run( + const std::vector& feed_names, + bool need_fetch = true, + bool enable_job_schedule_profiler = false) override; void Build( const std::vector& feed_names, @@ -99,6 +105,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { bool IsStaticBuild() const override { return static_build_; } + std::tuple InterpreterRunTime() override; + private: // build graph void Convert(std::vector* op_func_nodes); @@ -211,6 +219,12 @@ class ProgramInterpreter : public InterpreterBaseImpl { InstructionSchedulingPriorityLess instruction_scheduling_priority_less; std::vector hookfuncs_; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + std::unique_ptr calculate_stream_timer_; +#endif + size_t last_calculate_instr_id_; + bool enable_job_schedule_profiler_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index db1e522ad636f..e2edadd444cd2 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -152,7 +152,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, } paddle::framework::FetchList StandaloneExecutor::Run( - const std::vector& feed_names) { + const std::vector& feed_names, + const bool enable_job_schedule_profiler) { platform::RecordEvent record_event( "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); @@ -190,7 +191,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( VLOG(6) << "Run job (" << job_idx << "), type = " << job_type << ", micro_batch_id =" << job->MicroBatchId(); - // Note(sonder): Share build results don't work for new IR now. + // NOTE(sonder): Share build results don't work for new IR now. if (type_to_first_id.count(job_type) != 0 && !FLAGS_enable_pir_in_executor) { interpretercores_[job_idx]->ShareBuildResultsFrom( @@ -211,13 +212,42 @@ paddle::framework::FetchList StandaloneExecutor::Run( if (jobs.size() > 1 && job_type != "forward") { const std::vector tmp_feed_names = {}; interpretercores_[job_idx]->Run(tmp_feed_names, - /*need_fetch = */ false); + /*need_fetch = */ false, + /*enable_job_schedule_profiler = */ + enable_job_schedule_profiler); } else { - interpretercores_[job_idx]->Run(feed_names, /*need_fetch = */ false); + interpretercores_[job_idx]->Run(feed_names, + /*need_fetch = */ false, + /*enable_job_schedule_profiler = */ + enable_job_schedule_profiler); } } } + // record each job's run time +#if defined(PADDLE_WITH_CUDA) + if (enable_job_schedule_profiler) { + for (size_t job_idx = 0; job_idx < jobs.size(); ++job_idx) { + const auto& job = jobs[job_idx]; + const std::string& job_type = job->Type(); + double start_time, end_time; + std::tie(start_time, end_time) = + interpretercores_[job_idx]->InterpreterRunTime(); + + // Note(sonder): Used to record the runtime of each job in order to + // generate a parallel pipeline timeline. Job runtime information can be + // extracted from the logs using the scripts "profiler_helper_static.py". + // Do not modify, as it may affect the results of regular expression + // matching. + VLOG(0) << "Profiler Info: Job (" << job->MicroBatchId() + << "), type = " << job_type + << ", micro_batch_id = " << job->MicroBatchId() + << ", job_start_time = " << std::to_string(start_time) + << ", job_end_time = " << std::to_string(end_time); + } + } +#endif + // return Fetch Tensors if (FLAGS_enable_pir_in_executor) { framework::FetchList fetch_res; diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index 8feef6e5b2f91..621cbf9335dfc 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -39,7 +39,9 @@ class StandaloneExecutor { ~StandaloneExecutor() {} - paddle::framework::FetchList Run(const std::vector& feed_names); + paddle::framework::FetchList Run( + const std::vector& feed_names, + const bool enable_job_schedule_profiler = false); private: bool is_interpretercore_build_result_shared_{false}; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 87d6a029ccf78..e9877b5325357 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2012,11 +2012,13 @@ All parameter, weight, gradient are variables in Paddle. const interpreter::Plan &, Scope *>()) .def("run", - [](StandaloneExecutor &self, std::vector feed_names) { + [](StandaloneExecutor &self, + std::vector feed_names, + bool enable_job_schedule_profiler = false) { paddle::framework::FetchList ret; { pybind11::gil_scoped_release release; - ret = self.Run(feed_names); + ret = self.Run(feed_names, enable_job_schedule_profiler); } return py::cast(std::move(ret)); }); diff --git a/paddle/phi/kernels/autotune/gpu_timer.h b/paddle/phi/kernels/autotune/gpu_timer.h index 87eca2613a7b5..c50a571a7fd95 100644 --- a/paddle/phi/kernels/autotune/gpu_timer.h +++ b/paddle/phi/kernels/autotune/gpu_timer.h @@ -14,9 +14,15 @@ #pragma once +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" + #ifdef PADDLE_WITH_CUDA #include #endif @@ -26,6 +32,16 @@ namespace phi { +static void CUDART_CB RecordEventTimerCallback(cudaStream_t stream, + cudaError_t status, + void *user_data) { + struct timeval time_now {}; + gettimeofday(&time_now, nullptr); + double *cpu_time = static_cast(user_data); + *cpu_time = (time_now.tv_sec * 1000) + (time_now.tv_usec / 1000.0); + VLOG(3) << "RecordEventCallback: " << std::to_string(*cpu_time); +} + class GpuTimer { public: GpuTimer() { @@ -85,4 +101,99 @@ class GpuTimer { gpuEvent_t stop_; }; +class CalculateStreamTimer { + public: + CalculateStreamTimer() + : calculated_stream_(nullptr), + start_time_(0), + end_time_(0), + is_started_(false) {} + + explicit CalculateStreamTimer(const phi::Place &place) + : calculated_stream_(nullptr), + start_time_(0), + end_time_(0), + is_started_(false), + place_(place) {} + + void Start() { + // Note(sonder): Since it is not possible to directly obtain the start time + // of the event, "gettimeofday" is used here to retrieve it. The callback is + // used to record the start time of the event. + if (!is_started_) { + calculated_stream_ = dynamic_cast( + phi::DeviceContextPool::Instance().Get(place_)) + ->stream(); + } + if (calculated_stream_ != nullptr) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipStreamAddCallback(calculated_stream_, + RecordEventTimerCallback, + reinterpret_cast(&start_time_), + 0)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaStreamAddCallback(calculated_stream_, + RecordEventTimerCallback, + reinterpret_cast(&start_time_), + 0)); +#endif + is_started_ = true; + } + } + + void Stop() { + if (calculated_stream_ != nullptr) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipStreamAddCallback(calculated_stream_, + RecordEventTimerCallback, + reinterpret_cast(&end_time_), + 0)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaStreamAddCallback(calculated_stream_, + RecordEventTimerCallback, + reinterpret_cast(&end_time_), + 0)); +#endif + is_started_ = false; + } + } + + double StartTime() { + if (calculated_stream_ != nullptr) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(calculated_stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(calculated_stream_)); +#endif + } + return start_time_; + } + + double EndTime() { + if (calculated_stream_ != nullptr) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamSynchronize(calculated_stream_)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(calculated_stream_)); +#endif + } + return end_time_; + } + + bool IsStarted() { return is_started_; } + + void SetStream(gpuStream_t stream) { calculated_stream_ = stream; } + + private: + gpuStream_t calculated_stream_; + double start_time_; + double end_time_; + bool is_started_; + const phi::Place place_; +}; + } // namespace phi diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index ee586155360a1..2887bf0bb2102 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -808,7 +808,9 @@ def __init__(self, place, plan, scope): self._scope = scope self._new_exe = self._create_new_executor() - def run(self, feed_names, return_numpy=True): + def run( + self, feed_names, return_numpy=True, enable_job_schedule_profiler=False + ): """ Args: feed_names(list): This parameter represents the input names of the model. @@ -818,7 +820,9 @@ def run(self, feed_names, return_numpy=True): (the Tensor specified in the fetch list) to numpy.ndarray. if it is False, the type of the return value is a list of :code:`LoDTensor`. The default is True. """ - tensors = self._new_exe.run(feed_names)._move_to_list() + tensors = self._new_exe.run( + feed_names, enable_job_schedule_profiler + )._move_to_list() if return_numpy: tensors = as_numpy(tensors, copy=True) if not get_flags("FLAGS_enable_pir_in_executor")[ @@ -1201,6 +1205,8 @@ def __init__(self, place=None): self.op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName() + self.enable_job_schedule_profiler = False + def _is_optimizer_op(self, op): return self.op_role_key in op.attr_names and int( op.all_attrs()[self.op_role_key] @@ -1921,7 +1927,11 @@ def _run_impl( else: tensor._copy_from(cpu_tensor, self.place) - ret = new_exe.run(list(feed.keys()), return_numpy) + ret = new_exe.run( + list(feed.keys()), + return_numpy, + self.enable_job_schedule_profiler, + ) set_flags(stored_flag) return ret diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 0b95f3be8af98..bdcdb04b93617 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -114,6 +114,7 @@ def set_field_default_config(category, field, default_value): set_field_default_config(PIPELINE, "accumulate_steps", 1) set_field_default_config(PIPELINE, "generation_batch_size", 1) set_field_default_config(PIPELINE, "enable_send_recv_overlap", False) +set_field_default_config(PIPELINE, "schedule_profiler", False) ######################################### # quantization configuration diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index ce1351fc69b28..93f1d477baf5c 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -254,6 +254,8 @@ def __init__( paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1}) paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1}) + self.enable_job_schedule_profiler = False + def _prepare_data_spec(self, data, split, batch_size): inputs_spec = [] labels_spec = [] @@ -1492,6 +1494,11 @@ def run(self, data=None, feed=None, fetch_list=None, mode=None): and not self._has_prepared_reader[self._mode] ): self._prepare_reader() + + self._executor.enable_job_schedule_profiler = ( + self.enable_job_schedule_profiler + ) + outs = self._executor.run( self.main_program, feed=feed_dict, diff --git a/python/paddle/distributed/auto_parallel/static/profiler_helper_static.py b/python/paddle/distributed/auto_parallel/static/profiler_helper_static.py new file mode 100644 index 0000000000000..08a048c0bb68e --- /dev/null +++ b/python/paddle/distributed/auto_parallel/static/profiler_helper_static.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import re +from argparse import ArgumentParser + +import paddle +from paddle.base.log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + + +color_map = { + "forward": "thread_state_running", # RGB: 126, 200, 148 + "backward": "rail_idle", # RGB: 238, 142, 0 + "optimizer": "rail_response", # RGB: 238, 142, 0 + "default": "thread_state_unknown", # RGB: 199, 155, 125 +} + + +def parse_args(): + parser = ArgumentParser() + device_count = paddle.device.cuda.device_count() + all_devices = ",".join([str(i) for i in range(device_count)]) + parser.add_argument("--devices", type=str, default=all_devices) + parser.add_argument("--log_dir", type=str, required=True) + args = parser.parse_args() + return args + + +def process_job_log(log_data, device_id): + log_pattern = r'.*?Profiler Info: Job \((\d+)\), type = (\w+), micro_batch_id = (\d+), job_start_time = (\d+.\d+), job_end_time = (\d+.\d+)' + matches = re.findall(log_pattern, log_data) + events = [] + last_end_time = None + + step_times = [] + step_start_time = 0 + step_end_time = 0 + + for i, match in enumerate(matches): + job_id, job_type, micro_batch_id, job_start_time, job_end_time = match + + start_time = float(job_start_time.strip()) * 1000 + end_time = float(job_end_time.strip()) * 1000 + + if job_type == "forward" and micro_batch_id == "0": + if step_start_time != 0: + step_times.append([step_start_time, step_end_time]) + step_start_time = start_time + step_end_time = end_time + + event_start = { + "name": job_type + "_" + str(job_id), + "cat": job_type, + "ph": "B", + "ts": start_time, + "pid": 0, + "tid": "GPU" + str(device_id), + } + event_end = { + "name": job_type + "_" + str(job_id), + "cat": job_type, + "ph": "E", + "pid": 0, + "ts": end_time, + "tid": "GPU" + str(device_id), + } + if job_type in color_map: + event_start["cname"] = color_map[job_type] + event_end["cname"] = color_map[job_type] + + events.append(event_start) + events.append(event_end) + + last_end_time = end_time + + step_times.append([step_start_time, step_end_time]) + return events, step_times + + +def main(): + args = parse_args() + all_events = [] + step_infos = [] + start_step = 0 + + for device_id in args.devices.split(","): + _logger.info(f"Process device {device_id}") + device_id = int(device_id) + log_file = os.path.join(args.log_dir, "workerlog." + str(device_id)) + with open(log_file, "r") as f: + log_data = f.read() + + start_step_pattern = ( + r'.*?Schedule Profiler start at step (\d+) and end at step.*' + ) + start_step_match = re.findall(start_step_pattern, log_data) + start_step = ( + int(start_step_match[0]) if len(start_step_match) > 0 else 0 + ) + + events, step_times = process_job_log(log_data, device_id) + all_events.extend(events) + for i, info in enumerate(step_times): + if len(step_infos) <= i: + step_infos.append([float("inf"), float("-inf")]) + step_infos[i][0] = min(step_infos[i][0], info[0]) + step_infos[i][1] = max(step_infos[i][1], info[1]) + + for i, info in enumerate(step_infos): + start_time = info[0] + if i > 0: + start_time = max(start_time, step_infos[i - 1][1]) + event_start = { + "name": "step" + str(i + start_step), + "cat": "step", + "ph": "B", + "ts": start_time, + "pid": 0, + "tid": "Step", + "cname": color_map["default"], + } + event_end = { + "name": "step" + str(i + start_step), + "cat": "step", + "ph": "E", + "ts": info[1], + "pid": 0, + "tid": "Step", + "cname": color_map["default"], + } + + all_events.append(event_start) + all_events.append(event_end) + + save_path = os.path.join(args.log_dir, "pipeline_profile.json") + with open(save_path, "w") as f: + f.write(json.dumps({"traceEvents": all_events})) + _logger.info(f"Save pipeline profile to {save_path}") + + # support Perfetto format + save_path = os.path.join(args.log_dir, "pipeline_profile_perfetto.json") + all_events.extend( + [ + { + "args": {"name": "STEP"}, + "cat": "__metadata", + "name": "thread_name", + "ph": "M", + "pid": 0, + "tid": 2333, + "ts": 0, + } + ] + ) + for i in range(len(args.devices.split(","))): + all_events.extend( + [ + { + "args": {"name": f"GPU:{i}"}, + "cat": "__metadata", + "name": "thread_name", + "ph": "M", + "pid": 0, + "tid": i + 2334, + "ts": 0, + } + ] + ) + json_str = json.dumps({"traceEvents": all_events}) + for i in range(len(args.devices.split(","))): + json_str = json_str.replace('"Step"', '2333') + json_str = json_str.replace(f'"GPU{i}"', f'{i + 2334}') + + with open(save_path, "w") as f: + f.write(json_str) + _logger.info(f"Save pipeline profile to {save_path}") + + +if __name__ == "__main__": + main()