Train GradientBoosting on 300 drugs x 839 GEO disease signatures with Repurposing-Hub indications as labels (432 positives), disease-grouped CV. Finding: 0.925 CV AUC looks like a win but is a MIRAGE. Feature importances are all drug-level (drug_std 0.33, drug_mean 0.30, broadness 0.17); drug-disease connectivity importance = 0.01. The model learned a drug-POPULARITY prior, not disease-specific matching. On held-out sickle it ranks hydroxyurea 231/300 (worse than baseline) and tops out with promiscuous drugs (dexamethasone, methotrexate). Classic degree-bias trap. Connectivity also has ~chance AUC (0.51) for predicting approved indications. Both obvious approaches now fail instructively: unsupervised = specificity ceiling; naive supervised = degree bias. Real progress needs degree- debiased training + much larger clean labels (a research effort). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
138 lines
6.1 KiB
Python
138 lines
6.1 KiB
Python
"""Phase D: supervised cross-disease repurposing — can labels break the specificity ceiling?
|
|
|
|
Connectivity alone can't tell therapeutic from coincidental reversal. A supervised model trained
|
|
on known drug-disease pairs CAN learn that pattern — given features that expose drug "broadness"
|
|
(a drug that reverses everything is non-specific). We train on 839 GEO disease signatures with
|
|
Repurposing-Hub indications as labels, evaluate with disease-grouped CV, then apply to HELD-OUT
|
|
sickle and check whether the coincidental reversers (norethindrone, ciprofloxacin) finally drop.
|
|
|
|
Baseline = rank by raw connectivity (the single conn feature). Win = model down-ranks the
|
|
negative controls vs baseline while keeping hydroxyurea high.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.ensemble import GradientBoostingClassifier
|
|
from sklearn.model_selection import GroupKFold, cross_val_predict
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
from src.scoring import _ks_connectivity # noqa: E402
|
|
|
|
PROCESSED = Path("data/processed")
|
|
SIGS = Path("data/raw/disease_sigs")
|
|
NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"]
|
|
ID_RE = re.compile(r"\s*(c\d{6,}|doid[-\s]?\d+|gse\d+|umls\S+)\s*", re.I)
|
|
|
|
|
|
def clean_disease(term: str) -> str:
|
|
return ID_RE.sub(" ", term).strip().lower()
|
|
|
|
|
|
def load_disease_sigs():
|
|
def parse(name):
|
|
out = {}
|
|
for line in (SIGS / name).read_text().splitlines():
|
|
p = line.split("\t")
|
|
if len(p) < 3:
|
|
continue
|
|
out[p[0]] = [g.split(",")[0].strip().upper() for g in p[2:] if g.strip()]
|
|
return out
|
|
up = parse("Disease_Perturbations_from_GEO_up.gmt")
|
|
down = parse("Disease_Perturbations_from_GEO_down.gmt")
|
|
terms = sorted(set(up) & set(down))
|
|
return [{"term": t, "disease": clean_disease(t), "up": up[t], "down": down[t]} for t in terms]
|
|
|
|
|
|
def cols_for(genes, g2c):
|
|
return np.array([g2c[g] for g in set(genes) if g in g2c], dtype=int)
|
|
|
|
|
|
def featurize(conn_col, C):
|
|
"""Per-(drug,disease) features for one disease column conn_col, given full matrix C."""
|
|
drug_mean, drug_std = C.mean(1), C.std(1)
|
|
drug_min = C.min(1)
|
|
broad = (C < -0.10).mean(1) # fraction of diseases this drug strongly reverses
|
|
dmean, dstd = conn_col.mean(), conn_col.std() or 1.0
|
|
return np.column_stack([
|
|
conn_col, drug_mean, drug_std, drug_min, broad,
|
|
np.full_like(conn_col, dmean),
|
|
(conn_col - drug_mean) / np.where(drug_std == 0, 1, drug_std), # specificity within drug
|
|
(conn_col - dmean) / dstd, # specificity within disease
|
|
])
|
|
|
|
|
|
FEATS = ["conn", "drug_mean", "drug_std", "drug_min", "broadness",
|
|
"disease_mean", "z_within_drug", "z_within_disease"]
|
|
|
|
|
|
def main():
|
|
lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet")
|
|
drugs = list(lincs.index)
|
|
g2c = {g: i for i, g in enumerate(lincs.columns)}
|
|
n = len(lincs.columns)
|
|
R = lincs.rank(axis=1, ascending=False).to_numpy()
|
|
|
|
refs = [d for d in load_disease_sigs()
|
|
if len(cols_for(d["up"], g2c)) + len(cols_for(d["down"], g2c)) >= 10]
|
|
C = np.column_stack([_ks_connectivity(R, cols_for(d["up"], g2c), cols_for(d["down"], g2c), n) for d in refs])
|
|
print(f"{len(drugs)} drugs x {len(refs)} disease signatures; connectivity matrix {C.shape}")
|
|
|
|
# labels from Repurposing Hub indications
|
|
hub = pd.read_csv("data/raw/repurposing_drugs.txt", sep="\t", comment="!", low_memory=False)
|
|
ind = {r.pert_iname: [i.strip().lower() for i in re.split(r"[|,]", r.indication) if len(i.strip()) > 3]
|
|
for r in hub.itertuples() if isinstance(r.indication, str)}
|
|
drug_idx = {d: i for i, d in enumerate(drugs)}
|
|
|
|
X_rows, y, grp = [], [], []
|
|
for j, d in enumerate(refs):
|
|
feats = featurize(C[:, j], C)
|
|
dz = d["disease"]
|
|
for i, drug in enumerate(drugs):
|
|
inds = ind.get(drug, [])
|
|
label = int(any(dz == k or (len(dz) > 4 and dz in k) or (len(k) > 4 and k in dz) for k in inds))
|
|
X_rows.append(feats[i]); y.append(label); grp.append(j)
|
|
X = np.array(X_rows); y = np.array(y); grp = np.array(grp)
|
|
print(f"pairs: {len(y)}, positives: {y.sum()} ({100*y.mean():.2f}%)")
|
|
|
|
# disease-grouped CV (generalize to unseen diseases)
|
|
clf = GradientBoostingClassifier(random_state=0)
|
|
proba = cross_val_predict(clf, X, y, cv=GroupKFold(5), groups=grp, method="predict_proba")[:, 1]
|
|
print(f"disease-grouped CV AUC: {roc_auc_score(y, proba):.3f} (conn-only AUC: {roc_auc_score(y, X[:,0]*-1):.3f})")
|
|
|
|
clf.fit(X, y)
|
|
print("feature importances: " + ", ".join(f"{f}={imp:.2f}" for f, imp in
|
|
sorted(zip(FEATS, clf.feature_importances_), key=lambda t: -t[1])))
|
|
|
|
# apply to HELD-OUT sickle
|
|
sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text())
|
|
sk = _ks_connectivity(R, cols_for([g["gene"] for g in sig["up_regulated"]], g2c),
|
|
cols_for([g["gene"] for g in sig["down_regulated"]], g2c), n)
|
|
Xs = featurize(sk, C)
|
|
p_sickle = clf.predict_proba(Xs)[:, 1]
|
|
|
|
res = pd.DataFrame({"model_p": p_sickle, "conn": sk}, index=drugs)
|
|
res["model_rank"] = res["model_p"].rank(ascending=False).astype(int)
|
|
res["conn_rank"] = res["conn"].rank(ascending=True).astype(int) # most negative = best
|
|
N = len(res)
|
|
print(f"\n{'drug':14s} {'model_rank':>10s} {'conn_rank(base)':>16s}")
|
|
for d in ["hydroxyurea", "glutamine"] + NEG5 + ["norethindrone", "ciprofloxacin"]:
|
|
if d in res.index:
|
|
r = res.loc[d]
|
|
print(f" {d:14s} {int(r['model_rank']):6d}/{N} {int(r['conn_rank']):6d}/{N}")
|
|
nb_model = sum(res.loc[d, "model_rank"] > N/2 for d in NEG5 if d in res.index)
|
|
nb_conn = sum(res.loc[d, "conn_rank"] > N/2 for d in NEG5 if d in res.index)
|
|
print(f"\nneg controls bottom-half: model {nb_model}/5 vs baseline {nb_conn}/5")
|
|
print("model top 10:", ", ".join(res.sort_values('model_p', ascending=False).head(10).index))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|