Phase D: supervised cross-disease (0.925 AUC degree-bias mirage)

Train GradientBoosting on 300 drugs x 839 GEO disease signatures with
Repurposing-Hub indications as labels (432 positives), disease-grouped CV.

Finding: 0.925 CV AUC looks like a win but is a MIRAGE. Feature
importances are all drug-level (drug_std 0.33, drug_mean 0.30,
broadness 0.17); drug-disease connectivity importance = 0.01. The model
learned a drug-POPULARITY prior, not disease-specific matching. On
held-out sickle it ranks hydroxyurea 231/300 (worse than baseline) and
tops out with promiscuous drugs (dexamethasone, methotrexate). Classic
degree-bias trap. Connectivity also has ~chance AUC (0.51) for predicting
approved indications.

Both obvious approaches now fail instructively: unsupervised = specificity
ceiling; naive supervised = degree bias. Real progress needs degree-
debiased training + much larger clean labels (a research effort).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-06-23 23:31:32 +02:00
parent 0ce688449d
commit 649f617019

View File

@@ -0,0 +1,137 @@
"""Phase D: supervised cross-disease repurposing — can labels break the specificity ceiling?
Connectivity alone can't tell therapeutic from coincidental reversal. A supervised model trained
on known drug-disease pairs CAN learn that pattern — given features that expose drug "broadness"
(a drug that reverses everything is non-specific). We train on 839 GEO disease signatures with
Repurposing-Hub indications as labels, evaluate with disease-grouped CV, then apply to HELD-OUT
sickle and check whether the coincidental reversers (norethindrone, ciprofloxacin) finally drop.
Baseline = rank by raw connectivity (the single conn feature). Win = model down-ranks the
negative controls vs baseline while keeping hydroxyurea high.
"""
from __future__ import annotations
import json
import re
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GroupKFold, cross_val_predict
from sklearn.metrics import roc_auc_score
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.scoring import _ks_connectivity # noqa: E402
PROCESSED = Path("data/processed")
SIGS = Path("data/raw/disease_sigs")
NEG5 = ["clotrimazole", "astemizole", "azithromycin", "ethinyl-estradiol", "caffeine"]
ID_RE = re.compile(r"\s*(c\d{6,}|doid[-\s]?\d+|gse\d+|umls\S+)\s*", re.I)
def clean_disease(term: str) -> str:
return ID_RE.sub(" ", term).strip().lower()
def load_disease_sigs():
def parse(name):
out = {}
for line in (SIGS / name).read_text().splitlines():
p = line.split("\t")
if len(p) < 3:
continue
out[p[0]] = [g.split(",")[0].strip().upper() for g in p[2:] if g.strip()]
return out
up = parse("Disease_Perturbations_from_GEO_up.gmt")
down = parse("Disease_Perturbations_from_GEO_down.gmt")
terms = sorted(set(up) & set(down))
return [{"term": t, "disease": clean_disease(t), "up": up[t], "down": down[t]} for t in terms]
def cols_for(genes, g2c):
return np.array([g2c[g] for g in set(genes) if g in g2c], dtype=int)
def featurize(conn_col, C):
"""Per-(drug,disease) features for one disease column conn_col, given full matrix C."""
drug_mean, drug_std = C.mean(1), C.std(1)
drug_min = C.min(1)
broad = (C < -0.10).mean(1) # fraction of diseases this drug strongly reverses
dmean, dstd = conn_col.mean(), conn_col.std() or 1.0
return np.column_stack([
conn_col, drug_mean, drug_std, drug_min, broad,
np.full_like(conn_col, dmean),
(conn_col - drug_mean) / np.where(drug_std == 0, 1, drug_std), # specificity within drug
(conn_col - dmean) / dstd, # specificity within disease
])
FEATS = ["conn", "drug_mean", "drug_std", "drug_min", "broadness",
"disease_mean", "z_within_drug", "z_within_disease"]
def main():
lincs = pd.read_parquet(PROCESSED / "lincs_signatures_v1.parquet")
drugs = list(lincs.index)
g2c = {g: i for i, g in enumerate(lincs.columns)}
n = len(lincs.columns)
R = lincs.rank(axis=1, ascending=False).to_numpy()
refs = [d for d in load_disease_sigs()
if len(cols_for(d["up"], g2c)) + len(cols_for(d["down"], g2c)) >= 10]
C = np.column_stack([_ks_connectivity(R, cols_for(d["up"], g2c), cols_for(d["down"], g2c), n) for d in refs])
print(f"{len(drugs)} drugs x {len(refs)} disease signatures; connectivity matrix {C.shape}")
# labels from Repurposing Hub indications
hub = pd.read_csv("data/raw/repurposing_drugs.txt", sep="\t", comment="!", low_memory=False)
ind = {r.pert_iname: [i.strip().lower() for i in re.split(r"[|,]", r.indication) if len(i.strip()) > 3]
for r in hub.itertuples() if isinstance(r.indication, str)}
drug_idx = {d: i for i, d in enumerate(drugs)}
X_rows, y, grp = [], [], []
for j, d in enumerate(refs):
feats = featurize(C[:, j], C)
dz = d["disease"]
for i, drug in enumerate(drugs):
inds = ind.get(drug, [])
label = int(any(dz == k or (len(dz) > 4 and dz in k) or (len(k) > 4 and k in dz) for k in inds))
X_rows.append(feats[i]); y.append(label); grp.append(j)
X = np.array(X_rows); y = np.array(y); grp = np.array(grp)
print(f"pairs: {len(y)}, positives: {y.sum()} ({100*y.mean():.2f}%)")
# disease-grouped CV (generalize to unseen diseases)
clf = GradientBoostingClassifier(random_state=0)
proba = cross_val_predict(clf, X, y, cv=GroupKFold(5), groups=grp, method="predict_proba")[:, 1]
print(f"disease-grouped CV AUC: {roc_auc_score(y, proba):.3f} (conn-only AUC: {roc_auc_score(y, X[:,0]*-1):.3f})")
clf.fit(X, y)
print("feature importances: " + ", ".join(f"{f}={imp:.2f}" for f, imp in
sorted(zip(FEATS, clf.feature_importances_), key=lambda t: -t[1])))
# apply to HELD-OUT sickle
sig = json.loads((PROCESSED / "sickle_cell_signature_v1.json").read_text())
sk = _ks_connectivity(R, cols_for([g["gene"] for g in sig["up_regulated"]], g2c),
cols_for([g["gene"] for g in sig["down_regulated"]], g2c), n)
Xs = featurize(sk, C)
p_sickle = clf.predict_proba(Xs)[:, 1]
res = pd.DataFrame({"model_p": p_sickle, "conn": sk}, index=drugs)
res["model_rank"] = res["model_p"].rank(ascending=False).astype(int)
res["conn_rank"] = res["conn"].rank(ascending=True).astype(int) # most negative = best
N = len(res)
print(f"\n{'drug':14s} {'model_rank':>10s} {'conn_rank(base)':>16s}")
for d in ["hydroxyurea", "glutamine"] + NEG5 + ["norethindrone", "ciprofloxacin"]:
if d in res.index:
r = res.loc[d]
print(f" {d:14s} {int(r['model_rank']):6d}/{N} {int(r['conn_rank']):6d}/{N}")
nb_model = sum(res.loc[d, "model_rank"] > N/2 for d in NEG5 if d in res.index)
nb_conn = sum(res.loc[d, "conn_rank"] > N/2 for d in NEG5 if d in res.index)
print(f"\nneg controls bottom-half: model {nb_model}/5 vs baseline {nb_conn}/5")
print("model top 10:", ", ".join(res.sort_values('model_p', ascending=False).head(10).index))
if __name__ == "__main__":
main()