Skip to content

Commit

Permalink
discojs: fail on fetch error
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik authored and JulienVig committed Dec 3, 2024
1 parent 0be927e commit 75a67ec
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 31 deletions.
2 changes: 2 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
url.pathname += `tasks/${this.task.id}/model.json`

const response = await fetch(url);
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);

const encoded = new Uint8Array(await response.arrayBuffer())
return await serialization.model.decode(encoded)
}
Expand Down
4 changes: 3 additions & 1 deletion discojs/src/task/task_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ export async function pushTask<D extends DataType>(
task: Task<D>,
model: Model<D>,
): Promise<void> {
await fetch(urlToTasks(base), {
const response = await fetch(urlToTasks(base), {
method: "POST",
body: JSON.stringify({
task,
model: await serialization.model.encode(model),
weights: await serialization.weights.encode(model.weights),
}),
});
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
}

export async function fetchTasks(
base: URL,
): Promise<Map<TaskID, Task<DataType>>> {
const response = await fetch(urlToTasks(base));
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
const tasks: unknown = await response.json();

if (!Array.isArray(tasks)) {
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/components/testing/__tests__/Testing.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ it("shows stored models", async () => {
it("allows to download server's models", async () => {
vi.stubGlobal("fetch", async (url: string | URL) => {
if (url.toString() === "http://localhost:8080/tasks")
return { json: () => Promise.resolve([TASK]) };
return new Response(JSON.stringify([TASK]));
throw new Error(`unhandled get: ${url}`);
});
afterEach(() => {
Expand Down
34 changes: 5 additions & 29 deletions webapp/src/components/training/__tests__/Trainer.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,15 @@ import { loadCSV } from "@epfml/discojs-web";
import Trainer from "../Trainer.vue";
import TrainingInformation from "../TrainingInformation.vue";

vi.mock("axios", async () => {
async function get(url: string) {
if (url === "http://localhost:8080/tasks/titanic/model.json") {
return {
data: await serialization.model.encode(
await defaultTasks.titanic.getModel(),
),
};
}
throw new Error("unhandled get");
}

const axios = await vi.importActual<typeof import("axios")>("axios");
return {
...axios,
default: {
...axios.default,
get,
},
};
});

async function setupForTask() {
const provider = defaultTasks.titanic;

vi.stubGlobal("fetch", async (url: string | URL) => {
if (url.toString() === "http://localhost:8080/tasks/titanic/model.json")
return {
arrayBuffer: async () => {
const model = await provider.getModel();
return await serialization.model.encode(model);
},
};
if (url.toString() === "http://localhost:8080/tasks/titanic/model.json") {
const model = await provider.getModel();
const encoded = await serialization.model.encode(model);
return new Response(encoded);
}
throw new Error(`unhandled get: ${url}`);
});
afterEach(() => {
Expand Down

0 comments on commit 75a67ec

Please sign in to comment.