-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvocabulary.py
More file actions
206 lines (178 loc) · 5.91 KB
/
vocabulary.py
File metadata and controls
206 lines (178 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import logging
from collections import Counter
from common import save_pickle
from common import load_pickle
logger = logging.getLogger(__name__)
"""
创建一个词表
"""
class Vocabulary(object):
def __init__(self, max_size=None,
min_freq=None,
pad_token="[PAD]",
unk_token = "[UNK]",
cls_token = "[CLS]",
sep_token = "[SEP]",
mask_token = "[MASK]",
add_unused = False):
self.max_size = max_size
self.min_freq = min_freq
self.cls_token = cls_token
self.sep_token = sep_token
self.pad_token = pad_token
self.mask_token = mask_token
self.unk_token = unk_token
self.word2idx = {}
self.idx2word = None
self.rebuild = True
self.add_unused = add_unused
self.word_counter = Counter()
self.reset()
def reset(self):
"""
将特殊标签先加入词表
"""
ctrl_symbols = [self.pad_token,self.unk_token,self.cls_token,self.sep_token,self.mask_token]
for index,syb in enumerate(ctrl_symbols):
self.word2idx[syb] = index
if self.add_unused:
for i in range(20):
self.word2idx[f'[UNUSED{i}]'] = len(self.word2idx)
def update(self, word_list):
'''
依次增加序列中词在词典中的出现频率
:param word_list:
:return:
'''
#Counter.update() 如果要更新的关键字已存在,则对它的值进行求和;如果不存在,则添加
self.word_counter.update(word_list)
def add(self, word):
'''
增加一个新词在词典中的出现频率
:param word:
:return:
'''
self.word_counter[word] += 1
def has_word(self, word):
'''
检查词是否被记录
:param word:
:return:
'''
return word in self.word2idx
def to_index(self, word):
'''
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出
:param word:
:return:
'''
if word in self.word2idx:
return self.word2idx[word]
if self.unk_token is not None:
return self.word2idx[self.unk_token]
else:
raise ValueError("word {} not in vocabulary".format(word))
def unknown_idx(self):
"""
unknown 对应的数字.
"""
if self.unk_token is None:
return None
return self.word2idx[self.unk_token]
def padding_idx(self):
"""
padding 对应的数字
"""
if self.pad_token is None:
return None
return self.word2idx[self.pad_token]
def to_word(self, idx):
"""
给定一个数字, 将其转为对应的词.
:param int idx: the index
:return str word: the word
"""
return self.idx2word[idx]
def build_vocab(self):
max_size = min(self.max_size, len(self.word_counter)) if self.max_size else None
# max_size参数表示输出max_size个次数最多的元素 默认为None,表示输出全部
words = self.word_counter.most_common(max_size)
if self.min_freq is not None:
words = filter(lambda kv: kv[1] >= self.min_freq, words) #筛选出次数大于min_freq的词
if self.word2idx:
words = filter(lambda kv: kv[0] not in self.word2idx, words)#筛选出不在词表的词
start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab() #id2word
self.rebuild = False
def save(self, file_path):
'''
保存vocab
:param file_name:
:param pickle_path:
:return:
'''
mappings = {
"word2idx": self.word2idx,
'idx2word': self.idx2word
}
save_pickle(data=mappings, file_path=file_path)
def save_bert_vocab(self,file_path):
bert_vocab = [x for x,y in self.word2idx.items()]
with open(str(file_path),'w') as fo:
for token in bert_vocab:
fo.write(token+"\n")
def load_from_file(self, file_path):
'''
从文件组红加载vocab
:param file_name:
:param pickle_path:
:return:
'''
mappings = load_pickle(input_file=file_path)
self.idx2word = mappings['idx2word']
self.word2idx = mappings['word2idx']
def build_reverse_vocab(self):
self.idx2word = {i: w for w, i in self.word2idx.items()}
def read_data(self,data_path):
if data_path.is_dir(): #如果路径是一个目录,则遍历整个目录
files = sorted([f for f in data_path.iterdir() if f.exists()])
else:
files = [data_path]
for file in files:
f = open(file, 'r')
lines = f.readlines() # 读取数据集
for line in lines:
line = line.strip("\n")
words = line.split(" ")
self.update(words)
def clear(self):
"""
删除Vocabulary中的词表数据。相当于重新初始化一下。
:return:
"""
self.word_counter.clear()
self.word2idx = None
self.idx2word = None
self.rebuild = True
self.reset()
def __len__(self):
return len(self.idx2word)
def tag2id(data_dir):
"""
return : tag2id,id2tag -> Dict
"""
train_path = data_dir / 'train.txt'
tags = set()
with open(train_path,'r',encoding='utf-8') as f:
for line in f:
line = line.strip().split(' ')
if len(line) == 2:
tags.add(line[-1])
tags = list(tags)
tag2id = {}
id2tag = {}
for id,tag in enumerate(tags):
tag2id[tag] = id
id2tag[id] = tag
return tag2id,id2tag