Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,93 @@
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Runtime;
using Microsoft.Win32.SafeHandles;

namespace Microsoft.ML.TextAnalytics
{

internal static class LdaInterface
{
public struct LdaEngine
public sealed class SafeLdaEngineHandle : SafeHandleZeroOrMinusOneIsInvalid
{
public IntPtr Ptr;
private SafeLdaEngineHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
DestroyEngine(handle);
return true;
}
}

private const string NativePath = "LdaNative";
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern LdaEngine CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter,
internal static extern SafeLdaEngineHandle CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter,
int likelihoodInterval, int numThread, int mhstep, int maxDocToken);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void AllocateModelMemory(LdaEngine engine, int numTopic, int numVocab, long tableSize, long aliasTableSize);
internal static extern void AllocateModelMemory(SafeLdaEngineHandle engine, int numTopic, int numVocab, long tableSize, long aliasTableSize);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void AllocateDataMemory(LdaEngine engine, int docNum, long corpusSize);
internal static extern void AllocateDataMemory(SafeLdaEngineHandle engine, int docNum, long corpusSize);

[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
internal static extern void Train(LdaEngine engine, string trainOutput);
internal static extern void Train(SafeLdaEngineHandle engine, string trainOutput);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetModelStat(LdaEngine engine, out long memBlockSize, out long aliasMemBlockSize);
internal static extern void GetModelStat(SafeLdaEngineHandle engine, out long memBlockSize, out long aliasMemBlockSize);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void Test(LdaEngine engine, int numBurninIter, float[] pLogLikelihood);
internal static extern void Test(SafeLdaEngineHandle engine, int numBurninIter, float[] pLogLikelihood);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void CleanData(LdaEngine engine);
internal static extern void CleanData(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void CleanModel(LdaEngine engine);
internal static extern void CleanModel(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void DestroyEngine(LdaEngine engine);
private static extern void DestroyEngine(IntPtr engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, ref int length);
internal static extern void GetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, ref int length);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void SetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, int length);
internal static extern void SetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, int length);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void SetAlphaSum(LdaEngine engine, float avgDocLength);
internal static extern void SetAlphaSum(SafeLdaEngineHandle engine, float avgDocLength);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern int FeedInData(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int numVocab);
internal static extern int FeedInData(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int numVocab);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern int FeedInDataDense(LdaEngine engine, int[] termFreq, int termNum, int numVocab);
internal static extern int FeedInDataDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int numVocab);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetDocTopic(LdaEngine engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn);
internal static extern void GetDocTopic(SafeLdaEngineHandle engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetTopicSummary(LdaEngine engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn);
internal static extern void GetTopicSummary(SafeLdaEngineHandle engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void TestOneDoc(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset);
internal static extern void TestOneDoc(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void TestOneDocDense(LdaEngine engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset);
internal static extern void TestOneDocDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void InitializeBeforeTrain(LdaEngine engine);
internal static extern void InitializeBeforeTrain(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void InitializeBeforeTest(LdaEngine engine);
internal static extern void InitializeBeforeTest(SafeLdaEngineHandle engine);
}

internal sealed class LdaSingleBox : IDisposable
{
private LdaInterface.LdaEngine _engine;
private LdaInterface.SafeLdaEngineHandle _engine;
private bool _isDisposed;
private int[] _topics;
private int[] _probabilities;
Expand Down Expand Up @@ -358,8 +368,7 @@ public void Dispose()
if (_isDisposed)
return;
_isDisposed = true;
LdaInterface.DestroyEngine(_engine);
_engine.Ptr = IntPtr.Zero;
_engine.Dispose();
}
}
}