490 dependents
| Package | Description | Downloads/month |
|---|---|---|
| Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.... | 1.1M | |
| Orbax provides common checkpointing and persistence utilities for JAX users | 895K | |
| Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax | 272K | |
| 228K | ||
| A Pallas Custom Kernel Library. | 227K | |
| Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https:... | 208K | |
| Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capab... | 180K | |
| a KLU solver for JAX | 149K | |
| Autograd and XLA for S-parameters | 103K | |
| CUDA accelerated rasterization of gaussian splatting | 98K | |
| This API provides programmatic access to the AlphaGenome model developed by Goog... | 89K | |
| An implementation of transformers tailored for mechanistic interpretability. | 89K | |
| Medical imaging processing for AI applications. | 81K | |
| Named Tensors for Legible Deep Learning in JAX | 75K | |
| A Lightweight LLM Post-Training Library | 68K | |
| Interpolation and function approximation with JAX | 56K | |
| Support PyTorch model conversion with LiteRT. | 43K | |
| Open-source deep-learning framework for building, training, and fine-tuning deep... | 40K | |
| Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax | 34K | |
| All-in-one repository for state-of-the-art NeRFs | 34K | |
| Functionalities such as a layers for building neural networks in Jax. | 31K | |
| Implementation of Alphafold 3 from Google Deepmind in Pytorch | 28K | |
| (EasyDel Former) is a utility library designed to simplify and enhance the devel... | 24K | |
| Modular, scalable library to train ML models | 24K | |
| A Python package for probabilistic state space modeling with JAX | 23K | |
| Support PyTorch model conversion with LiteRT. | 21K | |
| Gemma open-weight LLM library, from Google DeepMind | 19K | |
| A high-level functional language for writing mathematically-precise specificatio... | 19K | |
| Generalizable Perception Stack for all things 3D, 4D & Scene Understanding | 18K | |
| Differentiable Ray Tracing Toolbox for Radio Propagation Simulations | 17K | |
| easydel jax kernels writen in triton for gpus and pallas for tpus | 17K | |
| Sequential Least Squares Programming (SLSQP) optimizer implemented in pure JAX | 16K | |
| Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch | 16K | |
| From-scratch C++ and Python reimplementation of the Variational Moments Equilib... | 15K | |
| Unified Training of Universal Time Series Forecasting Transformers | 15K | |
| High-performance quantum systems simulation with JAX (GPU-accelerated & differen... | 14K | |
| Gaussian processes in JAX and Equinox. | 13K | |
| Efficient Differentiable n-d PDE Solvers in JAX. | 13K | |
| Support PyTorch model conversion with LiteRT. | 13K | |
| Neural Emulator Architectures in JAX. | 13K | |
| [Neurips 2024] A benchmark suite for autoregressive neural emulation of PDEs. (≥... | 13K | |
| Neural Emulator Architectures in JAX. | 13K | |
| GLM-HMM and GLM-HMMT tooling for behavioural task analysis. | 12K | |
| Normalizing-flow enhanced sampling package for probabilistic inference in JAX | 10K | |
| Fourier interpolation and function approximation with JAX | 9K | |
| Multiple dispatch over abstract array types in JAX. | 9K | |
| A small library of paramaterizations and parameter constraints for PyTrees. | 8K | |
| The Theory of Functional Connections: A functional interpolation method with app... | 8K | |
| Unitful Quantities in JAX | 8K | |
| Parametric modeling in JAX + Equinox | 7K |