一、使用ImageFolder读取数据集时忽略特定文件
如果事先知道需要忽略哪些文件,当然直接从数据集里删除就行了。但如果需要在程序运行时动态确认,或者筛选规则比较复杂,人工不好做,就需要让ImageFolder
在读取时使用自定义的筛选规则。
ImageFolder有一个可选参数为is_valid_file
,参数类型为可调用的函数,该函数传入一个str参数,返回一个bool值。当返回值为True时保留该文件,否则忽略。
例如,读取时想要忽略所有文件名带‘invalid’的文件,
代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
import platform from torchvision.datasets import ImageFolder class Check( object ): def __init__( self , key_word: str ): self .key_word = key_word self .separator = '\\' if platform.system() == ' Windows ' else ' / ' def __call__( self , file_name: str ) - > bool : folders = file_name.split( self .separator) return folders[ - 1 ].find( self .key_word) < 0 dataset = ImageFolder( './data' , is_valid_file = Check( 'invalid' )) |
这里定义了一个实现了__call__
方法的Check类,相比于直接定义函数的好处在于可以在构造函数里指定想要忽略的字符,并且能够根据操作系统的不同把文件目录分隔符给确定了。
更加复杂的功能可以自行修改代码逻辑实现,但是要注意如果某个类别的所有文件都被筛选掉了,ImageFolder
会报FileNotFoundError
错误。
如果想要忽略整个类别可以使用下面方法!!!
二、ImageFolder只读取部分类别文件夹
直接继承并且重写ImageFolder
类的find_classes
方法即可
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
from torchvision.datasets.folder import * from typing import * class FilterableImageFolder(ImageFolder): def __init__( self , root: str , transform: Optional[ Callable ] = None , target_transform: Optional[ Callable ] = None , loader: Callable [[ str ], Any ] = default_loader, is_valid_file: Optional[ Callable [[ str ], bool ]] = None , valid_classes: List = None ): self .valid_classes = valid_classes super (FilterableImageFolder, self ).__init__(root, transform, target_transform, loader, is_valid_file) def find_classes( self , directory: str ) - > Tuple [ List [ str ], Dict [ str , int ]]: classes = sorted (entry.name for entry in os.scandir(directory) if entry.is_dir()) #增加了这下面这句 classes = [valid_class for valid_class in classes if valid_class in self .valid_classes] if not classes: raise FileNotFoundError(f "Couldn't find any class folder in {directory}." ) class_to_idx = {cls_name: i for i, cls_name in enumerate (classes)} return classes, class_to_idx |
使用时,例如有mouse
、cat
、dog
三个类别的数据集文件夹,只想读取cat
和dog
,
代码如下:
1
|
dataset = FilterableImageFolder( './data' , valid_classes = [ 'cat' , 'dog' ]) |
到此这篇关于Pytorch中使用ImageFolder读取数据集时忽略特定文件的文章就介绍到这了,更多相关ImageFolder读取数据集内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!
原文链接:https://blog.csdn.net/qq_44980390/article/details/123439341