我有一串torch tensor,shape都是一样的,假设都是5x3的,我怎么对它们求均值呢?就是各个位置的均值,最后得到的还是一个5x3的torch tensor。
假设我有四个tensor,t1, t2, t3, t4。我进行下面的操作
torch.mean([t1, t2, t3, t4], 0)
结果报错
TypeError: mean() received an invalid combination of arguments - got (list, int), but expected one of:
* (Tensor input)
* (Tensor input, torch.dtype dtype)
didn't match because some of the arguments have invalid types: (list, int)
* (Tensor input, int dim, torch.dtype dtype, Tensor out)
* (Tensor input, int dim, bool keepdim, torch.dtype dtype, Tensor out)
* (Tensor input, int dim, bool keepdim, Tensor out)
1个回答
需要stack之后再mean
torch.mean(torch.stack([t1, t2, t3, t4]), 0)
SofaSofa数据科学社区DS面试题库 DS面经
多谢大神
-
zzzz
2019-05-24 10:06