1,713 dependents
| Package | Description | Downloads/month |
|---|---|---|
| Orbax provides common checkpointing and persistence utilities for JAX users | 11.1M | |
| Flax is a neural network library for JAX that is designed for flexibility. | 5.4M | |
| 4.4M | ||
| Optax is a gradient processing and optimization library for JAX. | 4.1M | |
| Chex: Testing made fun, in JAX! | 2.2M | |
| Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.... | 1.1M | |
| Orbax provides common checkpointing and persistence utilities for JAX users | 895K | |
| CLU lets you write beautiful training loops in JAX. | 830K | |
| Accurate Quantized Training library. | 581K | |
| Probabilistic programming with NumPy powered by JAX for autograd and JIT compila... | 577K | |
| Hardware accelerated, batchable and differentiable optimizers in JAX. | 554K | |
| a Jax quantization library | 494K | |
| Package of Pathways-on-Cloud utilities | 410K | |
| Task-based datasets, preprocessing, and evaluation for sequence models. | 347K | |
| Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax | 272K | |
| Mathematical operations for JAX pytrees | 239K | |
| jax library for E3 Equivariant Neural Networks | 237K | |
| 228K | ||
| A Pallas Custom Kernel Library. | 227K | |
| JAX compatible datetime and timedelta types | 221K | |
| Task-based datasets, preprocessing, and evaluation for sequence models. | 216K | |
| Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https:... | 208K | |
| Distrax: Probability distributions in JAX. | 206K | |
| Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capab... | 180K | |
| DrJAX - Scalable and Differentiable MapReduce Primitives in JAX. | 160K | |
| a KLU solver for JAX | 149K | |
| Multi-Joint dynamics with Contact. A general purpose physics simulator. | 138K | |
| BlackJAX is a Bayesian Inference library designed for ease of use, speed and mod... | 135K | |
| GDSFactory+: adds powerful features such as foundry PDKs, simulations, and verif... | 117K | |
| Autograd and XLA for S-parameters | 103K | |
| Differentiable, Hardware Accelerated, Molecular Dynamics | 95K | |
| Kaggle Environments | 87K | |
| [ICML'26] Phonon fine-tuning (PFT) and [NeurIPS'25 AI4Mat] Nequix: Training a fo... | 85K | |
| Pytrees + dataclasses ❤️ | 76K | |
| Named Tensors for Legible Deep Learning in JAX | 75K | |
| Efficiently Composable Data Augmentation on the GPU with Jax | 75K | |
| Open source code for AlphaFold. | 73K | |
| Differentiable neuron simulations. | 69K | |
| Solvers for tridiagonal systems in JAX. | 66K | |
| Rigid transforms + Lie groups for JAX | 66K | |
| Rax is a Learning-to-Rank library written in JAX. | 65K | |
| Interpolation and function approximation with JAX | 56K | |
| Massively parallel rigidbody physics simulation on accelerator hardware. | 53K | |
| Building a single-cell transcriptome-based coordinate system | 49K | |
| RL Environments in JAX 🌍 | 48K | |
| Support PyTorch model conversion with LiteRT. | 43K | |
| Extending JAX with xDSL. | 39K | |
| export JAX to ONNX | 37K | |
| A library of reinforcement learning building blocks in JAX. | 37K | |
| A JAX-native High Performance Eval Metrics Library | 35K |