jit只传入数据的shape,dtype,而不是数据的value
import numpy as np
import time
import jax
from jax import numpy as jnp
from jax import random, device_put, jit, grad, vmap, lax, make_jaxpr
from functools import partial
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
@jit
def m(x):
return x.reshape((np.prod(x.shape),))
x = jnp.ones((2, 3))
m(x)
以上,调用
m(x)
是可行的,不会报错。而调用
f(x)
会报错!!!
版权声明:本文为qq_39909808原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。