JAX中jit使用bug

  • Post author:
  • Post category:其他




jit只传入数据的shape,dtype,而不是数据的value

 import numpy as np
 import time
 import jax
 from jax import numpy as jnp
 from jax import random, device_put, jit, grad, vmap, lax, make_jaxpr
 from functools import partial
  
 @jit
 def f(x):
      return x.reshape(jnp.array(x.shape).prod())
  
 @jit
 def m(x):
      return x.reshape((np.prod(x.shape),))
  
 x = jnp.ones((2, 3))
 m(x)

以上,调用

m(x)

是可行的,不会报错。而调用

f(x)

会报错!!!

在这里插入图片描述



版权声明:本文为qq_39909808原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。