classTextConverter(object): def__init__(self, text_path, max_vocab=5000): with codecs.open(text_path, mode='r', encoding='utf-8') as f: text_file = f.readlines() word_list = [v for s in text_file for v in s] vocab = set(word_list) # 如果单词超过最长限制,则按单词出现频率去掉最小的部分 vocab_count = {} for word in vocab: vocab_count[word] = 0 for word in word_list: vocab_count[word] += 1 vocab_count_list = [] for word in vocab_count: vocab_count_list.append((word, vocab_count[word])) vocab_count_list.sort(key=lambda x: x[1], reverse=True) iflen(vocab_count_list) > max_vocab: vocab_count_list = vocab_count_list[:max_vocab] vocab = [x[0] for x in vocab_count_list] self.vocab = vocab
self.word_to_int_table = {c: i for i, c inenumerate(self.vocab)} self.int_to_word_table = dict(enumerate(self.vocab))
defforward(self, x, hs=None): batch = x.shape[0] if hs isNone: hs = nd.zeros( (self.num_layers, batch, self.hidden_size), ctx=mx.gpu()) word_embed = self.word_to_vec(x) # batch x len x embed word_embed = word_embed.transpose((1, 0, 2)) # len x batch x embed out, h0 = self.rnn(word_embed, hs) # len x batch x hidden le, mb, hd = out.shape out = out.reshape((le * mb, hd)) out = self.proj(out) out = out.reshape((le, mb, -1)) out = out.transpose((1, 0, 2)) # batch x len x hidden return out.reshape((-1, out.shape[2])), h0