Skip to content

Commit 48eeea4

Browse files
Refactor Cancelling Logic To Use /cancel (#8370)
* Cancel refactor * add changeset * add changeset * types * Add code * Fix types --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent 96d8de2 commit 48eeea4

File tree

12 files changed

+93
-85
lines changed

12 files changed

+93
-85
lines changed

.changeset/deep-weeks-show.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"@gradio/app": patch
3+
"@gradio/client": patch
4+
"gradio": patch
5+
---
6+
7+
feat:Refactor Cancelling Logic To Use /cancel

client/js/src/helpers/api_info.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ export function transform_api_info(
106106
dependencyIndex !== -1
107107
? config.dependencies.find((dep) => dep.id == dependencyIndex)
108108
?.types
109-
: { continuous: false, generator: false };
109+
: { continuous: false, generator: false, cancel: false };
110110

111111
if (
112112
dependencyIndex !== -1 &&

client/js/src/test/test_data.ts

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
4646
component: "Textbox"
4747
}
4848
],
49-
type: { continuous: false, generator: false }
49+
type: { continuous: false, generator: false, cancel: false }
5050
}
5151
},
5252
unnamed_endpoints: {
@@ -68,7 +68,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
6868
component: "Textbox"
6969
}
7070
],
71-
type: { continuous: false, generator: false }
71+
type: { continuous: false, generator: false, cancel: false }
7272
}
7373
}
7474
};
@@ -395,7 +395,8 @@ export const config_response: Config = {
395395
cancels: [],
396396
types: {
397397
continuous: false,
398-
generator: false
398+
generator: false,
399+
cancel: false
399400
},
400401
collects_event_data: false,
401402
trigger_after: null,
@@ -421,7 +422,8 @@ export const config_response: Config = {
421422
cancels: [],
422423
types: {
423424
continuous: false,
424-
generator: false
425+
generator: false,
426+
cancel: false
425427
},
426428
collects_event_data: false,
427429
trigger_after: null,
@@ -447,7 +449,8 @@ export const config_response: Config = {
447449
cancels: [],
448450
types: {
449451
continuous: false,
450-
generator: false
452+
generator: false,
453+
cancel: false
451454
},
452455
collects_event_data: false,
453456
trigger_after: null,

client/js/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ export interface Dependency {
235235
export interface DependencyTypes {
236236
continuous: boolean;
237237
generator: boolean;
238+
cancel: boolean;
238239
}
239240

240241
export interface Payload {

client/js/src/utils/submit.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ export function submit(
122122
fn_index: fn_index
123123
});
124124

125+
let reset_request = {};
125126
let cancel_request = {};
126127
if (protocol === "ws") {
127128
if (websocket && websocket.readyState === 0) {
@@ -131,21 +132,30 @@ export function submit(
131132
} else {
132133
websocket.close();
133134
}
134-
cancel_request = { fn_index, session_hash };
135+
reset_request = { fn_index, session_hash };
135136
} else {
136137
stream?.close();
137-
cancel_request = { event_id };
138+
reset_request = { event_id };
139+
cancel_request = { event_id, session_hash, fn_index };
138140
}
139141

140142
try {
141143
if (!config) {
142144
throw new Error("Could not resolve app config");
143145
}
144146

147+
if ("event_id" in cancel_request) {
148+
await fetch(`${config.root}/cancel`, {
149+
headers: { "Content-Type": "application/json" },
150+
method: "POST",
151+
body: JSON.stringify(cancel_request)
152+
});
153+
}
154+
145155
await fetch(`${config.root}/reset`, {
146156
headers: { "Content-Type": "application/json" },
147157
method: "POST",
148-
body: JSON.stringify(cancel_request)
158+
body: JSON.stringify(reset_request)
149159
});
150160
} catch (e) {
151161
console.warn(

gradio/blocks.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
TupleNoPrint,
7474
check_function_inputs_match,
7575
component_or_layout_class,
76-
get_cancel_function,
76+
get_cancelled_fn_indices,
7777
get_continuous_fn,
7878
get_package_version,
7979
get_upload_folder,
@@ -541,12 +541,7 @@ def __init__(
541541
self.rendered_in = rendered_in
542542

543543
# We need to keep track of which events are cancel events
544-
# in two places:
545-
# 1. So that we can skip postprocessing for cancel events.
546-
# They return event_ids that have been cancelled but there
547-
# are no output components
548-
# 2. So that we can place the ProcessCompletedMessage in the
549-
# event stream so that clients can close the stream when necessary
544+
# so that the client can call the /cancel route directly
550545
self.is_cancel_function = is_cancel_function
551546

552547
self.spaces_auto_wrap()
@@ -589,6 +584,7 @@ def get_config(self):
589584
"types": {
590585
"continuous": self.types_continuous,
591586
"generator": self.types_generator,
587+
"cancel": self.is_cancel_function,
592588
},
593589
"collects_event_data": self.collects_event_data,
594590
"trigger_after": self.trigger_after,
@@ -1377,7 +1373,7 @@ def render(self):
13771373
updated_cancels = [
13781374
root_context.fns[i].get_config() for i in dependency.cancels
13791375
]
1380-
dependency.fn = get_cancel_function(updated_cancels)[0]
1376+
dependency.cancels = get_cancelled_fn_indices(updated_cancels)
13811377
root_context.fns[root_context.fn_id] = dependency
13821378
root_context.fn_id += 1
13831379
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
@@ -1694,17 +1690,9 @@ async def postprocess_data(
16941690
block_fn: BlockFunction,
16951691
predictions: list | dict,
16961692
state: SessionState | None,
1697-
) -> Any:
1693+
) -> list[Any]:
16981694
state = state or SessionState(self)
16991695

1700-
# If the function is a cancel function, 'predictions' are the ids of
1701-
# the event in the queue that has been cancelled. We need these
1702-
# so that the server can put the ProcessCompleted message in the event stream
1703-
# Cancel events have no output components, so we need to return early otherise the output
1704-
# be None.
1705-
if block_fn.is_cancel_function:
1706-
return predictions
1707-
17081696
if isinstance(predictions, dict) and len(predictions) > 0:
17091697
predictions = convert_component_dict_to_list(
17101698
[block._id for block in block_fn.outputs], predictions
@@ -1920,7 +1908,7 @@ async def process_api(
19201908
for o in zip(*preds)
19211909
]
19221910
if root_path is not None:
1923-
data = processing_utils.add_root_url(data, root_path, None)
1911+
data = processing_utils.add_root_url(data, root_path, None) # type: ignore
19241912
data = list(zip(*data))
19251913
is_generating, iterator = None, None
19261914
else:

gradio/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from gradio.blocks import Block, Component
2727

2828
from gradio.context import get_blocks_context
29-
from gradio.utils import get_cancel_function
29+
from gradio.utils import get_cancelled_fn_indices
3030

3131

3232
def set_cancel_events(
@@ -36,15 +36,15 @@ def set_cancel_events(
3636
if cancels:
3737
if not isinstance(cancels, list):
3838
cancels = [cancels]
39-
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
39+
fn_indices_to_cancel = get_cancelled_fn_indices(cancels)
4040

4141
root_block = get_blocks_context()
4242
if root_block is None:
4343
raise AttributeError("Cannot cancel outside of a gradio.Blocks context.")
4444

4545
root_block.set_event_trigger(
4646
triggers,
47-
cancel_fn,
47+
fn=None,
4848
inputs=None,
4949
outputs=None,
5050
queue=False,

gradio/routes.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,8 @@ async def file_deprecated(path: str, request: fastapi.Request):
624624

625625
@app.post("/reset/")
626626
@app.post("/reset")
627-
async def reset_iterator(body: ResetBody):
628-
if body.event_id not in app.iterators:
629-
return {"success": False}
630-
async with app.lock:
631-
del app.iterators[body.event_id]
632-
app.iterators_to_reset.add(body.event_id)
633-
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
627+
async def reset_iterator(body: ResetBody): # noqa: ARG001
628+
# No-op, all the cancelling/reset logic handled by /cancel
634629
return {"success": True}
635630

636631
@app.get("/heartbeat/{session_hash}")
@@ -739,18 +734,6 @@ async def predict(
739734
fn=fn,
740735
root_path=root_path,
741736
)
742-
if fn.is_cancel_function:
743-
# Need to complete the job so that the client disconnects
744-
blocks = app.get_blocks()
745-
if body.session_hash in blocks._queue.pending_messages_per_session:
746-
for event_id in output["data"]:
747-
message = ProcessCompletedMessage(
748-
output={}, success=True, event_id=event_id
749-
)
750-
blocks._queue.pending_messages_per_session[ # type: ignore
751-
body.session_hash
752-
].put_nowait(message)
753-
754737
except BaseException as error:
755738
show_error = app.get_blocks().show_error or isinstance(error, Error)
756739
traceback.print_exc()
@@ -823,13 +806,24 @@ async def cancel_event(body: CancelBody):
823806
await cancel_tasks({f"{body.session_hash}_{body.fn_index}"})
824807
blocks = app.get_blocks()
825808
# Need to complete the job so that the client disconnects
826-
if body.session_hash in blocks._queue.pending_messages_per_session:
809+
session_open = (
810+
body.session_hash in blocks._queue.pending_messages_per_session
811+
)
812+
event_running = (
813+
body.event_id
814+
in blocks._queue.pending_event_ids_session.get(body.session_hash, {})
815+
)
816+
if session_open and event_running:
827817
message = ProcessCompletedMessage(
828818
output={}, success=True, event_id=body.event_id
829819
)
830820
blocks._queue.pending_messages_per_session[
831821
body.session_hash
832822
].put_nowait(message)
823+
if body.event_id in app.iterators:
824+
async with app.lock:
825+
del app.iterators[body.event_id]
826+
app.iterators_to_reset.add(body.event_id)
833827
return {"success": True}
834828

835829
@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])

gradio/utils.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def after_fn():
872872

873873
async def cancel_tasks(task_ids: set[str]) -> list[str]:
874874
tasks = [(task, task.get_name()) for task in asyncio.all_tasks()]
875-
event_ids = []
875+
event_ids: list[str] = []
876876
matching_tasks = []
877877
for task, name in tasks:
878878
if "<gradio-sep>" not in name:
@@ -891,27 +891,19 @@ def set_task_name(task, session_hash: str, fn_index: int, event_id: str, batch:
891891
task.set_name(f"{session_hash}_{fn_index}<gradio-sep>{event_id}")
892892

893893

894-
def get_cancel_function(
894+
def get_cancelled_fn_indices(
895895
dependencies: list[dict[str, Any]],
896-
) -> tuple[Callable, list[int]]:
897-
fn_to_comp = {}
896+
) -> list[int]:
897+
fn_indices = []
898898
for dep in dependencies:
899899
root_block = get_blocks_context()
900900
if root_block:
901901
fn_index = next(
902902
i for i, d in root_block.fns.items() if d.get_config() == dep
903903
)
904-
fn_to_comp[fn_index] = [root_block.blocks[o] for o in dep["outputs"]]
904+
fn_indices.append(fn_index)
905905

906-
async def cancel(session_hash: str) -> list[str]:
907-
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
908-
event_ids = await cancel_tasks(task_ids)
909-
return event_ids
910-
911-
return (
912-
cancel,
913-
list(fn_to_comp.keys()),
914-
)
906+
return fn_indices
915907

916908

917909
def get_type_hints(fn):

js/app/src/Blocks.svelte

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,6 @@
207207
208208
const current_status = loading_status.get_status_for_fn(dep_index);
209209
messages = messages.filter(({ fn_index }) => fn_index !== dep_index);
210-
if (dep.cancels) {
211-
await Promise.all(
212-
dep.cancels.map(async (fn_index) => {
213-
const submission = submit_map.get(fn_index);
214-
submission?.cancel();
215-
return submission;
216-
})
217-
);
218-
}
219210
if (current_status === "pending" || current_status === "generating") {
220211
dep.pending_request = true;
221212
}
@@ -242,6 +233,14 @@
242233
handle_update(v, dep_index);
243234
}
244235
});
236+
} else if (dep.types.cancel && dep.cancels) {
237+
await Promise.all(
238+
dep.cancels.map(async (fn_index) => {
239+
const submission = submit_map.get(fn_index);
240+
submission?.cancel();
241+
return submission;
242+
})
243+
);
245244
} else {
246245
if (dep.backend_fn) {
247246
if (dep.trigger_mode === "once") {

0 commit comments

Comments
 (0)