diff --git a/conf/zeppelin-env.sh.template b/conf/zeppelin-env.sh.template
index a5beda7859f..e6f78fbeebc 100644
--- a/conf/zeppelin-env.sh.template
+++ b/conf/zeppelin-env.sh.template
@@ -25,6 +25,7 @@
# export ZEPPELIN_LOG_DIR # Where log files are stored. PWD by default.
# export ZEPPELIN_PID_DIR # The pid files are stored. /tmp by default.
+# export ZEPPELIN_SERVER_ORIGINS # comma separated domains to allow. Empty value autodetects. '*' for allow all. (default '*')
# export ZEPPELIN_NOTEBOOK_DIR # Where notebook saved
# export ZEPPELIN_NOTEBOOK_HOMESCREEN # Id of notebook to be displayed in homescreen. ex) 2A94M5J1Z
# export ZEPPELIN_NOTEBOOK_HOMESCREEN_HIDE # hide homescreen notebook from list when this value set to "true". default "false"
diff --git a/conf/zeppelin-site.xml.template b/conf/zeppelin-site.xml.template
index 57d1b23b2ce..3c1846a2d30 100644
--- a/conf/zeppelin-site.xml.template
+++ b/conf/zeppelin-site.xml.template
@@ -31,6 +31,12 @@
Server port.
+
+ zeppelin.server.origins
+ *
+ Comma separated list of domains to allow. Empty value autodetects. * for allow all. eg. '', 'www.domain.com, 'www.domain1.com,www.domain2.com', '*'
+
+
zeppelin.notebook.dir
notebook
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
index c2a137f169d..1a65ae1e693 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java
@@ -37,15 +37,17 @@
*
*/
public class CorsFilter implements Filter {
-
+
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
+ OriginValidator originValidator = OriginValidator.singleton();
+
String sourceHost = request.getServerName();
- String currentHost = java.net.InetAddress.getLocalHost().getHostName();
String origin = "";
- if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) {
- origin = ((HttpServletRequest) request).getHeader("Origin");
+
+ if (originValidator.validate(sourceHost)) {
+ origin = ((HttpServletRequest) request).getHeader("Origin");
}
if (((HttpServletRequest) request).getMethod().equals("OPTIONS")) {
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java
new file mode 100644
index 00000000000..0a7cca24c61
--- /dev/null
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/OriginValidator.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.zeppelin.server;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.UnknownHostException;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.zeppelin.conf.ZeppelinConfiguration;
+import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Validates given host name is allowed
+ */
+public class OriginValidator {
+ Logger logger = LoggerFactory.getLogger(OriginValidator.class);
+
+ private static OriginValidator singletonInstance;
+ private ZeppelinConfiguration conf;
+ private final List allowedOrigins;
+ private final String ALLOW_ALL = "*";
+
+ public OriginValidator(ZeppelinConfiguration conf) {
+ this.conf = conf;
+ allowedOrigins = new LinkedList();
+ initAllowedOrigins();
+ singletonInstance = this;
+ }
+
+ /**
+ *
+ * @param origin origin to check
+ * @return
+ */
+ public boolean validate(String origin) {
+ try {
+ // just get host if origin is form of URI
+ URI sourceUri = new URI(origin);
+ String sourceHost = sourceUri.getHost();
+ if (sourceHost != null && !sourceHost.isEmpty()) {
+ origin = sourceHost;
+ }
+ } catch (URISyntaxException e) {
+ // we can silently ignore this error
+ }
+
+ if (origin == null) {
+ return false;
+ }
+
+ for (String p : allowedOrigins) {
+ if (p == null || p.trim().length() == 0) {
+ continue;
+ }
+ if (p.trim().compareToIgnoreCase(origin) == 0 || p.trim().equals(ALLOW_ALL)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private void initAllowedOrigins() {
+ String currentHost;
+ try {
+ currentHost = java.net.InetAddress.getLocalHost().getHostName();
+ allowedOrigins.add(currentHost);
+ } catch (UnknownHostException e) {
+ logger.error("Can't get hostname", e);
+ }
+
+
+ String origins = conf.getString(ConfVars.ZEPPELIN_SERVER_ORIGINS);
+ if (origins == null || origins.length() == 0) {
+ return;
+ } else {
+ for (String origin : origins.split(",")) {
+ allowedOrigins.add(origin);
+ }
+ }
+ }
+
+ public static OriginValidator singleton() {
+ return singletonInstance;
+ }
+}
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java
index ad1d9078952..860bceff46a 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/ZeppelinServer.java
@@ -76,6 +76,8 @@ public class ZeppelinServer extends Application {
public static Server jettyServer;
+ private static OriginValidator originValidator;
+
private InterpreterFactory replFactory;
private NotebookRepo notebookRepo;
@@ -93,8 +95,10 @@ public static void main(String[] args) throws Exception {
*/
final ServletContextHandler swagger = setupSwaggerContextHandler(conf);
+ originValidator = new OriginValidator(conf);
+
// Notebook server
- final ServletContextHandler notebook = setupNotebookServer(conf);
+ final ServletContextHandler notebook = setupNotebookServer(conf, originValidator);
// Web UI
final WebAppContext webApp = setupWebAppContext(conf);
@@ -161,10 +165,11 @@ private static Server setupJettyServer(ZeppelinConfiguration conf)
return server;
}
- private static ServletContextHandler setupNotebookServer(ZeppelinConfiguration conf)
+ private static ServletContextHandler setupNotebookServer(
+ ZeppelinConfiguration conf, OriginValidator originValidator)
throws Exception {
- notebookServer = new NotebookServer();
+ notebookServer = new NotebookServer(originValidator);
final ServletHolder servletHolder = new ServletHolder(notebookServer);
servletHolder.setInitParameter("maxTextMessageSize", "1024000");
diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java
index 5467fe69f52..1ce7114c658 100644
--- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java
+++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java
@@ -41,6 +41,7 @@
import org.apache.zeppelin.scheduler.Job;
import org.apache.zeppelin.scheduler.Job.Status;
import org.apache.zeppelin.scheduler.JobListener;
+import org.apache.zeppelin.server.OriginValidator;
import org.apache.zeppelin.server.ZeppelinServer;
import org.apache.zeppelin.socket.Message.OP;
import org.eclipse.jetty.websocket.WebSocket;
@@ -63,31 +64,17 @@ public class NotebookServer extends WebSocketServlet implements
Gson gson = new Gson();
final Map> noteSocketMap = new HashMap<>();
final List connectedSockets = new LinkedList<>();
+ private OriginValidator originValidator;
+ public NotebookServer(OriginValidator originValidator) {
+ this.originValidator = originValidator;
+ }
private Notebook notebook() {
return ZeppelinServer.notebook;
}
@Override
public boolean checkOrigin(HttpServletRequest request, String origin) {
- URI sourceUri = null;
- String currentHost = null;
-
- try {
- sourceUri = new URI(origin);
- currentHost = java.net.InetAddress.getLocalHost().getHostName();
- } catch (UnknownHostException e) {
- e.printStackTrace();
- }
- catch (URISyntaxException e) {
- e.printStackTrace();
- }
-
- String sourceHost = sourceUri.getHost();
- if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) {
- return true;
- }
-
- return false;
+ return originValidator.validate(origin);
}
@Override
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java
index 3c9152d35cc..d6bf3e2ddfc 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTests.java
@@ -19,6 +19,8 @@
*/
package org.apache.zeppelin.server;
+import org.apache.zeppelin.conf.ZeppelinConfiguration;
+import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars;
import org.apache.zeppelin.socket.TestHttpServletRequest;
import org.junit.Assert;
import org.junit.Test;
@@ -28,6 +30,7 @@
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletResponse;
+
import java.io.IOException;
import static org.mockito.Matchers.anyString;
@@ -42,54 +45,64 @@
*/
public class CorsFilterTests {
- public static String[] headers = new String[8];
- public static Integer count = 0;
-
- @Test
- public void ValidCorsFilterTest() throws IOException, ServletException {
- CorsFilter filter = new CorsFilter();
- HttpServletResponse mockResponse = mock(HttpServletResponse.class);
- FilterChain mockedFilterChain = mock(FilterChain.class);
- TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
- when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080");
- when(mockRequest.getMethod()).thenReturn("Empty");
- when(mockRequest.getServerName()).thenReturn("localhost");
-
-
- doAnswer(new Answer() {
- @Override
- public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
- headers[count] = invocationOnMock.getArguments()[1].toString();
- count++;
- return null;
- }
- }).when(mockResponse).addHeader(anyString(), anyString());
-
- filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
- Assert.assertTrue(headers[0].equals("http://localhost:8080"));
- }
-
- @Test
- public void InvalidCorsFilterTest() throws IOException, ServletException {
- CorsFilter filter = new CorsFilter();
- HttpServletResponse mockResponse = mock(HttpServletResponse.class);
- FilterChain mockedFilterChain = mock(FilterChain.class);
- TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
- when(mockRequest.getHeader("Origin")).thenReturn("http://evillocalhost:8080");
- when(mockRequest.getMethod()).thenReturn("Empty");
- when(mockRequest.getServerName()).thenReturn("evillocalhost");
-
-
- doAnswer(new Answer() {
- @Override
- public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
- headers[count] = invocationOnMock.getArguments()[1].toString();
- count++;
- return null;
- }
- }).when(mockResponse).addHeader(anyString(), anyString());
-
- filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
- Assert.assertTrue(headers[0].equals(""));
- }
+ public static String[] headers = new String[10];
+ public static Integer count = 0;
+
+ @Test
+ public void ValidCorsFilterTest() throws IOException, ServletException {
+ System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "localhost");
+
+ ZeppelinConfiguration conf = ZeppelinConfiguration.create();
+ OriginValidator originValidator = new OriginValidator(conf);
+
+ CorsFilter filter = new CorsFilter();
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
+ when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080");
+ when(mockRequest.getMethod()).thenReturn("Empty");
+ when(mockRequest.getServerName()).thenReturn("localhost");
+ when(mockRequest.getScheme()).thenReturn("http");
+
+ doAnswer(new Answer() {
+ @Override
+ public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+ headers[count] = invocationOnMock.getArguments()[1].toString();
+ count++;
+ return null;
+ }
+ }).when(mockResponse).addHeader(anyString(), anyString());
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+ Assert.assertEquals("http://localhost:8080", headers[0]);
+ }
+
+ @Test
+ public void InvalidCorsFilterTest() throws IOException, ServletException {
+ System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "");
+
+ ZeppelinConfiguration conf = ZeppelinConfiguration.create();
+ OriginValidator originValidator = new OriginValidator(conf);
+
+ CorsFilter filter = new CorsFilter();
+ HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+ FilterChain mockedFilterChain = mock(FilterChain.class);
+ TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
+ when(mockRequest.getHeader("Origin")).thenReturn(
+ "http://evillocalhost:8080");
+ when(mockRequest.getMethod()).thenReturn("Empty");
+ when(mockRequest.getServerName()).thenReturn("evillocalhost");
+
+ doAnswer(new Answer() {
+ @Override
+ public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+ headers[count] = invocationOnMock.getArguments()[1].toString();
+ count++;
+ return null;
+ }
+ }).when(mockResponse).addHeader(anyString(), anyString());
+
+ filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
+ Assert.assertEquals("", headers[0]);
+ }
}
diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java
index c262593fdaf..48a42b8371a 100644
--- a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java
+++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java
@@ -19,6 +19,9 @@
*/
package org.apache.zeppelin.socket;
+import org.apache.zeppelin.conf.ZeppelinConfiguration;
+import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars;
+import org.apache.zeppelin.server.OriginValidator;
import org.junit.Assert;
import org.junit.Test;
@@ -31,20 +34,28 @@
* @author joelz
*
*/
- public class NotebookServerTests {
+public class NotebookServerTests {
- @Test
- public void CheckOrigin() throws UnknownHostException {
- NotebookServer server = new NotebookServer();
- Assert.assertTrue(server.checkOrigin(new TestHttpServletRequest(),
- "http://" + java.net.InetAddress.getLocalHost().getHostName() + ":8080"));
- }
+ @Test
+ public void CheckOrigin() throws UnknownHostException {
+ System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "");
- @Test
- public void CheckInvalidOrigin(){
- NotebookServer server = new NotebookServer();
- Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), "http://evillocalhost:8080"));
- }
+ ZeppelinConfiguration conf = ZeppelinConfiguration.create();
+ OriginValidator originValidator = new OriginValidator(conf);
+ NotebookServer server = new NotebookServer(originValidator);
+ Assert.assertTrue(server
+ .checkOrigin(new TestHttpServletRequest(), "http://"
+ + java.net.InetAddress.getLocalHost().getHostName() + ":8080"));
+ }
+ @Test
+ public void CheckInvalidOrigin() {
+ System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "");
+ ZeppelinConfiguration conf = ZeppelinConfiguration.create();
+ OriginValidator originValidator = new OriginValidator(conf);
+ NotebookServer server = new NotebookServer(originValidator);
+ Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(),
+ "http://evillocalhost:8080"));
+ }
}
diff --git a/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java b/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
index 6fda2b2c600..db4cbb6cae2 100644
--- a/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
+++ b/zeppelin-zengine/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java
@@ -370,6 +370,7 @@ public static enum ConfVars {
ZEPPELIN_HOME("zeppelin.home", "../"),
ZEPPELIN_ADDR("zeppelin.server.addr", "0.0.0.0"),
ZEPPELIN_PORT("zeppelin.server.port", 8080),
+ ZEPPELIN_SERVER_ORIGINS("zeppelin.server.origins", "*"),
ZEPPELIN_SSL("zeppelin.ssl", false),
ZEPPELIN_SSL_CLIENT_AUTH("zeppelin.ssl.client.auth", false),
ZEPPELIN_SSL_KEYSTORE_PATH("zeppelin.ssl.keystore.path", "keystore"),