Skip to content

Pytorch implementation of “Learnable Prompt as Pseudo-Imputation (PAI)”

Notifications You must be signed in to change notification settings

MrBlankness/PAI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PAI: Learnable Prompt as Pseudo-Imputation

Welcome to the official GitHub repository for PAI (Learnable Prompt as Pseudo-Imputation)!

This is the official code for paper: Learnable Prompt as Pseudo-Imputation: Rethinking the Necessity of Traditional EHR Data Imputation in Downstream Clinical Prediction

Overview

Analyzing the health status of patients based on Electronic Health Records (EHR) is a fundamental research problem in medical informatics. The presence of extensive missing values in EHR makes it challenging for deep neural networks (DNNs) to directly model the patient’s health status. Existing DNNs training protocols, including Impute-then-Regress Procedure and Jointly Optimizing of Impute-n-Regress Procedure, require the additional imputation models to reconstruction missing values. However, Impute-then-Regress Procedure introduces the risk of injecting imputed, non-real data into downstream clinical prediction tasks, resulting in power loss, biased estimation, and poorly performing models, while Jointly Optimizing of Impute-n-Regress Procedure is also difficult to generalize due to the complex optimization space and demanding data requirements. Inspired by the recent advanced literature of learnable prompt in the fields of NLP and CV, in this work, we rethought the necessity of the imputation model in downstream clinical tasks, and proposed Learnable Prompt as Pseudo-Imputation (PAI) as a new training protocol to assist EHR analysis. PAI no longer introduces any imputed data but constructs a learnable prompt to model the implicit preferences of the downstream model for missing values, resulting in a significant performance improvement for all state-of-the-arts EHR analysis models on four real-world datasets across two clinical prediction tasks. Further experimental analysis indicates that PAI exhibits higher robustness in situations of data insufficiency and high missing rates. More importantly, as a plug-and-play protocol, PAI can be easily integrated into any existing or even imperceptible future EHR analysis models.

Install Environment

We use conda to manage the environment. Please refer to the following steps to install the environment:

conda create -n PAI python=3.11 -y
conda activate PAI
pip install -r requirements.txt

Download Datasets

Running

To run the code, simply execute the following command:

python train_model.py

And we have provided some arguments to run the code with different EHR analysis models, different datasets, and different prediction tasks. You can enable all these augments with the following command:

python train_model.py --task "your_task_name" --model "your_model_name" --data "your_dataset_name" --fill

The following table lists all the available arguments, their default values and options for each argument:

Argument Options
--task outcome (mortality), los, readmission
--model rnn, lstm, gru, transformer, retain, concare, m3care, safari
--data mimic (mimic-iv), cdsl, sepsis, eicu, mimiciii (mimic-iii)

You can choose to remove the --fill from the command to close PAI

About

Pytorch implementation of “Learnable Prompt as Pseudo-Imputation (PAI)”

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages