Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/proton/common/include/TraceDataIO/ByteSpan.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace proton {

class BufferException : public std::runtime_error {
class BufferException : public std::out_of_range {
public:
explicit BufferException(const std::string &message);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PROTON_COMMON_ENTRY_DECODER_H_

#include "ByteSpan.h"
#include "Utility/Errors.h"
#include <cstdint>
#include <iostream>
#include <memory>
Expand All @@ -11,7 +12,7 @@ namespace proton {
class EntryBase;

template <typename EntryT> void decodeFn(ByteSpan &buffer, EntryT &entry) {
throw std::runtime_error("No decoder function is implemented");
throw makeLogicError("No decoder function is implemented");
}

class EntryDecoder {
Expand Down
6 changes: 3 additions & 3 deletions third_party/proton/common/lib/TraceDataIO/ByteSpan.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "TraceDataIO/ByteSpan.h"
#include "Utility/Errors.h"

using namespace proton;

ByteSpan::ByteSpan(const uint8_t *data, size_t size)
: dataPtr(data), dataSize(size), pos(0) {
if (data == nullptr && size > 0) {
throw std::invalid_argument(
"Data pointer cannot be null for non-zero size");
throw makeInvalidArgument("Data pointer cannot be null for non-zero size");
}
}

Expand Down Expand Up @@ -74,4 +74,4 @@ void ByteSpan::seek(size_t position) {
}

BufferException::BufferException(const std::string &message)
: std::runtime_error(message) {}
: std::out_of_range(prefixErrorMessage(message)) {}
3 changes: 2 additions & 1 deletion third_party/proton/common/lib/TraceDataIO/Parser.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include "TraceDataIO/Parser.h"
#include "Utility/Errors.h"

using namespace proton;

ParserException::ParserException(const std::string &msg, ExceptionSeverity sev)
: std::runtime_error(msg), severity(sev) {}
: std::runtime_error(prefixErrorMessage(msg)), severity(sev) {}

ParserBase::ParserBase(ByteSpan &buffer, const ParserConfig &config)
: buffer(buffer), config(config) {}
Expand Down
15 changes: 8 additions & 7 deletions third_party/proton/csrc/include/Data/Metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PROTON_DATA_METRIC_H_

#include "Runtime/Runtime.h"
#include "Utility/Errors.h"
#include "Utility/String.h"
#include "Utility/Traits.h"
#include <atomic>
Expand Down Expand Up @@ -78,7 +79,7 @@ class Metric {
void updateValue(int valueId, MetricValueType value) {
// Enforce type consistency: once a valueId has a type, it must not change.
if (values[valueId].index() != value.index()) {
throw std::runtime_error(
throw makeInvalidArgument(
std::string("Metric value type mismatch for valueId ") +
std::to_string(valueId) + " (" + getValueName(valueId) + ")" +
": current=" + getTypeNameForIndex(values[valueId].index()) +
Expand All @@ -101,8 +102,8 @@ class Metric {
std::is_arithmetic_v<
typename CurrentType::value_type>) {
if (currentValue.size() != otherValue.size()) {
throw std::runtime_error(
std::string("[PROTON] Vector metric size mismatch for "
throw makeInvalidArgument(
std::string("Vector metric size mismatch for "
"valueId ") +
std::to_string(valueId) + " (" + getValueName(valueId) +
"): current=" + std::to_string(currentValue.size()) +
Expand All @@ -112,8 +113,8 @@ class Metric {
currentValue[i] += otherValue[i];
}
} else {
throw std::runtime_error(
std::string("[PROTON] Metric aggregation not supported for "
throw makeLogicError(
std::string("Metric aggregation not supported for "
"valueId ") +
std::to_string(valueId) + " (" + getValueName(valueId) +
"): type=" + getTypeNameForIndex(values[valueId].index()));
Expand Down Expand Up @@ -487,8 +488,8 @@ class MetricBuffer {
std::shared_lock<std::shared_mutex> lock(metricDescriptorMutex);
auto it = metricDescriptors.find(id);
if (it == metricDescriptors.end()) {
throw std::runtime_error("[PROTON] MetricBuffer: unknown metric id: " +
std::to_string(id));
throw makeOutOfRange("MetricBuffer: unknown metric id: " +
std::to_string(id));
}
return it->second;
}
Expand Down
5 changes: 3 additions & 2 deletions third_party/proton/csrc/include/Data/PhaseStore.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef PROTON_DATA_PHASE_STORE_H_
#define PROTON_DATA_PHASE_STORE_H_

#include "Utility/Errors.h"

#include <cstddef>
#include <map>
#include <memory>
Expand Down Expand Up @@ -96,8 +98,7 @@ template <typename T> class PhaseStore final : public PhaseStoreBase {
std::shared_lock<std::shared_mutex> lock(phasesMutex);
auto it = phases.find(phase);
if (it == phases.end() || !it->second) {
throw std::runtime_error("[PROTON] Phase " + std::to_string(phase) +
" has no data.");
throw makeOutOfRange("Phase " + std::to_string(phase) + " has no data.");
}
return it->second;
}
Expand Down
12 changes: 6 additions & 6 deletions third_party/proton/csrc/include/Driver/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <dlfcn.h>

#include "Utility/Env.h"
#include "Utility/Errors.h"
#include <stdexcept>
#include <string>

Expand Down Expand Up @@ -113,15 +114,14 @@ template <typename ExternLib> class Dispatch {
}
}
if (*lib == nullptr) {
throw std::runtime_error("Could not load `" + std::string(name) + "`");
throw makeRuntimeError("Could not load `" + std::string(name) + "`");
}
}

static void check(typename ExternLib::RetType ret, const char *functionName) {
if (ret != ExternLib::success) {
throw std::runtime_error("Failed to execute " +
std::string(functionName) + " with error " +
std::to_string(ret));
throw makeRuntimeError("Failed to execute " + std::string(functionName) +
" with error " + std::to_string(ret));
}
}

Expand All @@ -132,8 +132,8 @@ template <typename ExternLib> class Dispatch {
if (handler == nullptr) {
handler = reinterpret_cast<FnT>(dlsym(ExternLib::lib, functionName));
if (handler == nullptr) {
throw std::runtime_error("Failed to load " +
std::string(ExternLib::name));
throw makeRuntimeError("Failed to load " +
std::string(ExternLib::name));
}
}
auto ret = handler(args...);
Expand Down
31 changes: 30 additions & 1 deletion third_party/proton/csrc/include/Utility/Errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,41 @@
#define PROTON_UTILITY_ERRORS_H_

#include <stdexcept>
#include <string>
#include <utility>

namespace proton {

inline constexpr const char *kProtonErrorPrefix = "[PROTON] ";

inline std::string prefixErrorMessage(std::string message) {
return std::string(kProtonErrorPrefix) + message;
}

inline std::runtime_error makeRuntimeError(std::string message) {
return std::runtime_error(prefixErrorMessage(std::move(message)));
}

inline std::invalid_argument makeInvalidArgument(std::string message) {
return std::invalid_argument(prefixErrorMessage(std::move(message)));
}

inline std::out_of_range makeOutOfRange(std::string message) {
return std::out_of_range(prefixErrorMessage(std::move(message)));
}

inline std::length_error makeLengthError(std::string message) {
return std::length_error(prefixErrorMessage(std::move(message)));
}

inline std::logic_error makeLogicError(std::string message) {
return std::logic_error(prefixErrorMessage(std::move(message)));
}

class NotImplemented : public std::logic_error {
public:
NotImplemented() : std::logic_error("Not yet implemented") {};
NotImplemented()
: std::logic_error(prefixErrorMessage("Not yet implemented")) {}
};

} // namespace proton
Expand Down
5 changes: 3 additions & 2 deletions third_party/proton/csrc/lib/Context/Shadow.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "Context/Shadow.h"
#include "Utility/Errors.h"

#include <stdexcept>
#include <thread>
Expand Down Expand Up @@ -30,10 +31,10 @@ size_t ShadowContextSource::getDepth() {

void ShadowContextSource::exitScope(const Scope &scope) {
if (threadContextStack[this].empty()) {
throw std::runtime_error("Context stack is empty");
throw makeLogicError("Context stack is empty");
}
if (threadContextStack[this].back() != scope) {
throw std::runtime_error("Context stack is not balanced");
throw makeLogicError("Context stack is not balanced");
}
threadContextStack[this].pop_back();
}
Expand Down
7 changes: 4 additions & 3 deletions third_party/proton/csrc/lib/Data/Data.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "Data/Data.h"
#include "Utility/Errors.h"
#include "Utility/String.h"

#include <fstream>
Expand Down Expand Up @@ -163,7 +164,7 @@ OutputFormat parseOutputFormat(const std::string &outputFormat) {
} else if (toLower(outputFormat) == "chrome_trace") {
return OutputFormat::ChromeTrace;
} else {
throw std::runtime_error("Unknown output format: " + outputFormat);
throw makeInvalidArgument("Unknown output format: " + outputFormat);
}
}

Expand All @@ -175,8 +176,8 @@ const std::string outputFormatToString(OutputFormat outputFormat) {
} else if (outputFormat == OutputFormat::ChromeTrace) {
return "chrome_trace";
}
throw std::runtime_error("Unknown output format: " +
std::to_string(static_cast<int>(outputFormat)));
throw makeInvalidArgument("Unknown output format: " +
std::to_string(static_cast<int>(outputFormat)));
}

} // namespace proton
37 changes: 18 additions & 19 deletions third_party/proton/csrc/lib/Data/Metric.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "Data/Metric.h"
#include "Utility/Errors.h"

#include <cstring>
#include <stdexcept>
Expand Down Expand Up @@ -46,16 +47,15 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name,
if (nameIt != metricNameToId.end()) {
auto &descriptor = metricDescriptors.at(nameIt->second);
if (descriptor.typeIndex != typeIndex) {
throw std::runtime_error(
"[PROTON] MetricBuffer: type mismatch for metric " + name +
throw makeInvalidArgument(
"MetricBuffer: type mismatch for metric " + name +
": current=" + getTypeNameForIndex(descriptor.typeIndex) +
", new=" + getTypeNameForIndex(typeIndex));
}
if (descriptor.size != size) {
throw std::runtime_error(
"[PROTON] MetricBuffer: size mismatch for metric " + name +
": current=" + std::to_string(descriptor.size) +
", new=" + std::to_string(size));
throw makeInvalidArgument(
"MetricBuffer: size mismatch for metric " + name + ": current=" +
std::to_string(descriptor.size) + ", new=" + std::to_string(size));
}
return descriptor;
}
Expand All @@ -68,16 +68,16 @@ MetricBuffer::getOrCreateMetricDescriptor(const std::string &name,
if (nameIt != metricNameToId.end()) {
auto &descriptor = metricDescriptors.at(nameIt->second);
if (descriptor.typeIndex != typeIndex) {
throw std::runtime_error(
"[PROTON] MetricBuffer: type mismatch for metric " + name +
throw makeInvalidArgument(
"MetricBuffer: type mismatch for metric " + name +
": current=" + getTypeNameForIndex(descriptor.typeIndex) +
", new=" + getTypeNameForIndex(typeIndex));
}
if (descriptor.size != size) {
throw std::runtime_error(
"[PROTON] MetricBuffer: size mismatch for metric " + name +
": current=" + std::to_string(descriptor.size) +
", new=" + std::to_string(size));
throw makeInvalidArgument("MetricBuffer: size mismatch for metric " +
name +
": current=" + std::to_string(descriptor.size) +
", new=" + std::to_string(size));
}
return descriptor;
}
Expand Down Expand Up @@ -124,9 +124,8 @@ collectTensorMetrics(Runtime *runtime,
}
tensorMetricsHost[name] = std::move(values);
} else {
throw std::runtime_error(
"[PROTON] Unsupported tensor metric type index: " +
std::to_string(tensorMetric.typeIndex));
throw makeInvalidArgument("Unsupported tensor metric type index: " +
std::to_string(tensorMetric.typeIndex));
}
}
return tensorMetricsHost;
Expand Down Expand Up @@ -163,11 +162,11 @@ void MetricBuffer::queue(uint64_t seqId, MetricValueType scalarMetric,
[](auto &&value) -> uint64_t {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string>) {
throw std::runtime_error(
"[PROTON] String metrics are not supported in MetricBuffer");
throw makeInvalidArgument(
"String metrics are not supported in MetricBuffer");
} else if constexpr (is_std_vector_v<T>) {
throw std::runtime_error(
"[PROTON] Vector metrics are not supported in MetricBuffer");
throw makeInvalidArgument(
"Vector metrics are not supported in MetricBuffer");
} else {
static_assert(sizeof(T) == sizeof(uint64_t),
"MetricValueType alternative must be 8 bytes");
Expand Down
13 changes: 7 additions & 6 deletions third_party/proton/csrc/lib/Data/TraceData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "Context/Context.h"
#include "Profiler/Graph.h"
#include "TraceDataIO/TraceWriter.h"
#include "Utility/Errors.h"
#include "Utility/MsgPackWriter.h"
#include "nlohmann/json.hpp"

Expand Down Expand Up @@ -125,7 +126,7 @@ class TraceData::Trace {
std::vector<const TraceContext *> reversedContexts;
auto it = traceContextMap.find(contextId);
if (it == traceContextMap.end()) {
throw std::runtime_error("Context not found");
throw makeOutOfRange("Context not found");
}
auto *context = &it->second;
reversedContexts.push_back(context);
Expand Down Expand Up @@ -157,7 +158,7 @@ class TraceData::Trace {
Event &getEvent(size_t eventId) {
auto it = traceEvents.find(eventId);
if (it == traceEvents.end()) {
throw std::runtime_error("Event not found");
throw makeOutOfRange("Event not found");
}
return it->second;
}
Expand Down Expand Up @@ -697,7 +698,7 @@ void reconstructGraphScopeEvents(
continue;
}
if (!seenCaptureTag) {
throw std::runtime_error("Invalid graph contexts without capture tag");
throw makeLogicError("Invalid graph contexts without capture tag");
}
graphContexts.pop_back(); // Remove kernel name context
auto startTimeNs = std::get<uint64_t>(
Expand Down Expand Up @@ -883,7 +884,7 @@ void dumpCpuToGpuFlowEvents(
auto launchEventIt =
launchEventIdToCpuScopeEvent.find(event.launchEventId);
if (launchEventIt == launchEventIdToCpuScopeEvent.end()) {
throw std::runtime_error(
throw makeOutOfRange(
"Cannot find CPU scope event for kernel launch event id: " +
std::to_string(event.launchEventId));
}
Expand Down Expand Up @@ -1064,7 +1065,7 @@ void TraceData::dumpChromeTrace(std::ostream &os, size_t phase) const {
/*isGraphLinked=*/true);
}
if (hasKernelMetrics && hasCycleMetrics) {
throw std::runtime_error("only one active metric type is supported");
throw makeLogicError("only one active metric type is supported");
}
}
}
Expand Down Expand Up @@ -1113,7 +1114,7 @@ void TraceData::doDump(std::ostream &os, OutputFormat outputFormat,
if (outputFormat == OutputFormat::ChromeTrace) {
dumpChromeTrace(os, phase);
} else {
throw std::logic_error("Output format not supported");
throw makeInvalidArgument("Output format not supported");
}
}

Expand Down
Loading
Loading