Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conf/zeppelin-env.sh.template
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# export ZEPPELIN_LOG_DIR # Where log files are stored. PWD by default.
# export ZEPPELIN_PID_DIR # The pid files are stored. /tmp by default.
# export ZEPPELIN_SERVER_ORIGINS # comma separated domains to allow. Empty value autodetects. '*' for allow all. (default '*')
# export ZEPPELIN_NOTEBOOK_DIR # Where notebook saved
# export ZEPPELIN_NOTEBOOK_HOMESCREEN # Id of notebook to be displayed in homescreen. ex) 2A94M5J1Z
# export ZEPPELIN_NOTEBOOK_HOMESCREEN_HIDE # hide homescreen notebook from list when this value set to "true". default "false"
Expand Down
6 changes: 6 additions & 0 deletions conf/zeppelin-site.xml.template
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
<description>Server port.</description>
</property>

<property>
<name>zeppelin.server.origins</name>
<value>*</value>
<description>Comma separated list of domains to allow. Empty value autodetects. * for allow all. eg. '', 'www.domain.com, 'www.domain1.com,www.domain2.com', '*' </description>
</property>

<property>
<name>zeppelin.notebook.dir</name>
<value>notebook</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@
*
*/
public class CorsFilter implements Filter {

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
OriginValidator originValidator = OriginValidator.singleton();

String sourceHost = request.getServerName();
String currentHost = java.net.InetAddress.getLocalHost().getHostName();
String origin = "";
if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) {
origin = ((HttpServletRequest) request).getHeader("Origin");

if (originValidator.validate(sourceHost)) {
origin = ((HttpServletRequest) request).getHeader("Origin");
}

if (((HttpServletRequest) request).getMethod().equals("OPTIONS")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.server;

import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.LinkedList;
import java.util.List;

import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Validates given host name is allowed
*/
public class OriginValidator {
Logger logger = LoggerFactory.getLogger(OriginValidator.class);

private static OriginValidator singletonInstance;
private ZeppelinConfiguration conf;
private final List<String> allowedOrigins;
private final String ALLOW_ALL = "*";

public OriginValidator(ZeppelinConfiguration conf) {
this.conf = conf;
allowedOrigins = new LinkedList<String>();
initAllowedOrigins();
singletonInstance = this;
}

/**
*
* @param origin origin to check
* @return
*/
public boolean validate(String origin) {
try {
// just get host if origin is form of URI
URI sourceUri = new URI(origin);
String sourceHost = sourceUri.getHost();
if (sourceHost != null && !sourceHost.isEmpty()) {
origin = sourceHost;
}
} catch (URISyntaxException e) {
// we can silently ignore this error
}

if (origin == null) {
return false;
}

for (String p : allowedOrigins) {
if (p == null || p.trim().length() == 0) {
continue;
}
if (p.trim().compareToIgnoreCase(origin) == 0 || p.trim().equals(ALLOW_ALL)) {
return true;
}
}
return false;
}

private void initAllowedOrigins() {
String currentHost;
try {
currentHost = java.net.InetAddress.getLocalHost().getHostName();
allowedOrigins.add(currentHost);
} catch (UnknownHostException e) {
logger.error("Can't get hostname", e);
}


String origins = conf.getString(ConfVars.ZEPPELIN_SERVER_ORIGINS);
if (origins == null || origins.length() == 0) {
return;
} else {
for (String origin : origins.split(",")) {
allowedOrigins.add(origin);
}
}
}

public static OriginValidator singleton() {
return singletonInstance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ public class ZeppelinServer extends Application {

public static Server jettyServer;

private static OriginValidator originValidator;

private InterpreterFactory replFactory;

private NotebookRepo notebookRepo;
Expand All @@ -93,8 +95,10 @@ public static void main(String[] args) throws Exception {
*/
final ServletContextHandler swagger = setupSwaggerContextHandler(conf);

originValidator = new OriginValidator(conf);

// Notebook server
final ServletContextHandler notebook = setupNotebookServer(conf);
final ServletContextHandler notebook = setupNotebookServer(conf, originValidator);

// Web UI
final WebAppContext webApp = setupWebAppContext(conf);
Expand Down Expand Up @@ -161,10 +165,11 @@ private static Server setupJettyServer(ZeppelinConfiguration conf)
return server;
}

private static ServletContextHandler setupNotebookServer(ZeppelinConfiguration conf)
private static ServletContextHandler setupNotebookServer(
ZeppelinConfiguration conf, OriginValidator originValidator)
throws Exception {

notebookServer = new NotebookServer();
notebookServer = new NotebookServer(originValidator);
final ServletHolder servletHolder = new ServletHolder(notebookServer);
servletHolder.setInitParameter("maxTextMessageSize", "1024000");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.zeppelin.scheduler.Job;
import org.apache.zeppelin.scheduler.Job.Status;
import org.apache.zeppelin.scheduler.JobListener;
import org.apache.zeppelin.server.OriginValidator;
import org.apache.zeppelin.server.ZeppelinServer;
import org.apache.zeppelin.socket.Message.OP;
import org.eclipse.jetty.websocket.WebSocket;
Expand All @@ -63,31 +64,17 @@ public class NotebookServer extends WebSocketServlet implements
Gson gson = new Gson();
final Map<String, List<NotebookSocket>> noteSocketMap = new HashMap<>();
final List<NotebookSocket> connectedSockets = new LinkedList<>();
private OriginValidator originValidator;

public NotebookServer(OriginValidator originValidator) {
this.originValidator = originValidator;
}
private Notebook notebook() {
return ZeppelinServer.notebook;
}
@Override
public boolean checkOrigin(HttpServletRequest request, String origin) {
URI sourceUri = null;
String currentHost = null;

try {
sourceUri = new URI(origin);
currentHost = java.net.InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
e.printStackTrace();
}
catch (URISyntaxException e) {
e.printStackTrace();
}

String sourceHost = sourceUri.getHost();
if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) {
return true;
}

return false;
return originValidator.validate(origin);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package org.apache.zeppelin.server;

import org.apache.zeppelin.conf.ZeppelinConfiguration;
import org.apache.zeppelin.conf.ZeppelinConfiguration.ConfVars;
import org.apache.zeppelin.socket.TestHttpServletRequest;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -28,6 +30,7 @@
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletResponse;

import java.io.IOException;

import static org.mockito.Matchers.anyString;
Expand All @@ -42,54 +45,64 @@
*/
public class CorsFilterTests {

public static String[] headers = new String[8];
public static Integer count = 0;

@Test
public void ValidCorsFilterTest() throws IOException, ServletException {
CorsFilter filter = new CorsFilter();
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
FilterChain mockedFilterChain = mock(FilterChain.class);
TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
when(mockRequest.getServerName()).thenReturn("localhost");


doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
headers[count] = invocationOnMock.getArguments()[1].toString();
count++;
return null;
}
}).when(mockResponse).addHeader(anyString(), anyString());

filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
Assert.assertTrue(headers[0].equals("http://localhost:8080"));
}

@Test
public void InvalidCorsFilterTest() throws IOException, ServletException {
CorsFilter filter = new CorsFilter();
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
FilterChain mockedFilterChain = mock(FilterChain.class);
TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn("http://evillocalhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
when(mockRequest.getServerName()).thenReturn("evillocalhost");


doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
headers[count] = invocationOnMock.getArguments()[1].toString();
count++;
return null;
}
}).when(mockResponse).addHeader(anyString(), anyString());

filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
Assert.assertTrue(headers[0].equals(""));
}
public static String[] headers = new String[10];
public static Integer count = 0;

@Test
public void ValidCorsFilterTest() throws IOException, ServletException {
System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "localhost");

ZeppelinConfiguration conf = ZeppelinConfiguration.create();
OriginValidator originValidator = new OriginValidator(conf);

CorsFilter filter = new CorsFilter();
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
FilterChain mockedFilterChain = mock(FilterChain.class);
TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
when(mockRequest.getServerName()).thenReturn("localhost");
when(mockRequest.getScheme()).thenReturn("http");

doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
headers[count] = invocationOnMock.getArguments()[1].toString();
count++;
return null;
}
}).when(mockResponse).addHeader(anyString(), anyString());

filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
Assert.assertEquals("http://localhost:8080", headers[0]);
}

@Test
public void InvalidCorsFilterTest() throws IOException, ServletException {
System.setProperty(ConfVars.ZEPPELIN_SERVER_ORIGINS.getVarName(), "");

ZeppelinConfiguration conf = ZeppelinConfiguration.create();
OriginValidator originValidator = new OriginValidator(conf);

CorsFilter filter = new CorsFilter();
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
FilterChain mockedFilterChain = mock(FilterChain.class);
TestHttpServletRequest mockRequest = mock(TestHttpServletRequest.class);
when(mockRequest.getHeader("Origin")).thenReturn(
"http://evillocalhost:8080");
when(mockRequest.getMethod()).thenReturn("Empty");
when(mockRequest.getServerName()).thenReturn("evillocalhost");

doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
headers[count] = invocationOnMock.getArguments()[1].toString();
count++;
return null;
}
}).when(mockResponse).addHeader(anyString(), anyString());

filter.doFilter(mockRequest, mockResponse, mockedFilterChain);
Assert.assertEquals("", headers[0]);
}
}
Loading