torchvision源码解读之transform.docx
《torchvision源码解读之transform.docx》由会员分享,可在线阅读,更多相关《torchvision源码解读之transform.docx(21页珍藏版)》请在冰豆网上搜索。
![torchvision源码解读之transform.docx](https://file1.bdocx.com/fileroot1/2023-2/1/6d83ce00-0ba4-4495-b2e7-8561f72de01a/6d83ce00-0ba4-4495-b2e7-8561f72de01a1.gif)
torchvision源码解读之transform
from__future__importdivision
importtorch
importmath
importrandom
fromPILimportImage,ImageOps,ImageEnhance
try:
importaccimage
exceptImportError:
accimage=None
importnumpyasnp
importnumbers
importtypes
importcollections
importwarnings
from.importfunctionalasF
__all__=["Compose","ToTensor","ToPILImage","Normalize","Resize","Scale","CenterCrop","Pad",
"Lambda","RandomCrop","RandomHorizontalFlip","RandomVerticalFlip","RandomResizedCrop",
"RandomSizedCrop","FiveCrop","TenCrop","LinearTransformation","ColorJitter","RandomRotation",
"Grayscale","RandomGrayscale"]
#Compose这个类是用来管理各个transform的,可以看到主要的__call__方法就是对输入图像img循环所有的transform操作
classCompose(object):
"""Composesseveraltransformstogether.
Args:
transforms(listof``Transform``objects):
listoftransformstocompose.
Example:
>>>transforms.Compose([
>>>transforms.CenterCrop(10),
>>>transforms.ToTensor(),
>>>])
"""
def__init__(self,transforms):
self.transforms=transforms
def__call__(self,img):
fortinself.transforms:
img=t(img)
returnimg
#ToTensor类是实现:
ConvertaPILImageornumpy.ndarraytotensor的过程,
#在PyTorch中常用PIL库来读取图像数据,因此这个方法相当于搭建了PILImage和Tensor的桥梁。
#在做数据归一化之前必须要把PILImage转成Tensor,而其他resize或crop操作则不需要。
classToTensor(object):
"""Converta``PILImage``or``numpy.ndarray``totensor.
ConvertsaPILImageornumpy.ndarray(HxWxC)intherange
[0,255]toatorch.FloatTensorofshape(CxHxW)intherange[0.0,1.0].
"""
def__call__(self,pic):
"""
Args:
pic(PILImageornumpy.ndarray):
Imagetobeconvertedtotensor.
Returns:
Tensor:
Convertedimage.
"""
returnF.to_tensor(pic)
#ToPILImage顾名思义是从Tensor到PILImage的过程,和前面ToTensor类的相反的操作
classToPILImage(object):
"""ConvertatensororanndarraytoPILImage.
Convertsatorch.*TensorofshapeCxHxWoranumpyndarrayofshape
HxWxCtoaPILImagewhilepreservingthevaluerange.
Args:
mode(`PIL.Imagemode`_):
colorspaceandpixeldepthofinputdata(optional).
If``mode``is``None``(default)therearesomeassumptionsmadeabouttheinputdata:
1.Iftheinputhas3channels,the``mode``isassumedtobe``RGB``.
2.Iftheinputhas4channels,the``mode``isassumedtobe``RGBA``.
3.Iftheinputhas1channel,the``mode``isdeterminedbythedatatype(i,e,
``int``,``float``,``short``).
.._PIL.Imagemode:
http:
//pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
"""
def__init__(self,mode=None):
self.mode=mode
def__call__(self,pic):
"""
Args:
pic(Tensorornumpy.ndarray):
ImagetobeconvertedtoPILImage.
Returns:
PILImage:
ImageconvertedtoPILImage.
"""
returnF.to_pil_image(pic,self.mode)
#Normalize类是做数据归一化的,一般都会对输入数据做这样的操作,公式也在注释中给出了,比较容易理解。
#前面提到在调用Normalize的时候,输入得是Tensor,这个从__call__方法的输入也可以看出来了。
classNormalize(object):
"""Normalizeantensorimagewithmeanandstandarddeviation.
Givenmean:
``(M1,...,Mn)``andstd:
``(S1,..,Sn)``for``n``channels,thistransform
willnormalizeeachchanneloftheinput``torch.*Tensor``i.e.
``input[channel]=(input[channel]-mean[channel])/std[channel]``
Args:
mean(sequence):
Sequenceofmeansforeachchannel.
std(sequence):
Sequenceofstandarddeviationsforeachchannel.
"""
def__init__(self,mean,std):
self.mean=mean
self.std=std
def__call__(self,tensor):
"""
Args:
tensor(Tensor):
Tensorimageofsize(C,H,W)tobenormalized.
Returns:
Tensor:
NormalizedTensorimage.
"""
returnF.normalize(tensor,self.mean,self.std)
#Resize类是对PILImage做resize操作的,几乎都要用到。
#这里输入可以是int,此时表示将输入图像的短边resize到这个int数,长边则根据对应比例调整,图像的长宽比不变。
#如果输入是个(h,w)的序列,h和w都是int,则直接将输入图像resize到这个(h,w)尺寸,
#相当于forceresize,所以一般最后图像的长宽比会变化,也就是图像内容被拉长或缩短。
#注意,在__call__方法中调用了functional.py脚本中的resize函数来完成resize操作,
#因为输入是PILImage,所以resize函数基本是在调用Image的各种方法。
#如果输入是Tensor,则对应函数基本是在调用Tensor的各种方法,这就是functional.py中的主要内容。
classResize(object):
"""ResizetheinputPILImagetothegivensize.
Args:
size(sequenceorint):
Desiredoutputsize.Ifsizeisasequencelike
(h,w),outputsizewillbematchedtothis.Ifsizeisanint,
smalleredgeoftheimagewillbematchedtothisnumber.
i.e,ifheight>width,thenimagewillberescaledto
(size*height/width,size)
interpolation(int,optional):
Desiredinterpolation.Defaultis
``PIL.Image.BILINEAR``
"""
def__init__(self,size,interpolation=Image.BILINEAR):
assertisinstance(size,int)or(isinstance(size,collections.Iterable)andlen(size)==2)
self.size=size
self.interpolation=interpolation
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobescaled.
Returns:
PILImage:
Rescaledimage.
"""
returnF.resize(img,self.size,self.interpolation)
classScale(Resize):
"""
Note:
ThistransformisdeprecatedinfavorofResize.
"""
def__init__(self,*args,**kwargs):
warnings.warn("Theuseofthetransforms.Scaletransformisdeprecated,"+
"pleaseusetransforms.Resizeinstead.")
super(Scale,self).__init__(*args,**kwargs)
#CenterCrop是以输入图的中心点为中心点做指定size的crop操作,
#一般数据增强不会采用这个,因为当size固定的时候,在相同输入图像的情况下,N次CenterCrop的结果都是一样的。
#注释里面说明了size为int和序列时候尺寸的定义。
classCenterCrop(object):
"""CropsthegivenPILImageatthecenter.
Args:
size(sequenceorint):
Desiredoutputsizeofthecrop.Ifsizeisan
intinsteadofsequencelike(h,w),asquarecrop(size,size)is
made.
"""
def__init__(self,size):
ifisinstance(size,numbers.Number):
self.size=(int(size),int(size))
else:
self.size=size
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobecropped.
Returns:
PILImage:
Croppedimage.
"""
returnF.center_crop(img,self.size)
classPad(object):
"""PadthegivenPILImageonallsideswiththegiven"pad"value.
Args:
padding(intortuple):
Paddingoneachborder.Ifasingleintisprovidedthis
isusedtopadallborders.Iftupleoflength2isprovidedthisisthepadding
onleft/rightandtop/bottomrespectively.Ifatupleoflength4isprovided
thisisthepaddingfortheleft,top,rightandbottomborders
respectively.
fill:
Pixelfillvalue.Defaultis0.Ifatupleof
length3,itisusedtofillR,G,Bchannelsrespectively.
"""
def__init__(self,padding,fill=0):
assertisinstance(padding,(numbers.Number,tuple))
assertisinstance(fill,(numbers.Number,str,tuple))
ifisinstance(padding,collections.Sequence)andlen(padding)notin[2,4]:
raiseValueError("Paddingmustbeanintora2,or4elementtuple,nota"+
"{}elementtuple".format(len(padding)))
self.padding=padding
self.fill=fill
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobepadded.
Returns:
PILImage:
Paddedimage.
"""
returnF.pad(img,self.padding,self.fill)
classLambda(object):
"""Applyauser-definedlambdaasatransform.
Args:
lambd(function):
Lambda/functiontobeusedfortransform.
"""
def__init__(self,lambd):
assertisinstance(lambd,types.LambdaType)
self.lambd=lambd
def__call__(self,img):
returnself.lambd(img)
#相比前面的CenterCrop,这个RandomCrop更常用,差别就在于crop时的中心点坐标是随机的,并不是输入图像的中心点坐标,
#因此基本上每次crop生成的图像都是有差异的。
#就是通过i=random.randint(0,h-th)和j=random.randint(0,w-tw)两行生成一个随机中心点的横纵坐标。
#注意到在__call__中最后是调用了F.crop(img,i,j,h,w)来完成crop操作,
#其实前面CenterCrop中虽然是调用F.center_crop(img,self.size),
#但是在F.center_crop()函数中只是先计算了中心点坐标,最后还是调用F.crop(img,i,j,h,w)完成crop操作。
classRandomCrop(object):
"""CropthegivenPILImageatarandomlocation.
Args:
size(sequenceorint):
Desiredoutputsizeofthecrop.Ifsizeisan
intinsteadofsequencelike(h,w),asquarecrop(size,size)is
made.
padding(intorsequence,optional):
Optionalpaddingoneachborder
oftheimage.Defaultis0,i.enopadding.Ifasequenceoflength
4isprovided,itisusedtopadleft,top,right,bottomborders
respectively.
"""
def__init__(self,size,padding=0):
ifisinstance(size,numbers.Number):
self.size=(int(size),int(size))
else:
self.size=size
self.padding=padding
@staticmethod
defget_params(img,output_size):
"""Getparametersfor``crop``forarandomcrop.
Args:
img(PILImage):
Imagetobecropped.
output_size(tuple):
Expectedoutputsizeofthecrop.
Returns:
tuple:
params(i,j,h,w)tobepassedto``crop``forrandomcrop.
"""
w,h=img.size
th,tw=output_size
ifw==twandh==th:
return0,0,h,w
i=random.randint(0,h-th)
j=random.randint(0,w-tw)
returni,j,th,tw
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobecropped.
Returns:
PILImage:
Croppedimage.
"""
ifself.padding>0:
img=F.pad(img,self.padding)
i,j,h,w=self.get_params(img,self.size)
returnF.crop(img,i,j,h,w)
classRandomHorizontalFlip(object):
"""HorizontallyflipthegivenPILImagerandomlywithaprobabilityof0.5."""
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobeflipped.
Returns:
PILImage:
Randomlyflippedimage.
"""
ifrandom.random()<0.5:
returnF.hflip(img)
returnimg
classRandomVerticalFlip(object):
"""VerticallyflipthegivenPILImagerandomlywithaprobabilityof0.5."""
def__call__(self,img):
"""
Args:
img(PILImage):
Imagetobeflipped.
Returns:
PILImage:
Randomlyflippedimage.
"""
ifrandom.random()<0.5:
returnF.vflip(img)
returnimg
#RandomResizedCrop类也是比较常用的,个人非常喜欢用。
#前面不管是CenterCrop还是RandomCrop,在crop的时候其尺寸是固定的,而这个类则是randomsize的crop。
#该类主要用到3个参数:
size、scale和ratio,总的来讲就是先做crop(用到scale和ratio),再resize到指定尺寸(用到size)。
#做crop的时候,其中心点坐标和长宽是由get_params方法得到的,
#在get_params方法中主要用到两个参数:
scale和ratio,首先在scale限定的数值范围内随机生成一个数,
#用这个数乘以输入图像的面积作为crop后图像的面积;
#然后在ratio限定的数值范围内随机生成一个数,表示长宽的比值,根据这两个值就可以得到c