"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md). Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs Phase 1, releases the GPU — zero idle cost. Model weights cache in a persistent Volume so we never re-pay GPU time to re-download. Phase 1 (positive-control recovery test, §12.4): co-fold each known binder + a couple of negative controls against each sickle target and rank by Boltz-2 predicted affinity. The known binder should win its own target — the test Vina couldn't pass on metal/covalent/allosteric modes. Affinity ranking avoids the receptor-alignment that pose-RMSD would need (a later refinement). Setup (one-time): `pip install modal && modal token new`. Run: `modal run gpu/modal_app.py`. Helpers below the GPU function run locally (no GPU) and are import-safe for testing. """ from __future__ import annotations import json import subprocess from pathlib import Path import modal app = modal.App("reverso-binding") image = ( modal.Image.debian_slim(python_version="3.12") .apt_install("git", "wget") .pip_install("boltz", "rdkit", "numpy") ) weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True) WEIGHTS = "/weights" # target name -> (PDB id, crystal-ligand resname, drug name). Plus negatives co-folded into each. TARGETS = { "hemoglobin": ("5E83", "5L7", "voxelotor"), "PKR": ("8XFD", "WV2", "mitapivat"), "HDAC2": ("4LXZ", "SHH", "vorinostat"), } NEGATIVES = ["caffeine", "hydroxyurea"] # --------------------------------------------------------------------------- local helpers (CPU) def fetch_pdb(pdb: str, struct_dir: Path = Path("data/raw/structures")) -> Path: """Return a local PDB path, downloading from RCSB if absent (keeps the GPU run self-contained).""" import requests p = struct_dir / f"{pdb}.pdb" if not p.exists(): p.parent.mkdir(parents=True, exist_ok=True) p.write_bytes(requests.get(f"https://files.rcsb.org/download/{pdb}.pdb", timeout=60).content) return p def binding_chain_sequence(pdb: str, lig_resname: str) -> str: """One-letter sequence of the protein chain nearest the crystal ligand (the binding chain).""" import gemmi st = gemmi.read_structure(str(fetch_pdb(pdb))) model = st[0] lig_atoms = [a.pos for ch in model for res in ch if res.name == lig_resname for a in res] if not lig_atoms: raise ValueError(f"ligand {lig_resname} not found in {pdb}") lig_center = gemmi.Position( sum(p.x for p in lig_atoms) / len(lig_atoms), sum(p.y for p in lig_atoms) / len(lig_atoms), sum(p.z for p in lig_atoms) / len(lig_atoms), ) best, best_d = None, 1e9 for ch in model: poly = ch.get_polymer() if len(poly) < 20: continue d = min((a.pos.dist(lig_center) for res in ch for a in res), default=1e9) if d < best_d: best, best_d = poly, d if best is None: raise ValueError(f"no protein chain in {pdb}") return gemmi.one_letter_code(best.extract_sequence()).upper().replace("X", "") def pubchem_smiles(name: str) -> str: import requests for prop in ("SMILES", "ConnectivitySMILES", "IsomericSMILES", "CanonicalSMILES"): try: d = requests.get(f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}" f"/property/{prop}/JSON", timeout=30).json()["PropertyTable"]["Properties"][0] if prop in d: return d[prop] except Exception: continue raise ValueError(f"no SMILES for {name}") def build_boltz_yaml(protein_seq: str, ligand_smiles: str) -> str: """Boltz-2 YAML for a protein+ligand complex with an affinity prediction on the ligand.""" return ( "version: 1\n" "sequences:\n" " - protein:\n" " id: A\n" f" sequence: {protein_seq}\n" " - ligand:\n" " id: L\n" f" smiles: '{ligand_smiles}'\n" "properties:\n" " - affinity:\n" " binder: L\n" ) # ------------------------------------------------------------------------------- GPU function @app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600) def cofold(label: str, protein_seq: str, ligand_smiles: str) -> dict: """Co-fold one complex on the GPU; return predicted affinity + binder probability. Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips the download. GPU is released the moment this returns. """ import os os.environ["HF_HOME"] = f"{WEIGHTS}/hf" boltz_cache = f"{WEIGHTS}/boltz" os.makedirs(boltz_cache, exist_ok=True) weights.reload() work = Path("/tmp") / label work.mkdir(parents=True, exist_ok=True) (work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles)) out = work / "out" subprocess.run( ["boltz", "predict", str(work / "in.yaml"), "--use_msa_server", "--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"], check=True, ) weights.commit() # persist anything newly downloaded # Affinity is written to a JSON under out/predictions//; parse defensively (keys vary). aff = {"affinity_pred_value": None, "affinity_probability_binary": None} for jf in out.rglob("affinity*.json"): data = json.loads(jf.read_text()) for k in aff: if k in data: aff[k] = data[k] break cif = next(out.rglob("*_model_0.pdb"), None) or next(out.rglob("*.pdb"), None) return {"label": label, "affinity": aff["affinity_pred_value"], "prob_binder": aff["affinity_probability_binary"], "structure": cif.read_text() if cif else None} # ------------------------------------------------------------------------------- driver (local) @app.local_entrypoint() def main() -> None: """Fan out one GPU call per (target, ligand) pair; tabulate affinities; positive-control test.""" jobs = [] # (target, ligand_name, protein_seq, smiles) for target, (pdb, res, drug) in TARGETS.items(): seq = binding_chain_sequence(pdb, res) for lig in [drug, *NEGATIVES]: jobs.append((target, lig, seq, pubchem_smiles(lig))) print(f"co-folding {len(jobs)} complexes ({len(TARGETS)} targets x {1+len(NEGATIVES)} ligands)") results = list(cofold.starmap([(f"{t}_{l}", s, smi) for t, l, s, smi in jobs])) by = {r["label"]: r for r in results} print(f"\n{'target':12s}{'ligand':14s}{'affinity':>10s}{'P(binder)':>11s}") out_rows = [] for target, (pdb, res, drug) in TARGETS.items(): prob = {} # rank by P(binder): unambiguous (higher = more likely a binder). Boltz for lig in [drug, *NEGATIVES]: # affinity_pred_value is ~log(IC50) (lower=stronger) -- sign r = by.get(f"{target}_{lig}", {}) # is version-dependent, so don't rank on it. a, p = r.get("affinity"), r.get("prob_binder") if p is not None: prob[lig] = p print(f"{target:12s}{lig:14s}{(f'{a:.2f}' if a is not None else 'NA'):>10s}" f"{(f'{p:.2f}' if p is not None else 'NA'):>11s}") out_rows.append({"target": target, "ligand": lig, "affinity": a, "prob_binder": p, "is_known_binder": lig == drug}) if prob: best = max(prob, key=prob.get) # highest P(binder) print(f" -> {target}: best P(binder) = {best} (known = {drug}) " f"{'PASS' if best == drug else 'FAIL'}\n") outdir = Path("data/processed/binding"); outdir.mkdir(parents=True, exist_ok=True) import csv with (outdir / "phase1_affinity.csv").open("w", newline="") as f: w = csv.DictWriter(f, fieldnames=["target", "ligand", "affinity", "prob_binder", "is_known_binder"]) w.writeheader(); w.writerows(out_rows) print(f"wrote {outdir/'phase1_affinity.csv'}")