Document and wire the weight-caching mechanism: - modal.Volume is a cloud-backed FS independent of the GPU/container; run 1 downloads weights into /weights, run 2+ reuses them (no GPU time wasted re-downloading). - Point downloaders at the mount: HF_HOME/TORCH_HOME/boltz --cache; persist via weights.commit(), see updates via weights.reload(). - Volume storage costs pennies, separate from GPU = near-free caching. modal_app.py cofold(): set cache env vars to /weights, reload()/commit() around the (stubbed) boltz call. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
88 lines
3.8 KiB
Python
88 lines
3.8 KiB
Python
"""Ephemeral GPU runner for AF3-class co-folding (PLAN §12, docs/gpu_plan.md).
|
|
|
|
Serverless: `modal run gpu/modal_app.py` provisions a GPU, runs, and releases it — zero idle cost,
|
|
nothing to remember to kill. Model weights are cached in a persistent Volume so we never re-pay GPU
|
|
time to re-download them. Prep (Meeko/RDKit) and RMSD scoring (spyrmsd) stay light; only the model
|
|
forward pass needs the GPU.
|
|
|
|
Setup (one-time): `pip install modal && modal token new`.
|
|
Run Phase 1 (validate on 3 known binders): `modal run gpu/modal_app.py`.
|
|
|
|
STATUS: scaffold. The boltz invocation (input spec + output parsing) is stubbed where marked TODO;
|
|
wire it after a first `modal run` confirms the image builds and the GPU is reachable.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import modal
|
|
|
|
app = modal.App("reverso-binding")
|
|
|
|
# CUDA image + AF3-class model (Boltz-2) + light prep/scoring deps.
|
|
image = (
|
|
modal.Image.debian_slim(python_version="3.12")
|
|
.apt_install("git", "wget")
|
|
.pip_install("boltz", "rdkit", "meeko", "spyrmsd", "gemmi", "numpy")
|
|
)
|
|
|
|
# Persist model weights across runs so we download them once, not every GPU-billed run.
|
|
weights = modal.Volume.from_name("reverso-binding-weights", create_if_missing=True)
|
|
|
|
# Known binders -> (PDB id, crystal ligand resname, SMILES placeholder filled by caller).
|
|
# Phase 1 validation: does co-folding reproduce these crystal poses where Vina failed?
|
|
KNOWN = {
|
|
"voxelotor_Hb": ("5E83", "5L7"),
|
|
"mitapivat_PKR": ("8XFD", "WV2"),
|
|
"vorinostat_HDAC2": ("4LXZ", "SHH"),
|
|
}
|
|
|
|
|
|
# Cache locations on the persistent Volume — the model downloads here ONCE and reuses forever.
|
|
WEIGHTS = "/weights"
|
|
|
|
|
|
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600)
|
|
def cofold(protein_seq: str, ligand_smiles: str) -> dict:
|
|
"""Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string).
|
|
|
|
Runs on the GPU only for this call, then the GPU is released. Model weights persist on the
|
|
mounted Volume across runs (see HF_HOME / --cache below), so we never re-pay GPU time to
|
|
re-download them.
|
|
"""
|
|
import os
|
|
import subprocess # noqa: F401 (used once boltz is wired)
|
|
|
|
# Point every weight downloader at the persistent Volume so the cache survives teardown.
|
|
os.environ["HF_HOME"] = f"{WEIGHTS}/hf" # huggingface_hub cache
|
|
os.environ["TORCH_HOME"] = f"{WEIGHTS}/torch" # torch.hub cache
|
|
boltz_cache = f"{WEIGHTS}/boltz" # boltz --cache target
|
|
os.makedirs(boltz_cache, exist_ok=True)
|
|
|
|
# See what's already cached (run 2+ finds weights here and skips the download).
|
|
weights.reload()
|
|
|
|
# TODO: build boltz input (protein_seq + ligand_smiles), then:
|
|
# subprocess.run(["boltz", "predict", input_yaml, "--use_msa_server",
|
|
# "--cache", boltz_cache, "--out_dir", "/tmp/out"], check=True)
|
|
# parse predicted structure + affinity from /tmp/out.
|
|
|
|
# Persist anything newly downloaded into the cache so the NEXT run reuses it.
|
|
weights.commit()
|
|
raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.")
|
|
|
|
|
|
@app.local_entrypoint()
|
|
def main() -> None:
|
|
"""Phase 1 driver (runs locally; only cofold() touches the GPU).
|
|
|
|
Pulls target sequences + ligand SMILES from the repo, fans out one GPU call per known binder,
|
|
scores redocking RMSD vs the crystal pose locally (spyrmsd), and prints pass/fail. Results are
|
|
tiny — commit a summary into data/processed/binding/.
|
|
"""
|
|
# TODO: load protein sequences from data/raw/structures/<pdb>.pdb (gemmi) and ligand SMILES
|
|
# (PubChem / drug_set), then:
|
|
# results = list(cofold.map(seqs, smiles))
|
|
# and compute in-place spyrmsd RMSD vs the crystal ligand for each.
|
|
print("Scaffold: fill in sequence/SMILES loading + cofold.map, then score RMSD. "
|
|
"See docs/gpu_plan.md.")
|