Skip to content

Commit

Permalink
[Support] Validate number of arguments passed to formatv() (#105745)
Browse files Browse the repository at this point in the history
Change formatv() to validate that the number of arguments passed matches
number of replacement fields in the format string, and that the replacement
indices do not contain holes.

To support cases where this cannot be guaranteed, introduce a formatv()
overload that allows disabling validation with a bool flag as its first argument.
  • Loading branch information
jurahul authored Aug 29, 2024
1 parent 9edd998 commit fc11020
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,10 @@ void StdLibraryFunctionsChecker::checkPostCall(const CallEvent &Call,
ErrnoNote =
llvm::formatv("After calling '{0}' {1}", FunctionName, ErrnoNote);
} else {
CaseNote = llvm::formatv(Case.getNote().str().c_str(), FunctionName);
// Disable formatv() validation as the case note may not always have the
// {0} placeholder for function name.
CaseNote =
llvm::formatv(false, Case.getNote().str().c_str(), FunctionName);
}
const SVal RV = Call.getReturnValue();

Expand Down
1 change: 1 addition & 0 deletions llvm/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ set(LLVM_LINK_COMPONENTS
add_benchmark(DummyYAML DummyYAML.cpp PARTIAL_SOURCES_INTENDED)
add_benchmark(xxhash xxhash.cpp PARTIAL_SOURCES_INTENDED)
add_benchmark(GetIntrinsicForClangBuiltin GetIntrinsicForClangBuiltin.cpp PARTIAL_SOURCES_INTENDED)
add_benchmark(FormatVariadicBM FormatVariadicBM.cpp PARTIAL_SOURCES_INTENDED)
63 changes: 63 additions & 0 deletions llvm/benchmarks/FormatVariadicBM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//===- FormatVariadicBM.cpp - formatv() benchmark ---------- --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "benchmark/benchmark.h"
#include "llvm/Support/FormatVariadic.h"
#include <algorithm>
#include <string>
#include <vector>

using namespace llvm;
using namespace std;

// Generate a list of format strings that have `NumReplacements` replacements
// by permuting the replacements and some literal text.
static vector<string> getFormatStrings(int NumReplacements) {
vector<string> Components;
for (int I = 0; I < NumReplacements; I++)
Components.push_back("{" + to_string(I) + "}");
// Intersperse these with some other literal text (_).
const string_view Literal = "____";
for (char C : Literal)
Components.push_back(string(1, C));

vector<string> Formats;
do {
string Concat;
for (const string &C : Components)
Concat += C;
Formats.emplace_back(Concat);
} while (next_permutation(Components.begin(), Components.end()));
return Formats;
}

// Generate the set of formats to exercise outside the benchmark code.
static const vector<vector<string>> Formats = {
getFormatStrings(1), getFormatStrings(2), getFormatStrings(3),
getFormatStrings(4), getFormatStrings(5),
};

// Benchmark formatv() for a variety of format strings and 1-5 replacements.
static void BM_FormatVariadic(benchmark::State &state) {
for (auto _ : state) {
for (const string &Fmt : Formats[0])
formatv(Fmt.c_str(), 1).str();
for (const string &Fmt : Formats[1])
formatv(Fmt.c_str(), 1, 2).str();
for (const string &Fmt : Formats[2])
formatv(Fmt.c_str(), 1, 2, 3).str();
for (const string &Fmt : Formats[3])
formatv(Fmt.c_str(), 1, 2, 3, 4).str();
for (const string &Fmt : Formats[4])
formatv(Fmt.c_str(), 1, 2, 3, 4, 5).str();
}
}

BENCHMARK(BM_FormatVariadic);

BENCHMARK_MAIN();
39 changes: 22 additions & 17 deletions llvm/include/llvm/Support/FormatVariadic.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,20 @@ class formatv_object_base {
protected:
StringRef Fmt;
ArrayRef<support::detail::format_adapter *> Adapters;

static bool consumeFieldLayout(StringRef &Spec, AlignStyle &Where,
size_t &Align, char &Pad);

static std::pair<ReplacementItem, StringRef>
splitLiteralAndReplacement(StringRef Fmt);
bool Validate;

formatv_object_base(StringRef Fmt,
ArrayRef<support::detail::format_adapter *> Adapters)
: Fmt(Fmt), Adapters(Adapters) {}
ArrayRef<support::detail::format_adapter *> Adapters,
bool Validate)
: Fmt(Fmt), Adapters(Adapters), Validate(Validate) {}

formatv_object_base(formatv_object_base const &rhs) = delete;
formatv_object_base(formatv_object_base &&rhs) = default;

public:
void format(raw_ostream &S) const {
for (auto &R : parseFormatString(Fmt)) {
const auto Replacements = parseFormatString(Fmt, Adapters.size(), Validate);
for (const auto &R : Replacements) {
if (R.Type == ReplacementType::Empty)
continue;
if (R.Type == ReplacementType::Literal) {
Expand All @@ -101,9 +98,10 @@ class formatv_object_base {
Align.format(S, R.Options);
}
}
static SmallVector<ReplacementItem, 2> parseFormatString(StringRef Fmt);

static std::optional<ReplacementItem> parseReplacementItem(StringRef Spec);
// Parse and optionally validate format string (in debug builds).
static SmallVector<ReplacementItem, 2>
parseFormatString(StringRef Fmt, size_t NumArgs, bool Validate);

std::string str() const {
std::string Result;
Expand Down Expand Up @@ -149,8 +147,8 @@ template <typename Tuple> class formatv_object : public formatv_object_base {
};

public:
formatv_object(StringRef Fmt, Tuple &&Params)
: formatv_object_base(Fmt, ParameterPointers),
formatv_object(StringRef Fmt, Tuple &&Params, bool Validate)
: formatv_object_base(Fmt, ParameterPointers, Validate),
Parameters(std::move(Params)) {
ParameterPointers = std::apply(create_adapters(), Parameters);
}
Expand Down Expand Up @@ -247,15 +245,22 @@ template <typename Tuple> class formatv_object : public formatv_object_base {
// assertion. Otherwise, it will try to do something reasonable, but in general
// the details of what that is are undefined.
//

// formatv() with validation enable/disable controlled by the first argument.
template <typename... Ts>
inline auto formatv(const char *Fmt, Ts &&...Vals)
inline auto formatv(bool Validate, const char *Fmt, Ts &&...Vals)
-> formatv_object<decltype(std::make_tuple(
support::detail::build_format_adapter(std::forward<Ts>(Vals))...))> {
using ParamTuple = decltype(std::make_tuple(
support::detail::build_format_adapter(std::forward<Ts>(Vals))...));
return formatv_object<ParamTuple>(
Fmt, std::make_tuple(support::detail::build_format_adapter(
std::forward<Ts>(Vals))...));
auto Params = std::make_tuple(
support::detail::build_format_adapter(std::forward<Ts>(Vals))...);
return formatv_object<ParamTuple>(Fmt, std::move(Params), Validate);
}

// formatv() with validation enabled.
template <typename... Ts> inline auto formatv(const char *Fmt, Ts &&...Vals) {
return formatv<Ts...>(true, Fmt, std::forward<Ts>(Vals)...);
}

} // end namespace llvm
Expand Down
85 changes: 72 additions & 13 deletions llvm/lib/Support/FormatVariadic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ static std::optional<AlignStyle> translateLocChar(char C) {
LLVM_BUILTIN_UNREACHABLE;
}

bool formatv_object_base::consumeFieldLayout(StringRef &Spec, AlignStyle &Where,
size_t &Align, char &Pad) {
static bool consumeFieldLayout(StringRef &Spec, AlignStyle &Where,
size_t &Align, char &Pad) {
Where = AlignStyle::Right;
Align = 0;
Pad = ' ';
Expand All @@ -35,8 +35,7 @@ bool formatv_object_base::consumeFieldLayout(StringRef &Spec, AlignStyle &Where,

if (Spec.size() > 1) {
// A maximum of 2 characters at the beginning can be used for something
// other
// than the width.
// other than the width.
// If Spec[1] is a loc char, then Spec[0] is a pad char and Spec[2:...]
// contains the width.
// Otherwise, if Spec[0] is a loc char, then Spec[1:...] contains the width.
Expand All @@ -55,8 +54,7 @@ bool formatv_object_base::consumeFieldLayout(StringRef &Spec, AlignStyle &Where,
return !Failed;
}

std::optional<ReplacementItem>
formatv_object_base::parseReplacementItem(StringRef Spec) {
static std::optional<ReplacementItem> parseReplacementItem(StringRef Spec) {
StringRef RepString = Spec.trim("{}");

// If the replacement sequence does not start with a non-negative integer,
Expand All @@ -82,15 +80,14 @@ formatv_object_base::parseReplacementItem(StringRef Spec) {
RepString = StringRef();
}
RepString = RepString.trim();
if (!RepString.empty()) {
assert(false && "Unexpected characters found in replacement string!");
}
assert(RepString.empty() &&
"Unexpected characters found in replacement string!");

return ReplacementItem{Spec, Index, Align, Where, Pad, Options};
}

std::pair<ReplacementItem, StringRef>
formatv_object_base::splitLiteralAndReplacement(StringRef Fmt) {
static std::pair<ReplacementItem, StringRef>
splitLiteralAndReplacement(StringRef Fmt) {
while (!Fmt.empty()) {
// Everything up until the first brace is a literal.
if (Fmt.front() != '{') {
Expand Down Expand Up @@ -143,15 +140,77 @@ formatv_object_base::splitLiteralAndReplacement(StringRef Fmt) {
return std::make_pair(ReplacementItem{Fmt}, StringRef());
}

#ifndef NDEBUG
#define ENABLE_VALIDATION 1
#else
#define ENABLE_VALIDATION 0 // Conveniently enable validation in release mode.
#endif

SmallVector<ReplacementItem, 2>
formatv_object_base::parseFormatString(StringRef Fmt) {
formatv_object_base::parseFormatString(StringRef Fmt, size_t NumArgs,
bool Validate) {
SmallVector<ReplacementItem, 2> Replacements;
ReplacementItem I;

#if ENABLE_VALIDATION
const StringRef SavedFmtStr = Fmt;
size_t NumExpectedArgs = 0;
#endif

while (!Fmt.empty()) {
ReplacementItem I;
std::tie(I, Fmt) = splitLiteralAndReplacement(Fmt);
if (I.Type != ReplacementType::Empty)
Replacements.push_back(I);
#if ENABLE_VALIDATION
if (I.Type == ReplacementType::Format)
NumExpectedArgs = std::max(NumExpectedArgs, I.Index + 1);
#endif
}

#if ENABLE_VALIDATION
if (!Validate)
return Replacements;

// Perform additional validation. Verify that the number of arguments matches
// the number of replacement indices and that there are no holes in the
// replacement indices.

// When validation fails, return an array of replacement items that
// will print an error message as the outout of this formatv() (used when
// validation is enabled in release mode).
auto getErrorReplacements = [SavedFmtStr](StringLiteral ErrorMsg) {
return SmallVector<ReplacementItem, 2>{
ReplacementItem("Invalid formatv() call: "), ReplacementItem(ErrorMsg),
ReplacementItem(" for format string: "), ReplacementItem(SavedFmtStr)};
};

if (NumExpectedArgs != NumArgs) {
errs() << formatv(
"Expected {0} Args, but got {1} for format string '{2}'\n",
NumExpectedArgs, NumArgs, SavedFmtStr);
assert(0 && "Invalid formatv() call");
return getErrorReplacements("Unexpected number of arguments");
}

// Find the number of unique indices seen. All replacement indices
// are < NumExpectedArgs.
SmallVector<bool> Indices(NumExpectedArgs);
size_t Count = 0;
for (const ReplacementItem &I : Replacements) {
if (I.Type != ReplacementType::Format || Indices[I.Index])
continue;
Indices[I.Index] = true;
++Count;
}

if (Count != NumExpectedArgs) {
errs() << formatv(
"Replacement field indices cannot have holes for format string '{0}'\n",
SavedFmtStr);
assert(0 && "Invalid format string");
return getErrorReplacements("Replacement indices have holes");
}
#endif // ENABLE_VALIDATION
return Replacements;
}

Expand Down
Loading

0 comments on commit fc11020

Please sign in to comment.