diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs index 7d029c96a8..ed092be157 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -75,6 +75,7 @@ public DbCommand Command { _command = Connection.CreateCommand(); _command.CommandText = _source.CommandText; + _command.CommandTimeout = _source.CommandTimeoutInSeconds; } return _command; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs index 34f8973c81..77f971dba1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs @@ -10,21 +10,38 @@ namespace Microsoft.ML.Data /// Exposes the data required for opening a database for reading. public sealed class DatabaseSource { + private const int DefaultCommandTimeoutInSeconds = 30; + /// Creates a new instance of the class. /// The factory used to create the .. /// The string used to open the connection. /// The text command to run against the data source. - public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText) + public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText) : + this(providerFactory, connectionString, commandText, DefaultCommandTimeoutInSeconds) + { + } + + /// Creates a new instance of the class. + /// The factory used to create the .. + /// The string used to open the connection. + /// The text command to run against the data source. + /// The timeout(in seconds) for database command. + public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText, int commandTimeoutInSeconds) { Contracts.CheckValue(providerFactory, nameof(providerFactory)); Contracts.CheckValue(connectionString, nameof(connectionString)); Contracts.CheckValue(commandText, nameof(commandText)); + Contracts.CheckUserArg(commandTimeoutInSeconds >= 0, nameof(commandTimeoutInSeconds)); ProviderFactory = providerFactory; ConnectionString = connectionString; CommandText = commandText; + CommandTimeoutInSeconds = commandTimeoutInSeconds; } + /// Gets the timeout for database command. + public int CommandTimeoutInSeconds { get; } + /// Gets the text command to run against the data source. public string CommandText { get; } diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index 6826eb9aba..b6ca011fed 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -27,6 +27,22 @@ public DatabaseLoaderTests(ITestOutputHelper output) [LightGBMFact] public void IrisLightGbm() + { + DatabaseSource dbs = GetIrisDatabaseSource("SELECT * FROM {0}"); + IrisLightGbmImpl(dbs); + } + + [LightGBMFact] + public void IrisLightGbmWithTimeout() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) //sqlite does not have built-in command for sleep + return; + DatabaseSource dbs = GetIrisDatabaseSource("WAITFOR DELAY '00:00:01'; SELECT * FROM {0}", 1); + var ex = Assert.Throws(() => IrisLightGbmImpl(dbs)); + Assert.Contains("Timeout", ex.InnerException.Message); + } + + private void IrisLightGbmImpl(DatabaseSource dbs) { var mlContext = new MLContext(seed: 1); @@ -41,7 +57,7 @@ public void IrisLightGbm() var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns); - var trainingData = loader.Load(GetIrisDatabaseSource("SELECT * FROM {0}")); + var trainingData = loader.Load(dbs); IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")) @@ -211,18 +227,20 @@ public void IrisSdcaMaximumEntropy() /// SQLite database is used on Linux and MacOS builds. /// /// Return the appropiate Iris DatabaseSource according to build OS. - private DatabaseSource GetIrisDatabaseSource(string command) + private DatabaseSource GetIrisDatabaseSource(string command, int commandTimeoutInSeconds = 30) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return new DatabaseSource( SqlClientFactory.Instance, GetMSSQLConnectionString(TestDatasets.irisDb.name), - String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}""")); + String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}"""), + commandTimeoutInSeconds); else return new DatabaseSource( SQLiteFactory.Instance, GetSQLiteConnectionString(TestDatasets.irisDbSQLite.name), - String.Format(command, TestDatasets.irisDbSQLite.trainFilename)); + String.Format(command, TestDatasets.irisDbSQLite.trainFilename), + commandTimeoutInSeconds); } private string GetMSSQLConnectionString(string databaseName)