第四章 基于PyTorch的可视化工具
1. 网络结构的可视化
本节我们定义一个简单CNN对手写数据集进行分类,并通过相关的可视化库进行可视化。
模块导入
1 2 3 4 5 6 7 8 import torchimport torch.nn as nnimport torchvisionimport torchvision.utils as vutilsfrom torch.optim import SGDimport torch.utils.data as Datafrom sklearn.metrics import accuracy_scoreimport matplotlib.pyplot as plt
数据准备
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 train_data=torchvision.datasets.MNIST( root="./data/MNIST" , train=True , transform=torchvision.transforms.ToTensor(), download=False ) train_loader=Data.DataLoader( dataset=train_data, batch_size=256 , shuffle=True ) for step,(b_x,b_y) in enumerate (train_loader): if step>0 : break print (b_x.shape)print (b_y.shape)test_data=torchvision.datasets.MNIST( root="./data/MNIST" , train=False , download=False ) test_data_x=test_data.data.type (torch.FloatTensor)/255 test_data_x=torch.unsqueeze(test_data_x,dim=1 ) test_data_y=test_data.targets print ("Test_x" ,test_data_x.shape)print ("Test_y" ,test_data_y.shape)
结果如下:
1 2 3 4 5 6 ''' torch.Size([256, 1, 28, 28]) torch.Size([256]) Test_x torch.Size([10000, 1, 28, 28]) Test_y torch.Size([10000]) '''
网络搭建
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 class ConvNet (nn.Module): def __init__ (self ): super (ConvNet,self).__init__() self.conv1=nn.Sequential( nn.Conv2d(1 ,16 ,3 ,padding=1 ,stride=1 ), nn.ReLU(), nn.AvgPool2d(kernel_size=2 ,stride=2 ) ) self.conv2=nn.Sequential( nn.Conv2d(16 ,32 ,kernel_size=3 ,stride=1 ,padding=1 ), nn.ReLU(), nn.MaxPool2d(2 ,2 ) ) self.mlp=nn.Sequential( nn.Linear( in_features=32 *7 *7 , out_features=128 ), nn.ReLU(), nn.Linear(128 ,64 ), nn.ReLU() ) self.out=nn.Linear(64 ,10 ) def forward (self,x ): x=self.conv2(self.conv1(x)) x=x.view(x.size(0 ),-1 ) return self.out(self.mlp(x)) Conv1=ConvNet() print (Conv1)
HiddenLayer可视化
1 2 3 4 5 6 7 8 import hiddenlayer as hlhl_graph=hl.build_graph(Conv1,torch.zeros([1 ,1 ,28 ,28 ])) hl_graph.theme=hl.graph.THEMES["blue" ].copy() hl_graph.save(r"C:\Users\落花雨\Desktop\0001.png" ,format ="png" )
PyTorchViz可视化
1 2 3 4 5 6 7 8 9 10 from torchviz import make_dotx=torch.randn(1 ,1 ,28 ,28 ).requires_grad_(True ) y=Conv1(x) ConvVis=make_dot(y,params=dict (list (Conv1.named_parameters())+[("x" ,x)])) ConvVis.format ="png" ConvVis.directory=r"C:\Users\落花雨\Desktop" ConvVis.view()
2. 训练过程可视化
2.1 tensorboardX
tensorboardX的部分API如下:
函数
功能
用法
SummaryWriter()
创建编写器,保存日志
writer=SummaryWriter()
writer.add_scalar()
添加标量
writer.add_scalar(‘myscalar’,value,iteration)
writer.add_image()
添加图像
writer.add_image(‘imresult’,x,itertation)
writer.add_histogram()
添加直方图
writer.add_histogram(‘hist’,array,iteration)
writer.add_graph()
添加网络结构
writer.add_graph(model,input_to_model=None)
writer.add_audio()
添加音频
add_audio(tag,audio,iteration,sample_rate)
writer.add_text()
添加文本
writer.add_text(tag,text_string,global_step=None)
以下案例展现了tensorboard的使用方式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 Sumwriter=SummaryWriter(log_dir="data/log/chap4" ) optimizer=torch.optim.Adam(CNN1.parameters(),lr=0.0003 ) loss_func=nn.CrossEntropyLoss() train_loss=0 print_step=100 for epoch in range (5 ): for step,(b_x,b_y) in enumerate (train_loader): output=CNN1(b_x) loss=loss_func(output,b_y) optimizer.zero_grad() optimizer.step() train_loss+=loss niter=epoch*len (train_loader)+step+1 if niter%print_step==0 : Sumwriter.add_scalar(tag="train_loss" ,scalar_value=train_loss.item()/niter, global_step=niter) output=CNN1(test_data_x) _,pre_lab=torch.max (output,1 ) acc=accuracy_score(test_data_y,pre_lab) Sumwriter.add_scalar(tag="test acc" ,scalar_value=acc.item(),global_step=niter) b_x_im=vutils.make_grid(b_x,nrow=12 ) Sumwriter.add_image(tag="train image sample" ,img_tensor=b_x_im,global_step=niter) for name,param in CNN1.named_parameters(): Sumwriter.add_histogram(name,param.data.numpy(),niter)
3. Visdom可视化
Visdom是Facebook专门为PyTorch开发的一款可视化工具,可视化的同时支持Tensor和Numpy两种数据格式。
函数
功能
vis.image
可视化一张图像
vis.images
可视化一个batch的图像,或者一个图像列表
vis.text
可视化文本
vis.audio
播放音频
vis.video
播放视频
vis.matplot
可视化matplotlib的图像
vis.scatter
2D或3D的散点图
vis.line
线图
vis.stem
茎叶图
vis.hearmap
热力图
vis.bar
条形图
vis.histogram
直方图
vis.boxplot
盒形图
vis.surf
曲面图
vis.contour
等高线图
vis.quiver
箭头图
vis.mesh
网格图
Visdom服务器需要先挂起
以下使用一个案例来运用Visdom
1 2 3 4 5 6 from visdom import Visdomfrom sklearn.datasets import load_irisix,iy=load_iris(return_X_y=True )
散点图
1 2 3 4 5 6 7 vis=Visdom() vis.scatter(ix[:,0 :2 ],Y=iy+1 ,win="windows1" ,env="main" ) vis.scatter(ix[:,0 :3 ],Y=iy+1 ,win="3D 散点图" ,env="main" , opts=dict (markersize=4 , xlabel="特征1" , ylabel="特征2" ))
折线图
1 2 3 4 5 6 7 8 9 10 11 12 13 14 x=torch.linspace(-6 ,6 ,100 ).view((-1 ,1 )) sigmoid=torch.nn.Sigmoid() sigmoidy=sigmoid(x) tanh=torch.nn.Tanh() tanhy=tanh(x) relu=torch.nn.ReLU() reluy=relu(x) ploty=torch.cat((sigmoidy,tanhy,reluy),dim=1 ) plotx=torch.cat((x,x,x),dim=1 ) vis.line(Y=ploty,X=plotx,win="line plot" ,env="main" , opts=dict (dash=np.array(["solid" ,"dash" ,"dashdot" ]), legend=['sigmoid' ,'tanh' ,'relu' ]))
茎叶图
1 2 3 4 5 6 7 8 9 10 x=torch.linspace(-6 ,6 ,100 ).view((-1 ,1 )) y1=torch.sin(x) y2=torch.cos(x) plotx=torch.cat((y1,y2),1 ) ploty=torch.cat((x,x),1 ) vis.stem(X=plotx,Y=ploty,win="stem plot" ,env="main" , opts=dict (legend=['sin' ,'cos' ],title="茎叶图" ))
热力图
1 2 3 4 5 6 7 8 irs_cor=torch.from_numpy(np.corrcoef(ix,rowvar=False )) vis.heatmap(irs_cor,win='heatmap' ,env="main" ,opts=dict ( rownames=["x1" ,"x2" ,"x3" ,"x4" ], columnnames=["x1" ,"x2" ,"x3" ,"x4" ], title="热力图" ))
图片
1 2 3 4 5 6 7 8 9 10 11 12 13 14 bx=0 for step,(x,by) in enumerate (train_loader): bx=x if step>0 : break vis.image(bx[0 ,...],win="one image" ,env="Myimage" ,opts=dict ( title="图片" )) vis.images(bx,win="batch image" ,env="Myimage" ,nrow=16 ,opts=dict ( title="一批图片" ))
文本
1 2 3 4 5 6 7 text=''' asdsadasfhoisauew ''' vis.text(text,win="text plot" ,env="Myimage" ,opts=dict ( title="文本" ))