-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathrdkit_workflow.py
71 lines (53 loc) · 1.8 KB
/
rdkit_workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Adds the inchi field to a DataFrame, using rdkit.
Usage:
redun run rdkit_workflow.py calculate_inchi \
--input-dir <s3_path> \
--output-dir <s3_path> \
--smiles-col "smiles"
"""
import logging
from redun import ShardedS3Dataset, glue, task
@glue.udf
def get_inchi(smiles: str) -> str:
from rdkit import Chem
try:
result = Chem.MolToInchi(Chem.MolFromSmiles(smiles))
except Exception as e:
logging.getLogger().error(f"PROBLEM: {e}")
result = "ERROR"
return result
@task(
executor="glue", workers=10, worker_type="G.1X", additional_libs=["rdkit>=2021and<2023.09.3"]
)
def calculate_inchi(
input_dir: ShardedS3Dataset, output_dir: str, smiles_col: str = "smiles"
) -> ShardedS3Dataset:
"""
Adds the "inchi" column to a DataFrame as calculated from a column containing
SMILES strings.
Parameters:
----------
input_dir: ShardedS3Dataset
Location on S3 containing input dataframes. Can be sharded into one or more shards.
Should be in parquet format for this example.
output_dir: str
Location on S3 where output dataframes will be written as parquet-format
shards.
smiles_col: str
Column name containing SMILES strings. Defaults to "smiles_0_1_2_3".
"""
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# Load dataset
logger.info("Loading data")
dataset = input_dir.load_spark().repartition(10)
# Generate INCHI column
logger.info("Running UDF")
with_inchi = dataset.withColumn("inchi", get_inchi(smiles_col))
# Save dataset
logger.info("Saving output")
output_ds = ShardedS3Dataset(output_dir, format="parquet")
output_ds.purge_spark(remove_older_than=0)
output_ds.save_spark(with_inchi)
return output_ds