热门问题
时间线
聊天
视角

JAX

来自维基百科,自由的百科全书

JAX
Remove ads

JAX,是用於變換數值函數的Python機器學習框架,它由Google開發並具有來自Nvidia的一些貢獻[4][5][6]。它結合了修改版本的Autograd(自動通過函數的微分獲得其梯度函數)[7],和OpenXLA的XLA英語Accelerated Linear Algebra(加速線性代數[8]。它被設計為儘可能的遵從NumPy的結構和工作流程,並協同工作於各種現存的框架如TensorFlowPyTorch[9][10]

快速預覽 開發者, 首次發布 ...
Remove ads

主要功能

JAX的主要功能是[4]

grad

下面的代碼演示grad函數的自動微分。

# 导入库
from jax import grad
import jax.numpy as jnp

# 定义logistic函数
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)

# 求值logistic函数在x = 1处的梯度 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

最終的輸出為:

0.19661194

jit

下面的代碼演示jit函數的優化。

# 导入库
from jax import jit
import jax.numpy as jnp

# 定义cube函数
def cube(x):
    return x * x * x

# 生成数据
x = jnp.ones((10000, 10000))

# 创建cube函数的jit版本
jit_cube = jit(cube)

# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)

可見jit_cube的運行時間顯著的短於cube

vmap

下面的代碼展示vmap函數的通過SIMD的向量化。

# 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp

# 定义函数
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads
Remove ads

pmap

下面的代碼展示pmap函數的對矩陣乘法的並行化。

# 从JAX导入pmap和random;导入JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法 
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值
means = pmap(jnp.mean)(outputs)
print(means)

最終的輸出為:

[1.1566595 1.1805978]
Remove ads

使用JAX的庫

一些Python庫使用JAX作為後端,這包括:

Remove ads

參見

引用

Loading content...

外部連結

Loading related searches...

Wikiwand - on

Seamless Wikipedia browsing. On steroids.

Remove ads