diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
index 2c1f853431..319b39b9c2 100644
--- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs
@@ -88,6 +88,11 @@ public static string GetUrl(string suffix)
/// Returns a that tries to download a resource from a specified url, and returns the path to which it was
/// downloaded, and an exception if one was thrown.
///
+ ///
+ /// The function checks whether or not the absolute URL with the
+ /// default host "aka.ms" formed from redirects to the default Microsoft homepage.
+ /// As such, only absolute URLs with the host "aka.ms" is supported with .
+ ///
/// The host environment.
/// A channel to provide information about the download.
/// The relative url from which to download.
@@ -109,6 +114,8 @@ public async Task EnsureResourceAsync(IHostEnvironment
return new ResourceDownloadResults(filePath,
$"Could not create a valid URI from the base URI '{MlNetResourcesUrl}' and the relative URI '{relativeUrl}'");
}
+ if (absoluteUrl.Host != "aka.ms")
+ throw new NotSupportedException("The function ResourceManagerUtils.EnsureResourceAsync only supports downloading from URLs of the host \"aka.ms\"");
return new ResourceDownloadResults(filePath,
await DownloadFromUrlWithRetryAsync(env, ch, absoluteUrl.AbsoluteUri, fileName, timeout, filePath), absoluteUrl.AbsoluteUri);
}
@@ -160,27 +167,8 @@ private async Task DownloadFromUrlAsync(IHostEnvironment env, IChannel c
deleteNeeded = true;
return (await t).Message;
}
-
- return CheckValidDownload(ch, filePath, url, ref deleteNeeded);
- }
- }
-
- private static string CheckValidDownload(IChannel ch, string filePath, string url, ref bool deleteNeeded)
- {
- // If the relative url does not exist, aka.ms redirects to www.microsoft.com. Make sure this did not happen.
- // If the file is big then it is definitely not the redirect.
- var info = new FileInfo(filePath);
- if (info.Length > 4096)
return null;
- string error = null;
- using (var r = new StreamReader(filePath))
- {
- var text = r.ReadToEnd();
- if (text.Contains("") && text.Contains("") && text.Contains("microsoft.com"))
- error = $"The url '{url}' does not exist. Url was redirected to www.microsoft.com.";
}
- deleteNeeded = error != null;
- return error;
}
private static void TryDelete(IChannel ch, string filePath, bool warn = true)
@@ -274,6 +262,8 @@ private Exception DownloadResource(IHostEnvironment env, IChannel ch, WebClient
using (var ws = fh.CreateWriteStream())
{
var headers = webClient.ResponseHeaders.GetValues("Content-Length");
+ if (uri.Host == "aka.ms" && IsRedirectToDefaultPage(uri.AbsoluteUri))
+ throw new NotSupportedException($"The provided url ({uri}) redirects to the default url ({DefaultUrl})");
if (Utils.Size(headers) == 0 || !long.TryParse(headers[0], out var size))
size = 10000000;
@@ -311,6 +301,36 @@ private Exception DownloadResource(IHostEnvironment env, IChannel ch, WebClient
}
}
+ /// This method checks whether or not the provided aka.ms url redirects to
+ /// Microsoft's homepage, as the default faulty aka.ms URLs redirect to https://www.microsoft.com/?ref=aka .
+ /// The provided url to check
+ public bool IsRedirectToDefaultPage(string url)
+ {
+ try
+ {
+ var request = WebRequest.Create(url);
+ // FileWebRequests cannot be redirected to default aka.ms webpage
+ if (request.GetType() == typeof(FileWebRequest))
+ return false;
+ HttpWebRequest httpWebRequest = (HttpWebRequest)request;
+ httpWebRequest.AllowAutoRedirect = false;
+ HttpWebResponse httpWebResponse = (HttpWebResponse)httpWebRequest.GetResponse();
+ }
+ catch (WebException e)
+ {
+ HttpWebResponse webResponse = (HttpWebResponse)e.Response;
+ // Redirects to default url
+ if (webResponse.StatusCode == HttpStatusCode.Redirect && webResponse.Headers["Location"] == "https://www.microsoft.com/?ref=aka")
+ return true;
+ // Redirects to another url
+ else if (webResponse.StatusCode == HttpStatusCode.MovedPermanently)
+ return false;
+ else
+ return false;
+ }
+ return false;
+ }
+
public static ResourceDownloadResults GetErrorMessage(out string errorMessage, params ResourceDownloadResults[] result)
{
var errorResult = result.FirstOrDefault(res => !string.IsNullOrEmpty(res.ErrorMessage));
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
index 6afa0e2390..6333e637f7 100644
--- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs
@@ -628,7 +628,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, WordEm
{
string dir = kind == WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding ? Path.Combine("Text", "Sswe") : "WordVectors";
var url = $"{dir}/{modelFileName}";
- var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Host, ch, url, modelFileName, dir, Timeout);
+ var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, url, modelFileName, dir, Timeout);
ensureModel.Wait();
var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
if (errorResult != null)
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs
new file mode 100644
index 0000000000..ee7c102fcf
--- /dev/null
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs
@@ -0,0 +1,142 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.Runtime;
+using Xunit;
+using Xunit.Abstractions;
+
+[assembly: CollectionBehavior(DisableTestParallelization = true)]
+
+namespace Microsoft.ML.Core.Tests.UnitTests
+{
+ public class TestResourceDownload : BaseTestBaseline
+ {
+ public TestResourceDownload(ITestOutputHelper helper)
+ : base(helper)
+ {
+ }
+
+ [Fact]
+ [TestCategory("ResourceDownload")]
+ public async Task TestDownloadError()
+ {
+ var envVarOld = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable);
+ var timeoutVarOld = Environment.GetEnvironmentVariable(ResourceManagerUtils.TimeoutEnvVariable);
+ var resourcePathVarOld = Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable);
+ Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, null);
+
+ try
+ {
+ var envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable);
+ var saveToDir = GetOutputPath("copyto");
+ DeleteOutputPath("copyto", "breast-cancer.txt");
+ var sbOut = new StringBuilder();
+ var sbErr = new StringBuilder();
+
+ // Bad url.
+ if (!Uri.TryCreate("https://fake-website/fake-model.model/", UriKind.Absolute, out var badUri))
+ Fail("Uri could not be created");
+
+ Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, badUri.AbsoluteUri);
+ envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable);
+ if (envVar != badUri.AbsoluteUri)
+ Fail("Environment variable not set properly");
+
+ DeleteOutputPath("copyto", "ResNet_18_Updated.model");
+ sbOut.Clear();
+ sbErr.Clear();
+ using (var outWriter = new StringWriter(sbOut))
+ using (var errWriter = new StringWriter(sbErr))
+ {
+ var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter);
+ using (var ch = env.Start("Downloading"))
+ {
+ var fileName = "test_bad_url.model";
+ await Assert.ThrowsAsync(() => ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "Image/ResNet_18_Updated.model", fileName, saveToDir, 10 * 1000));
+
+ Log("Bad url");
+ Log($"out: {sbOut.ToString()}");
+ Log($"error: {sbErr.ToString()}");
+
+ if (File.Exists(Path.Combine(saveToDir, fileName)))
+ Fail($"File '{Path.Combine(saveToDir, fileName)}' should have been deleted.");
+ }
+ }
+
+ // Good url, bad page.
+ if (!Uri.TryCreate("https://cnn.com/", UriKind.Absolute, out var cnn))
+ Fail("Uri could not be created");
+ Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, cnn.AbsoluteUri);
+ envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable);
+ if (envVar != cnn.AbsoluteUri)
+ Fail("Environment variable not set properly");
+
+ DeleteOutputPath("copyto", "ResNet_18_Updated.model");
+ sbOut.Clear();
+ sbErr.Clear();
+ using (var outWriter = new StringWriter(sbOut))
+ using (var errWriter = new StringWriter(sbErr))
+ {
+ var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter);
+ using (var ch = env.Start("Downloading"))
+ {
+ var fileName = "test_cnn_page_does_not_exist.model";
+ await Assert.ThrowsAsync(() => ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "Image/ResNet_18_Updated.model", fileName, saveToDir, 10 * 1000));
+
+ Log("Good url, bad page");
+ Log($"out: {sbOut.ToString()}");
+ Log($"error: {sbErr.ToString()}");
+
+ if (File.Exists(Path.Combine(saveToDir, fileName)))
+ Fail($"File '{Path.Combine(saveToDir, fileName)}' should have been deleted.");
+ }
+ }
+
+ //Good url, good page
+ Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, envVarOld);
+ envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable);
+ if (envVar != envVarOld)
+ Fail("Environment variable not set properly");
+
+ DeleteOutputPath("copyto", "sentiment.emd");
+ sbOut.Clear();
+ sbErr.Clear();
+ using (var outWriter = new StringWriter(sbOut))
+ using (var errWriter = new StringWriter(sbErr))
+ {
+ var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter);
+ using (var ch = env.Start("Downloading"))
+ {
+ var fileName = "sentiment.emd";
+ var t = ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "text/Sswe/sentiment.emd", fileName, saveToDir, 1 * 60 * 1000);
+ var results = await t;
+
+ if (results.ErrorMessage != null)
+ Fail(String.Format("Expected zero length error string. Received error: {0}", results.ErrorMessage));
+ if (t.Status != TaskStatus.RanToCompletion)
+ Fail("Download did not complete succesfully");
+ if (!File.Exists(GetOutputPath("copyto", "sentiment.emd")))
+ {
+ Fail($"File '{GetOutputPath("copyto", "sentiment.emd")}' does not exist. " +
+ $"File was downloaded to '{results.FileName}' instead." +
+ $"MICROSOFTML_RESOURCE_PATH is set to {Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable)}");
+ }
+ }
+ }
+ Done();
+ }
+ finally
+ {
+ // Set environment variable back to its old value.
+ Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, envVarOld);
+ Environment.SetEnvironmentVariable(ResourceManagerUtils.TimeoutEnvVariable, timeoutVarOld);
+ Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, resourcePathVarOld);
+ }
+ }
+ }
+}