Quick JAX 1
Posted on June 2, 2019
Tags: machinelearning
1 Basics
#imports
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax.lax as lax
from jax import make_jaxpr #for ASTDO NOT use python iters because JAX is a functional language. Use the jax-defined function below like:
lax.fori_loop lax.scan lax.cond
# lax.fori_loop
array = jnp.arange(10)
output = lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)
print(output) # expected result: 45
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
output = func11(jnp.arange(16), 5.)
print(output)
# expected result: (DeviceArray(200., dtype=float32), DeviceArray([ 0., 5., 11., 18., 26., 35., 45., 56., 68., 81.,
# 95., 110., 126., 143., 161., 180.], dtype=float32, weak_type=True))
# lax.cond
array_operand = jnp.array([0.,1.,0.])
output = lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
print(output)
# expected: result [1. 2. 1.]1.1 AST
def f(x): return jnp.sin(jnp.cos(x))
ExprAST = make_jaxpr(f)(3)
print(ExprAST)
# { lambda ; a:i32[]. let
# b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
# c:f32[] = cos b
# d:f32[] = sin c
# in (d,) }