diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index c4958ee3ea..12d1ecf518 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -921,6 +921,7 @@ internal TensorFlowEstimator(IHostEnvironment env, Options options, TensorFlowMo _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowEstimator)); _options = options; _tensorFlowModel = tensorFlowModel; + tensorFlowModel.Session.graph.as_default(); var inputTuple = TensorFlowTransformer.GetInputInfo(_host, tensorFlowModel.Session, options.InputColumns); _tfInputTypes = inputTuple.tfInputTypes; var outputTuple = TensorFlowTransformer.GetOutputInfo(_host, tensorFlowModel.Session, options.OutputColumns); diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index de1d8427f1..8aba9b08a5 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -231,6 +231,21 @@ public void TestTensorFlowWithSchema() } } + [TensorFlowFact] + public void TestLoadMultipleModel() + { + var modelFile1 = "model_matmul/frozen_saved_model.pb"; + var modelFile2 = "cifar_model/frozen_model.pb"; + + MLContext context = new MLContext(seed: 1); + + TensorFlowModel model1 = context.Model.LoadTensorFlowModel(modelFile1); + TensorFlowModel model2 = context.Model.LoadTensorFlowModel(modelFile2); + + model1.ScoreTensorFlowModel(new[] { "c" }, new[] { "a", "b" }); + model2.ScoreTensorFlowModel("Output", "Input"); + } + private void ValidateTensorFlowTransformer(IDataView result) { using (var cursor = result.GetRowCursorForAllColumns())