Quick JAX 1

Posted on June 2, 2019
Tags: machinelearning

1 Basics

#imports
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax.lax as lax

from jax import make_jaxpr #for AST

DO NOT use python iters because JAX is a functional language. Use the jax-defined function below like: lax.fori_loop lax.scan lax.cond

# lax.fori_loop
array = jnp.arange(10)
output = lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)
print(output) # expected result: 45

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
output = func11(jnp.arange(16), 5.)
print(output)
# expected result: (DeviceArray(200., dtype=float32), DeviceArray([  0.,   5.,  11.,  18.,  26.,  35.,  45.,  56.,  68.,  81.,
# 95., 110., 126., 143., 161., 180.],            dtype=float32, weak_type=True))


# lax.cond
array_operand = jnp.array([0.,1.,0.])
output = lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
print(output)
# expected: result [1. 2. 1.]

1.1 AST

def f(x): return jnp.sin(jnp.cos(x))
ExprAST = make_jaxpr(f)(3)
print(ExprAST)
# { lambda ; a:i32[]. let
#     b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
#     c:f32[] = cos b
#     d:f32[] = sin c
#   in (d,) }