百度NLP赛事实践解读!

news/2024/7/2 23:42:14

开放领域的搜索场景下得到的网页数据会非常复杂,其中往往存在着网页文档质量参差不齐、长短不一,问题答案分布零散、长度较长等问题,给答案抽取答案置信度计算带来了较大挑战。

本文基于百度搜索技术创新挑战赛中的搜索问答赛题,对搜索问答类任务做详细解读。本文思路如图:

00b306920c9afb82f3dafea133e9aacc.png

(本文实践讲解框架,其中改进思路见文末)

赛题背景

本赛题希望调研真实网络环境下的文档级机器阅读理解技术,共分为两个子任务,涉及基于复杂网页文档内容的答案抽取答案检验技术(详细任务定义可参考赛事官网)。

赛事官网:

https://aistudio.baidu.com/aistudio/competition/detail/660/0/introduction

难点分析

如何在文档长度不定答案长度较长的数据环境中取得良好且鲁棒的答案抽取效果是子任务1关注的重点。

方案介绍

赛题可以视为基础的信息抽取任务,也可以直接视为问答类型的信息抽取问题。我们需要构建一个模型,根据query从document中找到想要的答案。

思路一:BERT或ERNIE

如果我们使用BERT 或者 ERNIE 可以直接参考如下思路,模型的输出可以为对应的两个位置,分别是回答的开始位置和结束位置。

4c52b79d4147efec30fb9fa8088b153a.png

这里需要深入一下模型的实现细节:

  • query和documnet是一起输入给模型,一般情况下query在前面。

  • 回答对应的输出可以通过模型输出后的全连接层完成分类,当然回归也可以。

思路二:QA

如果采用QA的思路,则需要将比赛数据集转换为QA的格式,特别是文本的处理:长文本需要进行截断。

方案代码

方案借助QA的思路,使用ERNIE快速完成模型训练与预测。同时,文末给出了提分改进方案的思路。

详情可参考源Notebook:

https://aistudio.baidu.com/aistudio/projectdetail/5013840(一键运行提交

步骤1:解压数据集

!pip install paddle-ernie > log.log
# !cp data/data174963/data_task1.tar /home/aistudio/
!tar -xf /home/aistudio/data_task1.tar

步骤2:读取数据集

# 导入常见的库
import numpy as np
import pandas as pd
import os, sys, json
# 读取训练集、测试集和验证集
train_json = pd.read_json('data_task1/train_data/train.json', lines=True)
test_json = pd.read_json('data_task1/test_data/test.json', lines=True)
dev_json = pd.read_json('data_task1/dev_data/dev.json', lines=True)
# 查看数据集样例
train_json.head(1)
train_json.iloc[0]
test_json.iloc[0]

步骤3:加载ERNIE模型

这里我们使用paddlenlp==2.0.7,当然你也可以选择更高的版本。更高的版本会将损失计算也封装进去,其他的部分区别不大。

import paddle
import paddlenlp
print('paddle version', paddle.__version__)
print('paddlenlp version', paddlenlp.__version__)from paddlenlp.transformers import ErnieForQuestionAnswering, ErnieTokenizer
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
model = ErnieForQuestionAnswering.from_pretrained('ernie-1.0')
# 对文档的文档进行划分、计算文档的长度
train_json['doc_sentence'] = train_json['doc_text'].str.split('。')
train_json['doc_sentence_length'] = train_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc])
train_json['doc_sentence_length_max'] = train_json['doc_sentence_length'].apply(max)
train_json = train_json[train_json['doc_sentence_length_max'] < 10000] # 删除了部分超长文档# 对文档的文档进行划分、计算文档的长度
dev_json['doc_sentence'] = dev_json['doc_text'].str.split('。')
dev_json['doc_sentence_length'] = dev_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc])
dev_json['doc_sentence_length_max'] = dev_json['doc_sentence_length'].apply(max)
dev_json = dev_json[dev_json['doc_sentence_length_max'] < 10000] # 删除了部分超长文档# 对文档的文档进行划分、计算文档的长度
test_json['doc_sentence'] = test_json['doc_text'].str.split('。')
test_json['doc_sentence_length'] = test_json['doc_sentence'].apply(lambda doc: [len(sentence) for sentence in doc])
test_json['doc_sentence_length_max'] = test_json['doc_sentence_length'].apply(max)
train_json.iloc[10]
test_json.iloc[10]

步骤4:构建数据集

接下来需要构建QA任务的数据集,这里的数据集需要处理为如下的格式:

query [SEP] sentence of document
  • 训练集数据集处理

