"""
Survival Analysis Utilities Module.
This module provides functions for performing survival analysis, including Cox proportional hazards
regression models and Kaplan-Meier survival curves. It includes utilities for data preprocessing,
multicollinearity checking, and visualization of results.
Requires: pandas, numpy, scikit-learn, lifelines, matplotlib, statsmodels, seaborn
"""
import os
import warnings
from pathlib import Path
from typing import List, Optional, Tuple, Union
import lifelines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.exceptions import ConvergenceError
from lifelines.statistics import logrank_test
from sklearn.preprocessing import StandardScaler
from statsmodels.stats.outliers_influence import variance_inflation_factor
from tqdm.auto import tqdm
from bca_survival.utils import make_quantile_split
[docs]
def standardize_columns(
df: pd.DataFrame, columns: List[str], nan_threshold: float = 0.7
) -> pd.DataFrame:
"""
Standardizes only numeric columns and handles missing values.
Args:
df (pd.DataFrame): The input dataframe.
columns (list): List of column names to consider for standardization.
nan_threshold (float, optional): Threshold for NaN values. Columns with more NaNs
than this threshold will be dropped. Defaults to 0.7.
Returns:
pd.DataFrame: DataFrame with standardized numeric columns.
Note:
This function creates a copy of the dataframe and standardizes only the numeric
columns using StandardScaler. Categorical columns are left unchanged.
"""
# Create a copy of the dataframe to avoid modifying the original
df_copy = df.copy()
# Filter out columns with too many NaNs
columns_to_process = columns.copy() # Create a copy to avoid modifying the input list
for column in columns:
nan_ratio = df[column].isna().mean()
if nan_ratio > nan_threshold:
print(f"Dropping column {column} due to {nan_ratio:.2%} NaNs")
columns_to_process.remove(column)
# Separate numeric and non-numeric columns
numeric_columns = []
for column in columns_to_process:
# Check if column is numeric (excluding columns with strings like 'X')
if pd.api.types.is_numeric_dtype(df[column]):
numeric_columns.append(column)
else:
print(f"Skipping non-numeric column {column} for standardization")
# Apply StandardScaler only to numeric columns
if numeric_columns:
scaler = StandardScaler()
df_copy[numeric_columns] = scaler.fit_transform(df_copy[numeric_columns])
return df_copy
[docs]
def check_multicollinearity(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
"""
Checks multicollinearity between variables using a correlation matrix.
Args:
df (pd.DataFrame): The input dataframe.
columns (list): List of column names to check for multicollinearity.
Returns:
pd.DataFrame: Correlation matrix of the specified columns.
Note:
This function also displays a heatmap of the correlation matrix.
"""
corr_matrix = df[columns].corr()
sns.heatmap(corr_matrix, annot=True, fmt=".2f")
plt.title("Correlation Matrix")
plt.show()
return corr_matrix
[docs]
def generate_kaplan_meier_plot(
df: pd.DataFrame,
column: str,
split_strategy: str = "median",
fixed_value: Optional[float] = None,
percentage: Optional[float] = None,
output_path: Optional[Union[os.PathLike[str], str]] = None,
dpi: int = 600,
custom_title: Optional[str] = None,
display_plot: bool = False,
custom_high_low_names: Tuple[str, str] = ("low", "high"),
) -> dict:
"""
Generates a Kaplan-Meier survival plot for a specified variable.
Args:
df (pd.DataFrame): The input dataframe. Must contain 'days' and 'event' columns.
column (str): Column name to use for grouping.
split_strategy (str, optional): Strategy for splitting data into high/low groups.
Options: 'mean', 'median', 'percentage', 'fixed', 'quantile'. Defaults to 'median'.
fixed_value (float, optional): Fixed threshold value when split_strategy is 'fixed'.
You can use this when you have found cutoff values from literature.
Defaults to None.
percentage (float, optional): Percentile threshold when split_strategy is 'percentage'.
Defaults to None.
output_path (str, optional): Directory path to save the plot. If None, saves in current
directory. Defaults to None.
dpi (int, optional): Resolution of the output image in dots per inch. Higher values
result in better quality but larger file sizes. Defaults to 600.
custom_title (str, optional): Custom title for the plot. If None, a default title will
be generated based on the column and split strategy. Defaults to None.
display_plot (bool, optional): Whether to display the plot in the notebook. If False,
the plot is only saved to file without rendering. Defaults to False.
custom_high_low_names (Tuple[str, str], optional): Custom high and low variable names.
Defaults to ("low", "high").
Returns:
dict: Dictionary containing the log-rank test p-value, plot filename, and test statistic.
Raises:
ValueError: If an invalid split_strategy is provided or if required parameters for a
particular strategy are missing.
Note:
This function splits the data into "high" and "low" groups based on the specified
variable and strategy, then generates a Kaplan-Meier survival plot comparing
the two groups. It also performs a log-rank test to compare the survival curves.
"""
import matplotlib.pyplot as plt
# For optimization in notebooks when display_plot is False
if not display_plot:
# Use plt.ioff() to turn off interactive mode
plt.ioff()
else:
plt.ion()
if split_strategy == "mean":
threshold = df[column].mean()
elif split_strategy == "median":
threshold = df[column].median()
elif split_strategy == "percentage":
threshold = df[column].quantile(percentage)
elif split_strategy == "fixed" and fixed_value is not None:
threshold = fixed_value
elif split_strategy == "quantile":
threshold = "quantile"
else:
raise ValueError(
"Invalid split_strategy. Use 'mean', 'median', 'percentage', 'fixed', or 'quantile'. "
"For 'fixed', provide fixed_value. For 'percentage', provide percentage."
)
df_tmp = df.copy().dropna(subset=column)
if threshold == "quantile":
df_tmp = make_quantile_split(df_tmp, column)
else:
df_tmp["group"] = np.where(
df_tmp[column] > threshold, custom_high_low_names[1], custom_high_low_names[0]
)
# Create figure
fig, ax = plt.subplots()
kmf = KaplanMeierFitter()
results_high = df_tmp[df_tmp["group"] == custom_high_low_names[1]]
results_low = df_tmp[df_tmp["group"] == custom_high_low_names[0]]
kmf.fit(
durations=results_high["days"],
event_observed=results_high["event"],
label=custom_high_low_names[1],
)
kmf.plot_survival_function(ax=ax)
kmf.fit(
durations=results_low["days"],
event_observed=results_low["event"],
label=custom_high_low_names[0],
)
kmf.plot_survival_function(ax=ax)
# Use custom title if provided, otherwise use default
if custom_title:
ax.set_title(custom_title)
else:
ax.set_title(f"Survival function by {column} ({split_strategy} split)")
ax.set_xlabel("Days")
ax.set_ylabel("Survival probability")
logrank_results = logrank_test(
results_high["days"], results_low["days"], results_high["event"], results_low["event"]
)
p_value = logrank_results.p_value
fig.text(0.15, 0.2, f"p-value: {p_value:.4f}", fontsize=12, ha="left")
plot_filename = f"km_plot_{column}_{split_strategy}.png"
plot_filename = (
plot_filename.replace(" ", "_").replace("\n", "").replace("/", "").replace(":", "_")
)
if output_path:
Path(output_path).mkdir(exist_ok=True, parents=True)
fig.savefig(str(Path(output_path, plot_filename)), dpi=dpi)
else:
fig.savefig(str(plot_filename), dpi=dpi)
plt.close(fig)
# Restore interactive mode if needed
if not display_plot:
plt.ion()
return {
"p-value": p_value,
"plot_filename": plot_filename,
"metrics": logrank_results.test_statistic,
}
[docs]
def calculate_vif(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
"""
Calculates the Variance Inflation Factor (VIF) for each variable.
Args:
df (pd.DataFrame): The input dataframe.
columns (list): List of column names to calculate VIF for.
Returns:
pd.DataFrame: DataFrame containing variables and their corresponding VIF values.
Note:
VIF is a measure of multicollinearity. Higher values indicate stronger
correlation with other variables. VIF > 10 is often considered problematic.
"""
# Add a constant for VIF calculation
df_with_const = df.copy()
df_with_const["constant"] = 1
# Calculate VIF for each variable
vif_data = pd.DataFrame()
vif_data["Variable"] = columns
vif_data["VIF"] = [
variance_inflation_factor(df_with_const[columns + ["constant"]].dropna().values, i)
for i in range(len(columns))
]
return vif_data