Skip to content

Commit

Permalink
Support RWKV4 Raven 1B5-14B (CPU/GPU)
Browse files Browse the repository at this point in the history
  • Loading branch information
xcssa committed May 2, 2023
1 parent 071bc3a commit eade853
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 105 deletions.
6 changes: 6 additions & 0 deletions CRWKV.sln
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@ EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Debug|x64 = Debug|x64
Release|Any CPU = Release|Any CPU
Release|x64 = Release|x64
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|x64.ActiveCfg = Debug|x64
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Debug|x64.Build.0 = Debug|x64
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|Any CPU.Build.0 = Release|Any CPU
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|x64.ActiveCfg = Release|x64
{D2C9B375-1A91-47C9-9953-5D9B7289AEBB}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
26 changes: 14 additions & 12 deletions CRWKV/CRWKV.csproj
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IncludeNativeLibrariesForSelfExtract>true</IncludeNativeLibrariesForSelfExtract>
</PropertyGroup>
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IncludeNativeLibrariesForSelfExtract>true</IncludeNativeLibrariesForSelfExtract>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.14.1" />
<PackageReference Include="Seq2SeqSharp" Version="2.5.0" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="MathNet.Numerics" Version="5.0.0" />
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="1.14.1" />
<PackageReference Include="Seq2SeqSharp" Version="2.5.0" />
</ItemGroup>

</Project>
29 changes: 9 additions & 20 deletions CRWKV/Program.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
using RWKV;

Console.Write("Input Model Name(rwkv-4-pile-169m-uint8.onnx): ");
Console.Write("Input Model Name(RWKV_32_2560_16.onnx): ");
var modelName = Console.ReadLine();
if (string.IsNullOrEmpty(modelName))
modelName = "rwkv-4-pile-169m-uint8.onnx";
modelName = "RWKV_32_2560_16.onnx";

Console.Write("ctx_len(1024): ");
var ctx_len = 1024;
var ctx_len_str = Console.ReadLine();
if (!string.IsNullOrEmpty(ctx_len_str))
ctx_len = int.Parse(ctx_len_str);
var modelNames = modelName.Split("_");
var n_layer = int.Parse(modelNames[1]);
var n_embd = int.Parse(modelNames[2]);

Console.Write("n_layer(12): ");
var n_layer = 12;
var n_layer_str = Console.ReadLine();
if (!string.IsNullOrEmpty(n_layer_str))
n_layer = int.Parse(n_layer_str);
Console.WriteLine($"Loading...");

Console.Write("n_embd(768): ");
var n_embd = 768;
var n_embd_str = Console.ReadLine();
if (!string.IsNullOrEmpty(n_embd_str))
n_embd = int.Parse(n_embd_str);
var rf = new RunnerFactory(modelName, n_layer, n_embd);
rf.Init();
var r = rf.NewRunner();

