Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,27 @@ Status TensorProtoWithExternalDataToTensorProto(
return Status::OK();
}

Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
const std::filesystem::path& location) {
// Reject absolute paths
ORT_RETURN_IF(location.is_absolute(),
"Absolute paths not allowed for external data location");
if (!base_dir.empty()) {
// Resolve and verify the path stays within model directory
auto base_canonical = std::filesystem::weakly_canonical(base_dir);
// If the symlink exists, it resolves to the target path;
// so if the symllink is outside the directory it would be caught here.
auto resolved = std::filesystem::weakly_canonical(base_dir / location);
// Check that resolved path starts with base directory
auto [base_end, resolved_it] = std::mismatch(
base_canonical.begin(), base_canonical.end(),
resolved.begin(), resolved.end());
ORT_RETURN_IF(base_end != base_canonical.end(),
"External data path: ", location, " escapes model directory: ", base_dir);
}
return Status::OK();
}

Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,
std::basic_string<ORTCHAR_T>& external_file_path,
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,18 @@ Status TensorProtoWithExternalDataToTensorProto(
const std::filesystem::path& model_path,
ONNX_NAMESPACE::TensorProto& new_tensor_proto);

/// <summary>
/// The functions will make sure the 'location' specified in the external data is under the 'base_dir'.
/// If the `base_dir` is empty, the function only ensures that `location` is not an absolute path.
/// </summary>
/// <param name="base_dir">model location directory</param>
/// <param name="location">location is a string retrieved from TensorProto external data that is not
/// an in-memory tag</param>
/// <returns>The function will fail if the resolved full path is not under the model directory
/// or one of the subdirectories</returns>
Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
const std::filesystem::path& location);

#endif // !defined(SHARED_PROVIDER)

inline bool HasType(const ONNX_NAMESPACE::AttributeProto& at_proto) {
Expand Down
20 changes: 17 additions & 3 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3724,9 +3724,14 @@ Status Graph::ConvertInitializersIntoOrtValues() {
std::vector<Graph*> all_subgraphs;
FindAllSubgraphs(all_subgraphs);

const auto& model_path = GetModel().ModelPath();
PathString model_dir;
if (!model_path.empty()) {
ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, model_dir));
}

auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
const auto& model_path = graph.ModelPath();
auto& graph_proto = *graph.graph_proto_;
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
auto& tensor_proto = *graph_proto.mutable_initializer(i);
Expand All @@ -3744,9 +3749,18 @@ Status Graph::ConvertInitializersIntoOrtValues() {
"The model contains initializers with arbitrary in-memory references.",
"This is an invalid model.");
}
} else {
// Validate external data location
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
const auto& location = external_data_info->GetRelPath();
auto st = utils::ValidateExternalDataPath(model_dir, location);
if (!st.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"External data path validation failed for initializer: ", tensor_proto.name(),
". Error: ", st.ErrorMessage());
}
}
// ignore data on disk, that will be loaded either by EP or at session_state finalize
// ignore valid in-memory references
continue;
}

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ inline bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto
return g_host->Utils__HasExternalDataInMemory(ten_proto);
}

inline Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
const std::filesystem::path& location) {
return g_host->Utils__ValidateExternalDataPath(base_dir, location);
}

} // namespace utils

namespace graph_utils {
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct ProviderHost;
struct ProviderHostCPU;

class ExternalDataInfo;

class PhiloxGenerator;
using ProviderType = const std::string&;
class RandomGenerator;
Expand Down Expand Up @@ -999,6 +1000,9 @@ struct ProviderHost {

virtual bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) = 0;

virtual Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path,
const std::filesystem::path& location) = 0;

// Model
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
Expand Down Expand Up @@ -1136,6 +1140,15 @@ struct ProviderHost {

virtual Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) = 0;

// ExternalDataInfo
virtual void ExternalDataInfo__operator_delete(ExternalDataInfo*) = 0;
virtual const PathString& ExternalDataInfo__GetRelPath(const ExternalDataInfo*) const = 0;
virtual int64_t ExternalDataInfo__GetOffset(const ExternalDataInfo*) const = 0;
virtual size_t ExternalDataInfo__GetLength(const ExternalDataInfo*) const = 0;
virtual const std::string& ExternalDataInfo__GetChecksum(const ExternalDataInfo*) const = 0;
virtual Status ExternalDataInfo__Create(const ONNX_NAMESPACE::StringStringEntryProtos& input,
std::unique_ptr<ExternalDataInfo>& out) = 0;

