diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
index b8e7ea87e5..d831383af3 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs
@@ -71,6 +71,7 @@ public DbCommand Command
{
_command = Connection.CreateCommand();
_command.CommandText = _source.CommandText;
+ _command.CommandTimeout = _source.CommandTimeoutInSeconds;
}
return _command;
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
index 34f8973c81..a2606ea755 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs
@@ -11,7 +11,7 @@ namespace Microsoft.ML.Data
public sealed class DatabaseSource
{
/// Creates a new instance of the class.
- /// The factory used to create the ..
+ /// The factory used to create the .
/// The string used to open the connection.
/// The text command to run against the data source.
public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText)
@@ -25,9 +25,29 @@ public DatabaseSource(DbProviderFactory providerFactory, string connectionString
CommandText = commandText;
}
+ /// Creates a new instance of the class.
+ /// The factory used to create the .
+ /// The string used to open the connection.
+ /// The text command to run against the data source.
+ /// The time in seconds to wait for the command to execute.
+ public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText, int commandTimeoutInSeconds)
+ {
+ Contracts.CheckValue(providerFactory, nameof(providerFactory));
+ Contracts.CheckValue(connectionString, nameof(connectionString));
+ Contracts.CheckValue(commandText, nameof(commandText));
+
+ ProviderFactory = providerFactory;
+ ConnectionString = connectionString;
+ CommandText = commandText;
+ CommandTimeoutInSeconds = commandTimeoutInSeconds;
+ }
+
/// Gets the text command to run against the data source.
public string CommandText { get; }
+ /// Gets the command timeout.
+ public int CommandTimeoutInSeconds { get; }
+
/// Gets the string used to open the connection.
public string ConnectionString { get; }
diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
index 9595d35af1..b4c3a02a10 100644
--- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
+++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs
@@ -79,6 +79,46 @@ public void IrisLightGbm()
}).PredictedLabel);
}
+ [LightGBMFact]
+ public void IrisLightGbmCommandTimeout()
+ {
+ if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ // https://github.com/dotnet/machinelearning/issues/4156
+ return;
+ }
+
+ var mlContext = new MLContext(seed: 1);
+
+ var connectionString = GetConnectionString(TestDatasets.irisDb.name);
+ var commandText = $@"WAITFOR DELAY '00:00:02'; SELECT * FROM ""{TestDatasets.irisDb.trainFilename}""";
+ var commandTimeout = 1;
+
+ var loaderColumns = new DatabaseLoader.Column[]
+ {
+ new DatabaseLoader.Column() { Name = "Label", Type = DbType.Int32 },
+ new DatabaseLoader.Column() { Name = "SepalLength", Type = DbType.Single },
+ new DatabaseLoader.Column() { Name = "SepalWidth", Type = DbType.Single },
+ new DatabaseLoader.Column() { Name = "PetalLength", Type = DbType.Single },
+ new DatabaseLoader.Column() { Name = "PetalWidth", Type = DbType.Single }
+ };
+
+ var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns);
+
+ var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText, commandTimeout);
+
+ var trainingData = loader.Load(databaseSource);
+
+ IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
+ .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
+ .AppendCacheCheckpoint(mlContext)
+ .Append(mlContext.MulticlassClassification.Trainers.LightGbm())
+ .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
+
+ var ex = Assert.Throws(() => pipeline.Fit(trainingData));
+ Assert.Contains("Timeout expired.", ex.InnerException.Message);
+ }
+
[LightGBMFact]
public void IrisLightGbmWithLoadColumnName()
{