GPU Phase 1 runnable: real Boltz-2 co-folding + alignment review
Flesh out the Modal app into a runnable Phase-1 positive-control test and reconcile it with the plan: - cofold() GPU fn: build Boltz-2 YAML (protein+ligand+affinity), run `boltz predict --use_msa_server --cache /weights/boltz`, parse affinity JSON + predicted pose; weights persist via Volume. - Local helpers (CPU, import-tested against our PDBs): binding_chain_sequence (gemmi -- correctly picks the binding chain, e.g. alpha-globin for 5E83), pubchem_smiles, build_boltz_yaml, fetch_pdb (RCSB). - main(): fan out cofold.starmap over 3 targets x (known binder + 2 negatives); tabulate; PASS if known binder has top P(binder) for its target. Alignment fixes: - Rank by P(binder) (higher=better), NOT raw affinity_pred_value whose sign (~log IC50) is version-dependent -- avoids a backwards positive-control test. - gpu_plan.md Phase 1 updated to affinity/P(binder) ranking; pose-RMSD noted as a later refinement (needs receptor superposition). Local half verified (sequence/SMILES/YAML); cofold() needs a live `modal run` (account + `modal token new`) to validate end-to-end. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -68,10 +68,13 @@ forgotten box can't bleed money.
|
|||||||
|
|
||||||
## What runs on the GPU (in cost order — cheap validation first)
|
## What runs on the GPU (in cost order — cheap validation first)
|
||||||
|
|
||||||
- **Phase 1 — modality validation (~minutes, ~$1):** co-fold the 3 known binders into their
|
- **Phase 1 — modality validation (~minutes, ~$1):** co-fold each known binder + 2 negative
|
||||||
targets (voxelotor/Hb, mitapivat/PKR, vorinostat/HDAC2) and check it reproduces the crystal pose
|
controls (caffeine, hydroxyurea) into each target (Hb, PKR, HDAC2) and check the **known binder
|
||||||
(RMSD <2 Å) where Vina failed on metal/covalent/allosteric modes. If this passes, the modality is
|
has the highest Boltz-2 P(binder)** for its own target — the discrimination Vina couldn't manage
|
||||||
real; if not, stop before spending on a screen.
|
on metal/covalent/allosteric modes. (Ranking uses P(binder), not the raw affinity value, whose
|
||||||
|
sign is version-dependent.) Pose-RMSD vs crystal is a deeper check but needs receptor
|
||||||
|
superposition (align predicted protein to crystal, transform ligand) — a later refinement. If
|
||||||
|
Phase 1 passes, the modality is real; if not, stop before paying for a screen.
|
||||||
- **Phase 2 — screen (~tens of minutes, a few $):** run the ~300-drug set (or a focused subset)
|
- **Phase 2 — screen (~tens of minutes, a few $):** run the ~300-drug set (or a focused subset)
|
||||||
against the sickle targets; rank by Boltz-2 predicted affinity; redo the §12.4 positive-control
|
against the sickle targets; rank by Boltz-2 predicted affinity; redo the §12.4 positive-control
|
||||||
recovery test. Output a ranked CSV, same shape as the connectivity `ranked_candidates`.
|
recovery test. Output a ranked CSV, same shape as the connectivity `ranked_candidates`.
|
||||||
|
|||||||
209
gpu/modal_app.py
209
gpu/modal_app.py
@@ -1,87 +1,190 @@
|
|||||||
"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md).
|
"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md).
|
||||||
|
|
||||||
Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs, and releases it — zero idle cost,
|
Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs Phase 1, releases the GPU — zero
|
||||||
nothing to remember to kill. Model weights are cached in a persistent Volume so we never re-pay GPU
|
idle cost. Model weights cache in a persistent Volume so we never re-pay GPU time to re-download.
|
||||||
time to re-download them. Prep (Meeko/RDKit) and RMSD scoring (spyrmsd) stay light; only the model
|
|
||||||
forward pass needs the GPU.
|
Phase 1 (positive-control recovery test, §12.4): co-fold each known binder + a couple of negative
|
||||||
|
controls against each sickle target and rank by Boltz-2 predicted affinity. The known binder
|
||||||
|
should win its own target — the test Vina couldn't pass on metal/covalent/allosteric modes.
|
||||||
|
Affinity ranking avoids the receptor-alignment that pose-RMSD would need (a later refinement).
|
||||||
|
|
||||||
Setup (one-time): `pip install modal && modal token new`.
|
Setup (one-time): `pip install modal && modal token new`.
|
||||||
Run Phase 1 (validate on 3 known binders): `modal run gpu/modal_app.py`.
|
Run: `modal run gpu/modal_app.py`.
|
||||||
|
|
||||||
STATUS: scaffold. The boltz invocation (input spec + output parsing) is stubbed where marked TODO;
|
Helpers below the GPU function run locally (no GPU) and are import-safe for testing.
|
||||||
wire it after a first `modal run` confirms the image builds and the GPU is reachable.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import modal
|
import modal
|
||||||
|
|
||||||
app = modal.App("reverso-binding")
|
app = modal.App("reverso-binding")
|
||||||
|
|
||||||
# CUDA image + AF3-class model (Boltz-2) + light prep/scoring deps.
|
|
||||||
image = (
|
image = (
|
||||||
modal.Image.debian_slim(python_version="3.12")
|
modal.Image.debian_slim(python_version="3.12")
|
||||||
.apt_install("git", "wget")
|
.apt_install("git", "wget")
|
||||||
.pip_install("boltz", "rdkit", "meeko", "spyrmsd", "gemmi", "numpy")
|
.pip_install("boltz", "rdkit", "numpy")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Persist model weights across runs so we download them once, not every GPU-billed run.
|
|
||||||
weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True)
|
weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True)
|
||||||
|
|
||||||
# Known binders -> (PDB id, crystal ligand resname, SMILES placeholder filled by caller).
|
|
||||||
# Phase 1 validation: does co-folding reproduce these crystal poses where Vina failed?
|
|
||||||
KNOWN = {
|
|
||||||
"voxelotor_Hb": ("5E83", "5L7"),
|
|
||||||
"mitapivat_PKR": ("8XFD", "WV2"),
|
|
||||||
"vorinostat_HDAC2": ("4LXZ", "SHH"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Cache locations on the persistent Volume — the model downloads here ONCE and reuses forever.
|
|
||||||
WEIGHTS = "/weights"
|
WEIGHTS = "/weights"
|
||||||
|
|
||||||
|
# target name -> (PDB id, crystal-ligand resname, drug name). Plus negatives co-folded into each.
|
||||||
|
TARGETS = {
|
||||||
|
"hemoglobin": ("5E83", "5L7", "voxelotor"),
|
||||||
|
"PKR": ("8XFD", "WV2", "mitapivat"),
|
||||||
|
"HDAC2": ("4LXZ", "SHH", "vorinostat"),
|
||||||
|
}
|
||||||
|
NEGATIVES = ["caffeine", "hydroxyurea"]
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- local helpers (CPU)
|
||||||
|
|
||||||
|
def fetch_pdb(pdb: str, struct_dir: Path = Path("data/raw/structures")) -> Path:
|
||||||
|
"""Return a local PDB path, downloading from RCSB if absent (keeps the GPU run self-contained)."""
|
||||||
|
import requests
|
||||||
|
p = struct_dir / f"{pdb}.pdb"
|
||||||
|
if not p.exists():
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p.write_bytes(requests.get(f"https://files.rcsb.org/download/{pdb}.pdb", timeout=60).content)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def binding_chain_sequence(pdb: str, lig_resname: str) -> str:
|
||||||
|
"""One-letter sequence of the protein chain nearest the crystal ligand (the binding chain)."""
|
||||||
|
import gemmi
|
||||||
|
st = gemmi.read_structure(str(fetch_pdb(pdb)))
|
||||||
|
model = st[0]
|
||||||
|
lig_atoms = [a.pos for ch in model for res in ch if res.name == lig_resname for a in res]
|
||||||
|
if not lig_atoms:
|
||||||
|
raise ValueError(f"ligand {lig_resname} not found in {pdb}")
|
||||||
|
lig_center = gemmi.Position(
|
||||||
|
sum(p.x for p in lig_atoms) / len(lig_atoms),
|
||||||
|
sum(p.y for p in lig_atoms) / len(lig_atoms),
|
||||||
|
sum(p.z for p in lig_atoms) / len(lig_atoms),
|
||||||
|
)
|
||||||
|
best, best_d = None, 1e9
|
||||||
|
for ch in model:
|
||||||
|
poly = ch.get_polymer()
|
||||||
|
if len(poly) < 20:
|
||||||
|
continue
|
||||||
|
d = min((a.pos.dist(lig_center) for res in ch for a in res), default=1e9)
|
||||||
|
if d < best_d:
|
||||||
|
best, best_d = poly, d
|
||||||
|
if best is None:
|
||||||
|
raise ValueError(f"no protein chain in {pdb}")
|
||||||
|
return gemmi.one_letter_code(best.extract_sequence()).upper().replace("X", "")
|
||||||
|
|
||||||
|
|
||||||
|
def pubchem_smiles(name: str) -> str:
|
||||||
|
import requests
|
||||||
|
for prop in ("SMILES", "ConnectivitySMILES", "IsomericSMILES", "CanonicalSMILES"):
|
||||||
|
try:
|
||||||
|
d = requests.get(f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}"
|
||||||
|
f"/property/{prop}/JSON", timeout=30).json()["PropertyTable"]["Properties"][0]
|
||||||
|
if prop in d:
|
||||||
|
return d[prop]
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
raise ValueError(f"no SMILES for {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_boltz_yaml(protein_seq: str, ligand_smiles: str) -> str:
|
||||||
|
"""Boltz-2 YAML for a protein+ligand complex with an affinity prediction on the ligand."""
|
||||||
|
return (
|
||||||
|
"version: 1\n"
|
||||||
|
"sequences:\n"
|
||||||
|
" - protein:\n"
|
||||||
|
" id: A\n"
|
||||||
|
f" sequence: {protein_seq}\n"
|
||||||
|
" - ligand:\n"
|
||||||
|
" id: L\n"
|
||||||
|
f" smiles: '{ligand_smiles}'\n"
|
||||||
|
"properties:\n"
|
||||||
|
" - affinity:\n"
|
||||||
|
" binder: L\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------- GPU function
|
||||||
|
|
||||||
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600)
|
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600)
|
||||||
def cofold(protein_seq: str, ligand_smiles: str) -> dict:
|
def cofold(label: str, protein_seq: str, ligand_smiles: str) -> dict:
|
||||||
"""Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string).
|
"""Co-fold one complex on the GPU; return predicted affinity + binder probability.
|
||||||
|
|
||||||
Runs on the GPU only for this call, then the GPU is released. Model weights persist on the
|
Weights persist on the mounted Volume (HF_HOME/boltz --cache under /weights), so run 2+ skips
|
||||||
mounted Volume across runs (see HF_HOME / --cache below), so we never re-pay GPU time to
|
the download. GPU is released the moment this returns.
|
||||||
re-download them.
|
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import subprocess # noqa: F401 (used once boltz is wired)
|
os.environ["HF_HOME"] = f"{WEIGHTS}/hf"
|
||||||
|
boltz_cache = f"{WEIGHTS}/boltz"
|
||||||
# Point every weight downloader at the persistent Volume so the cache survives teardown.
|
|
||||||
os.environ["HF_HOME"] = f"{WEIGHTS}/hf" # huggingface_hub cache
|
|
||||||
os.environ["TORCH_HOME"] = f"{WEIGHTS}/torch" # torch.hub cache
|
|
||||||
boltz_cache = f"{WEIGHTS}/boltz" # boltz --cache target
|
|
||||||
os.makedirs(boltz_cache, exist_ok=True)
|
os.makedirs(boltz_cache, exist_ok=True)
|
||||||
|
|
||||||
# See what's already cached (run 2+ finds weights here and skips the download).
|
|
||||||
weights.reload()
|
weights.reload()
|
||||||
|
|
||||||
# TODO: build boltz input (protein_seq + ligand_smiles), then:
|
work = Path("/tmp") / label
|
||||||
# subprocess.run(["boltz", "predict", input_yaml, "--use_msa_server",
|
work.mkdir(parents=True, exist_ok=True)
|
||||||
# "--cache", boltz_cache, "--out_dir", "/tmp/out"], check=True)
|
(work / "in.yaml").write_text(build_boltz_yaml(protein_seq, ligand_smiles))
|
||||||
# parse predicted structure + affinity from /tmp/out.
|
out = work / "out"
|
||||||
|
subprocess.run(
|
||||||
|
["boltz", "predict", str(work / "in.yaml"), "--use_msa_server",
|
||||||
|
"--cache", boltz_cache, "--out_dir", str(out), "--output_format", "pdb"],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
weights.commit() # persist anything newly downloaded
|
||||||
|
|
||||||
# Persist anything newly downloaded into the cache so the NEXT run reuses it.
|
# Affinity is written to a JSON under out/predictions/<name>/; parse defensively (keys vary).
|
||||||
weights.commit()
|
aff = {"affinity_pred_value": None, "affinity_probability_binary": None}
|
||||||
raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.")
|
for jf in out.rglob("affinity*.json"):
|
||||||
|
data = json.loads(jf.read_text())
|
||||||
|
for k in aff:
|
||||||
|
if k in data:
|
||||||
|
aff[k] = data[k]
|
||||||
|
break
|
||||||
|
cif = next(out.rglob("*_model_0.pdb"), None) or next(out.rglob("*.pdb"), None)
|
||||||
|
return {"label": label, "affinity": aff["affinity_pred_value"],
|
||||||
|
"prob_binder": aff["affinity_probability_binary"],
|
||||||
|
"structure": cif.read_text() if cif else None}
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------- driver (local)
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Phase 1 driver (runs locally; only cofold() touches the GPU).
|
"""Fan out one GPU call per (target, ligand) pair; tabulate affinities; positive-control test."""
|
||||||
|
jobs = [] # (target, ligand_name, protein_seq, smiles)
|
||||||
|
for target, (pdb, res, drug) in TARGETS.items():
|
||||||
|
seq = binding_chain_sequence(pdb, res)
|
||||||
|
for lig in [drug, *NEGATIVES]:
|
||||||
|
jobs.append((target, lig, seq, pubchem_smiles(lig)))
|
||||||
|
print(f"co-folding {len(jobs)} complexes ({len(TARGETS)} targets x {1+len(NEGATIVES)} ligands)")
|
||||||
|
|
||||||
Pulls target sequences + ligand SMILES from the repo, fans out one GPU call per known binder,
|
results = list(cofold.starmap([(f"{t}_{l}", s, smi) for t, l, s, smi in jobs]))
|
||||||
scores redocking RMSD vs the crystal pose locally (spyrmsd), and prints pass/fail. Results are
|
by = {r["label"]: r for r in results}
|
||||||
tiny — commit a summary into data/processed/binding/.
|
|
||||||
"""
|
print(f"\n{'target':12s}{'ligand':14s}{'affinity':>10s}{'P(binder)':>11s}")
|
||||||
# TODO: load protein sequences from data/raw/structures/<pdb>.pdb (gemmi) and ligand SMILES
|
out_rows = []
|
||||||
# (PubChem / drug_set), then:
|
for target, (pdb, res, drug) in TARGETS.items():
|
||||||
# results = list(cofold.map(seqs, smiles))
|
prob = {} # rank by P(binder): unambiguous (higher = more likely a binder). Boltz
|
||||||
# and compute in-place spyrmsd RMSD vs the crystal ligand for each.
|
for lig in [drug, *NEGATIVES]: # affinity_pred_value is ~log(IC50) (lower=stronger) -- sign
|
||||||
print("Scaffold: fill in sequence/SMILES loading + cofold.map, then score RMSD. "
|
r = by.get(f"{target}_{lig}", {}) # is version-dependent, so don't rank on it.
|
||||||
"See docs/gpu_plan.md.")
|
a, p = r.get("affinity"), r.get("prob_binder")
|
||||||
|
if p is not None:
|
||||||
|
prob[lig] = p
|
||||||
|
print(f"{target:12s}{lig:14s}{(f'{a:.2f}' if a is not None else 'NA'):>10s}"
|
||||||
|
f"{(f'{p:.2f}' if p is not None else 'NA'):>11s}")
|
||||||
|
out_rows.append({"target": target, "ligand": lig, "affinity": a, "prob_binder": p,
|
||||||
|
"is_known_binder": lig == drug})
|
||||||
|
if prob:
|
||||||
|
best = max(prob, key=prob.get) # highest P(binder)
|
||||||
|
print(f" -> {target}: best P(binder) = {best} (known = {drug}) "
|
||||||
|
f"{'PASS' if best == drug else 'FAIL'}\n")
|
||||||
|
|
||||||
|
outdir = Path("data/processed/binding"); outdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
import csv
|
||||||
|
with (outdir / "phase1_affinity.csv").open("w", newline="") as f:
|
||||||
|
w = csv.DictWriter(f, fieldnames=["target", "ligand", "affinity", "prob_binder", "is_known_binder"])
|
||||||
|
w.writeheader(); w.writerows(out_rows)
|
||||||
|
print(f"wrote {outdir/'phase1_affinity.csv'}")
|
||||||
|
|||||||
Reference in New Issue
Block a user