Aller au contenu

Est-ce que vous me connaissez ? Je suis JAX

Un framework qui séduit les ingénieurs avec ses performances impressionnantes, une horreur qui vous poussera à redécouvrir les maths, ou un outil assez cool à découvrir ? Rencontrez JAX !

JAX par Google

JAX, JAX, ... ça ne vous dit rien ? Une recherche sur Google vous ressortira un jeu vidéo, un aéroport aux États-Unis, et une bibliothèque de calcul à haute performance conçue dès le départ pour répondre aux besoins du machine learning… Et pourtant, JAX a bien été créé par Google, en 2018.

Mais que signifie réellement l'acronyme JAX ? Et est-ce vraiment un acronyme ?

  • La signification de J restera probablement un mystère, mais il existe plusieurs théories sur ce sujet.
  • A - Autograd, qui permet de différencier automatiquement du code Numpy et Python natif.
  • X - XLA (Accelerated Linear Algebra), un compilateur pour l'algèbre linéaire, qui permet (surprise, surprise !!!) de compiler des expressions algébriques en code de bas niveau et de haute performance sur des accélérateurs tels que les GPU ou les TPU.

Naturellement, le framework a été principalement utilisé par les chercheurs en machine learning. Après la compétition MLPerf en 2020, il a commencé à captiver son public en surpassant même Tensorflow sur certains modèles, et PyTorch de manière significative.

Alors, quelles sont les caractéristiques les plus marquantes de JAX ?

  • numpy - à part quelques exceptions, la prise en main est assez facile, car il permet de manipuler les array de la même manière que dans numpy.
import numpy as np

arr_np = np.array([3, 1, 5, 2, 4])
sorted_arr_np = np.sort(arr_np)
# [1 2 3 4 5]
import jax
import jax.numpy as jnp

arr_jnp = jnp.array([3, 1, 5, 2, 4])
sorted_arr_jnp = jnp.sort(arr_jnp)
# [1 2 3 4 5]
  • Le code JAX est agnostique envers les accélérateurs et peut être exécuté sur le CPU, le GPU ou le TPU, pas besoin de préciser explicitement le périphérique, contrairement à PyTorch, par exemple. Il reste tout de même des packages nécessaires à installer :
pip install jaxlib

# CPU
pip install "jax[cpu]"

# GPU (CUDA)
pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

⚠️ Du fait que JAX se démarque principalement sur les accélérateurs, l'exécution du code sur un CPU provoquera un warning :
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU.

  • grad() - différenciation automatique qui retourne une fonction. Dans la plupart des cas, la différenciation est utilisée afin d'optimiser la "loss function" ou lors de "backpropagation". Dans le but de faciliter les tests, examinons son comportement sur des fonctions polynomiales.
import jax.numpy as jnp

from jax import grad

poly_f1 = lambda x: x**3 + 2*x**2 + 3*x + 7
df1dx = grad(poly_f1) # manual derivative : 3*x**2 + 4*x + 3

x = 2. # grad requires real- or complex-valued inputs
print(df1dx(x)) # 23.0

poly_f2 = lambda x,y: x**3 + 2*x**2 + 3*x*y + 4*y**2 + 7
df2dy = grad(poly_f2, argnums=(1)) # manual derivative : 3*x + 8*y

y = 3.
print(df2dy(x, y)) # 30.0
  • vmap() - vectorize la fonction 🤩. En machine learning, nous ne travaillons jamais avec un seul point séparé. Au moins, il vaut mieux ne pas le faire. Nous avons des arrays, des matrices, et qui sait quels autres éléments dans les boîtes noires du machine learning. Et qui dit tout cela, implique souvent l'utilisation de boucles for, qui peuvent avoir un impact significatif sur les performances. vmap() permet de prendre une fonction qui s'applique à un seul élément et de la transformer en une fonction "vectorisée" qui peut être appliquée à un ensemble d'éléments.
import jax
import jax.numpy as jnp

from jax import vmap

def loss_func(y_true, y_pred):
  return jnp.mean((y_true - y_pred) ** 2)

batch_size = 1000
vec_dim = 4

# no model, so let's test on some random data
true_values = jax.random.normal(jax.random.PRNGKey(42), (batch_size, vec_dim))
predictions = jax.random.normal(jax.random.PRNGKey(43), (batch_size, vec_dim))

# apply loss function over the batch using vmap
losses = vmap(loss_func)(true_values.T, predictions.T)

print(losses) # [1.9995344 1.9299097 2.0429108 2.0652385]

Et la première surprise, il existe une petite différence entre la génération des nombres "aléatoires" avec numpy et jax. JAX utilise des PRNG (pseudorandom number generators) explicitement pour générer des nombres "aléatoires". Nous utilisons jax.random.PRNGKey(0) pour fixer un random state et éliminer toute la randomité résiduelle du processus.

  • @jit ou jit() - Just-in-Time compilateur, qui utilise XLA pour compiler et optimiser le code en fonction des types et des formes de données, lors de la première exécution. Un cache du code est ensuite généré afin d'éviter de le recompiler si les mêmes paramètres sont utilisés lors des appels ultérieurs. Il est à noter que dans certains cas, la compilation avec jit peut s'avérer plus longue qu'une simple exécution, et les avantages seront perceptibles lors des utilisations du cache.
import jax
import jax.numpy as jnp

from jax import jit

def leaky_relu(x, alpha=0.2):
    return jnp.where(x >= 0, x, alpha * x)

x = jnp.array([-1, 0, 1, 2])

%timeit leaky_relu(x).block_until_ready()
# 438 µs ± 13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

leaky_relu_compiled = jit(leaky_relu)

%timeit leaky_relu_compiled(x).block_until_ready()
# 70.2 µs ± 5.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Bien que la fonction d'activation Leaky ReLU ne soit pas très complexe, l'utilisation de jit() nous apporte une amélioration considérable des performances.

La deuxième surprise : JAX utilise une exécution asynchrone. Cela sous-entend qu'il n'attend pas que les opérations soient terminées avant de rendre le contrôle au programme Python. JAX crée un DeviceArray qui n'est pas forcément disponible immédiatement, mais qui peut être transmis à d'autres opérations sans attendre la fin du calcul. Pour cette raison, si nous voulons estimer le temps d'exécution, il est nécessaire d'ajouter block_until_ready() pour indiquer à JAX d'attendre, jusqu'à ce que l'exécution soit complète.

En l'occurrence, JAX n'est pas directement comparable aux frameworks tels que Tensorflow ou PyTorch, car nous exploitons essentiellement ses fonctionnalités de haut niveau. Mais est-ce que cela implique que vos projets pour le week-end ont changé et que vous devez maintenant réviser rapidement les mathématiques et les bases de machine learning pour coder votre prochain modèle, neurone par neurone ? Ce n'est pas nécessairement le cas, parce que JAX ne vient pas seul, mais plutôt accompagné de toute une gamme de bibliothèques :

  • flax, haiku ou equinox pour concevoir des réseaux de neurones ;
  • jraph (prononcé "giraffe") pour les réseaux de neurones graphiques ou GNN (Graph Neural Networks) ;
  • rlax (prononcé "relax") pour implementer des agents d'apprentissage par renforcement reinforcement learning.

Maitenant que vouz connaissez JAX, une question fondamentale se pose : avez-vous envie de l'essayer ? Si la réponse est oui, voici le lien de sa documentation : JAX: High-Performance Array Computing.

Dernier