train_encoding = []# for idx in range(len(train_json)):
for idx in range(10000):# 读取原始数据的一条样本title = train_json.iloc[idx]['title']answer_start_list = train_json.iloc[idx]['answer_start_list']answer_list = train_json.iloc[idx]['answer_list']doc_text = train_json.iloc[idx]['doc_text']query = train_json.iloc[idx]['query']doc_sentence = train_json.iloc[idx]['doc_sentence']#  对于文章中的每个句子for sentence in set(doc_sentence):# 如果存在答案for answer in answer_list:answer = answer.strip("。")# 如果问题 + 答案 太长,跳过if len(query + sentence) > 512:continue# 对问题 + 答案进行编码encoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True, return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True)# 如果答案在这个句子中,找到start 和 end的 位置if answer in sentence:            encoding['start_positions'] = len(query) + 2 + sentence.index(answer)encoding['end_positions'] = len(query) + 2 + sentence.index(answer) + len(answer)# 如果不存在,则位置设置为0else:encoding['start_positions'] = 0encoding['end_positions'] = 0# 存储正样本if encoding['start_positions'] != 0:train_encoding.append(encoding)# 对负样本进行采样,因为负样本太多# 正样本:query + sentence -> answer 的情况# 负样本:query + sentence -> No answer 的情况if encoding['start_positions'] == 0 and np.random.randint(0, 100) > 99:train_encoding.append(encoding)if len(train_encoding) % 500 == 0:print(len(train_encoding))
  • 验证集数据集处理

val_encoding = []for idx in range(len(dev_json)):
# for idx in range(200):title = dev_json.iloc[idx]['title']answer_start_list = dev_json.iloc[idx]['answer_start_list']answer_list = dev_json.iloc[idx]['answer_list']doc_text = dev_json.iloc[idx]['doc_text']query = dev_json.iloc[idx]['query']doc_sentence = dev_json.iloc[idx]['doc_sentence']for sentence in set(doc_sentence):for answer in answer_list:answer = answer.strip("。")if len(query + sentence) > 512:continueencoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True, return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True)if answer in sentence:            encoding['start_positions'] = len(query) + 2 + sentence.index(answer)encoding['end_positions'] = len(query) + 2 + sentence.index(answer) + len(answer)else:encoding['start_positions'] = 0encoding['end_positions'] = 0if encoding['start_positions'] != 0:val_encoding.append(encoding)if encoding['start_positions'] == 0 and np.random.randint(0, 100) > 99:val_encoding.append(encoding)
  • 测试集数据集处理

test_encoding = []
test_raw_txt = []
for idx in range(len(test_json)):title = test_json.iloc[idx]['title']doc_text = test_json.iloc[idx]['doc_text']query = test_json.iloc[idx]['query']doc_sentence = test_json.iloc[idx]['doc_sentence']for sentence in set(doc_sentence):if len(query + sentence) > 512:continueencoding = tokenizer.encode(query, sentence, max_seq_len=512, return_length=True, return_position_ids=True, pad_to_max_seq_len=True, return_attention_mask=True)test_encoding.append(encoding)test_raw_txt.append([idx, query, sentence])

步骤5:批量数据读取

