From ddf28c73060dd76aaf1a06d634664b4322b301c3 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 29 Mar 2023 12:28:55 -0700 Subject: [PATCH 1/3] remove handler --- .../AutoMLExperiment/AutoMLExperiment.cs | 136 +++++++++--------- .../StopTrainingManagerTests.cs | 2 +- 2 files changed, 66 insertions(+), 72 deletions(-) diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs index e2fa5b9716..cc1cd05ef3 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs @@ -248,94 +248,88 @@ public async Task RunAsync(CancellationToken ct = default) var parameter = tuner.Propose(trialSettings); trialSettings.Parameter = parameter; - using (var trialCancellationTokenSource = new CancellationTokenSource()) + var trialCancellationTokenSource = new CancellationTokenSource(); + monitor?.ReportRunningTrial(trialSettings); + var stopTrialManager = new CancellationTokenStopTrainingManager(trialCancellationTokenSource.Token, null); + aggregateTrainingStopManager.AddTrainingStopManager(stopTrialManager); + try { - monitor?.ReportRunningTrial(trialSettings); - - void handler(object o, EventArgs e) - { - trialCancellationTokenSource.Cancel(); - } - try + using (var performanceMonitor = serviceProvider.GetService()) + using (var runner = serviceProvider.GetRequiredService()) { - using (var performanceMonitor = serviceProvider.GetService()) - using (var runner = serviceProvider.GetRequiredService()) + performanceMonitor.PerformanceMetricsUpdated += (o, metrics) => { - aggregateTrainingStopManager.OnStopTraining += handler; - performanceMonitor.PerformanceMetricsUpdated += (o, metrics) => - { - performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource); - }; - - performanceMonitor.Start(); - logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}"); - var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token); - - var peakCpu = performanceMonitor?.GetPeakCpuUsage(); - var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte(); - trialResult.PeakCpu = peakCpu; - trialResult.PeakMemoryInMegaByte = peakMemoryInMB; - trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow; - - performanceMonitor.Pause(); - monitor?.ReportCompletedTrial(trialResult); - tuner.Update(trialResult); - trialResultManager?.AddOrUpdateTrialResult(trialResult); - aggregateTrainingStopManager.Update(trialResult); - - var loss = trialResult.Loss; - if (loss < _bestLoss) - { - _bestTrialResult = trialResult; - _bestLoss = loss; - monitor?.ReportBestTrial(trialResult); - } - } - } - catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false) - { - var exceptionMessage = $@" -Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)} + performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource); + }; -Exception Details: ex.Message + performanceMonitor.Start(); + logger.Trace($"trial setting - {JsonSerializer.Serialize(trialSettings)}"); + var trialResult = await runner.RunAsync(trialSettings, trialCancellationTokenSource.Token); -Abandoning Trial {trialSettings.TrialId} and continue training. -"; - logger.Trace(exceptionMessage); - trialSettings.EndedAtUtc = DateTime.UtcNow; - monitor?.ReportFailTrial(trialSettings, ex); - var trialResult = new TrialResult - { - TrialSettings = trialSettings, - Loss = double.MaxValue, - }; + var peakCpu = performanceMonitor?.GetPeakCpuUsage(); + var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte(); + trialResult.PeakCpu = peakCpu; + trialResult.PeakMemoryInMegaByte = peakMemoryInMB; + trialResult.TrialSettings.EndedAtUtc = DateTime.UtcNow; + performanceMonitor.Pause(); + monitor?.ReportCompletedTrial(trialResult); tuner.Update(trialResult); trialResultManager?.AddOrUpdateTrialResult(trialResult); aggregateTrainingStopManager.Update(trialResult); - if (ex is not OperationCanceledException && _bestTrialResult == null) + var loss = trialResult.Loss; + if (loss < _bestLoss) { - logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training"); - - // TODO - // it's questionable on whether to abort the entire training process - // for a single fail trial. We should make it an option and only exit - // when error is fatal (like schema mismatch). - throw; + _bestTrialResult = trialResult; + _bestLoss = loss; + monitor?.ReportBestTrial(trialResult); } - continue; } - catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested()) + } + catch (Exception ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false) + { + var exceptionMessage = $@" +Exception thrown during Trial {trialSettings.TrialId} with configuration {JsonSerializer.Serialize(trialSettings)} + +Exception Details: {ex.Message} + +Abandoning Trial {trialSettings.TrialId} and continue training. +"; + logger.Trace(exceptionMessage); + trialSettings.EndedAtUtc = DateTime.UtcNow; + monitor?.ReportFailTrial(trialSettings, ex); + var trialResult = new TrialResult { - logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training"); + TrialSettings = trialSettings, + Loss = double.MaxValue, + }; - break; - } - finally + tuner.Update(trialResult); + trialResultManager?.AddOrUpdateTrialResult(trialResult); + aggregateTrainingStopManager.Update(trialResult); + + if (ex is not OperationCanceledException && _bestTrialResult == null) { - aggregateTrainingStopManager.OnStopTraining -= handler; + logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training"); + + // TODO + // it's questionable on whether to abort the entire training process + // for a single fail trial. We should make it an option and only exit + // when error is fatal (like schema mismatch). + throw; } + continue; + } + catch (Exception) when (aggregateTrainingStopManager.IsStopTrainingRequested()) + { + logger.Trace($"trial cancelled - {JsonSerializer.Serialize(trialSettings)}, stop training"); + + break; + } + finally + { + aggregateTrainingStopManager.RemoveTrainingStopManagerIfExist(stopTrialManager); } } diff --git a/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs b/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs index 4e87bf238d..fd0748fd1b 100644 --- a/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs @@ -70,7 +70,7 @@ public async Task TimeoutTrainingStopManager_isStopTrainingRequested_test() public async Task AggregateTrainingStopManager_isStopTrainingRequested_test() { var cts = new CancellationTokenSource(); - var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(1), null); + var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(0), null); var cancellationManager = new CancellationTokenStopTrainingManager(cts.Token, null); var aggregationManager = new AggregateTrainingStopManager(null, timeoutManager, cancellationManager); From a322df1a63fcbe219509b10ff1aa0f871e508c64 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 29 Mar 2023 12:48:15 -0700 Subject: [PATCH 2/3] update --- .../AutoMLExperiment/AutoMLExperiment.cs | 7 ++++++- test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs index cc1cd05ef3..349269c01f 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs @@ -252,11 +252,16 @@ public async Task RunAsync(CancellationToken ct = default) monitor?.ReportRunningTrial(trialSettings); var stopTrialManager = new CancellationTokenStopTrainingManager(trialCancellationTokenSource.Token, null); aggregateTrainingStopManager.AddTrainingStopManager(stopTrialManager); + void handler(object o, EventArgs e) + { + trialCancellationTokenSource.Cancel(); + } try { using (var performanceMonitor = serviceProvider.GetService()) using (var runner = serviceProvider.GetRequiredService()) { + aggregateTrainingStopManager.OnStopTraining += handler; performanceMonitor.PerformanceMetricsUpdated += (o, metrics) => { performanceMonitor.OnPerformanceMetricsUpdatedHandler(trialSettings, metrics, trialCancellationTokenSource); @@ -329,7 +334,7 @@ Abandoning Trial {trialSettings.TrialId} and continue training. } finally { - aggregateTrainingStopManager.RemoveTrainingStopManagerIfExist(stopTrialManager); + aggregateTrainingStopManager.OnStopTraining -= handler; } } diff --git a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs index 415528ba9c..34c4f7c3a7 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs @@ -39,7 +39,7 @@ public async Task AutoMLExperiment_throw_timeout_exception_when_ct_is_canceled_a { var channel = serviceProvider.GetService(); var settings = serviceProvider.GetService(); - return new DummyTrialRunner(settings, 5, channel); + return new DummyTrialRunner(settings, 50, channel); }) .SetTuner(); @@ -47,8 +47,10 @@ public async Task AutoMLExperiment_throw_timeout_exception_when_ct_is_canceled_a context.Log += (o, e) => { + this.Output.WriteLine(e.RawMessage); if (e.RawMessage.Contains("Update Running Trial")) { + this.Output.WriteLine(e.RawMessage); cts.Cancel(); } }; From c4a86555fb902c0c0443f59aaeb13cb7982fee5d Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 29 Mar 2023 12:49:02 -0700 Subject: [PATCH 3/3] checkout test file --- test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs | 4 +--- test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs index 34c4f7c3a7..415528ba9c 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs @@ -39,7 +39,7 @@ public async Task AutoMLExperiment_throw_timeout_exception_when_ct_is_canceled_a { var channel = serviceProvider.GetService(); var settings = serviceProvider.GetService(); - return new DummyTrialRunner(settings, 50, channel); + return new DummyTrialRunner(settings, 5, channel); }) .SetTuner(); @@ -47,10 +47,8 @@ public async Task AutoMLExperiment_throw_timeout_exception_when_ct_is_canceled_a context.Log += (o, e) => { - this.Output.WriteLine(e.RawMessage); if (e.RawMessage.Contains("Update Running Trial")) { - this.Output.WriteLine(e.RawMessage); cts.Cancel(); } }; diff --git a/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs b/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs index fd0748fd1b..4e87bf238d 100644 --- a/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/StopTrainingManagerTests.cs @@ -70,7 +70,7 @@ public async Task TimeoutTrainingStopManager_isStopTrainingRequested_test() public async Task AggregateTrainingStopManager_isStopTrainingRequested_test() { var cts = new CancellationTokenSource(); - var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(0), null); + var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(1), null); var cancellationManager = new CancellationTokenStopTrainingManager(cts.Token, null); var aggregationManager = new AggregateTrainingStopManager(null, timeoutManager, cancellationManager);