Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/common/compute/interactive/message.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@
}

class Message {
constructor(type, data) {
constructor(sessionID, type, data) {
this.sessionID = sessionID;
this.type = type;
this.data = data;
}

static decode(serialized) {
const {type, data} = JSON.parse(serialized);
return new Message(type, data);
const {sessionID, type, data} = JSON.parse(serialized);
return new Message(sessionID, type, data);
}

encode() {
return Message.encode(this.type, this.data);
return Message.encode(this.sessionID, this.type, this.data);
}

static encode(type, data=0) {
static encode(sessionID, type, data=0) {
if (typeof Buffer !== 'undefined' && data instanceof Buffer) {
data = data.toString();
}
return JSON.stringify({type, data});
return JSON.stringify({sessionID, type, data});
}
}
Object.assign(Message, Constants);
Expand Down
145 changes: 89 additions & 56 deletions src/common/compute/interactive/session.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,14 @@ define([
const {CommandFailedError} = Errors;
const isNodeJs = typeof window === 'undefined';
const WebSocket = isNodeJs ? require('ws') : window.WebSocket;
let numSessions = 1;

class InteractiveSession {
constructor(computeID, config={}) {
constructor(channel) {
this.currentTask = null;
const address = gmeConfig.extensions.InteractiveComputeHost ||
this.getDefaultServerURL();
this.ws = new WebSocket(address);
this.connected = defer();
this.ws.onopen = () => {
this.ws.send(JSON.stringify([computeID, config, this.getGMEToken()]));
this.ws.onmessage = async (wsMsg) => {
const data = await Task.getMessageData(wsMsg);

const msg = Message.decode(data);
if (msg.type === Message.COMPLETE) {
const err = msg.data;
this.channel = new MessageChannel(this.ws);
if (err) {
this.connected.reject(err);
} else {
this.connected.resolve();
this.checkReady();
}
}
};
};

this.ready = null;
}

getDefaultServerURL() {
const isSecure = !isNodeJs && location.protocol.includes('s');
const protocol = isSecure ? 'wss' : 'ws';
const defaultHost = isNodeJs ? '127.0.0.1' :
location.origin
.replace(location.protocol + '//', '')
.replace(/:[0-9]+$/, '');
return `${protocol}://${defaultHost}:${gmeConfig.server.port + 1}`;
}

getGMEToken() {
if (isNodeJs) {
return '';
}

const [, token] = (document.cookie || '').split('=');
return token;
this.id = numSessions++;
this.channel = channel;
this.channel.onClientConnect(this.id);
}

checkReady() {
Expand All @@ -72,7 +33,7 @@ define([
}

isIdle() {
return !this.currentTask && this.ws.readyState === WebSocket.OPEN;
return !this.currentTask && this.channel.isOpen();
}

ensureIdle(action) {
Expand All @@ -84,7 +45,7 @@ define([
spawn(cmd) {
this.ensureIdle('spawn a task');

const msg = new Message(Message.RUN, cmd);
const msg = new Message(this.id, Message.RUN, cmd);
const task = new Task(this.channel, msg);
this.runTask(task);
return task;
Expand All @@ -111,7 +72,7 @@ define([

async exec(cmd) {
this.ensureIdle('exec a task');
const msg = new Message(Message.RUN, cmd);
const msg = new Message(this.id, Message.RUN, cmd);
const task = new Task(this.channel, msg);
const result = {
stdout: '',
Expand All @@ -131,14 +92,14 @@ define([
async addArtifact(name, dataInfo, type, auth) {
auth = auth || {};
this.ensureIdle('add artifact');
const msg = new Message(Message.ADD_ARTIFACT, [name, dataInfo, type, auth]);
const msg = new Message(this.id, Message.ADD_ARTIFACT, [name, dataInfo, type, auth]);
const task = new Task(this.channel, msg);
await this.runTask(task);
}

async saveArtifact(/*path, name, storageId, config*/) {
this.ensureIdle('save artifact');
const msg = new Message(Message.SAVE_ARTIFACT, [...arguments]);
const msg = new Message(this.id, Message.SAVE_ARTIFACT, [...arguments]);
const task = new Task(this.channel, msg);
const [exitCode, dataInfo] = await this.runTask(task);
if (exitCode) {
Expand All @@ -149,21 +110,21 @@ define([

async addFile(filepath, content) {
this.ensureIdle('add file');
const msg = new Message(Message.ADD_FILE, [filepath, content]);
const msg = new Message(this.id, Message.ADD_FILE, [filepath, content]);
const task = new Task(this.channel, msg);
await this.runTask(task);
}

async removeFile(filepath) {
this.ensureIdle('remove file');
const msg = new Message(Message.REMOVE_FILE, [filepath]);
const msg = new Message(this.id, Message.REMOVE_FILE, [filepath]);
const task = new Task(this.channel, msg);
await this.runTask(task);
}

async setEnvVar(name, value) {
this.ensureIdle('set env var');
const msg = new Message(Message.SET_ENV, [name, value]);
const msg = new Message(this.id, Message.SET_ENV, [name, value]);
const task = new Task(this.channel, msg);
await this.runTask(task);
}
Expand All @@ -174,24 +135,75 @@ define([
'Cannot kill task. Must be a RUN task.'
);
if (task === this.currentTask) {
const msg = new Message(Message.KILL, task.msg.data);
const msg = new Message(this.id, Message.KILL, task.msg.data);
const killTask = new Task(this.channel, msg);
await killTask.run();
this.checkReady();
}
}

close() {
this.ws.close();
this.channel.onClientExit(this.id);
}

fork() {
const Session = this.constructor;
return new Session(this.channel);
}

static async new(computeID, config={}, SessionClass=InteractiveSession) {
const session = new SessionClass(computeID, config);
await session.whenConnected();
const channel = await createMessageChannel(computeID, config);
const session = new SessionClass(channel);
return session;
}
}

async function createMessageChannel(computeID, config) {
const address = gmeConfig.extensions.InteractiveComputeHost ||
getDefaultServerURL();

const connectedWs = await new Promise((resolve, reject) => {
const ws = new WebSocket(address);
ws.onopen = () => {
ws.send(JSON.stringify([computeID, config, getGMEToken()]));
ws.onmessage = async (wsMsg) => {
const data = await Task.getMessageData(wsMsg);

const msg = Message.decode(data);
if (msg.type === Message.COMPLETE) {
const err = msg.data;
if (err) {
reject(err);
} else {
resolve(ws);
}
}
};
};
});

return new MessageChannel(connectedWs);
}

function getDefaultServerURL() {
const isSecure = !isNodeJs && location.protocol.includes('s');
const protocol = isSecure ? 'wss' : 'ws';
const defaultHost = isNodeJs ? '127.0.0.1' :
location.origin
.replace(location.protocol + '//', '')
.replace(/:[0-9]+$/, '');
return `${protocol}://${defaultHost}:${gmeConfig.server.port + 1}`;
}

function getGMEToken() {
if (isNodeJs) {
return '';
}

const [, token] = (document.cookie || '').split('=');
return token;
}

function assert(cond, msg) {
if (!cond) {
throw new Error(msg);
Expand All @@ -208,6 +220,7 @@ define([
this.ws.onmessage = message => {
this.listeners.forEach(fn => fn(message));
};
this.clients = [];
}

send(data) {
Expand All @@ -224,6 +237,26 @@ define([
this.listeners.splice(index, 1);
}
}

isOpen() {
return this.ws.readyState === WebSocket.OPEN;
}

onClientConnect(id) {
this.clients.push(id);
}

onClientExit(id) {
const index = this.clients.indexOf(id);
if (index === -1) {
throw new Error(`Client not found: ${id}`);
}
this.clients.splice(index, 1);

if (this.clients.length === 0) {
this.ws.close();
}
}
}

return InteractiveSession;
Expand Down
2 changes: 1 addition & 1 deletion src/routers/InteractiveCompute/InteractiveCompute.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ComputeBroker {
const session = new InteractiveSession(blobClient, client, ws);
this.initSessions.push(session);
} catch (err) {
ws.send(Message.encode(Message.COMPLETE, err.message));
ws.send(Message.encode(-1, Message.COMPLETE, err.message));
this.logger.warn(`Error creating session: ${err}`);
ws.close();
}
Expand Down
2 changes: 1 addition & 1 deletion src/routers/InteractiveCompute/Session.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Session extends EventEmitter {
this.workerSocket = socket;
this.emit('connected');

this.clientSocket.send(Message.encode(Message.COMPLETE));
this.clientSocket.send(Message.encode(-1, Message.COMPLETE));
this.queuedMsgs.forEach(msg => this.workerSocket.send(msg));
this.wsChannel = new Channel(this.clientSocket, this.workerSocket);
this.wsChannel.on(Channel.CLOSE, () => this.close());
Expand Down
Loading