Skip to content

Commit

Permalink
Add no PCST option in WebQSPDataset (#9722)
Browse files Browse the repository at this point in the history
This PR adds an option to disable PCST in WebQSPDataset, it allows the
user to use some custom retrieval method (such as GNN-RAG).

---------

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
Kh4L and rusty1s authored Oct 28, 2024
1 parent 78e3f39 commit b823c7e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))

Expand Down
35 changes: 20 additions & 15 deletions torch_geometric/datasets/web_qsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ def retrieval_via_pcst(
cost_e: float = 0.5,
) -> Tuple[Data, str]:
c = 0.01
if len(textual_nodes) == 0 or len(textual_edges) == 0:
desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv(
index=False,
columns=["src", "edge_attr", "dst"],
)
return data, desc

from pcst_fast import pcst_fast

Expand Down Expand Up @@ -135,13 +129,17 @@ class WebQSPDataset(InMemoryDataset):
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
use_pcst (bool, optional): Whether to preprocess the dataset's graph
with PCST or return the full graphs. (default: :obj:`True`)
"""
def __init__(
self,
root: str,
split: str = "train",
force_reload: bool = False,
use_pcst: bool = True,
) -> None:
self.use_pcst = use_pcst
super().__init__(root, force_reload=force_reload)

if split not in {'train', 'val', 'test'}:
Expand Down Expand Up @@ -224,15 +222,22 @@ def process(self) -> None:
edge_index=edge_index,
edge_attr=edge_attr,
)
data, desc = retrieval_via_pcst(
data,
question_embs[i],
nodes,
edges,
topk=3,
topk_e=5,
cost_e=0.5,
)
if self.use_pcst and len(nodes) > 0 and len(edges) > 0:
data, desc = retrieval_via_pcst(
data,
question_embs[i],
nodes,
edges,
topk=3,
topk_e=5,
cost_e=0.5,
)
else:
desc = nodes.to_csv(index=False) + "\n" + edges.to_csv(
index=False,
columns=["src", "edge_attr", "dst"],
)

data.question = question
data.label = label
data.desc = desc
Expand Down

0 comments on commit b823c7e

Please sign in to comment.