diff --git a/tensorboard/backend/application.py b/tensorboard/backend/application.py index d32c5e22c9b..1567d06dc4f 100644 --- a/tensorboard/backend/application.py +++ b/tensorboard/backend/application.py @@ -24,6 +24,7 @@ import atexit import collections +import functools import json import os import re @@ -36,6 +37,7 @@ import six from six.moves.urllib import parse as urlparse # pylint: disable=wrong-import-order +import werkzeug from werkzeug import wrappers from tensorboard.backend import http_util @@ -324,6 +326,33 @@ def _serve_plugins_listing(self, request): response[plugin.plugin_name] = plugin_metadata return http_util.Respond(request, response, 'application/json') + def _headers_with_colab_csp(self, headers): + """Add a Content-Security-Policy facilitating Colab output frames. + + This is intended for use with the `google.colab.kernel.proxyPort` + JavaScript function available from within a Colab output frame. + + If the headers already include an explicit CSP, they are returned + unchanged. + + Args: + headers: A list of WSGI headers (key-value tuples of `str`s). + + Returns: + A new list of WSGI headers; the original is unchanged. + """ + # use a Werkzeug `Headers` object for proper case-insensitivity + headers = werkzeug.Headers(headers) + csp_key = 'Content-Security-Policy' + if csp_key not in headers: + allowed_ancestors = ' '.join([ + 'https://*.googleusercontent.com', + 'https://*.google.com', + ]) + csp = 'frame-ancestors %s' % allowed_ancestors + headers[csp_key] = csp + return headers.to_wsgi_list() + def __call__(self, environ, start_response): # pylint: disable=invalid-name """Central entry point for the TensorBoard application. @@ -344,13 +373,17 @@ class are WSGI applications. parsed_url = urlparse.urlparse(request.path) clean_path = _clean_path(parsed_url.path, self._path_prefix) + @functools.wraps(start_response) + def new_start_response(status, headers): + return start_response(status, self._headers_with_colab_csp(headers)) + # pylint: disable=too-many-function-args if clean_path in self.data_applications: - return self.data_applications[clean_path](environ, start_response) + return self.data_applications[clean_path](environ, new_start_response) else: logger.warn('path %s not found, sending 404', clean_path) return http_util.Respond(request, 'Not found', 'text/plain', code=404)( - environ, start_response) + environ, new_start_response) # pylint: enable=too-many-function-args diff --git a/tensorboard/backend/application_test.py b/tensorboard/backend/application_test.py index a96ce478b95..a4ad93cab8d 100644 --- a/tensorboard/backend/application_test.py +++ b/tensorboard/backend/application_test.py @@ -36,6 +36,7 @@ except ImportError: import mock # pylint: disable=g-import-not-at-top,unused-import +import werkzeug from werkzeug import test as werkzeug_test from werkzeug import wrappers @@ -148,6 +149,8 @@ def setUp(self): is_active_value=True, routes_mapping={ '/esmodule': lambda req: None, + '/no_csp': functools.partial(self._serve, False), + '/csp': functools.partial(self._serve, True), }, es_module_path_value='/esmodule' ), @@ -155,6 +158,14 @@ def setUp(self): app = application.TensorBoardWSGI(plugins) self.server = werkzeug_test.Client(app, wrappers.BaseResponse) + @wrappers.Request.application + def _serve(self, include_csp, request): + assert isinstance(include_csp, bool), include_csp + response = wrappers.Response('hello\n') + if include_csp: + response.headers['CONTENT-sEcUrItY-POLICY'] = "frame-ancestors 'none'" + return response + def _get_json(self, path): response = self.server.get(path) self.assertEqual(200, response.status_code) @@ -206,6 +217,27 @@ def testPluginsListing(self): } ) + def testColabCsp_whenNoCspPresent(self): + response = self.server.get('/data/plugin/baz/no_csp') + self.assertEqual( + response.headers.get('Content-Security-Policy'), + 'frame-ancestors https://*.googleusercontent.com https://*.google.com', + ) + + def testColabCsp_whenExistingCspPresent(self): + response = self.server.get('/data/plugin/baz/csp') + self.assertEqual( + response.headers.get('Content-Security-Policy'), + "frame-ancestors 'none'", + ) + + def testColabCsp_on404(self): + response = self.server.get('/asdf') + self.assertEqual(404, response.status_code) + self.assertEqual( + response.headers.get('Content-Security-Policy'), + 'frame-ancestors https://*.googleusercontent.com https://*.google.com', + ) class ApplicationBaseUrlTest(tb_test.TestCase): path_prefix = '/test'