Console.WriteLine($"Loading({modelName})[{ctx_len},{n_layer},{n_embd}]...");
var r = new Runner(modelName, ctx_len, n_layer, n_embd);
r.Init();
while (true)
{
Console.Write(">");
Expand Down
158 changes: 158 additions & 0 deletions CRWKV/RWKV/OnnxModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

namespace RWKV
{
public enum OnnxModelType
{
FP16,
FP32
}

public class OnnxModel
{
private InferenceSession _inferenceSession;
private Type _type;
private int _embed;
private int _layers;
private List<string> _input_names;
private List<string> _output_names;
private List<NamedOnnxValue> _inputs;
private OnnxModelType _modelType;

public OnnxModelType ModelType => _modelType;

public OnnxModel(string model, int embed, int layers)
{
var options = new SessionOptions();
options.AppendExecutionProvider_CPU();
options.AppendExecutionProvider_CUDA();
_inferenceSession = new InferenceSession(model, options);
_type = _inferenceSession.InputMetadata["instate0"].ElementType;
_embed = embed;
_layers = layers;
_input_names = _inferenceSession.InputMetadata.Select(x => x.Key).ToList();
_output_names = _inferenceSession.OutputMetadata.Select(x => x.Key).ToList();
_inputs = new List<NamedOnnxValue>();

if (_type == typeof(Float16))
{
_modelType = OnnxModelType.FP16;
}
else if (_type == typeof(float))
{
_modelType = OnnxModelType.FP32;
}
else
{
throw new NotSupportedException();
}
}

public object GetEmptyStates()
{
switch (_modelType)
{
case OnnxModelType.FP16:
{
var state = new List<Tensor<Float16>>();
for (int i = 0; i < _layers; i++)
{
state.Add(GDenseTensor<Float16>(0));
state.Add(GDenseTensor<Float16>(0));
state.Add(GDenseTensor<Float16>(0));
state.Add(GDenseTensor<Float16>(0));
state.Add(GDenseTensor<Float16>(64512));
}
return state;
}
case OnnxModelType.FP32:
{
var state = new List<Tensor<float>>();
for (int i = 0; i < _layers; i++)
{
state.Add(GDenseTensor<float>(0));
state.Add(GDenseTensor<float>(0));
state.Add(GDenseTensor<float>(0));
state.Add(GDenseTensor<float>(0));
state.Add(GDenseTensor<float>(float.NegativeInfinity));
}
return state;
};
default:
throw new NotSupportedException();
}
}

public (IEnumerable<float> logits, object state) Forward(int xi, object state)
{
switch (_modelType)
{
case OnnxModelType.FP16:
{
var ret = Forward_FP16(xi, (List<Tensor<Float16>>)state);
return (ret.logits.Select(x => HalfToSinglePrecision(x)).AsEnumerable(), ret.state);
}
case OnnxModelType.FP32:
{
var ret = Forward_FP32(xi, (List<Tensor<float>>)state);
return (ret.logits.AsEnumerable(), ret.state);
}
default:
throw new NotSupportedException();
}
}

private (Tensor<Float16> logits, IList<Tensor<Float16>> state) Forward_FP16(int xi, List<Tensor<Float16>> state)
{
_inputs.Clear();
var input = new DenseTensor<int>(new[] { xi }, new[] { 1 });
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input));
for (int i = 1; i < _input_names.Count; i++)
{
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1]));
}
var data = _inferenceSession.Run(_inputs);
return (data.First().AsTensor<Float16>(), data.Skip(1).Select(x => x.AsTensor<Float16>()).ToList());
}

private (Tensor<float> logits, IList<Tensor<float>> state) Forward_FP32(int xi, IList<Tensor<float>> state)
{
_inputs.Clear();
var input = new DenseTensor<int>(new[] { xi }, new[] { 1 });
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names.First(), input));
for (int i = 1; i < _input_names.Count; i++)
{
_inputs.Add(NamedOnnxValue.CreateFromTensor(_input_names[i], state[i - 1]));
}
var data = _inferenceSession.Run(_inputs);
return (data.First().AsTensor<float>(), data.Skip(1).Select(x => x.AsTensor<float>()).ToList());
}

private float HalfToSinglePrecision(ushort half)
{
uint sign = (uint)(half >> 15);
uint exponent = (uint)((half & 0x7C00) >> 10);
uint mantissa = (uint)(half & 0x03FF);

uint singleSign = sign << 31;
uint singleExponent = (exponent + 127 - 15) << 23;
uint singleMantissa = mantissa << (23 - 10);

uint singleFloatBits = singleSign | singleExponent | singleMantissa;
float result = BitConverter.ToSingle(BitConverter.GetBytes(singleFloatBits), 0);

return result;
}

private DenseTensor<T> GDenseTensor<T>(T value)
{
var tvalue = new DenseTensor<T>(_embed);
for (int i2 = 0; i2 < _embed; i2++)
{
tvalue[i2] = value;
}
return tvalue;
}
}
}
Loading

0 comments on commit eade853

Please sign in to comment.