diff --git a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs index 2c1f853431..319b39b9c2 100644 --- a/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/ResourceManagerUtils.cs @@ -88,6 +88,11 @@ public static string GetUrl(string suffix) /// Returns a that tries to download a resource from a specified url, and returns the path to which it was /// downloaded, and an exception if one was thrown. /// + /// + /// The function checks whether or not the absolute URL with the + /// default host "aka.ms" formed from redirects to the default Microsoft homepage. + /// As such, only absolute URLs with the host "aka.ms" is supported with . + /// /// The host environment. /// A channel to provide information about the download. /// The relative url from which to download. @@ -109,6 +114,8 @@ public async Task EnsureResourceAsync(IHostEnvironment return new ResourceDownloadResults(filePath, $"Could not create a valid URI from the base URI '{MlNetResourcesUrl}' and the relative URI '{relativeUrl}'"); } + if (absoluteUrl.Host != "aka.ms") + throw new NotSupportedException("The function ResourceManagerUtils.EnsureResourceAsync only supports downloading from URLs of the host \"aka.ms\""); return new ResourceDownloadResults(filePath, await DownloadFromUrlWithRetryAsync(env, ch, absoluteUrl.AbsoluteUri, fileName, timeout, filePath), absoluteUrl.AbsoluteUri); } @@ -160,27 +167,8 @@ private async Task DownloadFromUrlAsync(IHostEnvironment env, IChannel c deleteNeeded = true; return (await t).Message; } - - return CheckValidDownload(ch, filePath, url, ref deleteNeeded); - } - } - - private static string CheckValidDownload(IChannel ch, string filePath, string url, ref bool deleteNeeded) - { - // If the relative url does not exist, aka.ms redirects to www.microsoft.com. Make sure this did not happen. - // If the file is big then it is definitely not the redirect. - var info = new FileInfo(filePath); - if (info.Length > 4096) return null; - string error = null; - using (var r = new StreamReader(filePath)) - { - var text = r.ReadToEnd(); - if (text.Contains("") && text.Contains("") && text.Contains("microsoft.com")) - error = $"The url '{url}' does not exist. Url was redirected to www.microsoft.com."; } - deleteNeeded = error != null; - return error; } private static void TryDelete(IChannel ch, string filePath, bool warn = true) @@ -274,6 +262,8 @@ private Exception DownloadResource(IHostEnvironment env, IChannel ch, WebClient using (var ws = fh.CreateWriteStream()) { var headers = webClient.ResponseHeaders.GetValues("Content-Length"); + if (uri.Host == "aka.ms" && IsRedirectToDefaultPage(uri.AbsoluteUri)) + throw new NotSupportedException($"The provided url ({uri}) redirects to the default url ({DefaultUrl})"); if (Utils.Size(headers) == 0 || !long.TryParse(headers[0], out var size)) size = 10000000; @@ -311,6 +301,36 @@ private Exception DownloadResource(IHostEnvironment env, IChannel ch, WebClient } } + /// This method checks whether or not the provided aka.ms url redirects to + /// Microsoft's homepage, as the default faulty aka.ms URLs redirect to https://www.microsoft.com/?ref=aka . + /// The provided url to check + public bool IsRedirectToDefaultPage(string url) + { + try + { + var request = WebRequest.Create(url); + // FileWebRequests cannot be redirected to default aka.ms webpage + if (request.GetType() == typeof(FileWebRequest)) + return false; + HttpWebRequest httpWebRequest = (HttpWebRequest)request; + httpWebRequest.AllowAutoRedirect = false; + HttpWebResponse httpWebResponse = (HttpWebResponse)httpWebRequest.GetResponse(); + } + catch (WebException e) + { + HttpWebResponse webResponse = (HttpWebResponse)e.Response; + // Redirects to default url + if (webResponse.StatusCode == HttpStatusCode.Redirect && webResponse.Headers["Location"] == "https://www.microsoft.com/?ref=aka") + return true; + // Redirects to another url + else if (webResponse.StatusCode == HttpStatusCode.MovedPermanently) + return false; + else + return false; + } + return false; + } + public static ResourceDownloadResults GetErrorMessage(out string errorMessage, params ResourceDownloadResults[] result) { var errorResult = result.FirstOrDefault(res => !string.IsNullOrEmpty(res.ErrorMessage)); diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 6afa0e2390..6333e637f7 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -628,7 +628,7 @@ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, WordEm { string dir = kind == WordEmbeddingEstimator.PretrainedModelKind.SentimentSpecificWordEmbedding ? Path.Combine("Text", "Sswe") : "WordVectors"; var url = $"{dir}/{modelFileName}"; - var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(Host, ch, url, modelFileName, dir, Timeout); + var ensureModel = ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, url, modelFileName, dir, Timeout); ensureModel.Wait(); var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result); if (errorResult != null) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs new file mode 100644 index 0000000000..ee7c102fcf --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestResourceDownload.cs @@ -0,0 +1,142 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.RunTests; +using Microsoft.ML.Runtime; +using Xunit; +using Xunit.Abstractions; + +[assembly: CollectionBehavior(DisableTestParallelization = true)] + +namespace Microsoft.ML.Core.Tests.UnitTests +{ + public class TestResourceDownload : BaseTestBaseline + { + public TestResourceDownload(ITestOutputHelper helper) + : base(helper) + { + } + + [Fact] + [TestCategory("ResourceDownload")] + public async Task TestDownloadError() + { + var envVarOld = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + var timeoutVarOld = Environment.GetEnvironmentVariable(ResourceManagerUtils.TimeoutEnvVariable); + var resourcePathVarOld = Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable); + Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, null); + + try + { + var envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + var saveToDir = GetOutputPath("copyto"); + DeleteOutputPath("copyto", "breast-cancer.txt"); + var sbOut = new StringBuilder(); + var sbErr = new StringBuilder(); + + // Bad url. + if (!Uri.TryCreate("https://fake-website/fake-model.model/", UriKind.Absolute, out var badUri)) + Fail("Uri could not be created"); + + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, badUri.AbsoluteUri); + envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + if (envVar != badUri.AbsoluteUri) + Fail("Environment variable not set properly"); + + DeleteOutputPath("copyto", "ResNet_18_Updated.model"); + sbOut.Clear(); + sbErr.Clear(); + using (var outWriter = new StringWriter(sbOut)) + using (var errWriter = new StringWriter(sbErr)) + { + var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter); + using (var ch = env.Start("Downloading")) + { + var fileName = "test_bad_url.model"; + await Assert.ThrowsAsync(() => ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "Image/ResNet_18_Updated.model", fileName, saveToDir, 10 * 1000)); + + Log("Bad url"); + Log($"out: {sbOut.ToString()}"); + Log($"error: {sbErr.ToString()}"); + + if (File.Exists(Path.Combine(saveToDir, fileName))) + Fail($"File '{Path.Combine(saveToDir, fileName)}' should have been deleted."); + } + } + + // Good url, bad page. + if (!Uri.TryCreate("https://cnn.com/", UriKind.Absolute, out var cnn)) + Fail("Uri could not be created"); + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, cnn.AbsoluteUri); + envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + if (envVar != cnn.AbsoluteUri) + Fail("Environment variable not set properly"); + + DeleteOutputPath("copyto", "ResNet_18_Updated.model"); + sbOut.Clear(); + sbErr.Clear(); + using (var outWriter = new StringWriter(sbOut)) + using (var errWriter = new StringWriter(sbErr)) + { + var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter); + using (var ch = env.Start("Downloading")) + { + var fileName = "test_cnn_page_does_not_exist.model"; + await Assert.ThrowsAsync(() => ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "Image/ResNet_18_Updated.model", fileName, saveToDir, 10 * 1000)); + + Log("Good url, bad page"); + Log($"out: {sbOut.ToString()}"); + Log($"error: {sbErr.ToString()}"); + + if (File.Exists(Path.Combine(saveToDir, fileName))) + Fail($"File '{Path.Combine(saveToDir, fileName)}' should have been deleted."); + } + } + + //Good url, good page + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, envVarOld); + envVar = Environment.GetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable); + if (envVar != envVarOld) + Fail("Environment variable not set properly"); + + DeleteOutputPath("copyto", "sentiment.emd"); + sbOut.Clear(); + sbErr.Clear(); + using (var outWriter = new StringWriter(sbOut)) + using (var errWriter = new StringWriter(sbErr)) + { + var env = new ConsoleEnvironment(42, outWriter: outWriter, errWriter: errWriter); + using (var ch = env.Start("Downloading")) + { + var fileName = "sentiment.emd"; + var t = ResourceManagerUtils.Instance.EnsureResourceAsync(env, ch, "text/Sswe/sentiment.emd", fileName, saveToDir, 1 * 60 * 1000); + var results = await t; + + if (results.ErrorMessage != null) + Fail(String.Format("Expected zero length error string. Received error: {0}", results.ErrorMessage)); + if (t.Status != TaskStatus.RanToCompletion) + Fail("Download did not complete succesfully"); + if (!File.Exists(GetOutputPath("copyto", "sentiment.emd"))) + { + Fail($"File '{GetOutputPath("copyto", "sentiment.emd")}' does not exist. " + + $"File was downloaded to '{results.FileName}' instead." + + $"MICROSOFTML_RESOURCE_PATH is set to {Environment.GetEnvironmentVariable(Utils.CustomSearchDirEnvVariable)}"); + } + } + } + Done(); + } + finally + { + // Set environment variable back to its old value. + Environment.SetEnvironmentVariable(ResourceManagerUtils.CustomResourcesUrlEnvVariable, envVarOld); + Environment.SetEnvironmentVariable(ResourceManagerUtils.TimeoutEnvVariable, timeoutVarOld); + Environment.SetEnvironmentVariable(Utils.CustomSearchDirEnvVariable, resourcePathVarOld); + } + } + } +}