Top Qs
Timeline
Chat
Perspective

JAX (software)

Machine Learning framework designed for parallelization and autograd. From Wikipedia, the free encyclopedia

Remove ads

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.[1][2][3]

It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[4][5] The primary features of JAX are:[6]

  1. Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
  2. Built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
  3. Efficient evaluation of gradients via its automatic differentiation transformations.
  4. Automatically vectorized to efficiently map them over arrays representing batches of inputs.
Remove ads

See also

Remove ads

References

Loading related searches...

Wikiwand - on

Seamless Wikipedia browsing. On steroids.

Remove ads