From ea54b55bfadf6a1ab777866c2e1d03979dc049d6 Mon Sep 17 00:00:00 2001 From: joelz Date: Wed, 12 Aug 2015 12:16:29 -0700 Subject: [PATCH 1/3] Fixing issue with ZEPPELIN-173: Zeppelin websocket server is vulnerable to Cross-Site WebSocket Hijacking --- .../zeppelin/socket/NotebookServer.java | 151 +++---- .../zeppelin/socket/NotebookServerTests.java | 53 +++ .../socket/TestHttpServletRequest.java | 372 ++++++++++++++++++ .../socket/TestNotebookSocketListener.java | 41 ++ 4 files changed, 516 insertions(+), 101 deletions(-) create mode 100644 zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java create mode 100644 zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java create mode 100644 zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java index fe0d3912bb8..2f11e8a5e53 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java @@ -14,19 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.zeppelin.socket; - import java.io.IOException; -import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; - import javax.servlet.http.HttpServletRequest; - import org.apache.zeppelin.display.AngularObject; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.AngularObjectRegistryListener; @@ -46,34 +44,51 @@ import org.quartz.SchedulerException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import com.google.common.base.Strings; import com.google.gson.Gson; - /** * Zeppelin websocket service. * * @author anthonycorbacho */ public class NotebookServer extends WebSocketServlet implements - NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { - + NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { private static final Logger LOG = LoggerFactory - .getLogger(NotebookServer.class); - + .getLogger(NotebookServer.class); Gson gson = new Gson(); - Map> noteSocketMap = new HashMap>(); - List connectedSockets = new LinkedList(); + final Map> noteSocketMap = new HashMap<>(); + final List connectedSockets = new LinkedList<>(); private Notebook notebook() { return ZeppelinServer.notebook; } + @Override + public boolean checkOrigin(HttpServletRequest request, String origin) { + URI sourceUri = null; + String currentHost = null; + + try { + sourceUri = new URI(origin); + currentHost = java.net.InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + e.printStackTrace(); + } + catch (URISyntaxException e) { + e.printStackTrace(); + } + + String sourceHost = sourceUri.getHost(); + if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) { + return true; + } + + return false; + } @Override public WebSocket doWebSocketConnect(HttpServletRequest req, String protocol) { return new NotebookSocket(req, protocol, this); } - @Override public void onOpen(NotebookSocket conn) { LOG.info("New connection from {} : {}", conn.getRequest().getRemoteAddr(), @@ -82,7 +97,6 @@ public void onOpen(NotebookSocket conn) { connectedSockets.add(conn); } } - @Override public void onMessage(NotebookSocket conn, String msg) { Notebook notebook = notebook(); @@ -98,7 +112,7 @@ public void onMessage(NotebookSocket conn, String msg) { sendNote(conn, notebook, messagereceived); break; case NEW_NOTE: - createNote(conn, notebook, messagereceived); + createNote(conn, notebook); break; case DEL_NOTE: removeNote(conn, notebook, messagereceived); @@ -141,7 +155,6 @@ public void onMessage(NotebookSocket conn, String msg) { LOG.error("Can't handle message", e); } } - @Override public void onClose(NotebookSocket conn, int code, String reason) { LOG.info("Closed connection to {} : {}. ({}) {}", conn.getRequest() @@ -151,32 +164,26 @@ public void onClose(NotebookSocket conn, int code, String reason) { connectedSockets.remove(conn); } } - private Message deserializeMessage(String msg) { - Message m = gson.fromJson(msg, Message.class); - return m; + return gson.fromJson(msg, Message.class); } - private String serializeMessage(Message m) { return gson.toJson(m); } - private void addConnectionToNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { removeConnectionFromAllNote(socket); // make sure a socket relates only a - // single note. + // single note. List socketList = noteSocketMap.get(noteId); if (socketList == null) { - socketList = new LinkedList(); + socketList = new LinkedList<>(); noteSocketMap.put(noteId, socketList); } - - if (socketList.contains(socket) == false) { + if (!socketList.contains(socket)) { socketList.add(socket); } } } - private void removeConnectionFromNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { List socketList = noteSocketMap.get(noteId); @@ -185,13 +192,11 @@ private void removeConnectionFromNote(String noteId, NotebookSocket socket) { } } } - private void removeNote(String noteId) { synchronized (noteSocketMap) { List socketList = noteSocketMap.remove(noteId); } } - private void removeConnectionFromAllNote(NotebookSocket socket) { synchronized (noteSocketMap) { Set keys = noteSocketMap.keySet(); @@ -200,7 +205,6 @@ private void removeConnectionFromAllNote(NotebookSocket socket) { } } } - private String getOpenNoteId(NotebookSocket socket) { String id = null; synchronized (noteSocketMap) { @@ -214,7 +218,6 @@ private String getOpenNoteId(NotebookSocket socket) { } return id; } - private void broadcastToNoteBindedInterpreter(String interpreterGroupId, Message m) { Notebook notebook = notebook(); @@ -228,16 +231,13 @@ private void broadcastToNoteBindedInterpreter(String interpreterGroupId, } } } - private void broadcast(String noteId, Message m) { synchronized (noteSocketMap) { List socketLists = noteSocketMap.get(noteId); if (socketLists == null || socketLists.size() == 0) { return; } - LOG.info("SEND >> " + m.op); - for (NotebookSocket conn : socketLists) { try { conn.send(serializeMessage(m)); @@ -247,7 +247,6 @@ private void broadcast(String noteId, Message m) { } } } - private void broadcastAll(Message m) { synchronized (connectedSockets) { for (NotebookSocket conn : connectedSockets) { @@ -259,24 +258,21 @@ private void broadcastAll(Message m) { } } } - private void broadcastNote(Note note) { broadcast(note.id(), new Message(OP.NOTE).put("note", note)); } - private void broadcastNoteList() { Notebook notebook = notebook(); List notes = notebook.getAllNotes(); - List> notesInfo = new LinkedList>(); + List> notesInfo = new LinkedList<>(); for (Note note : notes) { - Map info = new HashMap(); + Map info = new HashMap<>(); info.put("id", note.id()); info.put("name", note.getName()); notesInfo.add(info); } broadcastAll(new Message(OP.NOTES_INFO).put("notes", notesInfo)); } - private void sendNote(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String noteId = (String) fromMessage.get("id"); @@ -284,14 +280,12 @@ private void sendNote(NotebookSocket conn, Notebook notebook, return; } Note note = notebook.getNote(noteId); - if (note != null) { addConnectionToNote(note.id(), conn); conn.send(serializeMessage(new Message(OP.NOTE).put("note", note))); sendAllAngularObjects(note, conn); } } - private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) throws SchedulerException, IOException { String noteId = (String) fromMessage.get("id"); @@ -309,17 +303,14 @@ private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) boolean cronUpdated = isCronUpdated(config, note.getConfig()); note.setName(name); note.setConfig(config); - if (cronUpdated) { notebook.refreshCron(note.id()); } note.persist(); - broadcastNote(note); broadcastNoteList(); } } - private boolean isCronUpdated(Map configA, Map configB) { boolean cronUpdated = false; @@ -333,20 +324,13 @@ private boolean isCronUpdated(Map configA, } return cronUpdated; } - - private void createNote(WebSocket conn, Notebook notebook, Message message) throws IOException { + private void createNote(WebSocket conn, Notebook notebook) throws IOException { Note note = notebook.createNote(); note.addParagraph(); // it's an empty note. so add one paragraph - if (message != null) { - String noteName = (String) message.get("name"); - if (noteName != null && !noteName.isEmpty()) - note.setName(noteName); - } note.persist(); broadcastNote(note); broadcastNoteList(); } - private void removeNote(WebSocket conn, Notebook notebook, Message fromMessage) throws IOException { String noteId = (String) fromMessage.get("id"); @@ -358,7 +342,6 @@ private void removeNote(WebSocket conn, Notebook notebook, Message fromMessage) removeNote(noteId); broadcastNoteList(); } - private void updateParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String paragraphId = (String) fromMessage.get("id"); @@ -378,7 +361,6 @@ private void updateParagraph(NotebookSocket conn, Notebook notebook, note.persist(); broadcast(note.id(), new Message(OP.PARAGRAPH).put("paragraph", p)); } - private void removeParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); @@ -393,31 +375,27 @@ private void removeParagraph(NotebookSocket conn, Notebook notebook, broadcastNote(note); } } - private void completion(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String paragraphId = (String) fromMessage.get("id"); String buffer = (String) fromMessage.get("buf"); int cursor = (int) Double.parseDouble(fromMessage.get("cursor").toString()); Message resp = new Message(OP.COMPLETION_LIST).put("id", paragraphId); - if (paragraphId == null) { conn.send(serializeMessage(resp)); return; } - final Note note = notebook.getNote(getOpenNoteId(conn)); List candidates = note.completion(paragraphId, buffer, cursor); resp.put("completions", candidates); conn.send(serializeMessage(resp)); } - /** * When angular object updated from client - * - * @param conn - * @param notebook - * @param fromMessage + * + * @param conn the web socket. + * @param notebook the notebook. + * @param fromMessage the message. */ private void angularObjectUpdated(WebSocket conn, Notebook notebook, Message fromMessage) { @@ -425,10 +403,8 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, String interpreterGroupId = (String) fromMessage.get("interpreterGroupId"); String varName = (String) fromMessage.get("name"); Object varValue = fromMessage.get("value"); - AngularObject ao = null; boolean global = false; - // propagate change to (Remote) AngularObjectRegistry Note note = notebook.getNote(noteId); if (note != null) { @@ -438,11 +414,9 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); - // first trying to get local registry ao = angularObjectRegistry.get(varName, noteId); if (ao == null) { @@ -460,14 +434,12 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, ao.set(varValue, false); global = false; } - break; } } } - if (global) { // broadcast change to all web session that uses related - // interpreter. + // interpreter. for (Note n : notebook.getAllNotes()) { List settings = note.getNoteReplLoader() .getInterpreterSettings(); @@ -475,7 +447,6 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); @@ -495,14 +466,12 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, .put("noteId", note.id())); } } - private void moveParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } - final int newIndex = (int) Double.parseDouble(fromMessage.get("index") .toString()); final Note note = notebook.getNote(getOpenNoteId(conn)); @@ -510,30 +479,25 @@ private void moveParagraph(NotebookSocket conn, Notebook notebook, note.persist(); broadcastNote(note); } - private void insertParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final int index = (int) Double.parseDouble(fromMessage.get("index") - .toString()); - + .toString()); final Note note = notebook.getNote(getOpenNoteId(conn)); note.insertParagraph(index); note.persist(); broadcastNote(note); } - private void cancelParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } - final Note note = notebook.getNote(getOpenNoteId(conn)); Paragraph p = note.getParagraph(paragraphId); p.abort(); } - private void runParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); @@ -546,12 +510,11 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, p.setText(text); p.setTitle((String) fromMessage.get("title")); Map params = (Map) fromMessage - .get("params"); + .get("params"); p.settings.setParams(params); Map config = (Map) fromMessage - .get("config"); + .get("config"); p.setConfig(config); - // if it's the last paragraph, let's add a new one boolean isTheLastParagraph = note.getLastParagraph().getId() .equals(p.getId()); @@ -560,7 +523,6 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, } note.persist(); broadcastNote(note); - try { note.run(paragraphId); } catch (Exception ex) { @@ -573,7 +535,6 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, } } } - /** * Need description here. * @@ -581,12 +542,10 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, public static class ParagraphJobListener implements JobListener { private NotebookServer notebookServer; private Note note; - public ParagraphJobListener(NotebookServer notebookServer, Note note) { this.notebookServer = notebookServer; this.note = note; } - @Override public void onProgressUpdate(Job job, int progress) { notebookServer.broadcast( @@ -594,11 +553,9 @@ public void onProgressUpdate(Job job, int progress) { new Message(OP.PROGRESS).put("id", job.getId()).put("progress", job.progress())); } - @Override public void beforeStatusChange(Job job, Status before, Status after) { } - @Override public void afterStatusChange(Job job, Status before, Status after) { if (after == Status.ERROR) { @@ -617,22 +574,18 @@ public void afterStatusChange(Job job, Status before, Status after) { notebookServer.broadcastNote(note); } } - @Override public JobListener getParagraphJobListener(Note note) { return new ParagraphJobListener(this, note); } - private void pong() { } - private void sendAllAngularObjects(Note note, NotebookSocket conn) throws IOException { List settings = note.getNoteReplLoader() .getInterpreterSettings(); if (settings == null || settings.size() == 0) { return; } - for (InterpreterSetting intpSetting : settings) { AngularObjectRegistry registry = intpSetting.getInterpreterGroup() .getAngularObjectRegistry(); @@ -646,31 +599,25 @@ private void sendAllAngularObjects(Note note, NotebookSocket conn) throws IOExce } } } - @Override public void onAdd(String interpreterGroupId, AngularObject object) { onUpdate(interpreterGroupId, object); } - @Override public void onUpdate(String interpreterGroupId, AngularObject object) { Notebook notebook = notebook(); if (notebook == null) { return; } - List notes = notebook.getAllNotes(); for (Note note : notes) { if (object.getNoteId() != null && !note.id().equals(object.getNoteId())) { continue; } - List intpSettings = note.getNoteReplLoader() .getInterpreterSettings(); - if (intpSettings.isEmpty()) continue; - for (InterpreterSetting setting : intpSettings) { if (setting.getInterpreterGroup().getId().equals(interpreterGroupId)) { broadcast( @@ -683,7 +630,6 @@ public void onUpdate(String interpreterGroupId, AngularObject object) { } } } - @Override public void onRemove(String interpreterGroupId, String name, String noteId) { Notebook notebook = notebook(); @@ -692,16 +638,19 @@ public void onRemove(String interpreterGroupId, String name, String noteId) { if (noteId != null && !note.id().equals(noteId)) { continue; } - List ids = note.getNoteReplLoader().getInterpreters(); for (String id : ids) { if (id.equals(interpreterGroupId)) { broadcast( note.id(), new Message(OP.ANGULAR_OBJECT_REMOVE).put("name", name).put( - "noteId", noteId)); + "noteId", noteId)); } } } } + private String getOrigin(NotebookSocket conn) { + return conn.getRequest().getHeader("Origin"); + } } + diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java new file mode 100644 index 00000000000..3ab06f0e2c5 --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java @@ -0,0 +1,53 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import org.apache.zeppelin.notebook.Note; +import org.apache.zeppelin.server.ZeppelinServer; +import org.junit.Assert; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.IOException; +import java.net.UnknownHostException; + +/** + * BASIC Zeppelin rest api tests + * + * + * @author joelz + * + */ + public class NotebookServerTests { + + @Test + public void CheckOrigin() throws UnknownHostException { + NotebookServer server = new NotebookServer(); + Assert.assertTrue(server.checkOrigin(new TestHttpServletRequest(), + "http://" + java.net.InetAddress.getLocalHost().getHostName() + ":8080")); + } + + @Test + public void CheckInvalidOrigin(){ + NotebookServer server = new NotebookServer(); + Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), "http://evillocalhost:8080")); + } +} diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java new file mode 100644 index 00000000000..9ec54baa95a --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java @@ -0,0 +1,372 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import javax.servlet.*; +import javax.servlet.http.*; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.security.Principal; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Locale; +import java.util.Map; + +/** + * Created by joelz on 8/6/15. + * Helps mocking a http servlet request + */ +public class TestHttpServletRequest implements HttpServletRequest { + @Override + public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException { + return false; + } + + @Override + public String getAuthType() { + return null; + } + + @Override + public String getContextPath() { + return null; + } + + @Override + public Cookie[] getCookies() { + return new Cookie[0]; + } + + @Override + public long getDateHeader(String s) { + return 0; + } + + @Override + public String getHeader(String s) { + switch (s) { + case "Origin": + return "http://localhost:8080"; + } + + return null; + } + + @Override + public Enumeration getHeaderNames() { + return null; + } + + @Override + public Enumeration getHeaders(String s) { + return null; + } + + @Override + public int getIntHeader(String s) { + return 0; + } + + @Override + public String getMethod() { + return null; + } + + @Override + public Part getPart(String s) throws IOException, ServletException { + return null; + } + + @Override + public Collection getParts() throws IOException, ServletException { + return null; + } + + @Override + public String getPathInfo() { + return null; + } + + @Override + public String getPathTranslated() { + return null; + } + + @Override + public String getQueryString() { + return null; + } + + @Override + public String getRemoteUser() { + return null; + } + + @Override + public String getRequestedSessionId() { + return null; + } + + @Override + public String getRequestURI() { + return null; + } + + @Override + public StringBuffer getRequestURL() { + return null; + } + + @Override + public String getServletPath() { + return null; + } + + @Override + public HttpSession getSession() { + return null; + } + + @Override + public HttpSession getSession(boolean b) { + return null; + } + + @Override + public Principal getUserPrincipal() { + return null; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromUrl() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return false; + } + + @Override + public boolean isRequestedSessionIdValid() { + return false; + } + + @Override + public boolean isUserInRole(String s) { + return false; + } + + @Override + public void login(String s, String s1) throws ServletException { + + } + + @Override + public void logout() throws ServletException { + + } + + @Override + public AsyncContext getAsyncContext() { + return null; + } + + @Override + public Object getAttribute(String s) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public int getContentLength() { + return 0; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public DispatcherType getDispatcherType() { + return null; + } + + @Override + public ServletInputStream getInputStream() throws IOException { + return null; + } + + @Override + public String getLocalAddr() { + return null; + } + + @Override + public Locale getLocale() { + return null; + } + + @Override + public Enumeration getLocales() { + return null; + } + + @Override + public String getLocalName() { + return null; + } + + @Override + public int getLocalPort() { + return 0; + } + + @Override + public String getParameter(String s) { + return null; + } + + @Override + public Map getParameterMap() { + return null; + } + + @Override + public Enumeration getParameterNames() { + return null; + } + + @Override + public String[] getParameterValues(String s) { + return new String[0]; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public BufferedReader getReader() throws IOException { + return null; + } + + @Override + public String getRealPath(String s) { + return null; + } + + @Override + public String getRemoteAddr() { + return null; + } + + @Override + public String getRemoteHost() { + return null; + } + + @Override + public int getRemotePort() { + return 0; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) { + return null; + } + + @Override + public String getScheme() { + return null; + } + + @Override + public String getServerName() { + return null; + } + + @Override + public int getServerPort() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return null; + } + + @Override + public boolean isAsyncStarted() { + return false; + } + + @Override + public boolean isAsyncSupported() { + return false; + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public void removeAttribute(String s) { + + } + + @Override + public void setAttribute(String s, Object o) { + + } + + @Override + public void setCharacterEncoding(String s) throws UnsupportedEncodingException { + + } + + @Override + public AsyncContext startAsync() { + return null; + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) { + return null; + } +} diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java new file mode 100644 index 00000000000..8d0637cdbce --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java @@ -0,0 +1,41 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +/** + * Created by joelz on 8/6/15. + * This enables mocking a socket listener. + */ +public class TestNotebookSocketListener implements NotebookSocketListener { + @Override + public void onClose(NotebookSocket socket, int code, String message) { + + } + + @Override + public void onOpen(NotebookSocket socket) { + + } + + @Override + public void onMessage(NotebookSocket socket, String message) { + + } +} From 013f22da77e44ee9285a6997d8c4976ca178be98 Mon Sep 17 00:00:00 2001 From: joelz Date: Thu, 13 Aug 2015 11:21:31 -0700 Subject: [PATCH 2/3] Fixing issue with ZEPPELIN-173: Zeppelin websocket server is vulnerable to Cross-Site WebSocket Hijacking --- .../zeppelin/socket/NotebookServer.java | 71 +++++++++++++++++-- 1 file changed, 66 insertions(+), 5 deletions(-) diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java index 2f11e8a5e53..8c8b60089a2 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java @@ -89,6 +89,7 @@ public boolean checkOrigin(HttpServletRequest request, String origin) { public WebSocket doWebSocketConnect(HttpServletRequest req, String protocol) { return new NotebookSocket(req, protocol, this); } + @Override public void onOpen(NotebookSocket conn) { LOG.info("New connection from {} : {}", conn.getRequest().getRemoteAddr(), @@ -97,6 +98,7 @@ public void onOpen(NotebookSocket conn) { connectedSockets.add(conn); } } + @Override public void onMessage(NotebookSocket conn, String msg) { Notebook notebook = notebook(); @@ -112,7 +114,7 @@ public void onMessage(NotebookSocket conn, String msg) { sendNote(conn, notebook, messagereceived); break; case NEW_NOTE: - createNote(conn, notebook); + createNote(conn, notebook, messagereceived); break; case DEL_NOTE: removeNote(conn, notebook, messagereceived); @@ -155,6 +157,7 @@ public void onMessage(NotebookSocket conn, String msg) { LOG.error("Can't handle message", e); } } + @Override public void onClose(NotebookSocket conn, int code, String reason) { LOG.info("Closed connection to {} : {}. ({}) {}", conn.getRequest() @@ -164,12 +167,15 @@ public void onClose(NotebookSocket conn, int code, String reason) { connectedSockets.remove(conn); } } + private Message deserializeMessage(String msg) { return gson.fromJson(msg, Message.class); } + private String serializeMessage(Message m) { return gson.toJson(m); } + private void addConnectionToNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { removeConnectionFromAllNote(socket); // make sure a socket relates only a @@ -184,6 +190,7 @@ private void addConnectionToNote(String noteId, NotebookSocket socket) { } } } + private void removeConnectionFromNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { List socketList = noteSocketMap.get(noteId); @@ -192,11 +199,13 @@ private void removeConnectionFromNote(String noteId, NotebookSocket socket) { } } } + private void removeNote(String noteId) { synchronized (noteSocketMap) { List socketList = noteSocketMap.remove(noteId); } } + private void removeConnectionFromAllNote(NotebookSocket socket) { synchronized (noteSocketMap) { Set keys = noteSocketMap.keySet(); @@ -205,6 +214,7 @@ private void removeConnectionFromAllNote(NotebookSocket socket) { } } } + private String getOpenNoteId(NotebookSocket socket) { String id = null; synchronized (noteSocketMap) { @@ -216,8 +226,10 @@ private String getOpenNoteId(NotebookSocket socket) { } } } + return id; } + private void broadcastToNoteBindedInterpreter(String interpreterGroupId, Message m) { Notebook notebook = notebook(); @@ -231,6 +243,7 @@ private void broadcastToNoteBindedInterpreter(String interpreterGroupId, } } } + private void broadcast(String noteId, Message m) { synchronized (noteSocketMap) { List socketLists = noteSocketMap.get(noteId); @@ -247,6 +260,7 @@ private void broadcast(String noteId, Message m) { } } } + private void broadcastAll(Message m) { synchronized (connectedSockets) { for (NotebookSocket conn : connectedSockets) { @@ -258,9 +272,11 @@ private void broadcastAll(Message m) { } } } + private void broadcastNote(Note note) { broadcast(note.id(), new Message(OP.NOTE).put("note", note)); } + private void broadcastNoteList() { Notebook notebook = notebook(); List notes = notebook.getAllNotes(); @@ -271,14 +287,17 @@ private void broadcastNoteList() { info.put("name", note.getName()); notesInfo.add(info); } + broadcastAll(new Message(OP.NOTES_INFO).put("notes", notesInfo)); } + private void sendNote(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String noteId = (String) fromMessage.get("id"); if (noteId == null) { return; } + Note note = notebook.getNote(noteId); if (note != null) { addConnectionToNote(note.id(), conn); @@ -286,6 +305,7 @@ private void sendNote(NotebookSocket conn, Notebook notebook, sendAllAngularObjects(note, conn); } } + private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) throws SchedulerException, IOException { String noteId = (String) fromMessage.get("id"); @@ -298,6 +318,7 @@ private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) if (config == null) { return; } + Note note = notebook.getNote(noteId); if (note != null) { boolean cronUpdated = isCronUpdated(config, note.getConfig()); @@ -306,11 +327,13 @@ private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) if (cronUpdated) { notebook.refreshCron(note.id()); } + note.persist(); broadcastNote(note); broadcastNoteList(); } } + private boolean isCronUpdated(Map configA, Map configB) { boolean cronUpdated = false; @@ -322,32 +345,44 @@ private boolean isCronUpdated(Map configA, } else if (configA.get("cron") != null || configB.get("cron") != null) { cronUpdated = true; } + return cronUpdated; } - private void createNote(WebSocket conn, Notebook notebook) throws IOException { + private void createNote(WebSocket conn, Notebook notebook, Message message) throws IOException { Note note = notebook.createNote(); note.addParagraph(); // it's an empty note. so add one paragraph + if (message != null) { + String noteName = (String) message.get("name"); + if (noteName != null && !noteName.isEmpty()){ + note.setName(noteName); + } + } + note.persist(); broadcastNote(note); broadcastNoteList(); } + private void removeNote(WebSocket conn, Notebook notebook, Message fromMessage) throws IOException { String noteId = (String) fromMessage.get("id"); if (noteId == null) { return; } + Note note = notebook.getNote(noteId); notebook.removeNote(noteId); removeNote(noteId); broadcastNoteList(); } + private void updateParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } + Map params = (Map) fromMessage .get("params"); Map config = (Map) fromMessage @@ -361,12 +396,14 @@ private void updateParagraph(NotebookSocket conn, Notebook notebook, note.persist(); broadcast(note.id(), new Message(OP.PARAGRAPH).put("paragraph", p)); } + private void removeParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); /** We dont want to remove the last paragraph */ if (!note.isLastParagraph(paragraphId)) { @@ -375,6 +412,7 @@ private void removeParagraph(NotebookSocket conn, Notebook notebook, broadcastNote(note); } } + private void completion(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { String paragraphId = (String) fromMessage.get("id"); @@ -385,11 +423,13 @@ private void completion(NotebookSocket conn, Notebook notebook, conn.send(serializeMessage(resp)); return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); List candidates = note.completion(paragraphId, buffer, cursor); resp.put("completions", candidates); conn.send(serializeMessage(resp)); } + /** * When angular object updated from client * @@ -438,6 +478,7 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, } } } + if (global) { // broadcast change to all web session that uses related // interpreter. for (Note n : notebook.getAllNotes()) { @@ -466,12 +507,14 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, .put("noteId", note.id())); } } + private void moveParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } + final int newIndex = (int) Double.parseDouble(fromMessage.get("index") .toString()); final Note note = notebook.getNote(getOpenNoteId(conn)); @@ -479,6 +522,7 @@ private void moveParagraph(NotebookSocket conn, Notebook notebook, note.persist(); broadcastNote(note); } + private void insertParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final int index = (int) Double.parseDouble(fromMessage.get("index") @@ -488,22 +532,26 @@ private void insertParagraph(NotebookSocket conn, Notebook notebook, note.persist(); broadcastNote(note); } + private void cancelParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); Paragraph p = note.getParagraph(paragraphId); p.abort(); } + private void runParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final String paragraphId = (String) fromMessage.get("id"); if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); Paragraph p = note.getParagraph(paragraphId); String text = (String) fromMessage.get("paragraph"); @@ -521,6 +569,7 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, if (!Strings.isNullOrEmpty(text) && isTheLastParagraph) { note.addParagraph(); } + note.persist(); broadcastNote(note); try { @@ -535,6 +584,7 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, } } } + /** * Need description here. * @@ -546,6 +596,7 @@ public ParagraphJobListener(NotebookServer notebookServer, Note note) { this.notebookServer = notebookServer; this.note = note; } + @Override public void onProgressUpdate(Job job, int progress) { notebookServer.broadcast( @@ -553,9 +604,11 @@ public void onProgressUpdate(Job job, int progress) { new Message(OP.PROGRESS).put("id", job.getId()).put("progress", job.progress())); } + @Override public void beforeStatusChange(Job job, Status before, Status after) { } + @Override public void afterStatusChange(Job job, Status before, Status after) { if (after == Status.ERROR) { @@ -563,6 +616,7 @@ public void afterStatusChange(Job job, Status before, Status after) { LOG.error("Error", job.getException()); } } + if (job.isTerminated()) { LOG.info("Job {} is finished", job.getId()); try { @@ -571,21 +625,25 @@ public void afterStatusChange(Job job, Status before, Status after) { e.printStackTrace(); } } + notebookServer.broadcastNote(note); } } + @Override public JobListener getParagraphJobListener(Note note) { return new ParagraphJobListener(this, note); } private void pong() { } + private void sendAllAngularObjects(Note note, NotebookSocket conn) throws IOException { List settings = note.getNoteReplLoader() .getInterpreterSettings(); if (settings == null || settings.size() == 0) { return; } + for (InterpreterSetting intpSetting : settings) { AngularObjectRegistry registry = intpSetting.getInterpreterGroup() .getAngularObjectRegistry(); @@ -599,21 +657,25 @@ private void sendAllAngularObjects(Note note, NotebookSocket conn) throws IOExce } } } + @Override public void onAdd(String interpreterGroupId, AngularObject object) { onUpdate(interpreterGroupId, object); } + @Override public void onUpdate(String interpreterGroupId, AngularObject object) { Notebook notebook = notebook(); if (notebook == null) { return; } + List notes = notebook.getAllNotes(); for (Note note : notes) { if (object.getNoteId() != null && !note.id().equals(object.getNoteId())) { continue; } + List intpSettings = note.getNoteReplLoader() .getInterpreterSettings(); if (intpSettings.isEmpty()) @@ -630,6 +692,7 @@ public void onUpdate(String interpreterGroupId, AngularObject object) { } } } + @Override public void onRemove(String interpreterGroupId, String name, String noteId) { Notebook notebook = notebook(); @@ -638,6 +701,7 @@ public void onRemove(String interpreterGroupId, String name, String noteId) { if (noteId != null && !note.id().equals(noteId)) { continue; } + List ids = note.getNoteReplLoader().getInterpreters(); for (String id : ids) { if (id.equals(interpreterGroupId)) { @@ -649,8 +713,5 @@ public void onRemove(String interpreterGroupId, String name, String noteId) { } } } - private String getOrigin(NotebookSocket conn) { - return conn.getRequest().getHeader("Origin"); - } } From 08ff36956f9c641949bfe2ee9c982a536863555e Mon Sep 17 00:00:00 2001 From: djoelz Date: Thu, 13 Aug 2015 11:31:15 -0700 Subject: [PATCH 3/3] unecessary file --- .../socket/TestNotebookSocketListener.java | 41 ------------------- 1 file changed, 41 deletions(-) delete mode 100644 zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java deleted file mode 100644 index 8d0637cdbce..00000000000 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestNotebookSocketListener.java +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Created by joelz on 8/6/15. - * - * - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.zeppelin.socket; - -/** - * Created by joelz on 8/6/15. - * This enables mocking a socket listener. - */ -public class TestNotebookSocketListener implements NotebookSocketListener { - @Override - public void onClose(NotebookSocket socket, int code, String message) { - - } - - @Override - public void onOpen(NotebookSocket socket) { - - } - - @Override - public void onMessage(NotebookSocket socket, String message) { - - } -}