10、提取词向量,保存到hdf5
<pre><code>
import os
import pickle
#import numpy as np
import h5py
import hashlib
import json
from bert_serving.client import BertClient
'''
<HDF5 dataset "大波": shape (2, 1024), type "<f8">
[[ 0.73608518 0.29880694 -0.149615 ... -0.22505963 0.42839751
-0.18229473]
[-0.16124742 0.29405186 0.0544907 ... 0.2817606 0.70748979
-0.05054991]]
<HDF5 dataset "挣得": shape (3, 1024), type "<f8">
[[-0.24522281 0.45859608 -0.11046886 ... -0.30834651 0.29136491
0.62058806]
[-0.24838917 0.11798686 -0.16676699 ... 0.70589703 0.94294947
-0.01259621]
[ 0.17104931 0.7874096 0.12995088 ... 0.29047191 0.4176755
-0.06562544]
...
'''
'''
res = bc.sents2elmo(question_token)
Out[25]:
[array([[ 0.11186478, 0.19980264, -0.04579666, ..., -1.3476261 ,
0.24147476, -0.01047087]], dtype=float32),
array([[ 0.5765926 , 0.07752912, 0.2669233 , ..., -0.88939697,
0.57002026, 0.43547738]], dtype=float32),
array([[-0.24481785, 0.01219011, 0.9444544 , ..., 0.04462969,
-0.3311766 , 0.04817347]], dtype=float32),
array([[-0.1921978 , 0.5232609 , -0.02506225, ..., -0.5678321 ,
-0.523696 , 0.54723406]], dtype=float32)]
In [27]: res[0]
Out[27]:
array([[ 0.10416368, 0.21417041, -0.0221702 , ..., -0.6439903 ,
0.15619688, -0.44247976]], dtype=float32)
In [28]: res[0].shape
Out[28]: (1, 1024)
# 每个词的向量都是二维的
# 有几个词行数是几
In [32]: question_token
Out[32]: ['你好', '中', '国']
In [33]: res = bc.sents2elmo(question_token)
2019-05-08 15:23:55,711 INFO: 1 batches, avg len: 3.3
In [34]: res
Out[34]:
[array([[ 0.10398692, 0.21505527, -0.02499238, ..., -0.23105145,
0.08063925, -0.13005304],
[ 0.682537 , 0.4822921 , 0.43847165, ..., 0.07830879,
0.27263913, 0.18026733]], dtype=float32),
array([[-0.18547088, -0.06569155, 0.8559456 , ..., 0.379452 ,
0.16279131, -0.6330448 ]], dtype=float32),
array([[-0.17902641, 0.5306758 , 0.14992061, ..., -0.45520094,
-0.60934097, -0.01732008]], dtype=float32)]
In [35]: res[0]
Out[35]:
array([[ 0.10398692, 0.21505527, -0.02499238, ..., -0.23105145,
0.08063925, -0.13005304],
[ 0.682537 , 0.4822921 , 0.43847165, ..., 0.07830879,
0.27263913, 0.18026733]], dtype=float32)
In [36]: res[0].shape
Out[36]: (2, 1024)
In [39]: len(res[0].tolist()[0])
Out[39]: 1024
In [40]: res[0].shape
Out[40]: (2, 1024)
In [41]: len(res[0].tolist())
Out[41]: 2
In [42]:
'''
# np.array([ ]).tolist()
#
# i.tolist()[0] for i in bc.sents2elmo(question_token)
# with open('result.json') as fin:
# for lidx, line in enumerate(fin, 1):
# sample = json.loads(line.strip())
# print(type(sample['屋子']))
# print(sample['屋子'])
products_set = set()
bc = BertClient()
# bc = Embedder('/tmp/ELMoForManyLangs/zhs.model')
# bc = Embedder('/content/drive/My Drive/baidu/ELMoForManyLangs/zhsmodel')
with open(os.path.join('../data/vocab/', 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
# with open('result.json', 'w') as fout:
# for pred_answer in pred_answers:
# fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
f = h5py.File('new_bert_vocab_to_hdf5.h5', 'w')
# f = {}
for id, item in enumerate(vocab.token2id):
hash_title = hashlib.md5(item.encode(encoding='UTF-8')).hexdigest()
if hash_title in products_set:
continue
products_set.add(hash_title)
try:
# f[item] = (np.array([i.tolist()[0] for i in bc.sents2elmo(item)]).tolist())
f[item] = bc.encode([item])[0].tolist()
except Exception as e:
f[str(item) + str(id)] = bc.encode([item])[0].tolist()
print('****************')
print(str(item) + str(id))
print(e)
continue
if id % 50 ==1:
print(id)
# print(item)
# if id > 10:
# break
# fout.write(json.dumps(f, ensure_ascii=False) + '\n')
f.close()
bert_vocab = {}
f = h5py.File('bert_vocab_to_hdf5.h5', 'r')
with open(os.path.join('../data/vocab/', 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin)
for id,k in enumerate(f.keys()):
if id > 10000:
break
bert_vocab[k] = f[k][:]
print(k)
#
# for id, item in enumerate(vocab.token2id):
# if id > 10000:
# break
#
# print(f[item][:])
# print(f[item])
# if id >10020:
# break
# with open(os.path.join('../data/vocab/', 'vocab.data'), 'rb') as fin:
# vocab = pickle.load(fin)
# for id, item in enumerate(vocab.token2id):
# for word in item:
# print(word)
# print(item)
#
# if id >20:
# break
# # print(item)
# # if id >7500:
# # break
</code></pre>