diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index fc0b5d0d89..fe0895744b 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -158,7 +158,7 @@ static string GetPrompt(bool interactive) while (!generator.IsDone()) { generator.GenerateNextToken(); - Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); + Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); watch.Stop(); @@ -195,7 +195,7 @@ static string GetPrompt(bool interactive) while (!generator.IsDone()) { generator.GenerateNextToken(); - Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); + Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); watch.Stop(); diff --git a/examples/csharp/HelloPhi3V/Program.cs b/examples/csharp/HelloPhi3V/Program.cs index 71766c575d..09e1038bd5 100644 --- a/examples/csharp/HelloPhi3V/Program.cs +++ b/examples/csharp/HelloPhi3V/Program.cs @@ -175,7 +175,7 @@ void PrintUsage() { break; } - Console.Write(stream.Decode(generator.GetSequence(0)[^1])); + Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); var runTimeInSeconds = watch.Elapsed.TotalSeconds; diff --git a/examples/csharp/HelloPhi4MM/Program.cs b/examples/csharp/HelloPhi4MM/Program.cs index d10b88b00d..bcec8d714f 100644 --- a/examples/csharp/HelloPhi4MM/Program.cs +++ b/examples/csharp/HelloPhi4MM/Program.cs @@ -222,7 +222,7 @@ void PrintUsage() { break; } - Console.Write(stream.Decode(generator.GetSequence(0)[^1])); + Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); var runTimeInSeconds = watch.Elapsed.TotalSeconds; diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index f9d593db1f..0c7eb31d81 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -61,6 +61,15 @@ public void RewindTo(ulong newLength) Result.VerifySuccess(NativeMethods.OgaGenerator_RewindTo(_generatorHandle, (UIntPtr)newLength)); } + public ReadOnlySpan GetNextTokens() + { + Result.VerifySuccess(NativeMethods.OgaGenerator_GetNextTokens(_generatorHandle, out IntPtr tokenIds, out UIntPtr tokenCount)); + unsafe + { + return new ReadOnlySpan(tokenIds.ToPointer(), (int)tokenCount.ToUInt64()); + } + } + public ReadOnlySpan GetSequence(ulong index) { ulong sequenceLength = NativeMethods.OgaGenerator_GetSequenceCount(_generatorHandle, (UIntPtr)index).ToUInt64(); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index c22b9a790a..ce505b6ec1 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -125,6 +125,11 @@ internal class NativeLib [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern byte OgaGenerator_IsDone(IntPtr /* const OgaGenerator* */ generator); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGenerator_GetNextTokens(IntPtr /* const OgaGenerator* */ generator, + out IntPtr /* const int32_t** */ outTokenIds, + out UIntPtr /* size_t* */ outTokenCount); + // This function is used to generate the next token in the sequence using the greedy search algorithm. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken(IntPtr /* OgaGenerator* */ generator);