// Initializer
virtual Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type,
std::string_view name,
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,39 @@ struct ConstGraphNodes final {
PROVIDER_DISALLOW_ALL(ConstGraphNodes)
};

class ExternalDataInfo {
public:
static void operator delete(void* p) {
g_host->ExternalDataInfo__operator_delete(reinterpret_cast<ExternalDataInfo*>(p));
}

const PathString& GetRelPath() const {
return g_host->ExternalDataInfo__GetRelPath(this);
}

int64_t GetOffset() const {
return g_host->ExternalDataInfo__GetOffset(this);
}

size_t GetLength() const {
return g_host->ExternalDataInfo__GetLength(this);
}

const std::string& GetChecksum() const {
return g_host->ExternalDataInfo__GetChecksum(this);
}

static Status Create(
const ONNX_NAMESPACE::StringStringEntryProtos& input,
std::unique_ptr<ExternalDataInfo>& out) {
return g_host->ExternalDataInfo__Create(input, out);
}

ExternalDataInfo() = delete;
ExternalDataInfo(const ExternalDataInfo&) = delete;
ExternalDataInfo& operator=(const ExternalDataInfo& v) = delete;
};

class Initializer {
public:
Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type,
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "core/framework/run_options.h"
#include "core/framework/sparse_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/tensor_external_data_info.h"
#include "core/framework/TensorSeq.h"
#include "core/graph/constants.h"
#include "core/graph/graph_proto_serializer.h"
Expand Down Expand Up @@ -1281,6 +1282,11 @@ struct ProviderHostImpl : ProviderHost {
return onnxruntime::utils::HasExternalDataInMemory(ten_proto);
}

Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path,
const std::filesystem::path& location) override {
return onnxruntime::utils::ValidateExternalDataPath(base_path, location);
}

// Model (wrapped)
std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
Expand Down Expand Up @@ -1487,6 +1493,25 @@ struct ProviderHostImpl : ProviderHost {
graph_utils::MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_in_memory);
}

// ExternalDataInfo (wrapped)
void ExternalDataInfo__operator_delete(ExternalDataInfo* p) override { delete p; }
const PathString& ExternalDataInfo__GetRelPath(const ExternalDataInfo* p) const override {
return p->GetRelPath();
}
int64_t ExternalDataInfo__GetOffset(const ExternalDataInfo* p) const override {
return narrow<int64_t>(p->GetOffset());
}
size_t ExternalDataInfo__GetLength(const ExternalDataInfo* p) const override {
return p->GetLength();
}
const std::string& ExternalDataInfo__GetChecksum(const ExternalDataInfo* p) const override {
return p->GetChecksum();
}
Status ExternalDataInfo__Create(const ONNX_NAMESPACE::StringStringEntryProtos& input,
std::unique_ptr<ExternalDataInfo>& out) override {
return ExternalDataInfo::Create(input, out);
}

// Initializer (wrapped)
Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type,
std::string_view name,
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/framework/tensorutils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <cstdint>
#include <limits>
#include <fstream>

#include "gtest/gtest.h"
#include "gmock/gmock.h"
Expand Down Expand Up @@ -502,5 +503,88 @@ TEST(TensorProtoUtilsTest, ConstantTensorProtoWithExternalData) {
TestConstantNodeConversionWithExternalData<float>(TensorProto_DataType_FLOAT);
TestConstantNodeConversionWithExternalData<double>(TensorProto_DataType_DOUBLE);
}

// Test fixture for creating temporary directories and files for path validation tests.
class PathValidationTest : public ::testing::Test {
protected:
void SetUp() override {
// Create a temporary directory for the tests.
base_dir_ = std::filesystem::temp_directory_path() / "PathValidationTest";
outside_dir_ = std::filesystem::temp_directory_path() / "outside";
std::filesystem::create_directories(base_dir_);
std::filesystem::create_directories(outside_dir_);
}

void TearDown() override {
// Clean up the temporary directory.
std::filesystem::remove_all(base_dir_);
std::filesystem::remove_all(outside_dir_);
}

std::filesystem::path base_dir_;
std::filesystem::path outside_dir_;
};

