Quick JAX 1
Posted on June 2, 2019
Tags: machinelearning
1 Basics
#imports
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax.lax as lax
from jax import make_jaxpr #for AST
DO NOT use python iters because JAX is a functional language. Use the jax-defined function below like:
lax.fori_loop
lax.scan
lax.cond
# lax.fori_loop
= jnp.arange(10)
array = lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)
output print(output) # expected result: 45
# lax.scan
def func11(arr, extra):
= jnp.ones(arr.shape)
ones def body(carry, aelems):
= aelems
ae1, ae2 return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
= func11(jnp.arange(16), 5.)
output print(output)
# expected result: (DeviceArray(200., dtype=float32), DeviceArray([ 0., 5., 11., 18., 26., 35., 45., 56., 68., 81.,
# 95., 110., 126., 143., 161., 180.], dtype=float32, weak_type=True))
# lax.cond
= jnp.array([0.,1.,0.])
array_operand = lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
output print(output)
# expected: result [1. 2. 1.]
1.1 AST
def f(x): return jnp.sin(jnp.cos(x))
= make_jaxpr(f)(3)
ExprAST print(ExprAST)
# { lambda ; a:i32[]. let
# b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
# c:f32[] = cos b
# d:f32[] = sin c
# in (d,) }