求多个torch tensor的平均值

  统计/机器学习 Python    浏览次数:16843        分享
0

我有一串torch tensor,shape都是一样的,假设都是5x3的,我怎么对它们求均值呢?就是各个位置的均值,最后得到的还是一个5x3的torch tensor。

假设我有四个tensor,t1, t2, t3, t4。我进行下面的操作

torch.mean([t1, t2, t3, t4], 0)

结果报错

TypeError: mean() received an invalid combination of arguments - got (list, int), but expected one of:
 * (Tensor input)
 * (Tensor input, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (list, int)
 * (Tensor input, int dim, torch.dtype dtype, Tensor out)
 * (Tensor input, int dim, bool keepdim, torch.dtype dtype, Tensor out)
 * (Tensor input, int dim, bool keepdim, Tensor out)


 

zzzz   2019-05-23 11:45



   1个回答 
3

需要stack之后再mean

torch.mean(torch.stack([t1, t2, t3, t4]), 0)
SofaSofa数据科学社区DS面试题库 DS面经

xiaosu   2019-05-23 14:53

多谢大神 - zzzz   2019-05-24 10:06


  相关讨论

python怎么只保留某列前几个数字

获取DataFrame所占空间的大小

如何用python获取一个文件的最后修改时间?

map lambda if逗号报错,invalid syntax

python里@property有什么用

怎么check dataframe 中的某个元素是否字符串?

怎么合并(串联)两个dataframe?

用python生成一个取值在a到b之间的随机矩阵

python怎么去除字符串中的连字符?

python里的<<或者>>符号是什么意思?

  随便看看

怎么利用permutation importance来解释xgboost模型的特征?

sklearn中的predict_proba方法的返回值的意义

plt.show()之后matplotlib图像依然不展示

如何获取pyspark DataFrame的行数和列数?

训练神经网络中经常提到的epoch是什么意思