// Test cases for ValidateExternalDataPath.
TEST_F(PathValidationTest, ValidateExternalDataPath) {
// Valid relative path.
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "data.bin"));

// Empty base directory.
ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "data.bin"));

// Empty location.
// Only validate it is not an absolute path.
ASSERT_TRUE(utils::ValidateExternalDataPath(base_dir_, "").IsOK());

// Path with ".." that escapes the base directory.
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "../data.bin").IsOK());

// Absolute path.
#ifdef _WIN32
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "C:\\data.bin").IsOK());
ASSERT_FALSE(utils::ValidateExternalDataPath("", "C:\\data.bin").IsOK());
#else
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "/data.bin").IsOK());
ASSERT_FALSE(utils::ValidateExternalDataPath("", "/data.bin").IsOK());
#endif // Absolute path.

// Windows vs Unix path separators.
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "sub/data.bin"));
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "sub\\data.bin"));

// Base directory does not exist.
ASSERT_STATUS_OK(utils::ValidateExternalDataPath("non_existent_dir", "data.bin"));
}

TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkInside) {
// Symbolic link that points inside the base directory.
try {
auto target = base_dir_ / "target.bin";
std::ofstream{target};
auto link = base_dir_ / "link.bin";
std::filesystem::create_symlink(target, link);
} catch (const std::exception& e) {
GTEST_SKIP() << "Skipping symlink tests since symlink creation is not supported in this environment. Exception: "
<< e.what();
}
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "link.bin"));
}

TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkOutside) {
// Symbolic link that points outside the base directory.
auto outside_target = outside_dir_ / "outside.bin";
try {
{
std::ofstream{outside_target};
auto outside_link = base_dir_ / "outside_link.bin";
std::filesystem::create_symlink(outside_target, outside_link);
}
} catch (const std::exception& e) {
GTEST_SKIP() << "Skipping symlink tests since symlink creation is not supported in this environment. Exception: " << e.what();
}
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "outside_link.bin").IsOK());
}

} // namespace test
} // namespace onnxruntime
39 changes: 36 additions & 3 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4762,9 +4762,9 @@ TEST(CApiTest, custom_cast) {
custom_op_domain, nullptr);
}

TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
TEST(CApiTest, ModelWithMaliciousExternalDataInMemoryShouldFailToLoad) {
// Attempt to create an ORT session with the malicious model
// This should fail due to the invalid external data reference
// This should fail due to the invalid external in-memory reference
constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_evil_weights.onnx");

Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Expand All @@ -4785,7 +4785,7 @@ TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
}

// Verify that loading the model failed
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data";
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious in-memory data";

// Verify that the exception message indicates security or external data issues
EXPECT_TRUE(exception_message.find("in-memory") != std::string::npos ||
Expand All @@ -4794,3 +4794,36 @@ TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
exception_message.find("model") != std::string::npos)
<< "Exception message should indicate external data or security issue. Got: " << exception_message;
}

TEST(CApiTest, ModelWithExternalDataOutsideModelDirectoryShouldFailToLoad) {
// Attempt to create an ORT session with the malicious model
// This should fail due to the external file that is not under model directory structure
// i.e. ../../../../etc/passwd
constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_arbitrary_external_file.onnx");

Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;

bool exception_thrown = false;
std::string exception_message;

try {
// This should throw an exception due to malicious external data
Ort::Session session(env, model_path, session_options);
} catch (const Ort::Exception& e) {
exception_thrown = true;
exception_message = e.what();
} catch (const std::exception& e) {
exception_thrown = true;
exception_message = e.what();
}

// Verify that loading the model failed
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data";

// Verify that the exception message indicates security or external data issues
EXPECT_TRUE(exception_message.find("External data path escapes model directory") != std::string::npos ||
exception_message.find("invalid") != std::string::npos ||
exception_message.find("model") != std::string::npos)
<< "Exception message should indicate external data or security issue. Got: " << exception_message;
}
Binary file not shown.
Loading
Loading