diff --git a/docs/gpu_plan.md b/docs/gpu_plan.md index 0741261..b72431b 100644 --- a/docs/gpu_plan.md +++ b/docs/gpu_plan.md @@ -17,6 +17,27 @@ the inputs are *tiny*, so the design optimises for zero idle cost, not for a per The 27 GB LINCS data is **not** part of this track — nothing big to upload. The only thing worth persisting is the model-weights cache (so we don't re-download = re-pay GPU time every run). +## How the model weights persist (the cost-saver) + +A `modal.Volume` is a **named, cloud-backed filesystem that lives independently of any container +or GPU** — it survives every teardown. Mounted into the function at `/weights`: + +- **Run 1:** `/weights` is empty → the model downloads weights there (the one-time slow cost). +- **Run 2+:** the same Volume mounts with the files already present → download skipped → **no + GPU-billed seconds wasted re-fetching 5 GB.** + +Two things make it actually cache: +1. **Point the downloader at the mount** (weights only persist if written under `/weights`): + `HF_HOME=/weights/hf` (HuggingFace), `TORCH_HOME=/weights/torch`, `boltz --cache /weights/boltz`. +2. **Commit semantics:** writes persist on `weights.commit()` (modern Modal also auto-commits on a + clean exit); other containers see them after `weights.reload()`. Pattern: `reload()` → run → + `commit()`. + +The Volume itself costs pennies (~$/GB-month of storage), *separate from the GPU* — so caching ~5 GB +of weights is near-free and saves real GPU time on every subsequent run. +(Alternative: bake weights into the image at build time via `image.run_function(download)` — fastest +cold start, but the image rebuilds when weights change. The skeleton uses the Volume approach.) + ## Provider choice | Option | Billing | Idle cost | "Kill" model | Best for | diff --git a/gpu/modal_app.py b/gpu/modal_app.py index 94b5a6f..f735828 100644 --- a/gpu/modal_app.py +++ b/gpu/modal_app.py @@ -37,17 +37,37 @@ KNOWN = { } -@app.function(gpu="L4", image=image, volumes={"/weights": weights}, timeout=3600) -def cofold(protein_seq: str, ligand_smiles: str, weights_dir: str = "/weights") -> dict: +# Cache locations on the persistent Volume — the model downloads here ONCE and reuses forever. +WEIGHTS = "/weights" + + +@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600) +def cofold(protein_seq: str, ligand_smiles: str) -> dict: """Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string). - Runs on the GPU only for this call, then the GPU is released. TODO: replace the stub with the - actual Boltz-2 invocation (write the YAML/FASTA input spec, call `boltz predict - --use_msa_server --out_dir ... --cache /weights`, parse the predicted structure + affinity). + Runs on the GPU only for this call, then the GPU is released. Model weights persist on the + mounted Volume across runs (see HF_HOME / --cache below), so we never re-pay GPU time to + re-download them. """ + import os import subprocess # noqa: F401 (used once boltz is wired) - # TODO: build boltz input (protein_seq + ligand_smiles), run, parse pose+affinity. + # Point every weight downloader at the persistent Volume so the cache survives teardown. + os.environ["HF_HOME"] = f"{WEIGHTS}/hf" # huggingface_hub cache + os.environ["TORCH_HOME"] = f"{WEIGHTS}/torch" # torch.hub cache + boltz_cache = f"{WEIGHTS}/boltz" # boltz --cache target + os.makedirs(boltz_cache, exist_ok=True) + + # See what's already cached (run 2+ finds weights here and skips the download). + weights.reload() + + # TODO: build boltz input (protein_seq + ligand_smiles), then: + # subprocess.run(["boltz", "predict", input_yaml, "--use_msa_server", + # "--cache", boltz_cache, "--out_dir", "/tmp/out"], check=True) + # parse predicted structure + affinity from /tmp/out. + + # Persist anything newly downloaded into the cache so the NEXT run reuses it. + weights.commit() raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.")