diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
index 7d029c96a8..ed092be157 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
@@ -75,6 +75,7 @@ public DbCommand Command
{
_command = Connection.CreateCommand();
_command.CommandText = _source.CommandText;
+ _command.CommandTimeout = _source.CommandTimeoutInSeconds;
}
return _command;
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
index 34f8973c81..77f971dba1 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
@@ -10,21 +10,38 @@ namespace Microsoft.ML.Data
/// Exposes the data required for opening a database for reading.
public sealed class DatabaseSource
{
+ private const int DefaultCommandTimeoutInSeconds = 30;
+
/// Creates a new instance of the class.
/// The factory used to create the ..
/// The string used to open the connection.
/// The text command to run against the data source.
- public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText)
+ public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText) :
+ this(providerFactory, connectionString, commandText, DefaultCommandTimeoutInSeconds)
+ {
+ }
+
+ /// Creates a new instance of the class.
+ /// The factory used to create the ..
+ /// The string used to open the connection.
+ /// The text command to run against the data source.
+ /// The timeout(in seconds) for database command.
+ public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText, int commandTimeoutInSeconds)
{
Contracts.CheckValue(providerFactory, nameof(providerFactory));
Contracts.CheckValue(connectionString, nameof(connectionString));
Contracts.CheckValue(commandText, nameof(commandText));
+ Contracts.CheckUserArg(commandTimeoutInSeconds >= 0, nameof(commandTimeoutInSeconds));
ProviderFactory = providerFactory;
ConnectionString = connectionString;
CommandText = commandText;
+ CommandTimeoutInSeconds = commandTimeoutInSeconds;
}
+ /// Gets the timeout for database command.
+ public int CommandTimeoutInSeconds { get; }
+
/// Gets the text command to run against the data source.
public string CommandText { get; }
diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
index 6826eb9aba..b6ca011fed 100644
--- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
+++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
@@ -27,6 +27,22 @@ public DatabaseLoaderTests(ITestOutputHelper output)
[LightGBMFact]
public void IrisLightGbm()
+ {
+ DatabaseSource dbs = GetIrisDatabaseSource("SELECT * FROM {0}");
+ IrisLightGbmImpl(dbs);
+ }
+
+ [LightGBMFact]
+ public void IrisLightGbmWithTimeout()
+ {
+ if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) //sqlite does not have built-in command for sleep
+ return;
+ DatabaseSource dbs = GetIrisDatabaseSource("WAITFOR DELAY '00:00:01'; SELECT * FROM {0}", 1);
+ var ex = Assert.Throws(() => IrisLightGbmImpl(dbs));
+ Assert.Contains("Timeout", ex.InnerException.Message);
+ }
+
+ private void IrisLightGbmImpl(DatabaseSource dbs)
{
var mlContext = new MLContext(seed: 1);
@@ -41,7 +57,7 @@ public void IrisLightGbm()
var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns);
- var trainingData = loader.Load(GetIrisDatabaseSource("SELECT * FROM {0}"));
+ var trainingData = loader.Load(dbs);
IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
.Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
@@ -211,18 +227,20 @@ public void IrisSdcaMaximumEntropy()
/// SQLite database is used on Linux and MacOS builds.
///
/// Return the appropiate Iris DatabaseSource according to build OS.
- private DatabaseSource GetIrisDatabaseSource(string command)
+ private DatabaseSource GetIrisDatabaseSource(string command, int commandTimeoutInSeconds = 30)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return new DatabaseSource(
SqlClientFactory.Instance,
GetMSSQLConnectionString(TestDatasets.irisDb.name),
- String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}"""));
+ String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}"""),
+ commandTimeoutInSeconds);
else
return new DatabaseSource(
SQLiteFactory.Instance,
GetSQLiteConnectionString(TestDatasets.irisDbSQLite.name),
- String.Format(command, TestDatasets.irisDbSQLite.trainFilename));
+ String.Format(command, TestDatasets.irisDbSQLite.trainFilename),
+ commandTimeoutInSeconds);
}
private string GetMSSQLConnectionString(string databaseName)