diff --git a/docs/gpu_plan.md b/docs/gpu_plan.md index b72431b..063c99c 100644 --- a/docs/gpu_plan.md +++ b/docs/gpu_plan.md @@ -68,10 +68,13 @@ forgotten box can't bleed money. ## What runs on the GPU (in cost order — cheap validation first) -- **Phase 1 — modality validation (~minutes, ~$1):** co-fold the 3 known binders into their - targets (voxelotor/Hb, mitapivat/PKR, vorinostat/HDAC2) and check it reproduces the crystal pose - (RMSD <2 Å) where Vina failed on metal/covalent/allosteric modes. If this passes, the modality is - real; if not, stop before spending on a screen. +- **Phase 1 — modality validation (~minutes, ~$1):** co-fold each known binder + 2 negative + controls (caffeine, hydroxyurea) into each target (Hb, PKR, HDAC2) and check the **known binder + has the highest Boltz-2 P(binder)** for its own target — the discrimination Vina couldn't manage + on metal/covalent/allosteric modes. (Ranking uses P(binder), not the raw affinity value, whose + sign is version-dependent.) Pose-RMSD vs crystal is a deeper check but needs receptor + superposition (align predicted protein to crystal, transform ligand) — a later refinement. If + Phase 1 passes, the modality is real; if not, stop before paying for a screen. - **Phase 2 — screen (~tens of minutes, a few $):** run the ~300-drug set (or a focused subset) against the sickle targets; rank by Boltz-2 predicted affinity; redo the §12.4 positive-control recovery test. Output a ranked CSV, same shape as the connectivity `ranked_candidates`. diff --git a/gpu/modal_app.py b/gpu/modal_app.py index f735828..6038d3f 100644 --- a/gpu/modal_app.py +++ b/gpu/modal_app.py @@ -1,87 +1,190 @@ """Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md). -Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs, and releases it — zero idle cost, -nothing to remember to kill. Model weights are cached in a persistent Volume so we never re-pay GPU -time to re-download them. Prep (Meeko/RDKit) and RMSD scoring (spyrmsd) stay light; only the model -forward pass needs the GPU. +Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs Phase 1, releases the GPU — zero +idle cost. Model weights cache in a persistent Volume so we never re-pay GPU time to re-download. + +Phase 1 (positive-control recovery test, §12.4): co-fold each known binder + a couple of negative +controls against each sickle target and rank by Boltz-2 predicted affinity. The known binder +should win its own target — the test Vina couldn't pass on metal/covalent/allosteric modes. +Affinity ranking avoids the receptor-alignment that pose-RMSD would need (a later refinement). Setup (one-time): `pip install modal && modal token new`. -Run Phase 1 (validate on 3 known binders): `modal run gpu/modal_app.py`. +Run: `modal run gpu/modal_app.py`. -STATUS: scaffold. The boltz invocation (input spec + output parsing) is stubbed where marked TODO; -wire it after a first `modal run` confirms the image builds and the GPU is reachable. +Helpers below the GPU function run locally (no GPU) and are import-safe for testing. """ from __future__ import annotations +import json +import subprocess +from pathlib import Path + import modal app = modal.App("reverso-binding") -# CUDA image + AF3-class model (Boltz-2) + light prep/scoring deps. image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("git", "wget") - .pip_install("boltz", "rdkit", "meeko", "spyrmsd", "gemmi", "numpy") + .pip_install("boltz", "rdkit", "numpy") ) - -# Persist model weights across runs so we download them once, not every GPU-billed run. weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True) - -# Known binders -> (PDB id, crystal ligand resname, SMILES placeholder filled by caller). -# Phase 1 validation: does co-folding reproduce these crystal poses where Vina failed? -KNOWN = { - "voxelotor_Hb": ("5E83", "5L7"), - "mitapivat_PKR": ("8XFD", "WV2"), - "vorinostat_HDAC2": ("4LXZ", "SHH"), -} - - -# Cache locations on the persistent Volume — the model downloads here ONCE and reuses forever. WEIGHTS = "/weights" +# target name -> (PDB id, crystal-ligand resname, drug name). Plus negatives co-folded into each. +TARGETS = { + "hemoglobin": ("5E83", "5L7", "voxelotor"), + "PKR": ("8XFD", "WV2", "mitapivat"), + "HDAC2": ("4LXZ", "SHH", "vorinostat"), +} +NEGATIVES = ["caffeine", "hydroxyurea"] + + +# --------------------------------------------------------------------------- local helpers (CPU) + +def fetch_pdb(pdb: str, struct_dir: Path = Path("data/raw/structures")) -> Path: + """Return a local PDB path, downloading from RCSB if absent (keeps the GPU run self-contained).""" + import requests + p = struct_dir / f"{pdb}.pdb" + if not p.exists(): + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(requests.get(f"https://files.rcsb.org/download/{pdb}.pdb", timeout=60).content) + return p + + +def binding_chain_sequence(pdb: str, lig_resname: str) -> str: + """One-letter sequence of the protein chain nearest the crystal ligand (the binding chain).""" + import gemmi + st = gemmi.read_structure(str(fetch_pdb(pdb))) + model = st[0] + lig_atoms = [a.pos for ch in model for res in ch if res.name == lig_resname for a in res] + if not lig_atoms: + raise ValueError(f"ligand {lig_resname} not found in {pdb}") + lig_center = gemmi.Position( + sum(p.x for p in lig_atoms) / len(lig_atoms), + sum(p.y for p in lig_atoms) / len(lig_atoms), + sum(p.z for p in lig_atoms) / len(lig_atoms), + ) + best, best_d = None, 1e9 + for ch in model: + poly = ch.get_polymer() + if len(poly) < 20: + continue + d = min((a.pos.dist(lig_center) for res in ch for a in res), default=1e9) + if d < best_d: + best, best_d = poly, d + if best is None: + raise ValueError(f"no protein chain in {pdb}") + return gemmi.one_letter_code(best.extract_sequence()).upper().replace("X", "") + + +def pubchem_smiles(name: str) -> str: + import requests + for prop in ("SMILES", "ConnectivitySMILES", "IsomericSMILES", "CanonicalSMILES"): + try: + d = requests.get(f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}" + f"/property/{prop}/JSON", timeout=30).json()["PropertyTable"]["Properties"][0] + if prop in d: + return d[prop] + except Exception: + continue + raise ValueError(f"no SMILES for {name}") + + +def build_boltz_yaml(protein_seq: str, ligand_smiles: str) -> str: + """Boltz-2 YAML for a protein+ligand complex with an affinity prediction on the ligand.""" + return ( + "version: 1\n" + "sequences:\n" + " - protein:\n" + " id: A\n" + f" sequence: {protein_seq}\n" + " - ligand:\n" + " id: L\n" + f" smiles: '{ligand_smiles}'\n" + "properties:\n" + " - affinity:\n" + " binder: L\n" + ) + + +# ------------------------------------------------------------------------------- GPU function @app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600) -def cofold(protein_seq: str, ligand_smiles: str) -> dict: - """Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string). +def cofold(label: str, protein_seq: str, ligand_smiles: str) -> dict: + """Co-fold one complex on the GPU; return predicted affinity + binder probability. - Runs on the GPU only for this call, then the GPU is released. Model weights persist on the - mounted Volume across runs (see HF_HOME / --cache below), so we never re-pay GPU time to - re-download them. + Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips + the download. GPU is released the moment this returns. """ import os - import subprocess # noqa: F401 (used once boltz is wired) - - # Point every weight downloader at the persistent Volume so the cache survives teardown. - os.environ["HF_HOME"] = f"{WEIGHTS}/hf" # huggingface_hub cache - os.environ["TORCH_HOME"] = f"{WEIGHTS}/torch" # torch.hub cache - boltz_cache = f"{WEIGHTS}/boltz" # boltz --cache target + os.environ["HF_HOME"] = f"{WEIGHTS}/hf" + boltz_cache = f"{WEIGHTS}/boltz" os.makedirs(boltz_cache, exist_ok=True) - - # See what's already cached (run 2+ finds weights here and skips the download). weights.reload() - # TODO: build boltz input (protein_seq + ligand_smiles), then: - # subprocess.run(["boltz", "predict", input_yaml, "--use_msa_server", - # "--cache", boltz_cache, "--out_dir", "/tmp/out"], check=True) - # parse predicted structure + affinity from /tmp/out. + work = Path("/tmp") / label + work.mkdir(parents=True, exist_ok=True) + (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles)) + out = work / "out" + subprocess.run( + ["boltz", "predict", str(work / "in.yaml"), "--use_msa_server", + "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], + check=True, + ) + weights.commit() # persist anything newly downloaded - # Persist anything newly downloaded into the cache so the NEXT run reuses it. - weights.commit() - raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.") + # Affinity is written to a JSON under out/predictions//; parse defensively (keys vary). + aff = {"affinity_pred_value": None, "affinity_probability_binary": None} + for jf in out.rglob("affinity*.json"): + data = json.loads(jf.read_text()) + for k in aff: + if k in data: + aff[k] = data[k] + break + cif = next(out.rglob("*_model_0.pdb"), None) or next(out.rglob("*.pdb"), None) + return {"label": label, "affinity": aff["affinity_pred_value"], + "prob_binder": aff["affinity_probability_binary"], + "structure": cif.read_text() if cif else None} +# ------------------------------------------------------------------------------- driver (local) + @app.local_entrypoint() def main() -> None: - """Phase 1 driver (runs locally; only cofold() touches the GPU). + """Fan out one GPU call per (target, ligand) pair; tabulate affinities; positive-control test.""" + jobs = [] # (target, ligand_name, protein_seq, smiles) + for target, (pdb, res, drug) in TARGETS.items(): + seq = binding_chain_sequence(pdb, res) + for lig in [drug, *NEGATIVES]: + jobs.append((target, lig, seq, pubchem_smiles(lig))) + print(f"co-folding {len(jobs)} complexes ({len(TARGETS)} targets x {1+len(NEGATIVES)} ligands)") - Pulls target sequences + ligand SMILES from the repo, fans out one GPU call per known binder, - scores redocking RMSD vs the crystal pose locally (spyrmsd), and prints pass/fail. Results are - tiny — commit a summary into data/processed/binding/. - """ - # TODO: load protein sequences from data/raw/structures/.pdb (gemmi) and ligand SMILES - # (PubChem / drug_set), then: - # results = list(cofold.map(seqs, smiles)) - # and compute in-place spyrmsd RMSD vs the crystal ligand for each. - print("Scaffold: fill in sequence/SMILES loading + cofold.map, then score RMSD. " - "See docs/gpu_plan.md.") + results = list(cofold.starmap([(f"{t}_{l}", s, smi) for t, l, s, smi in jobs])) + by = {r["label"]: r for r in results} + + print(f"\n{'target':12s}{'ligand':14s}{'affinity':>10s}{'P(binder)':>11s}") + out_rows = [] + for target, (pdb, res, drug) in TARGETS.items(): + prob = {} # rank by P(binder): unambiguous (higher = more likely a binder). Boltz + for lig in [drug, *NEGATIVES]: # affinity_pred_value is ~log(IC50) (lower=stronger) -- sign + r = by.get(f"{target}_{lig}", {}) # is version-dependent, so don't rank on it. + a, p = r.get("affinity"), r.get("prob_binder") + if p is not None: + prob[lig] = p + print(f"{target:12s}{lig:14s}{(f'{a:.2f}' if a is not None else 'NA'):>10s}" + f"{(f'{p:.2f}' if p is not None else 'NA'):>11s}") + out_rows.append({"target": target, "ligand": lig, "affinity": a, "prob_binder": p, + "is_known_binder": lig == drug}) + if prob: + best = max(prob, key=prob.get) # highest P(binder) + print(f" -> {target}: best P(binder) = {best} (known = {drug}) " + f"{'PASS' if best == drug else 'FAIL'}\n") + + outdir = Path("data/processed/binding"); outdir.mkdir(parents=True, exist_ok=True) + import csv + with (outdir / "phase1_affinity.csv").open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=["target", "ligand", "affinity", "prob_binder", "is_known_binder"]) + w.writeheader(); w.writerows(out_rows) + print(f"wrote {outdir/'phase1_affinity.csv'}")