GPU plan: make weight persistence concrete (Modal Volume cache)
Document and wire the weight-caching mechanism: - modal.Volume is a cloud-backed FS independent of the GPU/container; run 1 downloads weights into /weights, run 2+ reuses them (no GPU time wasted re-downloading). - Point downloaders at the mount: HF_HOME/TORCH_HOME/boltz --cache; persist via weights.commit(), see updates via weights.reload(). - Volume storage costs pennies, separate from GPU = near-free caching. modal_app.py cofold(): set cache env vars to /weights, reload()/commit() around the (stubbed) boltz call. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,27 @@ the inputs are *tiny*, so the design optimises for zero idle cost, not for a per
|
|||||||
The 27 GB LINCS data is **not** part of this track — nothing big to upload. The only thing worth
|
The 27 GB LINCS data is **not** part of this track — nothing big to upload. The only thing worth
|
||||||
persisting is the model-weights cache (so we don't re-download = re-pay GPU time every run).
|
persisting is the model-weights cache (so we don't re-download = re-pay GPU time every run).
|
||||||
|
|
||||||
|
## How the model weights persist (the cost-saver)
|
||||||
|
|
||||||
|
A `modal.Volume` is a **named, cloud-backed filesystem that lives independently of any container
|
||||||
|
or GPU** — it survives every teardown. Mounted into the function at `/weights`:
|
||||||
|
|
||||||
|
- **Run 1:** `/weights` is empty → the model downloads weights there (the one-time slow cost).
|
||||||
|
- **Run 2+:** the same Volume mounts with the files already present → download skipped → **no
|
||||||
|
GPU-billed seconds wasted re-fetching 5 GB.**
|
||||||
|
|
||||||
|
Two things make it actually cache:
|
||||||
|
1. **Point the downloader at the mount** (weights only persist if written under `/weights`):
|
||||||
|
`HF_HOME=/weights/hf` (HuggingFace), `TORCH_HOME=/weights/torch`, `boltz --cache /weights/boltz`.
|
||||||
|
2. **Commit semantics:** writes persist on `weights.commit()` (modern Modal also auto-commits on a
|
||||||
|
clean exit); other containers see them after `weights.reload()`. Pattern: `reload()` → run →
|
||||||
|
`commit()`.
|
||||||
|
|
||||||
|
The Volume itself costs pennies (~$/GB-month of storage), *separate from the GPU* — so caching ~5 GB
|
||||||
|
of weights is near-free and saves real GPU time on every subsequent run.
|
||||||
|
(Alternative: bake weights into the image at build time via `image.run_function(download)` — fastest
|
||||||
|
cold start, but the image rebuilds when weights change. The skeleton uses the Volume approach.)
|
||||||
|
|
||||||
## Provider choice
|
## Provider choice
|
||||||
|
|
||||||
| Option | Billing | Idle cost | "Kill" model | Best for |
|
| Option | Billing | Idle cost | "Kill" model | Best for |
|
||||||
|
|||||||
@@ -37,17 +37,37 @@ KNOWN = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.function(gpu="L4", image=image, volumes={"/weights": weights}, timeout=3600)
|
# Cache locations on the persistent Volume — the model downloads here ONCE and reuses forever.
|
||||||
def cofold(protein_seq: str, ligand_smiles: str, weights_dir: str = "/weights") -> dict:
|
WEIGHTS = "/weights"
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(gpu="L4", image=image, volumes={WEIGHTS: weights}, timeout=3600)
|
||||||
|
def cofold(protein_seq: str, ligand_smiles: str) -> dict:
|
||||||
"""Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string).
|
"""Co-fold one protein+ligand complex and return predicted affinity + pose (PDB string).
|
||||||
|
|
||||||
Runs on the GPU only for this call, then the GPU is released. TODO: replace the stub with the
|
Runs on the GPU only for this call, then the GPU is released. Model weights persist on the
|
||||||
actual Boltz-2 invocation (write the YAML/FASTA input spec, call `boltz predict
|
mounted Volume across runs (see HF_HOME / --cache below), so we never re-pay GPU time to
|
||||||
--use_msa_server --out_dir ... --cache /weights`, parse the predicted structure + affinity).
|
re-download them.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import subprocess # noqa: F401 (used once boltz is wired)
|
import subprocess # noqa: F401 (used once boltz is wired)
|
||||||
|
|
||||||
# TODO: build boltz input (protein_seq + ligand_smiles), run, parse pose+affinity.
|
# Point every weight downloader at the persistent Volume so the cache survives teardown.
|
||||||
|
os.environ["HF_HOME"] = f"{WEIGHTS}/hf" # huggingface_hub cache
|
||||||
|
os.environ["TORCH_HOME"] = f"{WEIGHTS}/torch" # torch.hub cache
|
||||||
|
boltz_cache = f"{WEIGHTS}/boltz" # boltz --cache target
|
||||||
|
os.makedirs(boltz_cache, exist_ok=True)
|
||||||
|
|
||||||
|
# See what's already cached (run 2+ finds weights here and skips the download).
|
||||||
|
weights.reload()
|
||||||
|
|
||||||
|
# TODO: build boltz input (protein_seq + ligand_smiles), then:
|
||||||
|
# subprocess.run(["boltz", "predict", input_yaml, "--use_msa_server",
|
||||||
|
# "--cache", boltz_cache, "--out_dir", "/tmp/out"], check=True)
|
||||||
|
# parse predicted structure + affinity from /tmp/out.
|
||||||
|
|
||||||
|
# Persist anything newly downloaded into the cache so the NEXT run reuses it.
|
||||||
|
weights.commit()
|
||||||
raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.")
|
raise NotImplementedError("Wire Boltz-2 here; see docs/gpu_plan.md Phase 1.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user