From deccebf6accefcaabaeb63119d1b97ddb939a351 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 29 Apr 2026 15:49:41 -0400 Subject: [PATCH 1/3] Refactor error handling in Proton library - Changed exception types from std::runtime_error to more specific exceptions such as std::out_of_range, std::invalid_argument, and std::logic_error across multiple files for better error categorization. - Introduced utility functions in Utility/Errors.h to create prefixed error messages for consistency and clarity. - Updated various throw statements to utilize the new error handling functions, improving readability and maintainability of the code. - Ensured that all error messages are prefixed with "[PROTON]" for easier identification in logs. --- .../common/include/TraceDataIO/ByteSpan.h | 2 +- .../common/include/TraceDataIO/EntryDecoder.h | 3 +- .../common/lib/TraceDataIO/ByteSpan.cpp | 5 ++- .../proton/common/lib/TraceDataIO/Parser.cpp | 3 +- third_party/proton/csrc/include/Data/Metric.h | 15 +++---- .../proton/csrc/include/Data/PhaseStore.h | 5 ++- .../proton/csrc/include/Driver/Dispatch.h | 13 +++--- .../proton/csrc/include/Utility/Errors.h | 31 +++++++++++++- .../proton/csrc/lib/Context/Shadow.cpp | 5 ++- third_party/proton/csrc/lib/Data/Data.cpp | 7 ++-- third_party/proton/csrc/lib/Data/Metric.cpp | 39 +++++++++--------- .../proton/csrc/lib/Data/TraceData.cpp | 13 +++--- third_party/proton/csrc/lib/Data/TreeData.cpp | 22 +++++----- third_party/proton/csrc/lib/Driver/Device.cpp | 4 +- .../lib/Profiler/Cupti/CuptiPCSampling.cpp | 7 ++-- .../csrc/lib/Profiler/Cupti/CuptiProfiler.cpp | 10 ++--- .../proton/csrc/lib/Profiler/GPUProfiler.cpp | 10 +++-- .../proton/csrc/lib/Profiler/Graph.cpp | 5 ++- .../InstrumentationProfiler.cpp | 41 ++++++++++++++----- .../lib/Profiler/Instrumentation/Metadata.cpp | 3 +- .../Profiler/RocTracer/RoctracerProfiler.cpp | 4 +- .../proton/csrc/lib/Session/Session.cpp | 22 +++++----- 22 files changed, 166 insertions(+), 103 deletions(-) diff --git a/third_party/proton/common/include/TraceDataIO/ByteSpan.h b/third_party/proton/common/include/TraceDataIO/ByteSpan.h index 220347e34f63..2e3c43cc3214 100644 --- a/third_party/proton/common/include/TraceDataIO/ByteSpan.h +++ b/third_party/proton/common/include/TraceDataIO/ByteSpan.h @@ -8,7 +8,7 @@ namespace proton { -class BufferException : public std::runtime_error { +class BufferException : public std::out_of_range { public: explicit BufferException(const std::string &message); }; diff --git a/third_party/proton/common/include/TraceDataIO/EntryDecoder.h b/third_party/proton/common/include/TraceDataIO/EntryDecoder.h index ae3fe5e92ae6..46eddf9d52fa 100644 --- a/third_party/proton/common/include/TraceDataIO/EntryDecoder.h +++ b/third_party/proton/common/include/TraceDataIO/EntryDecoder.h @@ -2,6 +2,7 @@ #define PROTON_COMMON_ENTRY_DECODER_H_ #include "ByteSpan.h" +#include "Utility/Errors.h" #include #include #include @@ -11,7 +12,7 @@ namespace proton { class EntryBase; template void decodeFn(ByteSpan &buffer, EntryT &entry) { - throw std::runtime_error("No decoder function is implemented"); + throw makeLogicError("No decoder function is implemented"); } class EntryDecoder { diff --git a/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp b/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp index f218c8ff7f9f..bb5a6253c4cc 100644 --- a/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp +++ b/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp @@ -1,11 +1,12 @@ #include "TraceDataIO/ByteSpan.h" +#include "Utility/Errors.h" using namespace proton; ByteSpan::ByteSpan(const uint8_t *data, size_t size) : dataPtr(data), dataSize(size), pos(0) { if (data == nullptr && size > 0) { - throw std::invalid_argument( + throw makeInvalidArgument( "Data pointer cannot be null for non-zero size"); } } @@ -74,4 +75,4 @@ void ByteSpan::seek(size_t position) { } BufferException::BufferException(const std::string &message) - : std::runtime_error(message) {} + : std::out_of_range(prefixErrorMessage(message)) {} diff --git a/third_party/proton/common/lib/TraceDataIO/Parser.cpp b/third_party/proton/common/lib/TraceDataIO/Parser.cpp index 819db452449f..5d47b72c0c5f 100644 --- a/third_party/proton/common/lib/TraceDataIO/Parser.cpp +++ b/third_party/proton/common/lib/TraceDataIO/Parser.cpp @@ -1,9 +1,10 @@ #include "TraceDataIO/Parser.h" +#include "Utility/Errors.h" using namespace proton; ParserException::ParserException(const std::string &msg, ExceptionSeverity sev) - : std::runtime_error(msg), severity(sev) {} + : std::runtime_error(prefixErrorMessage(msg)), severity(sev) {} ParserBase::ParserBase(ByteSpan &buffer, const ParserConfig &config) : buffer(buffer), config(config) {} diff --git a/third_party/proton/csrc/include/Data/Metric.h b/third_party/proton/csrc/include/Data/Metric.h index 338f05d472f5..8a11741b870c 100644 --- a/third_party/proton/csrc/include/Data/Metric.h +++ b/third_party/proton/csrc/include/Data/Metric.h @@ -2,6 +2,7 @@ #define PROTON_DATA_METRIC_H_ #include "Runtime/Runtime.h" +#include "Utility/Errors.h" #include "Utility/String.h" #include "Utility/Traits.h" #include @@ -76,7 +77,7 @@ class Metric { void updateValue(int valueId, MetricValueType value) { // Enforce type consistency: once a valueId has a type, it must not change. if (values[valueId].index() != value.index()) { - throw std::runtime_error( + throw makeInvalidArgument( std::string("Metric value type mismatch for valueId ") + std::to_string(valueId) + " (" + getValueName(valueId) + ")" + ": current=" + getTypeNameForIndex(values[valueId].index()) + @@ -99,8 +100,8 @@ class Metric { std::is_arithmetic_v< typename CurrentType::value_type>) { if (currentValue.size() != otherValue.size()) { - throw std::runtime_error( - std::string("[PROTON] Vector metric size mismatch for " + throw makeInvalidArgument( + std::string("Vector metric size mismatch for " "valueId ") + std::to_string(valueId) + " (" + getValueName(valueId) + "): current=" + std::to_string(currentValue.size()) + @@ -110,8 +111,8 @@ class Metric { currentValue[i] += otherValue[i]; } } else { - throw std::runtime_error( - std::string("[PROTON] Metric aggregation not supported for " + throw makeLogicError( + std::string("Metric aggregation not supported for " "valueId ") + std::to_string(valueId) + " (" + getValueName(valueId) + "): type=" + getTypeNameForIndex(values[valueId].index())); @@ -475,8 +476,8 @@ class MetricBuffer { std::shared_lock lock(metricDescriptorMutex); auto it = metricDescriptors.find(id); if (it == metricDescriptors.end()) { - throw std::runtime_error("[PROTON] MetricBuffer: unknown metric id: " + - std::to_string(id)); + throw makeOutOfRange("MetricBuffer: unknown metric id: " + + std::to_string(id)); } return it->second; } diff --git a/third_party/proton/csrc/include/Data/PhaseStore.h b/third_party/proton/csrc/include/Data/PhaseStore.h index d79b3801fc69..c381d622c977 100644 --- a/third_party/proton/csrc/include/Data/PhaseStore.h +++ b/third_party/proton/csrc/include/Data/PhaseStore.h @@ -1,6 +1,8 @@ #ifndef PROTON_DATA_PHASE_STORE_H_ #define PROTON_DATA_PHASE_STORE_H_ +#include "Utility/Errors.h" + #include #include #include @@ -96,8 +98,7 @@ template class PhaseStore final : public PhaseStoreBase { std::shared_lock lock(phasesMutex); auto it = phases.find(phase); if (it == phases.end() || !it->second) { - throw std::runtime_error("[PROTON] Phase " + std::to_string(phase) + - " has no data."); + throw makeOutOfRange("Phase " + std::to_string(phase) + " has no data."); } return it->second; } diff --git a/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/proton/csrc/include/Driver/Dispatch.h index 920151302dd0..338327df95ba 100644 --- a/third_party/proton/csrc/include/Driver/Dispatch.h +++ b/third_party/proton/csrc/include/Driver/Dispatch.h @@ -4,6 +4,7 @@ #include #include "Utility/Env.h" +#include "Utility/Errors.h" #include #include @@ -113,15 +114,15 @@ template class Dispatch { } } if (*lib == nullptr) { - throw std::runtime_error("Could not load `" + std::string(name) + "`"); + throw makeRuntimeError("Could not load `" + std::string(name) + "`"); } } static void check(typename ExternLib::RetType ret, const char *functionName) { if (ret != ExternLib::success) { - throw std::runtime_error("Failed to execute " + - std::string(functionName) + " with error " + - std::to_string(ret)); + throw makeRuntimeError("Failed to execute " + + std::string(functionName) + " with error " + + std::to_string(ret)); } } @@ -132,8 +133,8 @@ template class Dispatch { if (handler == nullptr) { handler = reinterpret_cast(dlsym(ExternLib::lib, functionName)); if (handler == nullptr) { - throw std::runtime_error("Failed to load " + - std::string(ExternLib::name)); + throw makeRuntimeError("Failed to load " + + std::string(ExternLib::name)); } } auto ret = handler(args...); diff --git a/third_party/proton/csrc/include/Utility/Errors.h b/third_party/proton/csrc/include/Utility/Errors.h index 09c44025dc45..b4183e202193 100644 --- a/third_party/proton/csrc/include/Utility/Errors.h +++ b/third_party/proton/csrc/include/Utility/Errors.h @@ -2,12 +2,41 @@ #define PROTON_UTILITY_ERRORS_H_ #include +#include +#include namespace proton { +inline constexpr const char *kProtonErrorPrefix = "[PROTON] "; + +inline std::string prefixErrorMessage(std::string message) { + return std::string(kProtonErrorPrefix) + message; +} + +inline std::runtime_error makeRuntimeError(std::string message) { + return std::runtime_error(prefixErrorMessage(std::move(message))); +} + +inline std::invalid_argument makeInvalidArgument(std::string message) { + return std::invalid_argument(prefixErrorMessage(std::move(message))); +} + +inline std::out_of_range makeOutOfRange(std::string message) { + return std::out_of_range(prefixErrorMessage(std::move(message))); +} + +inline std::length_error makeLengthError(std::string message) { + return std::length_error(prefixErrorMessage(std::move(message))); +} + +inline std::logic_error makeLogicError(std::string message) { + return std::logic_error(prefixErrorMessage(std::move(message))); +} + class NotImplemented : public std::logic_error { public: - NotImplemented() : std::logic_error("Not yet implemented") {}; + NotImplemented() + : std::logic_error(prefixErrorMessage("Not yet implemented")) {} }; } // namespace proton diff --git a/third_party/proton/csrc/lib/Context/Shadow.cpp b/third_party/proton/csrc/lib/Context/Shadow.cpp index 70366e6b9956..bef7b8bd13a0 100644 --- a/third_party/proton/csrc/lib/Context/Shadow.cpp +++ b/third_party/proton/csrc/lib/Context/Shadow.cpp @@ -1,4 +1,5 @@ #include "Context/Shadow.h" +#include "Utility/Errors.h" #include #include @@ -29,10 +30,10 @@ size_t ShadowContextSource::getDepth() { void ShadowContextSource::exitScope(const Scope &scope) { if (threadContextStack[this].empty()) { - throw std::runtime_error("Context stack is empty"); + throw makeLogicError("Context stack is empty"); } if (threadContextStack[this].back() != scope) { - throw std::runtime_error("Context stack is not balanced"); + throw makeLogicError("Context stack is not balanced"); } threadContextStack[this].pop_back(); } diff --git a/third_party/proton/csrc/lib/Data/Data.cpp b/third_party/proton/csrc/lib/Data/Data.cpp index a1831755b623..6ca6fa213e5c 100644 --- a/third_party/proton/csrc/lib/Data/Data.cpp +++ b/third_party/proton/csrc/lib/Data/Data.cpp @@ -1,4 +1,5 @@ #include "Data/Data.h" +#include "Utility/Errors.h" #include "Utility/String.h" #include @@ -173,7 +174,7 @@ OutputFormat parseOutputFormat(const std::string &outputFormat) { } else if (toLower(outputFormat) == "chrome_trace") { return OutputFormat::ChromeTrace; } else { - throw std::runtime_error("Unknown output format: " + outputFormat); + throw makeInvalidArgument("Unknown output format: " + outputFormat); } } @@ -185,8 +186,8 @@ const std::string outputFormatToString(OutputFormat outputFormat) { } else if (outputFormat == OutputFormat::ChromeTrace) { return "chrome_trace"; } - throw std::runtime_error("Unknown output format: " + - std::to_string(static_cast(outputFormat))); + throw makeInvalidArgument("Unknown output format: " + + std::to_string(static_cast(outputFormat))); } } // namespace proton diff --git a/third_party/proton/csrc/lib/Data/Metric.cpp b/third_party/proton/csrc/lib/Data/Metric.cpp index 6cd23181f29a..5228b427ff61 100644 --- a/third_party/proton/csrc/lib/Data/Metric.cpp +++ b/third_party/proton/csrc/lib/Data/Metric.cpp @@ -1,4 +1,5 @@ #include "Data/Metric.h" +#include "Utility/Errors.h" #include #include @@ -46,15 +47,15 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name, if (nameIt != metricNameToId.end()) { auto &descriptor = metricDescriptors.at(nameIt->second); if (descriptor.typeIndex != typeIndex) { - throw std::runtime_error( - "[PROTON] MetricBuffer: type mismatch for metric " + name + - ": current=" + getTypeNameForIndex(descriptor.typeIndex) + + throw makeInvalidArgument( + "MetricBuffer: type mismatch for metric " + name + ": current=" + + getTypeNameForIndex(descriptor.typeIndex) + ", new=" + getTypeNameForIndex(typeIndex)); } if (descriptor.size != size) { - throw std::runtime_error( - "[PROTON] MetricBuffer: size mismatch for metric " + name + - ": current=" + std::to_string(descriptor.size) + + throw makeInvalidArgument( + "MetricBuffer: size mismatch for metric " + name + ": current=" + + std::to_string(descriptor.size) + ", new=" + std::to_string(size)); } return descriptor; @@ -68,16 +69,15 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name, if (nameIt != metricNameToId.end()) { auto &descriptor = metricDescriptors.at(nameIt->second); if (descriptor.typeIndex != typeIndex) { - throw std::runtime_error( - "[PROTON] MetricBuffer: type mismatch for metric " + name + - ": current=" + getTypeNameForIndex(descriptor.typeIndex) + + throw makeInvalidArgument( + "MetricBuffer: type mismatch for metric " + name + ": current=" + + getTypeNameForIndex(descriptor.typeIndex) + ", new=" + getTypeNameForIndex(typeIndex)); } if (descriptor.size != size) { - throw std::runtime_error( - "[PROTON] MetricBuffer: size mismatch for metric " + name + - ": current=" + std::to_string(descriptor.size) + - ", new=" + std::to_string(size)); + throw makeInvalidArgument( + "MetricBuffer: size mismatch for metric " + name + ": current=" + + std::to_string(descriptor.size) + ", new=" + std::to_string(size)); } return descriptor; } @@ -124,9 +124,8 @@ collectTensorMetrics(Runtime *runtime, } tensorMetricsHost[name] = std::move(values); } else { - throw std::runtime_error( - "[PROTON] Unsupported tensor metric type index: " + - std::to_string(tensorMetric.typeIndex)); + throw makeInvalidArgument("Unsupported tensor metric type index: " + + std::to_string(tensorMetric.typeIndex)); } } return tensorMetricsHost; @@ -163,11 +162,11 @@ void MetricBuffer::queue(size_t metricId, MetricValueType scalarMetric, [](auto &&value) -> uint64_t { using T = std::decay_t; if constexpr (std::is_same_v) { - throw std::runtime_error( - "[PROTON] String metrics are not supported in MetricBuffer"); + throw makeInvalidArgument( + "String metrics are not supported in MetricBuffer"); } else if constexpr (is_std_vector_v) { - throw std::runtime_error( - "[PROTON] Vector metrics are not supported in MetricBuffer"); + throw makeInvalidArgument( + "Vector metrics are not supported in MetricBuffer"); } else { static_assert(sizeof(T) == sizeof(uint64_t), "MetricValueType alternative must be 8 bytes"); diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 0ea4a71e9893..953e80589bab 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -1,6 +1,7 @@ #include "Data/TraceData.h" #include "Profiler/Graph.h" #include "TraceDataIO/TraceWriter.h" +#include "Utility/Errors.h" #include "Utility/MsgPackWriter.h" #include "nlohmann/json.hpp" @@ -122,7 +123,7 @@ class TraceData::Trace { std::vector contexts; auto it = traceContextMap.find(contextId); if (it == traceContextMap.end()) { - throw std::runtime_error("Context not found"); + throw makeOutOfRange("Context not found"); } std::reference_wrapper context = it->second; contexts.push_back(context.get()); @@ -146,7 +147,7 @@ class TraceData::Trace { Event &getEvent(size_t eventId) { auto it = traceEvents.find(eventId); if (it == traceEvents.end()) { - throw std::runtime_error("Event not found"); + throw makeOutOfRange("Event not found"); } return it->second; } @@ -688,7 +689,7 @@ void reconstructGraphScopeEvents( } } if (!seenCaptureTag) { - throw std::runtime_error("Invalid graph contexts without capture tag"); + throw makeLogicError("Invalid graph contexts without capture tag"); } if (!isMetadataKernel) { graphContexts @@ -877,7 +878,7 @@ void dumpCpuToGpuFlowEvents( auto launchEventIt = launchEventIdToCpuScopeEvent.find(event.launchEventId); if (launchEventIt == launchEventIdToCpuScopeEvent.end()) { - throw std::runtime_error( + throw makeOutOfRange( "Cannot find CPU scope event for kernel launch event id: " + std::to_string(event.launchEventId)); } @@ -1062,7 +1063,7 @@ void TraceData::dumpChromeTrace(std::ostream &os, size_t phase) const { /*isGraphLinked=*/true); } if (hasKernelMetrics && hasCycleMetrics) { - throw std::runtime_error("only one active metric type is supported"); + throw makeLogicError("only one active metric type is supported"); } } } @@ -1111,7 +1112,7 @@ void TraceData::doDump(std::ostream &os, OutputFormat outputFormat, if (outputFormat == OutputFormat::ChromeTrace) { dumpChromeTrace(os, phase); } else { - throw std::logic_error("Output format not supported"); + throw makeInvalidArgument("Output format not supported"); } } diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index 928767ecc99b..29659382538e 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -2,6 +2,7 @@ #include "Context/Context.h" #include "Data/Metric.h" #include "Device.h" +#include "Utility/Errors.h" #include "Utility/MsgPackWriter.h" #include #include @@ -44,14 +45,13 @@ struct MetricSummary { void updateDeviceIdMask(uint64_t deviceType, uint64_t deviceId) { if (deviceType >= static_cast(DeviceType::COUNT)) { - throw std::runtime_error("[PROTON] Invalid deviceType " + - std::to_string(deviceType)); + throw makeOutOfRange("Invalid deviceType " + std::to_string(deviceType)); } if (deviceId >= kMaxRegisteredDeviceIds) { - throw std::runtime_error("[PROTON] DeviceId " + std::to_string(deviceId) + - " exceeds MaxRegisteredDeviceIds " + - std::to_string(kMaxRegisteredDeviceIds) + - " for deviceType " + std::to_string(deviceType)); + throw makeOutOfRange("DeviceId " + std::to_string(deviceId) + + " exceeds MaxRegisteredDeviceIds " + + std::to_string(kMaxRegisteredDeviceIds) + + " for deviceType " + std::to_string(deviceType)); } deviceIdMasks[static_cast(deviceType)] |= (1u << static_cast(deviceId)); @@ -81,7 +81,7 @@ struct MetricSummary { } else if (metricKind == MetricKind::Flexible) { // Flexible metrics are tracked in a separate map. } else { - throw std::runtime_error("MetricKind not supported"); + throw makeLogicError("MetricKind not supported"); } } } @@ -292,7 +292,7 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree, } else if (metricKind == MetricKind::Flexible) { // Flexible metrics are handled in a different way } else { - throw std::runtime_error("MetricKind not supported"); + throw makeLogicError("MetricKind not supported"); } } }; @@ -548,7 +548,7 @@ TreeData::buildHatchetMsgPack(TreeData::Tree *tree, } else if (metricKind == MetricKind::Flexible) { // Flexible metrics are tracked in a separate map. } else { - throw std::runtime_error("MetricKind not supported"); + throw makeLogicError("MetricKind not supported"); } } if (isRoot) { @@ -643,7 +643,7 @@ TreeData::buildHatchetMsgPack(TreeData::Tree *tree, writer.packStr(cycleMetricDeviceTypeName); writer.packStr(std::to_string(deviceType)); } else { - throw std::runtime_error("MetricKind not supported"); + throw makeLogicError("MetricKind not supported"); } } if (isRoot) { @@ -926,7 +926,7 @@ void TreeData::doDump(std::ostream &os, OutputFormat outputFormat, } else if (outputFormat == OutputFormat::HatchetMsgPack) { dumpHatchetMsgPack(os, phase); } else { - throw std::logic_error("Output format not supported"); + throw makeInvalidArgument("Output format not supported"); } } diff --git a/third_party/proton/csrc/lib/Driver/Device.cpp b/third_party/proton/csrc/lib/Driver/Device.cpp index 24f4d1612623..aca33b7e3edd 100644 --- a/third_party/proton/csrc/lib/Driver/Device.cpp +++ b/third_party/proton/csrc/lib/Driver/Device.cpp @@ -13,7 +13,7 @@ Device getDevice(DeviceType type, uint64_t index) { if (type == DeviceType::HIP) { return hip::getDevice(index); } - throw std::runtime_error("DeviceType not supported"); + throw makeInvalidArgument("DeviceType not supported"); } const std::string getDeviceTypeString(DeviceType type) { @@ -22,7 +22,7 @@ const std::string getDeviceTypeString(DeviceType type) { } else if (type == DeviceType::HIP) { return DeviceTraits::name; } - throw std::runtime_error("DeviceType not supported"); + throw makeInvalidArgument("DeviceType not supported"); } } // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp index 631be416021c..488b78773bc4 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -3,6 +3,7 @@ #include "Driver/GPU/CudaApi.h" #include "Driver/GPU/CuptiApi.h" #include "Utility/Atomic.h" +#include "Utility/Errors.h" #include "Utility/Map.h" #include "Utility/String.h" #include @@ -142,8 +143,8 @@ CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) || (libVersion >= CUPTI_CUDA12_4_VERSION && CUPTI_API_VERSION < CUPTI_CUDA12_4_VERSION)) { - throw std::runtime_error( - "[PROTON] CUPTI API version: " + std::to_string(CUPTI_API_VERSION) + + throw makeRuntimeError( + "CUPTI API version: " + std::to_string(CUPTI_API_VERSION) + " and CUPTI driver version: " + std::to_string(libVersion) + " are not compatible. Please set the environment variable " " TRITON_CUPTI_INCLUDE_PATH and TRITON_CUPTI_LIB_PATH to resolve the " @@ -377,7 +378,7 @@ void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, auto *stallReason = &pcData->stallReason[j]; if (!configureData->stallReasonIndexToMetricIndex.count( stallReason->pcSamplingStallReasonIndex)) - throw std::runtime_error("[PROTON] Invalid stall reason index"); + throw makeOutOfRange("Invalid stall reason index"); for (const auto &[data, baseEntry] : dataToEntry) { auto entry = baseEntry; if (lineInfo.fileName.size()) diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 3a7613e6b06a..087d84535a22 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -9,6 +9,7 @@ #include "Profiler/Graph.h" #include "Runtime/CudaRuntime.h" #include "Utility/Env.h" +#include "Utility/Errors.h" #include "Utility/Map.h" #include "Utility/String.h" #include "Utility/Vector.h" @@ -125,7 +126,7 @@ uint32_t processActivityKernel( const GraphState::NodeState &nodeState = nodeIdToState->at( kernel->graphNodeId); // nodeIdToState must have the nodeId if (nodeState.status.isMissingName()) { - throw std::runtime_error("Kernel name is missing for a graph node."); + throw makeLogicError("Kernel name is missing for a graph node."); } const bool isMetricKernel = nodeState.status.isMetricNode(); for (auto &[data, entry] : externState.dataToGraphEntry) { @@ -408,7 +409,7 @@ void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer, getIntEnv("TRITON_PROFILE_BUFFER_SIZE", 64 * 1024 * 1024); *buffer = static_cast(aligned_alloc(AlignSize, envBufferSize)); if (*buffer == nullptr) { - throw std::runtime_error("[PROTON] aligned_alloc failed"); + throw makeRuntimeError("aligned_alloc failed"); } *bufferSize = envBufferSize; *maxNumRecords = 0; @@ -438,7 +439,7 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { break; } else { - throw std::runtime_error("[PROTON] cupti::activityGetNextRecord failed"); + throw makeRuntimeError("cupti::activityGetNextRecord failed"); } } while (true); @@ -814,8 +815,7 @@ void CuptiProfiler::doSetMode(const std::vector &modeAndOptions) { periodicFlushingFormat, modeAndOptions, "CuptiProfiler"); } else if (!mode.empty()) { - throw std::invalid_argument("[PROTON] CuptiProfiler: unsupported mode: " + - mode); + throw makeInvalidArgument("CuptiProfiler: unsupported mode: " + mode); } } diff --git a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp index ce10bdff6827..fc9f0c1a101e 100644 --- a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp @@ -1,5 +1,6 @@ #include "Profiler/GPUProfiler.h" #include "Profiler/Graph.h" +#include "Utility/Errors.h" #include #include @@ -185,13 +186,14 @@ void setPeriodicFlushingMode(bool &periodicFlushingEnabled, const std::string key = modeAndOptions[1].substr(0, delimiterPos); const std::string value = modeAndOptions[1].substr(delimiterPos + 1); if (key != "format") { - throw std::invalid_argument(std::string("[PROTON] ") + profilerName + - ": unsupported option key: " + key); + throw makeInvalidArgument(profilerName + + std::string(": unsupported option key: ") + + key); } if (value != "hatchet_msgpack" && value != "chrome_trace" && value != "hatchet") { - throw std::invalid_argument(std::string("[PROTON] ") + profilerName + - ": unsupported format: " + value); + throw makeInvalidArgument(profilerName + + std::string(": unsupported format: ") + value); } periodicFlushingFormat = value; } else { diff --git a/third_party/proton/csrc/lib/Profiler/Graph.cpp b/third_party/proton/csrc/lib/Profiler/Graph.cpp index 0856e80a4e16..22bfa3bcd761 100644 --- a/third_party/proton/csrc/lib/Profiler/Graph.cpp +++ b/third_party/proton/csrc/lib/Profiler/Graph.cpp @@ -2,6 +2,7 @@ #include "Data/Data.h" #include "Runtime/Runtime.h" +#include "Utility/Errors.h" #include #include @@ -81,8 +82,8 @@ void emitMetricRecords(MetricBuffer &metricBuffer, uint64_t *hostBasePtr, break; } default: - throw std::runtime_error("[PROTON] Unsupported metric type index: " + - std::to_string(metricTypeIndex)); + throw makeOutOfRange("Unsupported metric type index: " + + std::to_string(metricTypeIndex)); break; } diff --git a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp index afbd396adb53..d4bd1cc91985 100644 --- a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp @@ -3,10 +3,12 @@ #include "Runtime/CudaRuntime.h" #include "Runtime/HipRuntime.h" +#include "Utility/Errors.h" #include "Utility/Numeric.h" #include "Utility/String.h" #include #include +#include #include #include #include @@ -47,7 +49,7 @@ void InstrumentationProfiler::doStop() { void InstrumentationProfiler::doSetMode( const std::vector &modeAndOptions) { if (modeAndOptions.empty()) { - throw std::runtime_error("Mode cannot be empty"); + throw makeInvalidArgument("Mode cannot be empty"); } if (proton::toLower(modeAndOptions[0]) == proton::toLower(DeviceTraits::name)) { @@ -56,7 +58,7 @@ void InstrumentationProfiler::doSetMode( proton::toLower(DeviceTraits::name)) { runtime = &HipRuntime::instance(); } else { - throw std::runtime_error("Unknown device type: " + modeAndOptions[0]); + throw makeInvalidArgument("Unknown device type: " + modeAndOptions[0]); } for (size_t i = 1; i < modeAndOptions.size(); ++i) { auto delimiterPos = modeAndOptions[i].find('='); @@ -71,6 +73,26 @@ void InstrumentationProfiler::doSetMode( } namespace { +uint32_t parseUnitId(const std::string &unitId) { + auto trimmedId = proton::trim(unitId); + size_t parsedEnd = 0; + unsigned long id = 0; + try { + id = std::stoul(trimmedId, &parsedEnd); + } catch (const std::invalid_argument &) { + throw makeInvalidArgument("Invalid sampling warp id: " + trimmedId); + } catch (const std::out_of_range &) { + throw makeOutOfRange("Sampling warp id out of range: " + trimmedId); + } + if (parsedEnd != trimmedId.size()) { + throw makeInvalidArgument("Invalid sampling warp id: " + trimmedId); + } + if (id > std::numeric_limits::max()) { + throw makeOutOfRange("Sampling warp id out of range: " + trimmedId); + } + return static_cast(id); +} + std::vector getUnitIdVector(const std::map &modeOptions, size_t totalUnits) { @@ -82,8 +104,7 @@ getUnitIdVector(const std::map &modeOptions, if (proton::trim(uintId).empty()) { continue; } - uint32_t id = std::stoi(uintId); - unitIdVector.push_back(id); + unitIdVector.push_back(parseUnitId(uintId)); } } if (unitIdVector.empty()) { @@ -105,7 +126,7 @@ InstrumentationProfiler::getParserConfig(uint64_t functionId, functionMetadata.at(functionId).getScratchMemorySize(); if (!(modeOptions.count("granularity") == 0 || modeOptions.at("granularity") == "GRANULARITY.WARP")) { - throw std::runtime_error("Only warp granularity is supported for now"); + throw makeInvalidArgument("Only warp granularity is supported for now"); } config->totalUnits = functionMetadata.at(functionId).getNumWarps(); config->numBlocks = bufferSize / config->scratchMemSize; @@ -114,7 +135,7 @@ InstrumentationProfiler::getParserConfig(uint64_t functionId, // Check if the uidVec is valid for (auto uid : config->uidVec) if (uid >= config->totalUnits) { - throw std::runtime_error( + throw makeOutOfRange( "Invalid sampling warp id: " + std::to_string(uid) + ". We have " + std::to_string(config->totalUnits) + " warps in total. Please check the proton sampling options."); @@ -132,7 +153,7 @@ void InstrumentationProfiler::initFunctionMetadata( const std::vector> &scopeIdParentPairs, const std::string &metadataPath) { if (functionScopeIdNames.count(functionId)) { - throw std::runtime_error( + throw makeInvalidArgument( "Duplicate function id: " + std::to_string(functionId) + " for function " + functionName); } @@ -141,7 +162,7 @@ void InstrumentationProfiler::initFunctionMetadata( auto scopeId = pair.first; auto scopeName = pair.second; if (functionScopeIdNames[functionId].count(scopeId)) { - throw std::runtime_error( + throw makeInvalidArgument( "Duplicate scope id: " + std::to_string(scopeId) + " for function " + functionName); } @@ -198,7 +219,7 @@ void InstrumentationProfiler::exitInstrumentedOp(uint64_t streamId, } if (size > MAX_HOST_BUFFER_SIZE) { - throw std::runtime_error( + throw makeLengthError( "Buffer size " + std::to_string(size) + " exceeds the limit " + std::to_string(MAX_HOST_BUFFER_SIZE) + ", not supported yet in proton"); } else if (size > DEFAULT_HOST_BUFFER_SIZE) { @@ -214,7 +235,7 @@ void InstrumentationProfiler::exitInstrumentedOp(uint64_t streamId, auto circularLayoutConfig = std::dynamic_pointer_cast(config); if (!circularLayoutConfig) { - throw std::runtime_error( + throw makeLogicError( "Only circular layout parser is supported for now"); } diff --git a/third_party/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp b/third_party/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp index 0bf9147b2cfb..27337cd615c4 100644 --- a/third_party/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp +++ b/third_party/proton/csrc/lib/Profiler/Instrumentation/Metadata.cpp @@ -1,6 +1,7 @@ #include #include "Profiler/Instrumentation/Metadata.h" +#include "Utility/Errors.h" #include "nlohmann/json.hpp" using json = nlohmann::json; @@ -10,7 +11,7 @@ namespace proton { void InstrumentationMetadata::parse() { std::ifstream metadataFile(metadataPath); if (!metadataFile.is_open()) { - throw std::runtime_error("Failed to open metadata file: " + metadataPath); + throw makeRuntimeError("Failed to open metadata file: " + metadataPath); } json metadataJson; diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 141ac34a1329..6a032f970acf 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -6,6 +6,7 @@ #include "Driver/GPU/RoctracerApi.h" #include "Runtime/HipRuntime.h" #include "Utility/Env.h" +#include "Utility/Errors.h" #include "Driver/GPU/RoctxTypes.h" #include "hip/amd_detail/hip_runtime_prof.h" @@ -495,8 +496,7 @@ void RoctracerProfiler::doSetMode( periodicFlushingFormat, modeAndOptions, "RoctracerProfiler"); } else if (!mode.empty()) { - throw std::invalid_argument( - "[PROTON] RoctracerProfiler: unsupported mode: " + mode); + throw makeInvalidArgument("RoctracerProfiler: unsupported mode: " + mode); } } diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 74eeb0f55915..9e7265430019 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -6,6 +6,7 @@ #include "Profiler/Cupti/CuptiProfiler.h" #include "Profiler/Instrumentation/InstrumentationProfiler.h" #include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Utility/Errors.h" #include "Utility/String.h" namespace proton { @@ -20,7 +21,7 @@ Profiler *makeProfiler(const std::string &name) { } else if (proton::toLower(name) == "instrumentation") { return &InstrumentationProfiler::instance(); } - throw std::runtime_error("Unknown profiler: " + name); + throw makeInvalidArgument("Unknown profiler: " + name); } std::unique_ptr makeData(const std::string &dataName, @@ -31,7 +32,7 @@ std::unique_ptr makeData(const std::string &dataName, } else if (toLower(dataName) == "trace") { return std::make_unique(path, contextSource); } - throw std::runtime_error("Unknown data: " + dataName); + throw makeInvalidArgument("Unknown data: " + dataName); } std::unique_ptr @@ -41,15 +42,15 @@ makeContextSource(const std::string &contextSourceName) { } else if (toLower(contextSourceName) == "python") { return std::make_unique(); } - throw std::runtime_error("Unknown context source: " + contextSourceName); + throw makeInvalidArgument("Unknown context source: " + contextSourceName); } void throwIfSessionNotInitialized( const std::map> &sessions, size_t sessionId) { if (!sessions.count(sessionId)) { - throw std::runtime_error("Session has not been initialized: " + - std::to_string(sessionId)); + throw makeOutOfRange("Session has not been initialized: " + + std::to_string(sessionId)); } } @@ -80,8 +81,9 @@ Profiler *SessionManager::validateAndSetProfilerMode(Profiler *profiler, for (auto &[id, session] : sessions) { if (session->getProfiler() == profiler && session->getProfiler()->getMode() != modeAndOptions) { - throw std::runtime_error("Cannot add a session with the same profiler " - "but a different mode than existing sessions"); + throw makeInvalidArgument( + "Cannot add a session with the same profiler " + "but a different mode than existing sessions"); } } return profiler->setMode(modeAndOptions); @@ -332,8 +334,7 @@ std::vector SessionManager::getDataMsgPack(size_t sessionId, auto *session = getSessionOrThrow(sessionId); auto *treeData = dynamic_cast(session->data.get()); if (!treeData) { - throw std::runtime_error( - "Only TreeData is supported for getData() for now"); + throw makeLogicError("Only TreeData is supported for getData() for now"); } return treeData->toMsgPack(phase); } @@ -343,8 +344,7 @@ std::string SessionManager::getData(size_t sessionId, size_t phase) { auto *session = getSessionOrThrow(sessionId); auto *treeData = dynamic_cast(session->data.get()); if (!treeData) { - throw std::runtime_error( - "Only TreeData is supported for getData() for now"); + throw makeLogicError("Only TreeData is supported for getData() for now"); } return treeData->toJsonString(phase); } From ae6ca380290e3cd52013088f55b97879f7f9f037 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 29 Apr 2026 15:51:41 -0400 Subject: [PATCH 2/3] Refactor error messages for clarity and consistency in Proton library --- .../proton/common/lib/TraceDataIO/ByteSpan.cpp | 3 +-- .../proton/csrc/include/Driver/Dispatch.h | 5 ++--- third_party/proton/csrc/lib/Data/Metric.cpp | 18 +++++++++--------- .../proton/csrc/lib/Profiler/GPUProfiler.cpp | 5 ++--- .../InstrumentationProfiler.cpp | 3 +-- .../proton/csrc/lib/Session/Session.cpp | 5 ++--- 6 files changed, 17 insertions(+), 22 deletions(-) diff --git a/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp b/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp index bb5a6253c4cc..6b66b981415e 100644 --- a/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp +++ b/third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp @@ -6,8 +6,7 @@ using namespace proton; ByteSpan::ByteSpan(const uint8_t *data, size_t size) : dataPtr(data), dataSize(size), pos(0) { if (data == nullptr && size > 0) { - throw makeInvalidArgument( - "Data pointer cannot be null for non-zero size"); + throw makeInvalidArgument("Data pointer cannot be null for non-zero size"); } } diff --git a/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/proton/csrc/include/Driver/Dispatch.h index 338327df95ba..08c8554c3fa1 100644 --- a/third_party/proton/csrc/include/Driver/Dispatch.h +++ b/third_party/proton/csrc/include/Driver/Dispatch.h @@ -120,9 +120,8 @@ template class Dispatch { static void check(typename ExternLib::RetType ret, const char *functionName) { if (ret != ExternLib::success) { - throw makeRuntimeError("Failed to execute " + - std::string(functionName) + " with error " + - std::to_string(ret)); + throw makeRuntimeError("Failed to execute " + std::string(functionName) + + " with error " + std::to_string(ret)); } } diff --git a/third_party/proton/csrc/lib/Data/Metric.cpp b/third_party/proton/csrc/lib/Data/Metric.cpp index 5228b427ff61..86aae465b429 100644 --- a/third_party/proton/csrc/lib/Data/Metric.cpp +++ b/third_party/proton/csrc/lib/Data/Metric.cpp @@ -48,15 +48,14 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name, auto &descriptor = metricDescriptors.at(nameIt->second); if (descriptor.typeIndex != typeIndex) { throw makeInvalidArgument( - "MetricBuffer: type mismatch for metric " + name + ": current=" + - getTypeNameForIndex(descriptor.typeIndex) + + "MetricBuffer: type mismatch for metric " + name + + ": current=" + getTypeNameForIndex(descriptor.typeIndex) + ", new=" + getTypeNameForIndex(typeIndex)); } if (descriptor.size != size) { throw makeInvalidArgument( "MetricBuffer: size mismatch for metric " + name + ": current=" + - std::to_string(descriptor.size) + - ", new=" + std::to_string(size)); + std::to_string(descriptor.size) + ", new=" + std::to_string(size)); } return descriptor; } @@ -70,14 +69,15 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name, auto &descriptor = metricDescriptors.at(nameIt->second); if (descriptor.typeIndex != typeIndex) { throw makeInvalidArgument( - "MetricBuffer: type mismatch for metric " + name + ": current=" + - getTypeNameForIndex(descriptor.typeIndex) + + "MetricBuffer: type mismatch for metric " + name + + ": current=" + getTypeNameForIndex(descriptor.typeIndex) + ", new=" + getTypeNameForIndex(typeIndex)); } if (descriptor.size != size) { - throw makeInvalidArgument( - "MetricBuffer: size mismatch for metric " + name + ": current=" + - std::to_string(descriptor.size) + ", new=" + std::to_string(size)); + throw makeInvalidArgument("MetricBuffer: size mismatch for metric " + + name + + ": current=" + std::to_string(descriptor.size) + + ", new=" + std::to_string(size)); } return descriptor; } diff --git a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp index fc9f0c1a101e..abb421e27be7 100644 --- a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp @@ -186,9 +186,8 @@ void setPeriodicFlushingMode(bool &periodicFlushingEnabled, const std::string key = modeAndOptions[1].substr(0, delimiterPos); const std::string value = modeAndOptions[1].substr(delimiterPos + 1); if (key != "format") { - throw makeInvalidArgument(profilerName + - std::string(": unsupported option key: ") + - key); + throw makeInvalidArgument( + profilerName + std::string(": unsupported option key: ") + key); } if (value != "hatchet_msgpack" && value != "chrome_trace" && value != "hatchet") { diff --git a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp index d4bd1cc91985..aa14935602ab 100644 --- a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp @@ -235,8 +235,7 @@ void InstrumentationProfiler::exitInstrumentedOp(uint64_t streamId, auto circularLayoutConfig = std::dynamic_pointer_cast(config); if (!circularLayoutConfig) { - throw makeLogicError( - "Only circular layout parser is supported for now"); + throw makeLogicError("Only circular layout parser is supported for now"); } int64_t timeShiftCost = 0; diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 9e7265430019..7ed8aac2941f 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -81,9 +81,8 @@ Profiler *SessionManager::validateAndSetProfilerMode(Profiler *profiler, for (auto &[id, session] : sessions) { if (session->getProfiler() == profiler && session->getProfiler()->getMode() != modeAndOptions) { - throw makeInvalidArgument( - "Cannot add a session with the same profiler " - "but a different mode than existing sessions"); + throw makeInvalidArgument("Cannot add a session with the same profiler " + "but a different mode than existing sessions"); } } return profiler->setMode(modeAndOptions); From cc3f145a6ea5e99756ae3054bf93c84b811c54f3 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 29 Apr 2026 15:53:18 -0400 Subject: [PATCH 3/3] Remove redundant parseUnitId function and simplify unit ID parsing in InstrumentationProfiler --- .../InstrumentationProfiler.cpp | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp index aa14935602ab..2626b9e59137 100644 --- a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp @@ -73,26 +73,6 @@ void InstrumentationProfiler::doSetMode( } namespace { -uint32_t parseUnitId(const std::string &unitId) { - auto trimmedId = proton::trim(unitId); - size_t parsedEnd = 0; - unsigned long id = 0; - try { - id = std::stoul(trimmedId, &parsedEnd); - } catch (const std::invalid_argument &) { - throw makeInvalidArgument("Invalid sampling warp id: " + trimmedId); - } catch (const std::out_of_range &) { - throw makeOutOfRange("Sampling warp id out of range: " + trimmedId); - } - if (parsedEnd != trimmedId.size()) { - throw makeInvalidArgument("Invalid sampling warp id: " + trimmedId); - } - if (id > std::numeric_limits::max()) { - throw makeOutOfRange("Sampling warp id out of range: " + trimmedId); - } - return static_cast(id); -} - std::vector getUnitIdVector(const std::map &modeOptions, size_t totalUnits) { @@ -104,7 +84,8 @@ getUnitIdVector(const std::map &modeOptions, if (proton::trim(uintId).empty()) { continue; } - unitIdVector.push_back(parseUnitId(uintId)); + uint32_t id = std::stoi(uintId); + unitIdVector.push_back(id); } } if (unitIdVector.empty()) {