Skip to content

Commit

Permalink
fixed preprocessing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas Camillo authored and Lucas Camillo committed Dec 25, 2023
1 parent ef27ade commit b9afc3b
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 256 deletions.
36 changes: 13 additions & 23 deletions pyaging/predict/_pred_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,6 @@ def preprocess_data(
np.array([...])
"""
# Skip if it there is no preprocessing to be done
if preprocessing is None:
logger.info(
"There is no preprocessing to be done",
indent_level=3,
)
return adata

# Move to adata.X for preprocessing
adata.X = (
Expand All @@ -467,6 +460,14 @@ def preprocess_data(
else adata.layers["X_original"].copy()
)

# Skip if it there is no preprocessing to be done
if preprocessing is None:
logger.info(
"There is no preprocessing to be done",
indent_level=3,
)
return adata

logger.info(f"Preprocessing data with function {preprocessing}", indent_level=3)
# Apply specified preprocessing method
if preprocessing == "tpm_norm_log1p":
Expand Down Expand Up @@ -596,12 +597,7 @@ def postprocess_data(

@progress("Predict ages with model")
def predict_ages_with_model(
model: torch.nn.Module,
adata: torch.Tensor,
features: List[str],
device: str,
logger,
indent_level: int = 2,
model: torch.nn.Module, adata: torch.Tensor, features: List[str], device: str, logger, indent_level: int = 2
) -> torch.Tensor:
"""
Predict biological ages using a trained model and input data.
Expand Down Expand Up @@ -655,7 +651,7 @@ def predict_ages_with_model(
"""
# Create an AnnLoader
use_cuda = device == "cuda"
use_cuda = device == 'cuda'
dataloader = AnnLoader(adata, batch_size=1024, use_cuda=use_cuda)

# Use the AnnLoader for batched prediction
Expand Down Expand Up @@ -826,17 +822,11 @@ def filter_missing_features(
"""
n_missing_features = sum(adata.var["percent_na"] == 1)
if n_missing_features > 0:
logger.info(
f"Removing {n_missing_features} added features",
indent_level=indent_level + 1,
)
logger.info(f"Removing {n_missing_features} added features", indent_level=indent_level+1)
adata = adata[:, adata.var["percent_na"] < 1].copy()
else:
logger.info(
"No missing features, so adata size did not change",
indent_level=indent_level + 1,
)

logger.info("No missing features, so adata size did not change", indent_level=indent_level+1)

return adata


Expand Down
Loading

0 comments on commit b9afc3b

Please sign in to comment.