Skip to content

Roadmap: GPU/TPU-Agnostic Sparse Linear Solvers in JAX

Date: 2025-12-20 Status: Draft Authors: Robert Taylor, Claude

Goal

Enable efficient sparse linear solving in VAJAX that works across CPU, GPU (NVIDIA, AMD), and TPU, with cached symbolic factorization for Newton-Raphson iterations.

Current State of JAX Sparse Support

What Exists

Feature Status Source
BCOO format Stable PR #6824
BCSR format Stable jax.experimental.sparse
spsolve Experimental Uses cuSOLVER on GPU, no caching
gmres, cg, bicgstab Experimental Issue #11376
Sparse matmul (SpMV, SpMM) Stable Uses cusparse on GPU

What's Missing

  1. Cached symbolic factorization - Every spsolve call redoes METIS ordering
  2. TPU sparse support - Limited to dense operations
  3. Sparse triangular solve - Not exposed as primitive
  4. Pre-factorization API - JAX Issue #22500

Relevant JAX/XLA/MLIR Work

Closed/Stale PRs (Potential Revival)

PR Description Status Relevance
#6555 MLIR/TACO-like sparse representation Closed (stale) HIGH - Explored MLIR sparse tensor integration
#4422 Add experimental sparse support Merged Foundation for current sparse module
#2566 scipy.sparse.linalg.cg Merged Iterative solver baseline

Active/Recent Work

PR/Issue Description Status Relevance
#25958 Performance note on sparse docs Merged (Jan 2025) Documents current limitations
#11376 Development of scipy.sparse.linalg Open Tracking issue for sparse solvers
LLVM #151885 MLIR sparse loop ordering heuristics Active 30% speedup on sparse workloads

MLIR Sparse Tensor Dialect

The MLIR Sparse Tensor Dialect provides: - ~40 sparse tensor operations - Multiple storage formats (COO, CSR, CSC, etc.) - Automatic code generation from sparsity-agnostic Linalg ops - GPU codegen (experimental)

Gap: No direct LU factorization or triangular solve ops. Focus is on SpMM/SpMV.

Reference: Compiler Support for Sparse Tensor Computations in MLIR

XLA Sparse Support

XLA has limited sparse support: - SparseTensorDotGeneral for sparse matmul - Custom calls to cuSPARSE for GPU - No native sparse LU or triangular solve

Relevant issues: - openxla/xla#6834 - Triangular solve integer overflow (fixed) - openxla/xla#6871 - Fix for above

Proposed Architecture

Phase 1: Hybrid Backend (Immediate)

                    ┌─────────────────┐
                    │  VAJAX NR   │
                    │    Solver       │
                    └────────┬────────┘
              ┌──────────────┼──────────────┐
              │              │              │
              ▼              ▼              ▼
        ┌─────────┐   ┌─────────┐   ┌─────────┐
        │ Spineax │   │  GMRES  │   │  Dense  │
        │ (cuDSS) │   │ + Prec  │   │  BLAS   │
        └─────────┘   └─────────┘   └─────────┘
            GPU          TPU/CPU       CPU

Implementation: Done (Spineax integration committed)

Phase 2: Pure JAX Iterative Solver (Short-term)

Add GMRES with block-Jacobi preconditioner as agnostic fallback:

# vajax/analysis/iterative_solver.py

def build_block_jacobi_preconditioner(J_diag_blocks):
    """Build M^-1 from diagonal blocks of Jacobian."""
    inv_blocks = [jnp.linalg.inv(b) for b in J_diag_blocks]
    def apply_M_inv(v):
        # Apply block-diagonal inverse
        return jnp.concatenate([inv @ v_i for inv, v_i in zip(inv_blocks, split(v))])
    return apply_M_inv

def gmres_solve(J, b, preconditioner=None, restart=30, maxiter=100):
    """Solve Jx = b using preconditioned GMRES."""
    return jax.scipy.sparse.linalg.gmres(
        lambda v: J @ v,
        b,
        M=preconditioner,
        restart=restart,
        maxiter=maxiter,
    )

Effort: 1-2 weeks Benefit: Works on all backends, no custom ops

Phase 3: XLA Sparse Primitive (Medium-term)

Add sparse triangular solve to XLA via custom call:

