diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index c83e9e4d9f..2632d11a0e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1453,6 +1453,23 @@ internal static TextLoader CreateTextLoader(IHostEnvironment host, bool trimWhitespace = Defaults.TrimWhitespace, IMultiStreamSource dataSample = null) { + Options options = new Options + { + HasHeader = hasHeader, + Separators = new[] { separator }, + AllowQuoting = allowQuoting, + AllowSparse = supportSparse, + TrimWhitespace = trimWhitespace + }; + + return CreateTextLoader(host, options, dataSample); + } + + internal static TextLoader CreateTextLoader(IHostEnvironment host, + Options options = null, + IMultiStreamSource dataSample = null) + { + options = options ?? new Options(); var userType = typeof(TInput); var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance); @@ -1506,15 +1523,7 @@ internal static TextLoader CreateTextLoader(IHostEnvironment host, columns.Add(column); } - Options options = new Options - { - HasHeader = hasHeader, - Separators = new[] { separator }, - AllowQuoting = allowQuoting, - AllowSparse = supportSparse, - TrimWhitespace = trimWhitespace, - Columns = columns.ToArray() - }; + options.Columns = columns.ToArray(); return new TextLoader(host, options, dataSample: dataSample); } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs index bc1bd789fe..a3e6446fce 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs @@ -97,6 +97,19 @@ public static TextLoader CreateTextLoader(this DataOperationsCatalog cat => TextLoader.CreateTextLoader(CatalogUtils.GetEnvironment(catalog), hasHeader, separatorChar, allowQuoting, allowSparse, trimWhitespace, dataSample: dataSample); + /// + /// Create a text loader by inferencing the dataset schema from a data model type. + /// + /// The catalog. + /// Defines the settings of the load operation. Defines the settings of the load operation. No need to specify a Columns field, + /// as columns will be infered by this method. + /// The optional location of a data sample. The sample can be used to infer information + /// about the columns, such as slot names. + public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog, + TextLoader.Options options, + IMultiStreamSource dataSample = null) + => TextLoader.CreateTextLoader(CatalogUtils.GetEnvironment(catalog), options, dataSample); + /// /// Load a from a text file using . /// Note that 's are lazy, so no actual loading happens here, just schema validation. @@ -143,6 +156,35 @@ public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, return loader.Load(new MultiFileSource(path)); } + /// + /// Load a from a text file using . + /// Note that 's are lazy, so no actual loading happens here, just schema validation. + /// + /// The catalog. + /// Specifies a file from which to load. + /// Defines the settings of the load operation. + /// + /// + /// + /// + /// + public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, string path, + TextLoader.Options options = null) + { + Contracts.CheckNonEmpty(path, nameof(path)); + if (!File.Exists(path)) + { + throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path); + } + + var env = catalog.GetEnvironment(); + var source = new MultiFileSource(path); + + return new TextLoader(env, options, dataSample: source).Load(source); + } + /// /// Load a from a text file using . /// Note that 's are lazy, so no actual loading happens here, just schema validation. @@ -191,16 +233,11 @@ public static IDataView LoadFromTextFile(this DataOperationsCatalog cata /// /// The catalog. /// Specifies a file from which to load. - /// Defines the settings of the load operation. - /// - /// - /// - /// - /// - public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, string path, - TextLoader.Options options = null) + /// Defines the settings of the load operation. No need to specify a Columns field, + /// as columns will be infered by this method. + /// The data view. + public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, string path, + TextLoader.Options options) { Contracts.CheckNonEmpty(path, nameof(path)); if (!File.Exists(path)) @@ -208,10 +245,8 @@ public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, str throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path); } - var env = catalog.GetEnvironment(); - var source = new MultiFileSource(path); - - return new TextLoader(env, options, dataSample: source).Load(source); + return TextLoader.CreateTextLoader(CatalogUtils.GetEnvironment(catalog), options) + .Load(new MultiFileSource(path)); } /// diff --git a/test/Microsoft.ML.Functional.Tests/Prediction.cs b/test/Microsoft.ML.Functional.Tests/Prediction.cs index 8292ff709a..657cb20fcb 100644 --- a/test/Microsoft.ML.Functional.Tests/Prediction.cs +++ b/test/Microsoft.ML.Functional.Tests/Prediction.cs @@ -36,9 +36,14 @@ public void ReconfigurablePrediction() { var mlContext = new MLContext(seed: 1); + var options = new TextLoader.Options + { + HasHeader = TestDatasets.Sentiment.fileHasHeader, + Separators = new[] { TestDatasets.Sentiment.fileSeparator } + }; + var data = mlContext.Data.LoadFromTextFile(TestCommon.GetDataPath(DataDir, TestDatasets.Sentiment.trainFilename), - hasHeader: TestDatasets.Sentiment.fileHasHeader, - separatorChar: TestDatasets.Sentiment.fileSeparator); + options); // Create a training pipeline. var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText") diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index d7bd2fad2f..01492401bf 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -704,8 +704,10 @@ public class IrisColumnIndices public string Type; } - [Fact] - public void LoaderColumnsFromIrisData() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void LoaderColumnsFromIrisData(bool useOptionsObject) { var dataPath = GetDataPath(TestDatasets.irisData.trainFilename); var mlContext = new MLContext(1); @@ -719,7 +721,12 @@ public void LoaderColumnsFromIrisData() var irisFirstRowValues = irisFirstRow.Values.GetEnumerator(); // Simple load - var dataIris = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + IDataView dataIris; + if (useOptionsObject) + dataIris = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath); + else + dataIris = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + var previewIris = dataIris.Preview(1); Assert.Equal(5, previewIris.ColumnView.Length); @@ -735,7 +742,12 @@ public void LoaderColumnsFromIrisData() Assert.Equal("Iris-setosa", previewIris.RowView[0].Values[index].Value.ToString()); // Load with start and end indexes - var dataIrisStartEnd = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + IDataView dataIrisStartEnd; + if (useOptionsObject) + dataIrisStartEnd = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath); + else + dataIrisStartEnd = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + var previewIrisStartEnd = dataIrisStartEnd.Preview(1); Assert.Equal(2, previewIrisStartEnd.ColumnView.Length); @@ -752,7 +764,12 @@ public void LoaderColumnsFromIrisData() } // load setting the distinct columns. Loading column 0 and 2 - var dataIrisColumnIndices = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + IDataView dataIrisColumnIndices; + if (useOptionsObject) + dataIrisColumnIndices = mlContext.Data.CreateTextLoader(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath); + else + dataIrisColumnIndices = mlContext.Data.CreateTextLoader(separatorChar: ',').Load(dataPath); + var previewIrisColumnIndices = dataIrisColumnIndices.Preview(1); Assert.Equal(2, previewIrisColumnIndices.ColumnView.Length);