Roadmap: GPU/TPU-Agnostic Sparse Linear Solvers in JAX¶
Date: 2025-12-20 Status: Draft Authors: Robert Taylor, Claude
Goal¶
Enable efficient sparse linear solving in VAJAX that works across CPU, GPU (NVIDIA, AMD), and TPU, with cached symbolic factorization for Newton-Raphson iterations.
Current State of JAX Sparse Support¶
What Exists¶
| Feature | Status | Source |
|---|---|---|
| BCOO format | Stable | PR #6824 |
| BCSR format | Stable | jax.experimental.sparse |
spsolve |
Experimental | Uses cuSOLVER on GPU, no caching |
gmres, cg, bicgstab |
Experimental | Issue #11376 |
| Sparse matmul (SpMV, SpMM) | Stable | Uses cusparse on GPU |
What's Missing¶
- Cached symbolic factorization - Every
spsolvecall redoes METIS ordering - TPU sparse support - Limited to dense operations
- Sparse triangular solve - Not exposed as primitive
- Pre-factorization API - JAX Issue #22500
Relevant JAX/XLA/MLIR Work¶
Closed/Stale PRs (Potential Revival)¶
| PR | Description | Status | Relevance |
|---|---|---|---|
| #6555 | MLIR/TACO-like sparse representation | Closed (stale) | HIGH - Explored MLIR sparse tensor integration |
| #4422 | Add experimental sparse support | Merged | Foundation for current sparse module |
| #2566 | scipy.sparse.linalg.cg | Merged | Iterative solver baseline |
Active/Recent Work¶
| PR/Issue | Description | Status | Relevance |
|---|---|---|---|
| #25958 | Performance note on sparse docs | Merged (Jan 2025) | Documents current limitations |
| #11376 | Development of scipy.sparse.linalg | Open | Tracking issue for sparse solvers |
| LLVM #151885 | MLIR sparse loop ordering heuristics | Active | 30% speedup on sparse workloads |
MLIR Sparse Tensor Dialect¶
The MLIR Sparse Tensor Dialect provides: - ~40 sparse tensor operations - Multiple storage formats (COO, CSR, CSC, etc.) - Automatic code generation from sparsity-agnostic Linalg ops - GPU codegen (experimental)
Gap: No direct LU factorization or triangular solve ops. Focus is on SpMM/SpMV.
Reference: Compiler Support for Sparse Tensor Computations in MLIR
XLA Sparse Support¶
XLA has limited sparse support:
- SparseTensorDotGeneral for sparse matmul
- Custom calls to cuSPARSE for GPU
- No native sparse LU or triangular solve
Relevant issues: - openxla/xla#6834 - Triangular solve integer overflow (fixed) - openxla/xla#6871 - Fix for above
Proposed Architecture¶
Phase 1: Hybrid Backend (Immediate)¶
┌─────────────────┐
│ VAJAX NR │
│ Solver │
└────────┬────────┘
│
┌──────────────┼──────────────┐
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ Spineax │ │ GMRES │ │ Dense │
│ (cuDSS) │ │ + Prec │ │ BLAS │
└─────────┘ └─────────┘ └─────────┘
GPU TPU/CPU CPU
Implementation: Done (Spineax integration committed)
Phase 2: Pure JAX Iterative Solver (Short-term)¶
Add GMRES with block-Jacobi preconditioner as agnostic fallback:
# vajax/analysis/iterative_solver.py
def build_block_jacobi_preconditioner(J_diag_blocks):
"""Build M^-1 from diagonal blocks of Jacobian."""
inv_blocks = [jnp.linalg.inv(b) for b in J_diag_blocks]
def apply_M_inv(v):
# Apply block-diagonal inverse
return jnp.concatenate([inv @ v_i for inv, v_i in zip(inv_blocks, split(v))])
return apply_M_inv
def gmres_solve(J, b, preconditioner=None, restart=30, maxiter=100):
"""Solve Jx = b using preconditioned GMRES."""
return jax.scipy.sparse.linalg.gmres(
lambda v: J @ v,
b,
M=preconditioner,
restart=restart,
maxiter=maxiter,
)
Effort: 1-2 weeks Benefit: Works on all backends, no custom ops
Phase 3: XLA Sparse Primitive (Medium-term)¶
Add sparse triangular solve to XLA via custom call:
XLA HLO:
SparseLU(sparse_matrix) -> (L, U, perm) # Factorize
SparseTriangularSolve(L, b) -> x # Forward sub
SparseTriangularSolve(U, x) -> y # Back sub
This requires: 1. Define HLO ops for sparse LU components 2. Implement CPU lowering (via SuiteSparse/KLU) 3. Implement GPU lowering (via cuDSS) 4. Add JAX bindings
Effort: 2-3 months (significant XLA contribution) Benefit: Native caching, works with XLA optimizations
Phase 4: MLIR Sparse Integration (Long-term)¶
Leverage MLIR Sparse Tensor Dialect:
- Define sparse LU in Linalg dialect - Express as sparse tensor operations
- Use MLIR sparsifier - Generate optimized code automatically
- Multi-backend codegen - CPU, GPU, TPU from single definition
Reference: MLIR Sparsifier JAX Colab
Effort: 6-12 months (research project) Benefit: True platform independence, compiler-level optimization
Technical Challenges¶
1. Symbolic Factorization Caching in XLA¶
XLA's functional model doesn't naturally support mutable state (cached factorization).
Options: - Use XLA state tokens (like RNG state) - Store factorization as "constant" in compiled program - Use external state via custom call
2. TPU Sparse Support¶
TPUs are optimized for dense, regular computation. Sparse operations have overhead.
Options: - Dense solver for small matrices (our current approach works) - Padded block-sparse for structured sparsity - Iterative solvers (GMRES) which are more TPU-friendly
3. Dynamic Sparsity Patterns¶
Some applications have varying sparsity. Our circuit simulation has fixed patterns.
Options: - Re-analyze on pattern change (expensive but rare) - Approximate with superset pattern - Fall back to iterative solver
Recommendations¶
For VAJAX (Immediate)¶
- ✅ Use Spineax on NVIDIA GPUs (done)
- Add GMRES fallback for TPU/other GPUs
- Keep dense option for small circuits
For JAX Ecosystem (Contribution)¶
- Propose
spsolve_prefactorAPI to JAX (Issue #22500) - Add sparse triangular solve primitive
- Document hybrid sparse strategy
For Long-term Research¶
- Explore MLIR sparsifier for LU
- Work with XLA team on sparse primitive design
- Investigate AMD ROCm sparse solver integration
TODO: TPU and Non-NVIDIA GPU Support¶
Open Issues¶
- [ ] TPU CI failing - PR #1 adds TPU CI but tests are timing out/cancelled
- TPU tests ran for 6+ hours before being cancelled
- Need to investigate: is it a VM provisioning issue or actual test hang?
TPU Support Tasks¶
- [ ] Implement GMRES + block-Jacobi fallback solver
- Create
vajax/analysis/iterative_solver.py - Extract diagonal blocks from Jacobian for preconditioner
- Use
jax.scipy.sparse.linalg.gmreswith matrix-free matvec -
Benchmark on TPU vs dense solver
-
[ ] Test dense solver on TPU
- Verify current dense path works on TPU
- Measure performance for c6288 benchmark
-
Document TPU-specific tuning (e.g., padding for XLA efficiency)
-
[ ] Fix TPU CI workflow
- Debug why TPU tests timeout
- Add proper timeout handling
- Consider running smaller benchmarks for CI
Non-NVIDIA GPU Support Tasks¶
- [ ] AMD ROCm investigation
- Check if JAX
spsolveworks on ROCm - Investigate hipSPARSE/rocSOLVER for cached factorization
-
Look for ROCm equivalent to cuDSS
-
[ ] Intel GPU investigation
- Check JAX Intel plugin status
- Investigate oneMKL sparse solver support
Backend Detection Strategy¶
Current auto-detection in runner.py:
if jax.default_backend() == 'gpu':
try:
from spineax.cudss.solver import CuDSSSolver
use_spineax = True # NVIDIA GPU with Spineax
except ImportError:
use_spineax = False # Non-NVIDIA GPU, fall back to spsolve
elif jax.default_backend() == 'tpu':
# TODO: Use GMRES or dense
pass
else:
# CPU: dense solver is fast enough
pass
Need to extend this to: 1. Detect NVIDIA vs AMD vs Intel GPU 2. Select appropriate solver for each 3. Fall back gracefully when optimal solver unavailable
References¶
JAX Sparse¶
- jax.experimental.sparse docs
- JAX Issue #6544 - Sparse matrices (CLOSED)
- JAX Issue #11376 - scipy.sparse.linalg (OPEN)
- JAX PR #6555 - MLIR/TACO sparse (CLOSED - stale)