GRU情感分类问题再战

news/2024/7/2 23:32:23
import tensorflow as tf
import numpy as np 
import tensorflow as keras 
from tensorflow.keras import losses,Sequential,optimizers,layers,datasetsbatchsz=128#批量大小
total_words=10000#词汇表大小N_vocab
max_review_len=80#句子最大长度 s,大于的句子部分将截断,小于的将填充
embedding_len=100#词向量特征长度(x_train,y_train),(x_test,y_test)=datasets.imdb.load_data(num_words=total_words)
#print(x_train.shape,len(x_train[0]),y_train.shape)
#print(x_test.shape,len(x_test[0]),y_test.shape)x_train=tf.keras.preprocessing.sequence.pad_sequences(x_train,maxlen=max_review_len)
x_test=tf.keras.preprocessing.sequence.pad_sequences(x_test,maxlen=max_review_len)
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
train_db=train_db.shuffle(1000).batch(batchsz,drop_remainder=True)
test_db=test_db.batch(batchsz,drop_remainder=True)#统计数据集属性
#print('x_train shape: ',x_train.shape,tf.reduce_max(y_train),tf.reduce_min(y_train))
#print('x_test shape: ',x_test.shape)class MyLMST(tf.keras.Model):def __init__(self,units):super(MyLMST,self).__init__()self.state0=[tf.zeros([batchsz,units])]self.state1=[tf.zeros([batchsz,units])]#词向量编码[b,80]=>[b,80,100]self.embedding=layers.Embedding(total_words,embedding_len,input_length=max_review_len)#构建2个cell,使用dropout技术防止过拟合self.run_cell0=layers.GRUCell(units,dropout=0.5)self.run_cell1=layers.GRUCell(units,dropout=0.5)self.outlayer = Sequential([layers.Dense(units),layers.Dropout(rate=0.5),layers.ReLU(),layers.Dense(1)])def call(self,inputs,training=None):x=inputs#获取词向量[b,80]=>[b,80,100]x=self.embedding(x)#通过2个LSTM CELL,[b,80,100]=>[b,64]state0=self.state0state1=self.state1for word in tf.unstack(x,axis=1):out0,state0=self.run_cell0(word,state0,training)out1,state1=self.run_cell1(out0,state1,training)#末层最后一个输出作为分类网络的输入:[6,64]=>[b,1]x=self.outlayer(out1,training)#通过激活函数p(y is pos[x])prob=tf.sigmoid(x)return probdef main():units=64epochs=6learning_rate=0.001model=MyLMST(units)#装配model.compile(optimizer=optimizers.Adam(learning_rate),loss=losses.BinaryCrossentropy(),metrics=['accuracy'],experimental_run_tf_function=False)#训练和验证model.fit(train_db,epochs=epochs,validation_data=test_db)#测试model.evaluate(test_db)if __name__=='__main__':main()
Epoch 1/6
195/195 [==============================] - 272s 1s/step - loss: 0.5361 - accuracy: 0.6149 - val_loss: 0.3710 - val_accuracy: 0.8345
Epoch 2/6
195/195 [==============================] - 260s 1s/step - loss: 0.3333 - accuracy: 0.8485 - val_loss: 0.3599 - val_accuracy: 0.8411
Epoch 3/6
195/195 [==============================] - 262s 1s/step - loss: 0.2674 - accuracy: 0.8832 - val_loss: 0.4213 - val_accuracy: 0.8343
Epoch 4/6
195/195 [==============================] - 265s 1s/step - loss: 0.2288 - accuracy: 0.9078 - val_loss: 0.4900 - val_accuracy: 0.8303
Epoch 5/6
195/195 [==============================] - 272s 1s/step - loss: 0.1968 - accuracy: 0.9234 - val_loss: 0.4947 - val_accuracy: 0.8196
Epoch 6/6
195/195 [==============================] - 268s 1s/step - loss: 0.1711 - accuracy: 0.9320 - val_loss: 0.5421 - val_accuracy: 0.8242
195/195 [==============================] - 73s 373ms/step - loss: 0.5421 - accuracy: 0.8242

http://lihuaxi.xjx100.cn/news/286089.html

相关文章

图论 ---- CF1209F. Koala and Notebook(多位数字拆边+BFS)

题目链接 题目大意: 给你一个nnn个点mmm条边的无向联通图,每条边上面都有一条权值就是输入时候的位置,然后问你从111号点出发到其他n−1n-1n−1个点最小权值是多少?路径权值是路径上面的数字拼接起来的结果,不是相加 …

32岁程序员,失业4个月45次面试经历,与君共勉

程序员求职面试(微信号:CoderJob)整理内容综合自:网络一个32岁的程序员,失业4个月的45次面试,终于入职了,与君共勉。看到这么多面试经历,网友们也激动了。有网友说:面试确…

公开课 | 用AI给旧时光上色!详解GAN在黑白照片上色中的应用

在改革开放40周年之际,百度联合新华社推出了一个刷屏级的H5应用——用AI技术为黑白老照片上色,浓浓的怀旧风勾起了心底快被遗忘的时光。想了解如何给老照片上色?本次公开课中,我们邀请到了百度高级研发工程师李超,他的…

iOS开发中的 地区转经纬 经纬度转地区

2019独角兽企业重金招聘Python工程师标准>>> 参考 iOS 根据地名获取经纬度 iOS 根据经纬度显示地名 - (void)setCity {[[LoginUserInfo sharedLoginUserInfo] latitude];NSLog("%",[[LoginUserInfo sharedLoginUserInfo] latitude]);NSLog("------%…

求一个数的因子个数/因子和/质因子 C/C++实现

求一个数的因子个数时间复杂度O√n ll get_number(ll x){ll num0;for(ll i1;i*i<x;i){if(x%i0) num2; if(i*ix) num1;}return num; }求一个数的因子和时间复杂度O√n ll get_number(ll x){ll num0;for(ll i1;i*i<x;i){if(x%i0) numix/i;if(i*ix) numi;}return num…

poj1422(最小路径覆盖问题)

最小路径覆盖数: 对于一个DAG&#xff08;有向无环图&#xff09;&#xff0c;选取最少条路径&#xff0c;使得每个 顶点属于且仅属于一条路径。路径长度可以为零&#xff1b;&#xff08;有向图中找一些路径&#xff0c;使之覆盖了图中的所有顶点&#xff0c;就是任意一个顶点…

主席树 ---- CF 1422F. Boring Queries(由离线推出在线如何求的 ,求解多次询问的区间LCM)

题目链接 题目大意&#xff1a; 给你nnn个数&#xff0c; 每次往第iii个数里面里面乘aaa&#xff0c;问你这nnn个数的LCM\text{LCM}LCM是多少&#xff1f; 解题思路&#xff1a; 多个数的lcm不是所有数的乘积除以所有数的gcd&#xff0c;如 4 8 3 正确求法是每个数分解质因数…

春节停车难?用Python找空车位

作者 | Adam Geitgey译者 | 风车云马整理 | Jane出品 | AI科技大本营&#xff08;ID:rgznai100&#xff09;【导语】今天这篇文章的选题非常贴近生活。营长生活在北京&#xff0c;深知开车出门最怕的就是堵车和找不到停车位。记得冬至那个周末&#xff0c;几个小伙伴滑雪回来找…