AMP: Automatic Modality-Aware Parallelization with Hidden-Dimension Tensor Parallelism for Multi-Modal 3D Biological Models
Kailin Zhang, Hao Zheng, Lang YuanThree-dimensional (3D) spatial interaction data are fundamental to understanding genome architecture. Multi-modal deep learning models that jointly learn from 3D spatial data and orthogonal modalities, such as gene expression, face a critical computational challenge: the 3D spatial modality dominates computation by over one order of magnitude, creating a structural memory bottleneck that renders heavyweight model instances untrainable on single GPU. Existing distributed training methods rely on cost-model searching and treat model components uniformly, overlooking modality-specific memory asymmetries. We propose Automatic Modality-aware Parallelization (AMP), a framework that diagnoses memory bottlenecks from data configuration signals and prescribes a set of five strategies. At the core of this framework is a hidden-dimension tensor parallelism strategy (S5) that partitions the 3D decoder’s hidden dimension across GPUs, transforming five non-standard operators into sharded forms with formal equivalence proofs. Evaluated on Hi-C data and RNA-seq from the HiRES single-cell mouse brain dataset across lightweight and heavyweight configurations, AMP converts out-of-memory (OOM) failures into successful training runs. Scaling from four to eight GPUs under heavyweight configurations, the 500 kb and 100 kb variants achieve 2.0× and 3.8× training speedups respectively, with mathematical equivalence to single GPU computation guaranteed by formal proofs.