diff --git a/src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs b/src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs index 97054d8bb8..d0d38f272b 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs @@ -8,83 +8,93 @@ using System.Runtime.InteropServices; using System.Security; using Microsoft.ML.Runtime; +using Microsoft.Win32.SafeHandles; namespace Microsoft.ML.TextAnalytics { internal static class LdaInterface { - public struct LdaEngine + public sealed class SafeLdaEngineHandle : SafeHandleZeroOrMinusOneIsInvalid { - public IntPtr Ptr; + private SafeLdaEngineHandle() + : base(true) + { + } + + protected override bool ReleaseHandle() + { + DestroyEngine(handle); + return true; + } } private const string NativePath = "LdaNative"; [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern LdaEngine CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter, + internal static extern SafeLdaEngineHandle CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter, int likelihoodInterval, int numThread, int mhstep, int maxDocToken); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void AllocateModelMemory(LdaEngine engine, int numTopic, int numVocab, long tableSize, long aliasTableSize); + internal static extern void AllocateModelMemory(SafeLdaEngineHandle engine, int numTopic, int numVocab, long tableSize, long aliasTableSize); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void AllocateDataMemory(LdaEngine engine, int docNum, long corpusSize); + internal static extern void AllocateDataMemory(SafeLdaEngineHandle engine, int docNum, long corpusSize); [DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity] - internal static extern void Train(LdaEngine engine, string trainOutput); + internal static extern void Train(SafeLdaEngineHandle engine, string trainOutput); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void GetModelStat(LdaEngine engine, out long memBlockSize, out long aliasMemBlockSize); + internal static extern void GetModelStat(SafeLdaEngineHandle engine, out long memBlockSize, out long aliasMemBlockSize); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void Test(LdaEngine engine, int numBurninIter, float[] pLogLikelihood); + internal static extern void Test(SafeLdaEngineHandle engine, int numBurninIter, float[] pLogLikelihood); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void CleanData(LdaEngine engine); + internal static extern void CleanData(SafeLdaEngineHandle engine); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void CleanModel(LdaEngine engine); + internal static extern void CleanModel(SafeLdaEngineHandle engine); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void DestroyEngine(LdaEngine engine); + private static extern void DestroyEngine(IntPtr engine); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void GetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, ref int length); + internal static extern void GetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, ref int length); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void SetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, int length); + internal static extern void SetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, int length); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void SetAlphaSum(LdaEngine engine, float avgDocLength); + internal static extern void SetAlphaSum(SafeLdaEngineHandle engine, float avgDocLength); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern int FeedInData(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int numVocab); + internal static extern int FeedInData(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int numVocab); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern int FeedInDataDense(LdaEngine engine, int[] termFreq, int termNum, int numVocab); + internal static extern int FeedInDataDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int numVocab); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void GetDocTopic(LdaEngine engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn); + internal static extern void GetDocTopic(SafeLdaEngineHandle engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void GetTopicSummary(LdaEngine engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn); + internal static extern void GetTopicSummary(SafeLdaEngineHandle engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void TestOneDoc(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset); + internal static extern void TestOneDoc(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void TestOneDocDense(LdaEngine engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset); + internal static extern void TestOneDocDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void InitializeBeforeTrain(LdaEngine engine); + internal static extern void InitializeBeforeTrain(SafeLdaEngineHandle engine); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - internal static extern void InitializeBeforeTest(LdaEngine engine); + internal static extern void InitializeBeforeTest(SafeLdaEngineHandle engine); } internal sealed class LdaSingleBox : IDisposable { - private LdaInterface.LdaEngine _engine; + private LdaInterface.SafeLdaEngineHandle _engine; private bool _isDisposed; private int[] _topics; private int[] _probabilities; @@ -358,8 +368,7 @@ public void Dispose() if (_isDisposed) return; _isDisposed = true; - LdaInterface.DestroyEngine(_engine); - _engine.Ptr = IntPtr.Zero; + _engine.Dispose(); } } }