问题内容
我发现 jax 中的 vmap
在应用于多个参数时不会按预期运行。例如,考虑下面的函数:
def f1(x, y, z):
f = x[:, none, none] * z[none, none, :] + y[none, :, none]
return f
对于 x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
,该函数的输出形状为 (7, 5, 3)
。但是,对于以下 vmap 版本:
@partial(vmap, in_axes=(none, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
f = x*z + y
return f
它输出此错误:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* one axis had size 5: axis 0 of argument y of type int32[5];
* one axis had size 3: axis 0 of argument z of type int32[3]
有人可以解释一下这个错误背后的原因吗?
正确答案
vmap
的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0)
时,含义是“同时沿 y
和 z
的前导维度映射”:您看到的错误告诉您 y
和 y
的前导维度具有不同的大小,因此它们不兼容批处理。
您的函数 f1
本质上使用广播来编码三个批处理操作,因此要使用 vmap
复制该逻辑,您将需要 vmap
的三个应用程序。您可以这样表达:
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
f = x*z + y
return f
以上就是JAX `vmap` 对于多个参数的意外行为的详细内容,更多请关注编程网其它相关文章!
免责声明:
① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。
② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
软考中级精品资料免费领
- 历年真题答案解析
- 备考技巧名师总结
- 高频考点精准押题
- 资料下载
- 历年真题
193.9 KB下载数265
191.63 KB下载数245
143.91 KB下载数1148
183.71 KB下载数642
644.84 KB下载数2756
相关文章
发现更多好内容猜你喜欢
AI推送时光机 咦!没有更多了?去看看其它编程学习网 内容吧