Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -634,23 +634,25 @@ protected override bool MoveNextCore()
while (_liveCount < _poolRows && !_doneConsuming)
{
// We are under capacity. Try to get some more.
while (_toConsumeChannel.Reader.WaitToReadAsync().GetAwaiter().GetResult())
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
if (hasReadItem)
{
var hasReadItem = _toConsumeChannel.Reader.TryRead(out int got);
if (hasReadItem)
if (got == 0)
{
if (got == 0)
{
// We've reached the end of the Channel. There's no reason
// to attempt further communication with the producer.
// Check whether something horrible happened.
if (_producerTaskException != null)
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
_doneConsuming = true;
break;
}
_liveCount += got;
// We've reached the end of the Channel. There's no reason
// to attempt further communication with the producer.
// Check whether something horrible happened.
if (_producerTaskException != null)
throw Ch.Except(_producerTaskException, "Shuffle input cursor reader failed with an exception");
_doneConsuming = true;
break;
}
_liveCount += got;
}
else
{
// Sleeping for one millisecond to stop the thread from spinning while waiting for the producer.
Thread.Sleep(1);
}
}
if (_liveCount == 0)
Expand Down
43 changes: 43 additions & 0 deletions src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,49 @@ public static string DownloadTensorFlowSentimentModel()
return path;
}

public static string DownloadTaxiFareData()
{
string githubPath = "https://raw.githubusercontent.com/dotnet/machinelearning-samples/master/samples/csharp/getting-started/Regression_TaxiFarePrediction/TaxiFarePrediction/Data/taxi-fare-train.csv";

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have this file in the repo. See

var trainDataPath = GetDataPath("taxi-fare-train.csv");

public class TaxiTrip
{
[LoadColumn(0)]
public string VendorId;
[LoadColumn(1)]
public float RateCode;
[LoadColumn(2)]
public float PassengerCount;
[LoadColumn(3)]
public float TripTime;
[LoadColumn(4)]
public float TripDistance;
[LoadColumn(5)]
public string PaymentType;
[LoadColumn(6)]
public float FareAmount;
}
/// <summary>
/// Returns a DataView with a Features column which include HotEncodedData
/// </summary>
private IDataView GetOneHotEncodedData(int numberOfInstances = 100)
{
var trainDataPath = GetDataPath("taxi-fare-train.csv");

We shouldn't need to download the file again.

string dataFile = "taxi-fare-train.csv";

Download(githubPath, dataFile);

return dataFile;
}

public static IDataView LoadTaxiFareDataset(MLContext context)
{
var dataFile = DownloadTaxiFareData();

var dataView = context.Data.LoadFromTextFile<TaxiData>(dataFile, hasHeader: true, separatorChar: ',');

return dataView;
}

public class TaxiData
{
[LoadColumn(0)]
public string VendorId;

[LoadColumn(1)]
public string RateCode;

[LoadColumn(2)]
public float PassengerCount;

[LoadColumn(3)]
public float TripTime;

[LoadColumn(4)]
public float TripDistance;

[LoadColumn(5)]
public string PaymentType;

[LoadColumn(6)]
public float FareAmount;
}

private static string Download(string baseGitPath, string dataFile)
{
if (File.Exists(dataFile))
Expand Down
40 changes: 40 additions & 0 deletions test/Microsoft.ML.Tests/Scenarios/RegressionTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Xunit;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All code files should have a copyright at the top.


namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
[Fact]
public void TestRegressionScenario()
{
var context = new MLContext();

IDataView taxiData = Microsoft.ML.SamplesUtils.DatasetUtils.LoadTaxiFareDataset(context);
var splitData = context.Data.TrainTestSplit(taxiData, testFraction: 0.2);

IDataView trainingDataView = context.Data.FilterRowsByColumn(splitData.TrainSet, "FareAmount", lowerBound: 1, upperBound: 150);

var dataProcessPipeline = context.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: "FareAmount")
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "VendorIdEncoded", inputColumnName: "VendorId"))
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "RateCodeEncoded", inputColumnName: "RateCode"))
.Append(context.Transforms.Categorical.OneHotEncoding(outputColumnName: "PaymentTypeEncoded", inputColumnName: "PaymentType"))
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "PassengerCount"))
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripTime"))
.Append(context.Transforms.NormalizeMeanVariance(outputColumnName: "TripDistance"))
.Append(context.Transforms.Concatenate("Features", "VendorIdEncoded", "RateCodeEncoded", "PaymentTypeEncoded", "PassengerCount",
"TripTime", "TripDistance"));

var trainer = context.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features");
var trainingPipeline = dataProcessPipeline.Append(trainer);

var model = trainingPipeline.Fit(trainingDataView);

var predictions = model.Transform(splitData.TestSet);

var metrics = context.Regression.Evaluate(predictions);

Assert.True(metrics.RSquared > .9);
Assert.True(metrics.RootMeanSquaredError > 2);
}
}
}