Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.zeppelin.socket;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.servlet.http.HttpServletRequest;

import org.apache.zeppelin.display.AngularObject;
import org.apache.zeppelin.display.AngularObjectRegistry;
import org.apache.zeppelin.display.AngularObjectRegistryListener;
Expand All @@ -46,28 +44,46 @@
import org.quartz.SchedulerException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Strings;
import com.google.gson.Gson;

/**
* Zeppelin websocket service.
*
* @author anthonycorbacho
*/
public class NotebookServer extends WebSocketServlet implements
NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener {

NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener {
private static final Logger LOG = LoggerFactory
.getLogger(NotebookServer.class);

.getLogger(NotebookServer.class);
Gson gson = new Gson();
Map<String, List<NotebookSocket>> noteSocketMap = new HashMap<String, List<NotebookSocket>>();
List<NotebookSocket> connectedSockets = new LinkedList<NotebookSocket>();
final Map<String, List<NotebookSocket>> noteSocketMap = new HashMap<>();
final List<NotebookSocket> connectedSockets = new LinkedList<>();

private Notebook notebook() {
return ZeppelinServer.notebook;
}
@Override
public boolean checkOrigin(HttpServletRequest request, String origin) {
URI sourceUri = null;
String currentHost = null;

try {
sourceUri = new URI(origin);
currentHost = java.net.InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
e.printStackTrace();
}
catch (URISyntaxException e) {
e.printStackTrace();
}

String sourceHost = sourceUri.getHost();
if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) {
return true;
}

return false;
}

@Override
public WebSocket doWebSocketConnect(HttpServletRequest req, String protocol) {
Expand Down Expand Up @@ -153,8 +169,7 @@ public void onClose(NotebookSocket conn, int code, String reason) {
}

private Message deserializeMessage(String msg) {
Message m = gson.fromJson(msg, Message.class);
return m;
return gson.fromJson(msg, Message.class);
}

private String serializeMessage(Message m) {
Expand All @@ -164,14 +179,13 @@ private String serializeMessage(Message m) {
private void addConnectionToNote(String noteId, NotebookSocket socket) {
synchronized (noteSocketMap) {
removeConnectionFromAllNote(socket); // make sure a socket relates only a
// single note.
// single note.
List<NotebookSocket> socketList = noteSocketMap.get(noteId);
if (socketList == null) {
socketList = new LinkedList<NotebookSocket>();
socketList = new LinkedList<>();
noteSocketMap.put(noteId, socketList);
}

if (socketList.contains(socket) == false) {
if (!socketList.contains(socket)) {
socketList.add(socket);
}
}
Expand Down Expand Up @@ -212,6 +226,7 @@ private String getOpenNoteId(NotebookSocket socket) {
}
}
}

return id;
}

Expand All @@ -235,9 +250,7 @@ private void broadcast(String noteId, Message m) {
if (socketLists == null || socketLists.size() == 0) {
return;
}

LOG.info("SEND >> " + m.op);

for (NotebookSocket conn : socketLists) {
try {
conn.send(serializeMessage(m));
Expand Down Expand Up @@ -267,13 +280,14 @@ private void broadcastNote(Note note) {
private void broadcastNoteList() {
Notebook notebook = notebook();
List<Note> notes = notebook.getAllNotes();
List<Map<String, String>> notesInfo = new LinkedList<Map<String, String>>();
List<Map<String, String>> notesInfo = new LinkedList<>();
for (Note note : notes) {
Map<String, String> info = new HashMap<String, String>();
Map<String, String> info = new HashMap<>();
info.put("id", note.id());
info.put("name", note.getName());
notesInfo.add(info);
}

broadcastAll(new Message(OP.NOTES_INFO).put("notes", notesInfo));
}

Expand All @@ -283,8 +297,8 @@ private void sendNote(NotebookSocket conn, Notebook notebook,
if (noteId == null) {
return;
}
Note note = notebook.getNote(noteId);

Note note = notebook.getNote(noteId);
if (note != null) {
addConnectionToNote(note.id(), conn);
conn.send(serializeMessage(new Message(OP.NOTE).put("note", note)));
Expand All @@ -304,17 +318,17 @@ private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage)
if (config == null) {
return;
}

Note note = notebook.getNote(noteId);
if (note != null) {
boolean cronUpdated = isCronUpdated(config, note.getConfig());
note.setName(name);
note.setConfig(config);

if (cronUpdated) {
notebook.refreshCron(note.id());
}
note.persist();

note.persist();
broadcastNote(note);
broadcastNoteList();
}
Expand All @@ -331,17 +345,19 @@ private boolean isCronUpdated(Map<String, Object> configA,
} else if (configA.get("cron") != null || configB.get("cron") != null) {
cronUpdated = true;
}

return cronUpdated;
}

private void createNote(WebSocket conn, Notebook notebook, Message message) throws IOException {
Note note = notebook.createNote();
note.addParagraph(); // it's an empty note. so add one paragraph
if (message != null) {
String noteName = (String) message.get("name");
if (noteName != null && !noteName.isEmpty())
if (noteName != null && !noteName.isEmpty()){
note.setName(noteName);
}
}

note.persist();
broadcastNote(note);
broadcastNoteList();
Expand All @@ -353,6 +369,7 @@ private void removeNote(WebSocket conn, Notebook notebook, Message fromMessage)
if (noteId == null) {
return;
}

Note note = notebook.getNote(noteId);
notebook.removeNote(noteId);
removeNote(noteId);
Expand All @@ -365,6 +382,7 @@ private void updateParagraph(NotebookSocket conn, Notebook notebook,
if (paragraphId == null) {
return;
}

Map<String, Object> params = (Map<String, Object>) fromMessage
.get("params");
Map<String, Object> config = (Map<String, Object>) fromMessage
Expand All @@ -385,6 +403,7 @@ private void removeParagraph(NotebookSocket conn, Notebook notebook,
if (paragraphId == null) {
return;
}

final Note note = notebook.getNote(getOpenNoteId(conn));
/** We dont want to remove the last paragraph */
if (!note.isLastParagraph(paragraphId)) {
Expand All @@ -400,7 +419,6 @@ private void completion(NotebookSocket conn, Notebook notebook,
String buffer = (String) fromMessage.get("buf");
int cursor = (int) Double.parseDouble(fromMessage.get("cursor").toString());
Message resp = new Message(OP.COMPLETION_LIST).put("id", paragraphId);

if (paragraphId == null) {
conn.send(serializeMessage(resp));
return;
Expand All @@ -414,21 +432,19 @@ private void completion(NotebookSocket conn, Notebook notebook,

/**
* When angular object updated from client
*
* @param conn
* @param notebook
* @param fromMessage
*
* @param conn the web socket.
* @param notebook the notebook.
* @param fromMessage the message.
*/
private void angularObjectUpdated(WebSocket conn, Notebook notebook,
Message fromMessage) {
String noteId = (String) fromMessage.get("noteId");
String interpreterGroupId = (String) fromMessage.get("interpreterGroupId");
String varName = (String) fromMessage.get("name");
Object varValue = fromMessage.get("value");

AngularObject ao = null;
boolean global = false;

// propagate change to (Remote) AngularObjectRegistry
Note note = notebook.getNote(noteId);
if (note != null) {
Expand All @@ -438,11 +454,9 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook,
if (setting.getInterpreterGroup() == null) {
continue;
}

if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) {
AngularObjectRegistry angularObjectRegistry = setting
.getInterpreterGroup().getAngularObjectRegistry();

// first trying to get local registry
ao = angularObjectRegistry.get(varName, noteId);
if (ao == null) {
Expand All @@ -460,22 +474,20 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook,
ao.set(varValue, false);
global = false;
}

break;
}
}
}

if (global) { // broadcast change to all web session that uses related
// interpreter.
// interpreter.
for (Note n : notebook.getAllNotes()) {
List<InterpreterSetting> settings = note.getNoteReplLoader()
.getInterpreterSettings();
for (InterpreterSetting setting : settings) {
if (setting.getInterpreterGroup() == null) {
continue;
}

if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) {
AngularObjectRegistry angularObjectRegistry = setting
.getInterpreterGroup().getAngularObjectRegistry();
Expand Down Expand Up @@ -514,8 +526,7 @@ private void moveParagraph(NotebookSocket conn, Notebook notebook,
private void insertParagraph(NotebookSocket conn, Notebook notebook,
Message fromMessage) throws IOException {
final int index = (int) Double.parseDouble(fromMessage.get("index")
.toString());

.toString());
final Note note = notebook.getNote(getOpenNoteId(conn));
note.insertParagraph(index);
note.persist();
Expand All @@ -540,27 +551,27 @@ private void runParagraph(NotebookSocket conn, Notebook notebook,
if (paragraphId == null) {
return;
}

final Note note = notebook.getNote(getOpenNoteId(conn));
Paragraph p = note.getParagraph(paragraphId);
String text = (String) fromMessage.get("paragraph");
p.setText(text);
p.setTitle((String) fromMessage.get("title"));
Map<String, Object> params = (Map<String, Object>) fromMessage
.get("params");
.get("params");
p.settings.setParams(params);
Map<String, Object> config = (Map<String, Object>) fromMessage
.get("config");
.get("config");
p.setConfig(config);

// if it's the last paragraph, let's add a new one
boolean isTheLastParagraph = note.getLastParagraph().getId()
.equals(p.getId());
if (!Strings.isNullOrEmpty(text) && isTheLastParagraph) {
note.addParagraph();
}

note.persist();
broadcastNote(note);

try {
note.run(paragraphId);
} catch (Exception ex) {
Expand All @@ -581,7 +592,6 @@ private void runParagraph(NotebookSocket conn, Notebook notebook,
public static class ParagraphJobListener implements JobListener {
private NotebookServer notebookServer;
private Note note;

public ParagraphJobListener(NotebookServer notebookServer, Note note) {
this.notebookServer = notebookServer;
this.note = note;
Expand All @@ -606,6 +616,7 @@ public void afterStatusChange(Job job, Status before, Status after) {
LOG.error("Error", job.getException());
}
}

if (job.isTerminated()) {
LOG.info("Job {} is finished", job.getId());
try {
Expand All @@ -614,6 +625,7 @@ public void afterStatusChange(Job job, Status before, Status after) {
e.printStackTrace();
}
}

notebookServer.broadcastNote(note);
}
}
Expand All @@ -622,7 +634,6 @@ public void afterStatusChange(Job job, Status before, Status after) {
public JobListener getParagraphJobListener(Note note) {
return new ParagraphJobListener(this, note);
}

private void pong() {
}

Expand Down Expand Up @@ -667,10 +678,8 @@ public void onUpdate(String interpreterGroupId, AngularObject object) {

List<InterpreterSetting> intpSettings = note.getNoteReplLoader()
.getInterpreterSettings();

if (intpSettings.isEmpty())
continue;

for (InterpreterSetting setting : intpSettings) {
if (setting.getInterpreterGroup().getId().equals(interpreterGroupId)) {
broadcast(
Expand Down Expand Up @@ -699,9 +708,10 @@ public void onRemove(String interpreterGroupId, String name, String noteId) {
broadcast(
note.id(),
new Message(OP.ANGULAR_OBJECT_REMOVE).put("name", name).put(
"noteId", noteId));
"noteId", noteId));
}
}
}
}
}

Loading