diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs index b8e7ea87e5..d831383af3 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -71,6 +71,7 @@ public DbCommand Command { _command = Connection.CreateCommand(); _command.CommandText = _source.CommandText; + _command.CommandTimeout = _source.CommandTimeoutInSeconds; } return _command; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs index 34f8973c81..a2606ea755 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Data public sealed class DatabaseSource { /// Creates a new instance of the class. - /// The factory used to create the .. + /// The factory used to create the . /// The string used to open the connection. /// The text command to run against the data source. public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText) @@ -25,9 +25,29 @@ public DatabaseSource(DbProviderFactory providerFactory, string connectionString CommandText = commandText; } + /// Creates a new instance of the class. + /// The factory used to create the . + /// The string used to open the connection. + /// The text command to run against the data source. + /// The time in seconds to wait for the command to execute. + public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText, int commandTimeoutInSeconds) + { + Contracts.CheckValue(providerFactory, nameof(providerFactory)); + Contracts.CheckValue(connectionString, nameof(connectionString)); + Contracts.CheckValue(commandText, nameof(commandText)); + + ProviderFactory = providerFactory; + ConnectionString = connectionString; + CommandText = commandText; + CommandTimeoutInSeconds = commandTimeoutInSeconds; + } + /// Gets the text command to run against the data source. public string CommandText { get; } + /// Gets the command timeout. + public int CommandTimeoutInSeconds { get; } + /// Gets the string used to open the connection. public string ConnectionString { get; } diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index 9595d35af1..b4c3a02a10 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -79,6 +79,46 @@ public void IrisLightGbm() }).PredictedLabel); } + [LightGBMFact] + public void IrisLightGbmCommandTimeout() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // https://github.com/dotnet/machinelearning/issues/4156 + return; + } + + var mlContext = new MLContext(seed: 1); + + var connectionString = GetConnectionString(TestDatasets.irisDb.name); + var commandText = $@"WAITFOR DELAY '00:00:02'; SELECT * FROM ""{TestDatasets.irisDb.trainFilename}"""; + var commandTimeout = 1; + + var loaderColumns = new DatabaseLoader.Column[] + { + new DatabaseLoader.Column() { Name = "Label", Type = DbType.Int32 }, + new DatabaseLoader.Column() { Name = "SepalLength", Type = DbType.Single }, + new DatabaseLoader.Column() { Name = "SepalWidth", Type = DbType.Single }, + new DatabaseLoader.Column() { Name = "PetalLength", Type = DbType.Single }, + new DatabaseLoader.Column() { Name = "PetalWidth", Type = DbType.Single } + }; + + var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns); + + var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText, commandTimeout); + + var trainingData = loader.Load(databaseSource); + + IEstimator pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label") + .Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")) + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.MulticlassClassification.Trainers.LightGbm()) + .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + + var ex = Assert.Throws(() => pipeline.Fit(trainingData)); + Assert.Contains("Timeout expired.", ex.InnerException.Message); + } + [LightGBMFact] public void IrisLightGbmWithLoadColumnName() {