From 4b3ec642497b2b506579f166a526ebb163ce2201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Louren=C3=A7o?= <145066347+ricardo-lourenco@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:04:10 +0100 Subject: [PATCH] Update finemapping scripts to parse significant sumstats (#11) * Updated GWAS input dir and renamed type_id column in spark df * Updated finemapping scripts to work with extracted sumstats data --- 1_scan_input_parquets.py | 9 +++++++-- 2_make_manifest.py | 11 +---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/1_scan_input_parquets.py b/1_scan_input_parquets.py index 8d992d0..254962d 100644 --- a/1_scan_input_parquets.py +++ b/1_scan_input_parquets.py @@ -46,11 +46,15 @@ def main(): os.makedirs(out_path) # Load GWAS dfs - strip_path_gwas = udf(lambda x: x.replace('file:', '').split('/part-')[0], StringType()) + strip_path_gwas = udf(lambda x: x.replace('file://', '').split('/part-')[0], StringType()) gwas_dfs = [] + # Create a spark df from for every study directory (inf) + # in each batch and append it to gwas_dfs + # Note: run_coloc expects 'type_id' column to be named 'type' instead for inf in list(set([i.split("/part-")[0] for i in glob(gwas_pattern)])): df = ( spark.read.parquet(inf) + .withColumnRenamed("type_id", "type") .withColumn('pval_threshold', lit(gwas_pval_threshold)) .withColumn('input_name', strip_path_gwas(input_file_name())) .filter(col('pval') < col('pval_threshold')) @@ -63,11 +67,12 @@ def main(): # This has to be done separately, followed by unionByName as the hive # parititions differ across datasets due to different tissues # (bio_features) and chromosomes - strip_path_mol = udf(lambda x: x.replace('file:', '').split('/part-')[0], StringType()) + strip_path_mol = udf(lambda x: x.replace('file://', '').split('/part-')[0], StringType()) mol_dfs = [] for inf in list(set([i.split("/part-")[0] for i in glob(mol_pattern)])): df = ( spark.read.parquet(inf) + .withColumnRenamed("type_id", "type") .withColumn('pval_threshold', (0.05 / col('num_tests'))) .withColumn('pval_threshold', when(col('pval_threshold') > gwas_pval_threshold, col('pval_threshold')) diff --git a/2_make_manifest.py b/2_make_manifest.py index 3cb7165..b6a48d4 100644 --- a/2_make_manifest.py +++ b/2_make_manifest.py @@ -68,7 +68,7 @@ def main(): out_record['chrom'] = in_record.get('chrom') # Add input files - out_record['in_pq'] = parse_input_name(in_record.get('input_name')) + out_record['in_pq'] = in_record.get('input_name') out_record['in_ld'] = ld_ref # Add output files @@ -116,15 +116,6 @@ def read_json_records(in_pattern): in_record = json.loads(in_record.decode().rstrip()) yield in_record -def parse_input_name(s): - ''' Parses the required input name. Spark's input_file_name() returns the - nested parquet file, I need the top level parquet. - ''' - # Strip nested parquet - out_s = s.split('.parquet')[0] - # Stip file:// - out_s = out_s.replace('file://', '') - return out_s if __name__ == '__main__':