Google JAX
Biblioteca Python From Wikipedia, the free encyclopedia
Remove ads
Google JAX és un marc d'aprenentatge automàtic per transformar funcions numèriques.[1][2] Es descriu com reunir una versió modificada d'autograd (obtenció automàtica de la funció de gradient mitjançant la diferenciació d'una funció) i XLA de TensorFlow (àlgebra lineal accelerada). Està dissenyat per seguir l'estructura i el flux de treball de NumPy tan de prop com sigui possible i funciona amb diversos marcs existents com TensorFlow i PyTorch.[3][4] Les funcions principals de JAX són:
- grau: diferenciació automàtica
- jit: compilació
- vmap: vectorització automàtica
- pmap: programació SPMD
Remove ads
Funció grau
El codi següent mostra la diferenciació automàtica de la funció de graduació .
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
Remove ads
Funcio jit
El codi següent mostra l'optimització de la funció jit mitjançant la fusió.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
Remove ads
Funció vmap
El codi següent mostra la vectorització de la funció vmap.
# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
Funció pmap
El codi següent mostra la paral·lelització de la funció pmap per a la multiplicació de matrius.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
Remove ads
Biblioteques que utilitzen JAX
Diverses biblioteques de Python utilitzen JAX com a backend, incloent:
- Flax, una biblioteca de xarxes neuronals d'alt nivell desenvolupada inicialment per Google Brain.
- Equinox, una biblioteca que gira al voltant de la idea de representar funcions parametritzades (incloses les xarxes neuronals) com a PyTrees. Va ser creat per Patrick Kidger.
- Diffrax, una biblioteca per a la solució numèrica d'equacions diferencials, com ara equacions diferencials ordinàries i equacions diferencials estocàstiques.
- Optax, una biblioteca per al processament i optimització de gradients desenvolupada per DeepMind.
- Lineax, una biblioteca per resoldre numèricament sistemes lineals i mínims quadrats lineals.
- RLax, una biblioteca per desenvolupar agents d'aprenentatge de reforç desenvolupada per DeepMind.
- jraph, una biblioteca per a xarxes neuronals gràfics, desenvolupada per DeepMind.
- jaxtyping, una biblioteca per afegir anotacions de tipus per a la forma i el tipus de dades ("dtype") de matrius o tensors.
Remove ads
Referències
Wikiwand - on
Seamless Wikipedia browsing. On steroids.
Remove ads