XLA HLO:
  SparseLU(sparse_matrix) -> (L, U, perm)  # Factorize
  SparseTriangularSolve(L, b) -> x         # Forward sub
  SparseTriangularSolve(U, x) -> y         # Back sub

This requires: 1. Define HLO ops for sparse LU components 2. Implement CPU lowering (via SuiteSparse/KLU) 3. Implement GPU lowering (via cuDSS) 4. Add JAX bindings

Effort: 2-3 months (significant XLA contribution) Benefit: Native caching, works with XLA optimizations

Phase 4: MLIR Sparse Integration (Long-term)

Leverage MLIR Sparse Tensor Dialect:

  1. Define sparse LU in Linalg dialect - Express as sparse tensor operations
  2. Use MLIR sparsifier - Generate optimized code automatically
  3. Multi-backend codegen - CPU, GPU, TPU from single definition

Reference: MLIR Sparsifier JAX Colab

Effort: 6-12 months (research project) Benefit: True platform independence, compiler-level optimization

Technical Challenges

1. Symbolic Factorization Caching in XLA

XLA's functional model doesn't naturally support mutable state (cached factorization).

Options: - Use XLA state tokens (like RNG state) - Store factorization as "constant" in compiled program - Use external state via custom call

2. TPU Sparse Support

TPUs are optimized for dense, regular computation. Sparse operations have overhead.

Options: - Dense solver for small matrices (our current approach works) - Padded block-sparse for structured sparsity - Iterative solvers (GMRES) which are more TPU-friendly

3. Dynamic Sparsity Patterns

Some applications have varying sparsity. Our circuit simulation has fixed patterns.

Options: - Re-analyze on pattern change (expensive but rare) - Approximate with superset pattern - Fall back to iterative solver

Recommendations

For VAJAX (Immediate)

  1. ✅ Use Spineax on NVIDIA GPUs (done)
  2. Add GMRES fallback for TPU/other GPUs
  3. Keep dense option for small circuits

For JAX Ecosystem (Contribution)

  1. Propose spsolve_prefactor API to JAX (Issue #22500)
  2. Add sparse triangular solve primitive
  3. Document hybrid sparse strategy

For Long-term Research

  1. Explore MLIR sparsifier for LU
  2. Work with XLA team on sparse primitive design
  3. Investigate AMD ROCm sparse solver integration

TODO: TPU and Non-NVIDIA GPU Support

Open Issues

  • [ ] TPU CI failing - PR #1 adds TPU CI but tests are timing out/cancelled
  • TPU tests ran for 6+ hours before being cancelled
  • Need to investigate: is it a VM provisioning issue or actual test hang?

TPU Support Tasks

  • [ ] Implement GMRES + block-Jacobi fallback solver
  • Create vajax/analysis/iterative_solver.py
  • Extract diagonal blocks from Jacobian for preconditioner
  • Use jax.scipy.sparse.linalg.gmres with matrix-free matvec
  • Benchmark on TPU vs dense solver

  • [ ] Test dense solver on TPU

  • Verify current dense path works on TPU
  • Measure performance for c6288 benchmark
  • Document TPU-specific tuning (e.g., padding for XLA efficiency)

  • [ ] Fix TPU CI workflow

  • Debug why TPU tests timeout
  • Add proper timeout handling
  • Consider running smaller benchmarks for CI

Non-NVIDIA GPU Support Tasks

  • [ ] AMD ROCm investigation
  • Check if JAX spsolve works on ROCm
  • Investigate hipSPARSE/rocSOLVER for cached factorization
  • Look for ROCm equivalent to cuDSS

  • [ ] Intel GPU investigation

  • Check JAX Intel plugin status
  • Investigate oneMKL sparse solver support

Backend Detection Strategy

Current auto-detection in runner.py:

if jax.default_backend() == 'gpu':
    try:
        from spineax.cudss.solver import CuDSSSolver
        use_spineax = True  # NVIDIA GPU with Spineax
    except ImportError:
        use_spineax = False  # Non-NVIDIA GPU, fall back to spsolve
elif jax.default_backend() == 'tpu':
    # TODO: Use GMRES or dense
    pass
else:
    # CPU: dense solver is fast enough
    pass

Need to extend this to: 1. Detect NVIDIA vs AMD vs Intel GPU 2. Select appropriate solver for each 3. Fall back gracefully when optimal solver unavailable

References

JAX Sparse

MLIR Sparse

GPU Sparse Solvers

Iterative Solvers