diff --git a/src/GraphRequest.ts b/src/GraphRequest.ts index 743af0a28..e7f5999df 100644 --- a/src/GraphRequest.ts +++ b/src/GraphRequest.ts @@ -283,7 +283,7 @@ export class GraphRequest { */ private parseQueryParamenterString(queryParameter: string): void { /* The query key-value pair must be split on the first equals sign to avoid errors in parsing nested query parameters. - Example-> "/me?$expand=home($select=city)" */ + Example-> "/me?$expand=home($select=city)" */ if (this.isValidQueryKeyValuePair(queryParameter)) { const indexOfFirstEquals = queryParameter.indexOf("="); const paramKey = queryParameter.substring(0, indexOfFirstEquals); @@ -291,7 +291,7 @@ export class GraphRequest { this.setURLComponentsQueryParamater(paramKey, paramValue); } else { /* Push values which are not of key-value structure. - Example-> Handle an invalid input->.query(test), .query($select($select=name)) and let the Graph API respond with the error in the URL*/ + Example-> Handle an invalid input->.query(test), .query($select($select=name)) and let the Graph API respond with the error in the URL*/ this.urlComponents.otherURLQueryOptions.push(queryParameter); } } @@ -367,12 +367,15 @@ export class GraphRequest { let rawResponse: Response; const middlewareControl = new MiddlewareControl(this._middlewareOptions); this.updateRequestOptions(options); + const customHosts = this.config?.customHosts; try { const context: Context = await this.httpClient.sendRequest({ request, options, middlewareControl, + customHosts, }); + rawResponse = context.response; const response: any = await GraphResponseHandler.getResponse(rawResponse, this._responseType, callback); return response; diff --git a/src/GraphRequestUtil.ts b/src/GraphRequestUtil.ts index 78e51f021..47fc6e17f 100644 --- a/src/GraphRequestUtil.ts +++ b/src/GraphRequestUtil.ts @@ -9,6 +9,7 @@ * @module GraphRequestUtil */ import { GRAPH_URLS } from "./Constants"; +import { GraphClientError } from "./GraphClientError"; /** * To hold list of OData query params */ @@ -65,6 +66,27 @@ export const serializeContent = (content: any): any => { * @returns {boolean} - Returns true if the url is a Graph URL */ export const isGraphURL = (url: string): boolean => { + return isValidEndpoint(url); +}; + +/** + * Checks if the url is for one of the custom hosts provided during client initialization + * @param {string} url - The url to be verified + * @param {Set} customHosts - The url to be verified + * @returns {boolean} - Returns true if the url is a for a custom host + */ +export const isCustomHost = (url: string, customHosts: Set): boolean => { + customHosts.forEach((host) => isCustomHostValid(host)); + return isValidEndpoint(url, customHosts); +}; + +/** + * Checks if the url is for one of the provided hosts. + * @param {string} url - The url to be verified + * @param {Set} allowedHosts - A set of hosts. + * @returns {boolean} - Returns true is for one of the provided endpoints. + */ +const isValidEndpoint = (url: string, allowedHosts: Set = GRAPH_URLS): boolean => { // Valid Graph URL pattern - https://graph.microsoft.com/{version}/{resource}?{query-parameters} // Valid Graph URL example - https://graph.microsoft.com/v1.0/ url = url.toLowerCase(); @@ -79,13 +101,23 @@ export const isGraphURL = (url: string): boolean => { if (endOfHostStrPos !== -1) { if (startofPortNoPos !== -1 && startofPortNoPos < endOfHostStrPos) { hostName = url.substring(0, startofPortNoPos); - return GRAPH_URLS.has(hostName); + return allowedHosts.has(hostName); } // Parse out the host hostName = url.substring(0, endOfHostStrPos); - return GRAPH_URLS.has(hostName); + return allowedHosts.has(hostName); } } return false; }; + +/** + * Throws error if the string is not a valid host/hostname and contains other url parts. + * @param {string} url - The host to be verified + */ +const isCustomHostValid = (host: string) => { + if (host.indexOf("/") !== -1) { + throw new GraphClientError("Please add only hosts or hostnames to the CustomHosts config. If the url is `http://example.com:3000/`, host is `example:3000`"); + } +}; diff --git a/src/IClientOptions.ts b/src/IClientOptions.ts index c3b0c2ea2..ee9f62059 100644 --- a/src/IClientOptions.ts +++ b/src/IClientOptions.ts @@ -18,7 +18,9 @@ import { Middleware } from "./middleware/IMiddleware"; * @property {string} [defaultVersion] - The default version that needs to be used while making graph api request * @property {FetchOptions} [fetchOptions] - The options for fetch request * @property {Middleware| Middleware[]} [middleware] - The first middleware of the middleware chain or an array of the Middleware handlers + * @property {Set}[customHosts] - A set of custom host names. Should contain hostnames only. */ + export interface ClientOptions { authProvider?: AuthenticationProvider; baseUrl?: string; @@ -26,4 +28,8 @@ export interface ClientOptions { defaultVersion?: string; fetchOptions?: FetchOptions; middleware?: Middleware | Middleware[]; + /** + * Example - If URL is "https://test_host/v1.0", then set property "customHosts" as "customHosts: Set(["test_host"])" + */ + customHosts?: Set; } diff --git a/src/IContext.ts b/src/IContext.ts index 35695a257..c657cce75 100644 --- a/src/IContext.ts +++ b/src/IContext.ts @@ -14,6 +14,8 @@ import { MiddlewareControl } from "./middleware/MiddlewareControl"; * @property {FetchOptions} [options] - The options for the request * @property {Response} [response] - The response content * @property {MiddlewareControl} [middlewareControl] - The options for the middleware chain + * @property {Set}[customHosts] - A set of custom host names. Should contain hostnames only. + * */ export interface Context { @@ -21,4 +23,8 @@ export interface Context { options?: FetchOptions; response?: Response; middlewareControl?: MiddlewareControl; + /** + * Example - If URL is "https://test_host", then set property "customHosts" as "customHosts: Set(["test_host"])" + */ + customHosts?: Set; } diff --git a/src/IOptions.ts b/src/IOptions.ts index 460a04301..967c14ac1 100644 --- a/src/IOptions.ts +++ b/src/IOptions.ts @@ -16,6 +16,7 @@ import { FetchOptions } from "./IFetchOptions"; * @property {boolean} [debugLogging] - The boolean to enable/disable debug logging * @property {string} [defaultVersion] - The default version that needs to be used while making graph api request * @property {FetchOptions} [fetchOptions] - The options for fetch request + * @property {Set}[customHosts] - A set of custom host names. Should contain hostnames only. */ export interface Options { authProvider: AuthProvider; @@ -23,4 +24,8 @@ export interface Options { debugLogging?: boolean; defaultVersion?: string; fetchOptions?: FetchOptions; + /** + * Example - If URL is "https://test_host/v1.0", then set property "customHosts" as "customHosts: Set(["test_host"])" + */ + customHosts?: Set; } diff --git a/src/middleware/AuthenticationHandler.ts b/src/middleware/AuthenticationHandler.ts index 267f8543f..3e25f7cdf 100644 --- a/src/middleware/AuthenticationHandler.ts +++ b/src/middleware/AuthenticationHandler.ts @@ -9,7 +9,7 @@ * @module AuthenticationHandler */ -import { isGraphURL } from "../GraphRequestUtil"; +import { isCustomHost, isGraphURL } from "../GraphRequestUtil"; import { AuthenticationProvider } from "../IAuthenticationProvider"; import { AuthenticationProviderOptions } from "../IAuthenticationProviderOptions"; import { Context } from "../IContext"; @@ -62,7 +62,7 @@ export class AuthenticationHandler implements Middleware { */ public async execute(context: Context): Promise { const url = typeof context.request === "string" ? context.request : context.request.url; - if (isGraphURL(url)) { + if (isGraphURL(url) || (context.customHosts && isCustomHost(url, context.customHosts))) { let options: AuthenticationHandlerOptions; if (context.middlewareControl instanceof MiddlewareControl) { options = context.middlewareControl.getMiddlewareOptions(AuthenticationHandlerOptions) as AuthenticationHandlerOptions; diff --git a/src/middleware/TelemetryHandler.ts b/src/middleware/TelemetryHandler.ts index 3b609a4a0..77278437c 100644 --- a/src/middleware/TelemetryHandler.ts +++ b/src/middleware/TelemetryHandler.ts @@ -8,7 +8,7 @@ /** * @module TelemetryHandler */ -import { isGraphURL } from "../GraphRequestUtil"; +import { isCustomHost, isGraphURL } from "../GraphRequestUtil"; import { Context } from "../IContext"; import { PACKAGE_VERSION } from "../Version"; import { Middleware } from "./IMiddleware"; @@ -65,7 +65,7 @@ export class TelemetryHandler implements Middleware { */ public async execute(context: Context): Promise { const url = typeof context.request === "string" ? context.request : context.request.url; - if (isGraphURL(url)) { + if (isGraphURL(url) || (context.customHosts && isCustomHost(url, context.customHosts))) { // Add telemetry only if the request url is a Graph URL. // Errors are reported as in issue #265 if headers are present when redirecting to a non Graph URL let clientRequestId: string = getRequestHeader(context.request, context.options, TelemetryHandler.CLIENT_REQUEST_ID_HEADER); diff --git a/test/common/core/Client.ts b/test/common/core/Client.ts index ce655b427..eae57d88f 100644 --- a/test/common/core/Client.ts +++ b/test/common/core/Client.ts @@ -8,6 +8,7 @@ import "isomorphic-fetch"; import { assert } from "chai"; +import * as sinon from "sinon"; import { CustomAuthenticationProvider, TelemetryHandler } from "../../../src"; import { Client } from "../../../src/Client"; @@ -148,6 +149,63 @@ describe("Client.ts", () => { assert.equal(error.customError, customError); } }); + + it("Init middleware with custom hosts", async () => { + const accessToken = "DUMMY_TOKEN"; + const provider: AuthProvider = (done) => { + done(null, "DUMMY_TOKEN"); + }; + + const options = new ChaosHandlerOptions(ChaosStrategy.MANUAL, "Testing chained middleware array", 200, 100, ""); + const chaosHandler = new ChaosHandler(options); + + const authHandler = new AuthenticationHandler(new CustomAuthenticationProvider(provider)); + + const telemetry = new TelemetryHandler(); + const middleware = [authHandler, telemetry, chaosHandler]; + + const customHost = "test_custom"; + const customHosts = new Set([customHost]); + const client = Client.initWithMiddleware({ middleware, customHosts }); + + const spy = sinon.spy(telemetry, "execute"); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const response = await client.api(`https://${customHost}/v1.0/me`).get(); + const context = spy.getCall(0).args[0]; + + assert.equal(context.options.headers["Authorization"], `Bearer ${accessToken}`); + }); + + it("Pass invalid custom hosts", async () => { + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const accessToken = "DUMMY_TOKEN"; + const provider: AuthProvider = (done) => { + done(null, "DUMMY_TOKEN"); + }; + + const options = new ChaosHandlerOptions(ChaosStrategy.MANUAL, "Testing chained middleware array", 200, 100, ""); + const chaosHandler = new ChaosHandler(options); + + const authHandler = new AuthenticationHandler(new CustomAuthenticationProvider(provider)); + + const telemetry = new TelemetryHandler(); + const middleware = [authHandler, telemetry, chaosHandler]; + + const customHost = "https://test_custom"; + const customHosts = new Set([customHost]); + const client = Client.initWithMiddleware({ middleware, customHosts }); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const response = await client.api(`https://${customHost}/v1.0/me`).get(); + + throw new Error("Test fails - Error expected when custom host is not valid"); + } catch (error) { + assert.isDefined(error); + assert.isDefined(error.message); + assert.equal(error.message, "Please add only hosts or hostnames to the CustomHosts config. If the url is `http://example.com:3000/`, host is `example:3000`"); + } + }); }); describe("init", () => { diff --git a/test/common/middleware/AuthenticationHandler.ts b/test/common/middleware/AuthenticationHandler.ts index a51a847d5..677fdc33c 100644 --- a/test/common/middleware/AuthenticationHandler.ts +++ b/test/common/middleware/AuthenticationHandler.ts @@ -7,11 +7,15 @@ import { assert } from "chai"; +import { ChaosHandler, ChaosHandlerOptions, ChaosStrategy } from "../../../src"; +import { GRAPH_BASE_URL } from "../../../src/Constants"; +import { Context } from "../../../src/IContext"; import { AuthenticationHandler } from "../../../src/middleware/AuthenticationHandler"; import { DummyAuthenticationProvider } from "../../DummyAuthenticationProvider"; const dummyAuthProvider = new DummyAuthenticationProvider(); const authHandler = new AuthenticationHandler(dummyAuthProvider); +const chaosHandler = new ChaosHandler(new ChaosHandlerOptions(ChaosStrategy.MANUAL, "TEST_MESSAGE", 200)); describe("AuthenticationHandler.ts", async () => { describe("Constructor", () => { @@ -20,4 +24,49 @@ describe("AuthenticationHandler.ts", async () => { assert.equal(authHandler["authenticationProvider"], dummyAuthProvider); }); }); + describe("Auth Headers", () => { + it("Should delete Auth header when Request object is passed with non Graph URL", async () => { + const request = new Request("test_url"); + const context: Context = { + request, + options: { + headers: { + Authorization: "TEST_VALUE", + }, + }, + }; + authHandler.setNext(chaosHandler); + await authHandler.execute(context); + assert.equal(context.options.headers["Authorization"], undefined); + }); + + it("Should contain Auth header when Request object is passed with custom URL", async () => { + const request = new Request("https://custom/"); + const context: Context = { + request, + customHosts: new Set(["custom"]), + options: { + headers: {}, + }, + }; + const accessToken = "Bearer DUMMY_TOKEN"; + + await authHandler.execute(context); + assert.equal((request as Request).headers.get("Authorization"), accessToken); + }); + + it("Should contain Auth header when Request object is passed with a valid Graph URL", async () => { + const request = new Request(GRAPH_BASE_URL); + const context: Context = { + request, + customHosts: new Set(["custom"]), + options: { + headers: {}, + }, + }; + const accessToken = "Bearer DUMMY_TOKEN"; + await authHandler.execute(context); + assert.equal((request as Request).headers.get("Authorization"), accessToken); + }); + }); });