1- // Licensed to the .NET Foundation under one or more agreements.
1+ // Licensed to the .NET Foundation under one or more agreements.
22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
55using System ;
66using System . Collections . Concurrent ;
77using System . Collections . Generic ;
88using System . IO ;
9+ using Microsoft . ML ;
10+ using Microsoft . ML . Internal . Utilities ;
911
1012namespace Microsoft . ML . Runtime ;
1113
@@ -121,8 +123,8 @@ public abstract class HostBase : HostEnvironmentBase<TEnv>, IHost
121123
122124 public Random Rand => _rand ;
123125
124- public HostBase ( HostEnvironmentBase < TEnv > source , string shortName , string parentFullName , Random rand , bool verbose )
125- : base ( source , rand , verbose , shortName , parentFullName )
126+ public HostBase ( HostEnvironmentBase < TEnv > source , string shortName , string parentFullName , Random rand , IRandomSource randomSource , bool verbose )
127+ : base ( source , rand , randomSource ?? new RandomSourceAdapter ( rand ) , verbose , shortName , parentFullName )
126128 {
127129 Depth = source . Depth + 1 ;
128130 }
@@ -140,7 +142,8 @@ public HostBase(HostEnvironmentBase<TEnv> source, string shortName, string paren
140142 {
141143 _children . RemoveAll ( r => r . TryGetTarget ( out IHost _ ) == false ) ;
142144 Random rand = ( seed . HasValue ) ? RandomUtils . Create ( seed . Value ) : RandomUtils . Create ( _rand ) ;
143- host = RegisterCore ( this , name , Master ? . FullName , rand , verbose ?? Verbose ) ;
145+ IRandomSource randomSource = new RandomSourceAdapter ( rand ) ;
146+ host = RegisterCore ( this , name , Master ? . FullName , rand , randomSource , verbose ?? Verbose ) ;
144147 if ( ! IsCanceled )
145148 _children . Add ( new WeakReference < IHost > ( host ) ) ;
146149 }
@@ -338,6 +341,8 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
338341 protected Dictionary < string , object > Options { get ; } = [ ] ;
339342#pragma warning restore MSML_NoInstanceInitializers
340343
344+ public IRandomSource RandomSource => _randomSource ;
345+
341346 protected readonly TEnv Root ;
342347 // This is non-null iff this environment was a fork of another. Disposing a fork
343348 // doesn't free temp files. That is handled when the master is disposed.
@@ -348,6 +353,7 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
348353
349354 // The random number generator for this host.
350355 private readonly Random _rand ;
356+ private readonly IRandomSource _randomSource ;
351357
352358 public int ? Seed { get ; }
353359
@@ -369,11 +375,22 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
369375 /// The main constructor.
370376 /// </summary>
371377 protected HostEnvironmentBase ( int ? seed , bool verbose ,
378+ IRandomSource randomSource = null ,
372379 string shortName = null , string parentFullName = null )
373380 : base ( shortName , parentFullName , verbose )
374381 {
375382 Seed = seed ;
376- _rand = RandomUtils . Create ( Seed ) ;
383+ if ( randomSource is null )
384+ {
385+ var baseRandom = RandomUtils . Create ( Seed ) ;
386+ _rand = baseRandom ;
387+ _randomSource = new RandomSourceAdapter ( baseRandom ) ;
388+ }
389+ else
390+ {
391+ _randomSource = randomSource ;
392+ _rand = randomSource as Random ?? new RandomFromRandomSource ( randomSource ) ;
393+ }
377394 ListenerDict = new ConcurrentDictionary < Type , Dispatcher > ( ) ;
378395 ProgressTracker = new ProgressReporting . ProgressTracker ( this ) ;
379396 _cancelLock = new object ( ) ;
@@ -385,13 +402,14 @@ protected HostEnvironmentBase(int? seed, bool verbose,
385402 /// <summary>
386403 /// This constructor is for forking.
387404 /// </summary>
388- protected HostEnvironmentBase ( HostEnvironmentBase < TEnv > source , Random rand , bool verbose ,
405+ protected HostEnvironmentBase ( HostEnvironmentBase < TEnv > source , Random rand , IRandomSource randomSource , bool verbose ,
389406 string shortName = null , string parentFullName = null )
390407 : base ( shortName , parentFullName , verbose )
391408 {
392409 Contracts . CheckValue ( source , nameof ( source ) ) ;
393410 Contracts . CheckValueOrNull ( rand ) ;
394- _rand = rand ?? RandomUtils . Create ( ) ;
411+ _randomSource = randomSource ?? ( rand != null ? new RandomSourceAdapter ( rand ) : new RandomSourceAdapter ( RandomUtils . Create ( ) ) ) ;
412+ _rand = rand ?? ( _randomSource as Random ?? new RandomFromRandomSource ( _randomSource ) ) ;
395413 _cancelLock = new object ( ) ;
396414
397415 // This fork shares some stuff with the master.
@@ -419,7 +437,8 @@ public IHost Register(string name, int? seed = null, bool? verbose = null)
419437 {
420438 _children . RemoveAll ( r => r . TryGetTarget ( out IHost _ ) == false ) ;
421439 Random rand = ( seed . HasValue ) ? RandomUtils . Create ( seed . Value ) : RandomUtils . Create ( _rand ) ;
422- host = RegisterCore ( this , name , Master ? . FullName , rand , verbose ?? Verbose ) ;
440+ IRandomSource randomSource = new RandomSourceAdapter ( rand ) ;
441+ host = RegisterCore ( this , name , Master ? . FullName , rand , randomSource , verbose ?? Verbose ) ;
423442
424443 // Need to manually copy over the parameters
425444 //((IHostEnvironmentInternal)host).Seed = this.Seed;
@@ -433,7 +452,7 @@ public IHost Register(string name, int? seed = null, bool? verbose = null)
433452 }
434453
435454 protected abstract IHost RegisterCore ( HostEnvironmentBase < TEnv > source , string shortName ,
436- string parentFullName , Random rand , bool verbose ) ;
455+ string parentFullName , Random rand , IRandomSource randomSource , bool verbose ) ;
437456
438457 public IProgressChannel StartProgressChannel ( string name )
439458 {
@@ -659,3 +678,8 @@ public bool RemoveOption(string name)
659678 return Options . Remove ( name ) ;
660679 }
661680}
681+
682+
683+
684+
685+
0 commit comments