Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/csharp/Config.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ public void RemoveModelData(string modelFilename)
Result.VerifySuccess(NativeMethods.OgaConfigRemoveModelData(_configHandle, StringUtils.ToUtf8(modelFilename)));
}

public void Overlay(string json)
{
Result.VerifySuccess(NativeMethods.OgaConfigOverlay(_configHandle, StringUtils.ToUtf8(json)));
}

public void SetDecoderProviderOptionsHardwareDeviceType(string provider, string hardware_device_type)
{
Result.VerifySuccess(NativeMethods.OgaConfigSetDecoderProviderOptionsHardwareDeviceType(_configHandle, StringUtils.ToUtf8(provider), StringUtils.ToUtf8(hardware_device_type)));
Expand Down
3 changes: 3 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ internal class NativeLib
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaConfigRemoveModelData(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ model_filename);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaConfigOverlay(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ json);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaConfigSetDecoderProviderOptionsHardwareDeviceType(IntPtr /* OgaConfig* */ config, byte[] /* const char* */ provider_name, byte[] /* const char* */ hardware_device_type);

Expand Down
14 changes: 14 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ public void setProviderOption(String providerName, String optionKey, String opti
setProviderOption(nativeHandle, providerName, optionKey, optionValue);
}

/**
* Overlay JSON on top of the config file
*
* @param json The JSON string to overlay
*/
public void overlay(String json) {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
overlay(nativeHandle, json);
}

@Override
public void close() {
if (nativeHandle != 0) {
Expand Down Expand Up @@ -85,4 +97,6 @@ long nativeHandle() {

private native void setProviderOption(
long configHandle, String providerName, String optionKey, String optionValue);

private native void overlay(long configHandle, String json);
}
8 changes: 8 additions & 0 deletions src/java/src/main/native/ai_onnxruntime_genai_Config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,11 @@ Java_ai_onnxruntime_genai_Config_setProviderOption(JNIEnv* env, jobject thiz, jl

ThrowIfError(env, OgaConfigSetProviderOption(config, c_provider_name, c_option_key, c_option_value));
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Config_overlay(JNIEnv* env, jobject thiz, jlong native_handle, jstring json) {
CString c_json{env, json};
OgaConfig* config = reinterpret_cast<OgaConfig*>(native_handle);

ThrowIfError(env, OgaConfigOverlay(config, c_json));
}
65 changes: 65 additions & 0 deletions src/objectivec/include/ort_genai_objc.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,63 @@ typedef NS_ENUM(NSInteger, OGAElementType) {
OGAElementTypeUint64, // maps to c type uint64_t
};

/**
* An ORT GenAI config.
*/
@interface OGAConfig : NSObject

- (instancetype)init NS_UNAVAILABLE;

/**
* Creates a config.
*
* @param path The path to the ONNX GenAI model folder.
* @return The instance, or nil if an error occurs.
*/
- (nullable instancetype)initWithPath:(NSString*)path
error:(NSError**)error NS_DESIGNATED_INITIALIZER;

/**
* Clear the list of providers in the given config
*
* @param error Optional error information set if an error occurs.
*/
- (BOOL)clearProvidersWithError:(NSError**)error;

/**
* Add the provider at the end of the list of providers in the given config if it doesn't already exist.
* If it already exists, do nothing.
*
* @param provider The provider to set on the config.
* @param error Optional error information set if an error occurs.
*/
- (BOOL)appendProvider:(NSString*)provider
error:(NSError**)error;

/**
* Set a provider option.
*
* @param provider The provider to set the option on
* @param key The key of the option to set
* @param value The value of the option to set
* @param error Optional error information set if an error occurs.
*/
- (BOOL)setProviderOption:(NSString*)provider
key:(NSString*)key
value:(NSString*)value
error:(NSError**)error;

/**
* Overlay JSON on top of config file
*
* @param json The JSON to overlay on the config.
* @param error Optional error information set if an error occurs.
*/
- (BOOL)overlay:(NSString*)json
error:(NSError**)error;

@end

/**
* An ORT GenAI model.
*/
Expand All @@ -63,6 +120,14 @@ typedef NS_ENUM(NSInteger, OGAElementType) {
- (nullable instancetype)initWithPath:(NSString*)path
error:(NSError**)error NS_DESIGNATED_INITIALIZER;

/**
* Creates a model.
*
* @param config The OGAConfig object
* @return The instance, or nil if an error occurs.
*/
- (nullable instancetype)initWithConfig:(OGAConfig*)config
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
@end

/**
Expand Down
60 changes: 60 additions & 0 deletions src/objectivec/oga_config.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#import "error_utils.h"
#import "oga_internal.h"
#import "ort_genai_objc.h"

@implementation OGAConfig {
std::unique_ptr<OgaConfig> _config;
}

- (nullable instancetype)initWithPath:(NSString*)path error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}

try {
_config = OgaConfig::Create([path UTF8String]);
return self;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}

- (BOOL)clearProvidersWithError:(NSError**)error {
try {
_config->ClearProviders();
return YES;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

- (BOOL)appendProvider:(NSString*)provider error:(NSError**)error {
try {
_config->AppendProvider([provider UTF8String]);
return YES;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

- (BOOL)setProviderOption:(NSString*)provider key:(NSString*)key value:(NSString*)value error:(NSError**)error {
try {
_config->SetProviderOption([provider UTF8String], [key UTF8String], [value UTF8String]);
return YES;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

- (BOOL)overlay:(NSString*)json error:(NSError**)error {
try {
_config->Overlay([json UTF8String]);
return YES;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}

- (const OgaConfig&)CXXAPIOgaConfig {
return *(_config.get());
}

@end
6 changes: 6 additions & 0 deletions src/objectivec/oga_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

NS_ASSUME_NONNULL_BEGIN

@interface OGAConfig ()

- (OgaConfig&)CXXAPIOgaConfig;

@end

@interface OGAModel ()

- (const OgaModel&)CXXAPIOgaModel;
Expand Down
12 changes: 12 additions & 0 deletions src/objectivec/oga_model.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ - (nullable instancetype)initWithPath:(NSString*)path error:(NSError**)error {
OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}

- (nullable instancetype)initWithConfig:(OGAConfig*)config error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}

try {
_model = OgaModel::Create([config CXXAPIOgaConfig]);
return self;
}
OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}

- (const OgaModel&)CXXAPIOgaModel {
return *(_model.get());
}
Expand Down
16 changes: 15 additions & 1 deletion src/objectivec/test/ort_genai_api_test.mm
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,20 @@ - (void)testTensor_And_AddExtraInput {

NSError* error = nil;
BOOL ret = NO;
OGAModel* model = [[OGAModel alloc] initWithPath:[ORTGenAIAPITest getModelPath] error:&error];

OGAConfig* config = [[OGAConfig alloc] initWithPath:[ORTGenAIAPITest getModelPath] error:&error];
ORTAssertNullableResultSuccessful(config, error);

ret = [config clearProvidersWithError:&error];
ORTAssertBoolResultSuccessful(ret, error);

ret = [config appendProvider:@"cpu" error:&error];
ORTAssertBoolResultSuccessful(ret, error);

ret = [config overlay:@"{'num_beams': 1}" error:&error];
ORTAssertBoolResultSuccessful(ret, error);

OGAModel* model = [[OGAModel alloc] initWithConfig:config error:&error];
ORTAssertNullableResultSuccessful(model, error);

OGAGeneratorParams* params = [[OGAGeneratorParams alloc] initWithModel:model error:&error];
Expand All @@ -65,6 +78,7 @@ - (void)testGetOutput {

NSError* error = nil;
BOOL ret = NO;

OGAModel* model = [[OGAModel alloc] initWithPath:[ORTGenAIAPITest getModelPath] error:&error];
ORTAssertNullableResultSuccessful(model, error);

Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateConfig(const char* config_path, OgaC
OGA_EXPORT OgaResult* OGA_API_CALL OgaConfigClearProviders(OgaConfig* config);

/**
* \brief Add the provider at the end of the list of providers in the given config if it doesn't already exist
* if it already exists, does nothing.
* \brief Add the provider at the end of the list of providers in the given config if it doesn't already exist.
* If it already exists, do nothing.
* \param[in] config The config to set the provider on.
* \param[in] provider The provider to set on the config.
* \return OgaResult containing the error message if the setting of the provider failed.
Expand Down
Loading