Pytorch中torch.cat()函数的使用及说明

01-03 53阅读 0评论

?=一. torch.cat()函数解析

1. 函数说明

1.1 官网:torch.cat()

函数定义及参数说明如下图所示:

Pytorch中torch.cat()函数的使用及说明

1.2 函数功能

函数将两个张量(Tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐,如下面例子所示。

torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

2. 代码举例

2.1 输入两个二维张量(DIm=0):dim=0对行进行拼接

a = torch.randn(2,3) b =  torch.randn(3,3) c = torch.cat((a,b),dim=0) a,b,c 

输出结果如下:

(tensor([[-0.90, -0.37, 1.96],[-2.65, -0.60, 0.05]]),tensor([[ 1.30, 0.24, 0.27],[-1.99, -1.09, 1.67],[-1.62, 1.54, -0.14]]),tensor([[-0.90, -0.37, 1.96],[-2.65, -0.60, 0.05],[ 1.30, 0.24, 0.27],[-1.99, -1.09, 1.67],[-1.62, 1.54, -0.14]]))

2.2 输入两个二维张量(dim=1): dim=1对列进行拼接

a = torch.randn(2,3) b =  torch.randn(2,4) c = torch.cat((a,b),dim=1) a,b,c 

输出结果如下:

(tensor([[-0.55, -0.84, -1.60],[ 0.39, -0.96, 1.02]]),tensor([[-0.83, -0.09, 0.05, 0.17],[ 0.28, -0.74, -0.27, -0.85]]),tensor([[-0.55, -0.84, -1.60, -0.83, -0.09, 0.05, 0.17],[ 0.39, -0.96, 1.02, 0.28, -0.74, -0.27, -0.85]]))

2.3 输入两个三维张量:dim=0 对通道进行拼接

a = torch.randn(2,3,4) b =  torch.randn(1,3,4) c = torch.cat((a,b),dim=0) a,b,c 

输出结果如下:

(tensor([[[ 0.51, -0.72, -0.02, 0.76], [ 0.72, 1.01, 0.39, -0.13], [ 0.37, -0.63, -2.69, 0.74]],[[ 0.72, -0.31, -0.27, 0.10], [ 1.66, -0.06, 1.91, -0.66], [ 0.34, -0.23, -0.18, -1.22]]]),tensor([[[ 0.94, 0.77, -0.41, -1.20], [-0.23, -1.03, -0.25, 1.67], [-1.00, -0.68, -0.35, -0.50]]]),tensor([[[ 0.51, -0.72, -0.02, 0.76], [ 0.72, 1.01, 0.39, -0.13], [ 0.37, -0.63, -2.69, 0.74]],[[ 0.72, -0.31, -0.27, 0.10], [ 1.66, -0.06, 1.91, -0.66], [ 0.34, -0.23, -0.18, -1.22]],[[ 0.94, 0.77, -0.41, -1.20], [-0.23, -1.03, -0.25, 1.67], [-1.00, -0.68, -0.35, -0.50]]]))

2.4 输入两个三维张量:dim=1对行进行拼接

a = torch.randn(2,3,4) b =  torch.randn(2,4,4) c = torch.cat((a,b),dim=1) a,b,c 

输出结果如下:

(tensor([[[-0.86, 0.00, -1.26, 1.20], [-0.46, -1.08, -0.82, 2.03], [-0.89, 0.43, 1.92, 0.49]],[[ 0.24, -0.02, 0.32, 0.97], [ 0.33, -1.34, 0.76, -1.55], [ 0.38, 1.45, 0.27, -0.64]]]),tensor([[[ 0.82, 0.85, -0.30, -0.58], [-0.09, 0.40, 0.02, 0.75], [-0.70, 0.67, -0.88, -0.50], [-0.62, -1.65, -1.10, -1.39]],[[-0.85, -1.61, -0.35, -0.56], [ 0.00, 1.40, 0.41, 0.39], [-0.01, 0.04, 0.80, 0.41], [-1.21, -0.64, 1.14, 1.64]]]),tensor([[[-0.86, 0.00, -1.26, 1.20], [-0.46, -1.08, -0.82, 2.03], [-0.89, 0.43, 1.92, 0.49], [ 0.82, 0.85, -0.30, -0.58], [-0.09, 0.40, 0.02, 0.75], [-0.70, 0.67, -0.88, -0.50], [-0.62, -1.65, -1.10, -1.39]],[[ 0.24, -0.02, 0.32, 0.97], [ 0.33, -1.34, 0.76, -1.55], [ 0.38, 1.45, 0.27, -0.64], [-0.85, -1.61, -0.35, -0.56], [ 0.00, 1.40, 0.41, 0.39], [-0.01, 0.04, 0.80, 0.41], [-1.21, -0.64, 1.14, 1.64]]]))

2.5 输入两个三维张量:dim=2对列进行拼接

a = torch.randn(2,3,4) b =  torch.randn(2,3,5) c = torch.cat((a,b),dim=2) a,b,c 

输出结果如下:

(tensor([[[ 0.13, -0.02, 0.13, -0.25], [ 1.42, -0.22, -0.87, 0.27], [-0.07, 1.04, -0.06, 0.91]],[[ 0.88, -1.46, 0.04, 0.35], [ 1.36, 0.64, 0.75, 0.39], [ 0.36, 1.13, 0.83, 0.56]]]),tensor([[[-0.47, -2.30, -0.49, -1.02, 1.74], [ 0.71, 0.89, 0.80, -0.05, -1.35], [-0.40, 0.26, -0.78, -1.50, -0.92]],[[-0.77, -0.01, 1.23, 0.70, -0.66], [ 0.28, -0.18, -0.91, 2.23, 1.14], [-1.93, -0.17, 0.15, 0.40, 0.32]]]),tensor([[[ 0.13, -0.02, 0.13, -0.25, -0.47, -2.30, -0.49, -1.02, 1.74], [ 1.42, -0.22, -0.87, 0.27, 0.71, 0.89, 0.80, -0.05, -1.35], [-0.07, 1.04, -0.06, 0.91, -0.40, 0.26, -0.78, -1.50, -0.92]],[[ 0.88, -1.46, 0.04, 0.35, -0.77, -0.01, 1.23, 0.70, -0.66], [ 1.36, 0.64, 0.75, 0.39, 0.28, -0.18, -0.91, 2.23, 1.14], [ 0.36, 1.13, 0.83, 0.56, -1.93, -0.17, 0.15, 0.40, 0.32]]]))

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持云初冀北。

免责声明
本站提供的资源,都来自网络,版权争议与本站无关,所有内容及软件的文章仅限用于学习和研究目的。不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负,我们不保证内容的长久可用性,通过使用本站内容随之而来的风险与本站无关,您必须在下载后的24个小时之内,从您的电脑/手机中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。侵删请致信E-mail:Goliszhou@gmail.com
$

发表评论

表情:
评论列表 (暂无评论,53人围观)

还没有评论,来说两句吧...