diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs index e2fa5b9716..349269c01f 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs @@ -248,94 +248,93 @@ public async Task RunAsync(CancellationToken ct = default) var parameter = tuner.Propose(trialSettings); trialSettings.Parameter = parameter; - using (var trialCancellationTokenSource = new CancellationTokenSource()) + var trialCancellationTokenSource = new CancellationTokenSource(); + monitor?.ReportRunningTrial(trialSettings); + var stopTrialManager = new CancellationTokenStopTrainingManager(trialCancellationTokenSource.Token, null); + aggregateTrainingStopManager.AddTrainingStopManager(stopTrialManager); + void handler(object o, EventArgs e) { - monitor?.ReportRunningTrial(trialSettings); - - void handler(object o, EventArgs e) - { - trialCancellationTokenSource.Cancel(); - } - try + trialCancellationTokenSource.Cancel(); + } + try + { + using (var performanceMonitor = serviceProvider.GetService()) + using (var runner = serviceProvider.GetRequiredService()) { - using (var performanceMonitor = serviceProvider.GetService()) - using (var runner = serviceProvider.GetRequiredService()) + aggregateTrainingStopManager.OnStopTraining += handler; + performanceMonitor.PerformanceMetricsUpdated += (o, metrics) => { - aggregateTrainingStopManager.OnStopTraining += handler; - performanceMonitor.PerformanceMetricsUpdated += (o, metrics) => - { - performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource); - }; - - performanceMonitor.Start(); - logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}"); - var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token); - - var peakCpu = performanceMonitor?.GetPeakCpuUsage(); - var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte(); - trialResult.PeakCpu = peakCpu; - trialResult.PeakMemoryInMegaByte = peakMemoryInMB; - trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow; - - performanceMonitor.Pause(); - monitor?.ReportCompletedTrial(trialResult); - tuner.Update(trialResult); - trialResultManager?.AddOrUpdateTrialResult(trialResult); - aggregateTrainingStopManager.Update(trialResult); - - var loss = trialResult.Loss; - if (loss < _bestLoss) - { - _bestTrialResult = trialResult; - _bestLoss = loss; - monitor?.ReportBestTrial(trialResult); - } - } - } - catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false) - { - var exceptionMessage = $@" -Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)} + performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource); + }; -Exception Details: ex.Message + performanceMonitor.Start(); + logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}"); + var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token); -Abandoning Trial {trialSettings.TrialId} and continue training. -"; - logger.Trace(exceptionMessage); - trialSettings.EndedAtUtc = DateTime.UtcNow; - monitor?.ReportFailTrial(trialSettings, ex); - var trialResult = new TrialResult - { - TrialSettings = trialSettings, - Loss = double.MaxValue, - }; + var peakCpu = performanceMonitor?.GetPeakCpuUsage(); + var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte(); + trialResult.PeakCpu = peakCpu; + trialResult.PeakMemoryInMegaByte = peakMemoryInMB; + trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow; + performanceMonitor.Pause(); + monitor?.ReportCompletedTrial(trialResult); tuner.Update(trialResult); trialResultManager?.AddOrUpdateTrialResult(trialResult); aggregateTrainingStopManager.Update(trialResult); - if (ex is not OperationCanceledException && _bestTrialResult == null) + var loss = trialResult.Loss; + if (loss < _bestLoss) { - logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training"); - - // TODO - // it's questionable on whether to abort the entire training process - // for a single fail trial. We should make it an option and only exit - // when error is fatal (like schema mismatch). - throw; + _bestTrialResult = trialResult; + _bestLoss = loss; + monitor?.ReportBestTrial(trialResult); } - continue; } - catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested()) + } + catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false) + { + var exceptionMessage = $@" +Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)} + +Exception Details: {ex.Message} + +Abandoning Trial {trialSettings.TrialId} and continue training. +"; + logger.Trace(exceptionMessage); + trialSettings.EndedAtUtc = DateTime.UtcNow; + monitor?.ReportFailTrial(trialSettings, ex); + var trialResult = new TrialResult { - logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training"); + TrialSettings = trialSettings, + Loss = double.MaxValue, + }; - break; - } - finally + tuner.Update(trialResult); + trialResultManager?.AddOrUpdateTrialResult(trialResult); + aggregateTrainingStopManager.Update(trialResult); + + if (ex is not OperationCanceledException && _bestTrialResult == null) { - aggregateTrainingStopManager.OnStopTraining -= handler; + logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training"); + + // TODO + // it's questionable on whether to abort the entire training process + // for a single fail trial. We should make it an option and only exit + // when error is fatal (like schema mismatch). + throw; } + continue; + } + catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested()) + { + logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training"); + + break; + } + finally + { + aggregateTrainingStopManager.OnStopTraining -= handler; } }