diff --git a/conf/zeppelin-env.sh.template b/conf/zeppelin-env.sh.template index a5beda7859f..e6f78fbeebc 100644 --- a/conf/zeppelin-env.sh.template +++ b/conf/zeppelin-env.sh.template @@ -25,6 +25,7 @@ # export ZEPPELIN_LOG_DIR # Where log files are stored. PWD by default. # export ZEPPELIN_PID_DIR # The pid files are stored. /tmp by default. +# export ZEPPELIN_SERVER_ORIGINS # comma separated domains to allow. Empty value autodetects. '*' for allow all. (default '*') # export ZEPPELIN_NOTEBOOK_DIR # Where notebook saved # export ZEPPELIN_NOTEBOOK_HOMESCREEN # Id of notebook to be displayed in homescreen. ex) 2A94M5J1Z # export ZEPPELIN_NOTEBOOK_HOMESCREEN_HIDE # hide homescreen notebook from list when this value set to "true". default "false" diff --git a/conf/zeppelin-site.xml.template b/conf/zeppelin-site.xml.template index 57d1b23b2ce..3c1846a2d30 100644 --- a/conf/zeppelin-site.xml.template +++ b/conf/zeppelin-site.xml.template @@ -31,6 +31,12 @@ Server port. + + zeppelin.server.origins + * + Comma separated list of domains to allow. Empty value autodetects. * for allow all. eg. '', 'www.domain.com, 'www.domain1.com,www.domain2.com', '*' + + zeppelin.notebook.dir notebook diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java index c2a137f169d..1a65ae1e693 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java @@ -37,15 +37,17 @@ * */ public class CorsFilter implements Filter { - + @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException { + OriginValidator originValidator = OriginValidator.singleton(); + String sourceHost = request.getServerName(); - String currentHost = java.net.InetAddress.getLocalHost().getHostName(); String origin = ""; - if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) { - origin = ((HttpServletRequest) request).getHeader("Origin"); + + if (originValidator.validate(sourceHost)) { + origin = ((HttpServletRequest) request).getHeader("Origin"); } if (((HttpServletRequest) request).getMethod().equals("OPTIONS")) { diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java new file mode 100644 index 00000000000..0a7cca24c61 --- /dev/null +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.server; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.LinkedList; +import java.util.List; + +import org.apache.zeppelin.conf.ZeppelinConfiguration; +import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Validates given host name is allowed + */ +public class OriginValidator { + Logger logger = LoggerFactory.getLogger(OriginValidator.class); + + private static OriginValidator singletonInstance; + private ZeppelinConfiguration conf; + private final List allowedOrigins; + private final String ALLOW_ALL = "*"; + + public OriginValidator(ZeppelinConfiguration conf) { + this.conf = conf; + allowedOrigins = new LinkedList(); + initAllowedOrigins(); + singletonInstance = this; + } + + /** + * + * @param origin origin to check + * @return + */ + public boolean validate(String origin) { + try { + // just get host if origin is form of URI + URI sourceUri = new URI(origin); + String sourceHost = sourceUri.getHost(); + if (sourceHost != null && !sourceHost.isEmpty()) { + origin = sourceHost; + } + } catch (URISyntaxException e) { + // we can silently ignore this error + } + + if (origin == null) { + return false; + } + + for (String p : allowedOrigins) { + if (p == null || p.trim().length() == 0) { + continue; + } + if (p.trim().compareToIgnoreCase(origin) == 0 || p.trim().equals(ALLOW_ALL)) { + return true; + } + } + return false; + } + + private void initAllowedOrigins() { + String currentHost; + try { + currentHost = java.net.InetAddress.getLocalHost().getHostName(); + allowedOrigins.add(currentHost); + } catch (UnknownHostException e) { + logger.error("Can't get hostname", e); + } + + + String origins = conf.getString(ConfVars.ZEPPELIN_SERVER_ORIGINS); + if (origins == null || origins.length() == 0) { + return; + } else { + for (String origin : origins.split(",")) { + allowedOrigins.add(origin); + } + } + } + + public static OriginValidator singleton() { + return singletonInstance; + } +} diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java index ad1d9078952..860bceff46a 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java @@ -76,6 +76,8 @@ public class ZeppelinServer extends Application { public static Server jettyServer; + private static OriginValidator originValidator; + private InterpreterFactory replFactory; private NotebookRepo notebookRepo; @@ -93,8 +95,10 @@ public static void main(String[] args) throws Exception { */ final ServletContextHandler swagger = setupSwaggerContextHandler(conf); + originValidator = new OriginValidator(conf); + // Notebook server - final ServletContextHandler notebook = setupNotebookServer(conf); + final ServletContextHandler notebook = setupNotebookServer(conf, originValidator); // Web UI final WebAppContext webApp = setupWebAppContext(conf); @@ -161,10 +165,11 @@ private static Server setupJettyServer(ZeppelinConfiguration conf) return server; } - private static ServletContextHandler setupNotebookServer(ZeppelinConfiguration conf) + private static ServletContextHandler setupNotebookServer( + ZeppelinConfiguration conf, OriginValidator originValidator) throws Exception { - notebookServer = new NotebookServer(); + notebookServer = new NotebookServer(originValidator); final ServletHolder servletHolder = new ServletHolder(notebookServer); servletHolder.setInitParameter("maxTextMessageSize", "1024000"); diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java index 5467fe69f52..1ce7114c658 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java @@ -41,6 +41,7 @@ import org.apache.zeppelin.scheduler.Job; import org.apache.zeppelin.scheduler.Job.Status; import org.apache.zeppelin.scheduler.JobListener; +import org.apache.zeppelin.server.OriginValidator; import org.apache.zeppelin.server.ZeppelinServer; import org.apache.zeppelin.socket.Message.OP; import org.eclipse.jetty.websocket.WebSocket; @@ -63,31 +64,17 @@ public class NotebookServer extends WebSocketServlet implements Gson gson = new Gson(); final Map> noteSocketMap = new HashMap<>(); final List connectedSockets = new LinkedList<>(); + private OriginValidator originValidator; + public NotebookServer(OriginValidator originValidator) { + this.originValidator = originValidator; + } private Notebook notebook() { return ZeppelinServer.notebook; } @Override public boolean checkOrigin(HttpServletRequest request, String origin) { - URI sourceUri = null; - String currentHost = null; - - try { - sourceUri = new URI(origin); - currentHost = java.net.InetAddress.getLocalHost().getHostName(); - } catch (UnknownHostException e) { - e.printStackTrace(); - } - catch (URISyntaxException e) { - e.printStackTrace(); - } - - String sourceHost = sourceUri.getHost(); - if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) { - return true; - } - - return false; + return originValidator.validate(origin); } @Override diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java index 3c9152d35cc..d6bf3e2ddfc 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java @@ -19,6 +19,8 @@ */ package org.apache.zeppelin.server; +import org.apache.zeppelin.conf.ZeppelinConfiguration; +import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars; import org.apache.zeppelin.socket.TestHttpServletRequest; import org.junit.Assert; import org.junit.Test; @@ -28,6 +30,7 @@ import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletResponse; + import java.io.IOException; import static org.mockito.Matchers.anyString; @@ -42,54 +45,64 @@ */ public class CorsFilterTests { - public static String[] headers = new String[8]; - public static Integer count = 0; - - @Test - public void ValidCorsFilterTest() throws IOException, ServletException { - CorsFilter filter = new CorsFilter(); - HttpServletResponse mockResponse = mock(HttpServletResponse.class); - FilterChain mockedFilterChain = mock(FilterChain.class); - TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class); - when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080"); - when(mockRequest.getMethod()).thenReturn("Empty"); - when(mockRequest.getServerName()).thenReturn("localhost"); - - - doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - headers[count] = invocationOnMock.getArguments()[1].toString(); - count++; - return null; - } - }).when(mockResponse).addHeader(anyString(), anyString()); - - filter.doFilter(mockRequest, mockResponse, mockedFilterChain); - Assert.assertTrue(headers[0].equals("http://localhost:8080")); - } - - @Test - public void InvalidCorsFilterTest() throws IOException, ServletException { - CorsFilter filter = new CorsFilter(); - HttpServletResponse mockResponse = mock(HttpServletResponse.class); - FilterChain mockedFilterChain = mock(FilterChain.class); - TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class); - when(mockRequest.getHeader("Origin")).thenReturn("http://evillocalhost:8080"); - when(mockRequest.getMethod()).thenReturn("Empty"); - when(mockRequest.getServerName()).thenReturn("evillocalhost"); - - - doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - headers[count] = invocationOnMock.getArguments()[1].toString(); - count++; - return null; - } - }).when(mockResponse).addHeader(anyString(), anyString()); - - filter.doFilter(mockRequest, mockResponse, mockedFilterChain); - Assert.assertTrue(headers[0].equals("")); - } + public static String[] headers = new String[10]; + public static Integer count = 0; + + @Test + public void ValidCorsFilterTest() throws IOException, ServletException { + System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "localhost"); + + ZeppelinConfiguration conf = ZeppelinConfiguration.create(); + OriginValidator originValidator = new OriginValidator(conf); + + CorsFilter filter = new CorsFilter(); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080"); + when(mockRequest.getMethod()).thenReturn("Empty"); + when(mockRequest.getServerName()).thenReturn("localhost"); + when(mockRequest.getScheme()).thenReturn("http"); + + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + headers[count] = invocationOnMock.getArguments()[1].toString(); + count++; + return null; + } + }).when(mockResponse).addHeader(anyString(), anyString()); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + Assert.assertEquals("http://localhost:8080", headers[0]); + } + + @Test + public void InvalidCorsFilterTest() throws IOException, ServletException { + System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), ""); + + ZeppelinConfiguration conf = ZeppelinConfiguration.create(); + OriginValidator originValidator = new OriginValidator(conf); + + CorsFilter filter = new CorsFilter(); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class); + when(mockRequest.getHeader("Origin")).thenReturn( + "http://evillocalhost:8080"); + when(mockRequest.getMethod()).thenReturn("Empty"); + when(mockRequest.getServerName()).thenReturn("evillocalhost"); + + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + headers[count] = invocationOnMock.getArguments()[1].toString(); + count++; + return null; + } + }).when(mockResponse).addHeader(anyString(), anyString()); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + Assert.assertEquals("", headers[0]); + } } diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java index c262593fdaf..48a42b8371a 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java @@ -19,6 +19,9 @@ */ package org.apache.zeppelin.socket; +import org.apache.zeppelin.conf.ZeppelinConfiguration; +import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars; +import org.apache.zeppelin.server.OriginValidator; import org.junit.Assert; import org.junit.Test; @@ -31,20 +34,28 @@ * @author joelz * */ - public class NotebookServerTests { +public class NotebookServerTests { - @Test - public void CheckOrigin() throws UnknownHostException { - NotebookServer server = new NotebookServer(); - Assert.assertTrue(server.checkOrigin(new TestHttpServletRequest(), - "http://" + java.net.InetAddress.getLocalHost().getHostName() + ":8080")); - } + @Test + public void CheckOrigin() throws UnknownHostException { + System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), ""); - @Test - public void CheckInvalidOrigin(){ - NotebookServer server = new NotebookServer(); - Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), "http://evillocalhost:8080")); - } + ZeppelinConfiguration conf = ZeppelinConfiguration.create(); + OriginValidator originValidator = new OriginValidator(conf); + NotebookServer server = new NotebookServer(originValidator); + Assert.assertTrue(server + .checkOrigin(new TestHttpServletRequest(), "http://" + + java.net.InetAddress.getLocalHost().getHostName() + ":8080")); + } + @Test + public void CheckInvalidOrigin() { + System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), ""); + ZeppelinConfiguration conf = ZeppelinConfiguration.create(); + OriginValidator originValidator = new OriginValidator(conf); + NotebookServer server = new NotebookServer(originValidator); + Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), + "http://evillocalhost:8080")); + } } diff --git a/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java b/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java index 6fda2b2c600..db4cbb6cae2 100644 --- a/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java +++ b/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java @@ -370,6 +370,7 @@ public static enum ConfVars { ZEPPELIN_HOME("zeppelin.home", "../"), ZEPPELIN_ADDR("zeppelin.server.addr", "0.0.0.0"), ZEPPELIN_PORT("zeppelin.server.port", 8080), + ZEPPELIN_SERVER_ORIGINS("zeppelin.server.origins", "*"), ZEPPELIN_SSL("zeppelin.ssl", false), ZEPPELIN_SSL_CLIENT_AUTH("zeppelin.ssl.client.auth", false), ZEPPELIN_SSL_KEYSTORE_PATH("zeppelin.ssl.keystore.path", "keystore"),