# 手动将数据集进行批量打包
def data_generator(data_encoding, batch_size = 6):for idx in range(len(data_encoding) // batch_size):batch_data = data_encoding[idx * batch_size : (idx+1) * batch_size]batch_encoding = {}for key in batch_data[0].keys():if key == 'seq_len':continuebatch_encoding[key] = paddle.to_tensor(np.array([x[key] for x in batch_data]))yield batch_encoding

步骤6:模型训练与验证

# 优化器
optimizer = paddle.optimizer.SGD(0.0005, parameters=model.parameters())# 损失函数
loss_fct = paddle.nn.CrossEntropyLoss()
best_val_start_acc = 0for epoch in range(10):# 每次打乱训练集,防止过拟合np.random.shuffle(train_encoding)# 训练部分train_loss = []for batch_encoding in data_generator(train_encoding, 10):# ERNIE正向传播start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])# 计算损失start_loss = loss_fct(start_logits, batch_encoding['start_positions'])end_loss = loss_fct(end_logits, batch_encoding['end_positions'])total_loss = (start_loss + end_loss) / 2# 参数更新total_loss.backward()train_loss.append(total_loss)optimizer.step()optimizer.clear_gradients()# 验证部分val_start_acc = []val_end_acc = []with paddle.no_grad():for batch_encoding in data_generator(val_encoding, 10):# ERNIE正向传播start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])# 计算识别精度start_acc = paddle.mean((start_logits.argmax(1) == batch_encoding['start_positions']).astype(float))end_acc = paddle.mean((end_logits.argmax(1) == batch_encoding['end_positions']).astype(float))val_start_acc.append(start_acc)val_end_acc.append(end_acc)# 转换数据格式为floattrain_loss = paddle.to_tensor(train_loss).mean().item()val_start_acc = paddle.to_tensor(val_start_acc).mean().item()val_end_acc = paddle.to_tensor(val_end_acc).mean().item()# 存储最优模型if val_start_acc > best_val_start_acc:paddle.save(model.state_dict(), 'model.pkl')best_val_start_acc = val_start_acc# 每个epoch打印输出结果print(f'Epoch {epoch}, {train_loss:3f}, {val_start_acc:3f}/{val_end_acc:3f}')
# 关闭dropout
model = model.eval()

步骤7:模型预测

test_start_idx = []
test_end_idx = []# 对测试集中query 和 sentence的情况进行预测
with paddle.no_grad():for batch_encoding in data_generator(test_encoding, 12):start_logits, end_logits = model(batch_encoding['input_ids'], batch_encoding['token_type_ids'])test_start_idx += start_logits.argmax(1).tolist()test_end_idx += end_logits.argmax(1).tolist()if len(test_start_idx) % 500 == 0:print(len(test_start_idx), len(test_encoding))
test_submit = [''] * len(test_json)# 对预测结果进行后处理
for (idx, query, sentence), st_idx, end_idx in zip(test_raw_txt, test_start_idx, test_end_idx):# 如果start 或 end位置识别失败,或 start位置 晚于 end位置if st_idx == 0 or end_idx == 0 or st_idx >= end_idx:continue# 如果start位置在query部分if st_idx - len(query) - 2 < 0:continuetest_submit[idx] += sentence[st_idx - len(query) - 2: end_idx - len(query) - 2]
# 生成提交结果
with open('subtask1_test_pred.txt', 'w') as up:for x in test_submit:if x == '':up.write('1\tNoAnswer\n')else:up.write('1\t'+x+'\n')

改进方向

从精度改变大小,可以从以下几个角度改进:训练数据 > 数据处理 > 模型与预训练 > 模型集成

  • 训练数据:使用全量的训练数据

  • 数据处理:对文档进行切分,现在使用进行切分,后续也可以尝试其他。

  • 模型与预处理:尝试ERNIE版本,或者进行预训练。

  • 模型集成:

    • 尝试不同的数据划分得到不同的模型

    • 尝试不同的文本处理方法得到不同的模型

当然也可以考虑其他数据,如不同的网页拥有答案的概率不同,以及从标题可以判断是否包含答案。

完整代码也可以点击左下角“原文链接”进行查看。

8a3a6a9d6b9346118668e959e333018e.png

整理不易,三连


http://lihuaxi.xjx100.cn/news/254294.html

相关文章

在 Python 中如何实现类的继承,方法重载及重写?

作者 | 苏凉.py来源 | CSDN博客今天我们将进入类的继承以及对类的方法重写及重载的学习&#xff01;话不多说直接进入正题&#xff01;&#xff01;类的继承如果要编写的类是另一个现成类的特殊版本&#xff0c;那我们就可以使用继承。一个类继承另一个类时&#xff0c;将自动获…

是时候了解下 mmap 了

欢迎关注方志朋的博客&#xff0c;回复”666“获面试宝典作者&#xff1a;逗逼程序员链接&#xff1a;https://www.jianshu.com/p/0ce91e10d0261、mmap基础概念mmap 是一种内存映射文件的方法&#xff0c;即将一个文件或者其他对象映射到进程的地址空间&#xff0c;实现文件磁盘…

2018 蓝桥杯省赛 B 组模拟赛(一)-年龄

今天蒜头君带着花椰妹和朋友们一起聚会&#xff0c;当朋友们问起年龄的时候&#xff0c;蒜头君打了一个哑谜&#xff08;毕竟年龄是女孩子的隐私&#xff09;说&#xff1a;“我的年龄是花椰妹年龄个位数和十位数之和的二倍”。 花椰妹看大家一脸懵逼&#xff0c;就知道大家也不…

Numpy入门教程:04. 数学函数

背景 什么是 NumPy 呢&#xff1f; NumPy 这个词来源于两个单词 – Numerical和Python。其是一个功能强大的 Python 库&#xff0c;可以帮助程序员轻松地进行数值计算&#xff0c;通常应用于以下场景&#xff1a; 执行各种数学任务&#xff0c;如&#xff1a;数值积分、微分、…

PHP生成PDF文档的FPDF类

以前在PHP4的早期版本中用PDFlib生成PDF文档比较容易&#xff0c;现在升级到PHP5了&#xff0c;发现更麻烦了&#xff0c;装的PHP 5.2.4默认没有PHPlib&#xff0c;从php.net上找了一个&#xff0c;装上竟一直报错&#xff0c;开始以为是版本兼容问题&#xff0c;后来在租来的服…

Gartner发布:2023年十大重要技术趋势

Datawhale干货 发布&#xff1a;Gartner&#xff0c;方向&#xff1a;技术趋势前言Gartner是全球领先的信息技术研究和咨询公司。日前&#xff0c;Gartner发布企业机构在2023年需要探索的十大战略技术趋势。Gartner杰出研究副总裁Frances Karamouzis表示&#xff1a;“为了在经…

ISA---不能访问网址或是多次刷新才能访问的解决方法一则

当你安装ISA2006在WINDOWS 2003 SERVER上&#xff0c;并打上SP2补订时。遇SNAT客户端不能访问WEB&#xff0c;但能PING通&#xff0c;能TELNET通&#xff0c;也能访问QQ或是MSN的问题时可以利用以下方法解决。同时&#xff0c;如果你遇到在此环境下&#xff0c;客户端访问外部网…

Pandas 表格样式设置指南,看这一篇就够了!

作者 | 阳哥来源 | Python数据之道最近这些年&#xff0c;Python在数据分析以及人工智能领域是越来越火。这离不开pandas、numpy、sklearn、TensorFlow、PyTorch等数据科学包&#xff0c;尤其是 Pandas&#xff0c;几乎是每一个从事Python数据科学相关的同学都绕不过去的。Pand…