文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

sentence-transformers(SBert)中文文本相似度预测(附代码)

2023-09-20 14:34

关注

在这里插入图片描述

前言

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util# 后端接口from flask import Flask, jsonify, requestimport re# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容app = Flask(__name__)# 使通过jsonify返回的中文显示正常,否则显示为ASCII码app.config["JSON_AS_ASCII"] = Falsemodel_path = 'D:/xxx模型路径/'model = SentenceTransformer(model_path)@app.route("/evaluate",methods=['POST'])def evalute_sentence():    s1 = request.json.get("s1")    s2 = request.json.get("s2")    if s1 and s2:        embedding1 = model.encode(s1, convert_to_tensor=True)        embedding2 = model.encode(s2, convert_to_tensor=True)        similarity = util.cos_sim(embedding1, embedding2).tolist()        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})    else:        return jsonify({"code": 400, "msg": "缺少字段"})if __name__ == '__main__':    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>  <div class="recommend">    <el-card class="box">      <h2 class="title">中文文本相似度预测</h2>      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">        <el-form-item prop="s1">          <el-input            placeholder="请输入句子一"            maxlength="32"            show-word-limit            v-model="evaluateForm.s1"            autocomplete="false"            prefix-icon="el-icon-edit-outline"          ></el-input>        </el-form-item>        <el-form-item prop="s2">          <el-input            maxlength="32"            placeholder="请输入句子二"            v-model="evaluateForm.s2"            show-word-limit            autocomplete="false"            prefix-icon="el-icon-edit-outline"          ></el-input>        </el-form-item>        <el-form-item class="btn-container">          <el-button            type="primary"            @click="submitForm('evaluateForm')"            class="btn"            id="queryButton"          >开始预测</el-button>        </el-form-item>      </el-form>      <div v-show="result" style="margin-top: 20px">        <el-progress          :text-inside="true"          :stroke-width="26"          :percentage="result*100 ? result*100 : 0"          class="el-bg-inner-running"        ></el-progress>        <p>预测结果:{{result}}</p>      </div>    </el-card>  </div></template><script>import api from "@/api/index"export default {  data () {    return {      evaluateForm: {        s1: "",        s2: ""      },      evaluateRules: { // 评估表单校验规则        s1: [          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },        ],        s2: [          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },        ],      },      result: undefined,    }  },  methods: {    postEvaluate () { // 调用接口      api.postEvaluate(this.evaluateForm)        .then((res) => {          if (!res) {            return          }          console.log("res", res)          if (res.data.code !== 200) {            this.$message({              message: "请求失败",              type: "error"            })            return          }          let data = res.data.data[0]          this.result = data[0]          console.log("this.result", this.result)          this.$message({            message: "预测成功!",            type: "success"          })        })        .catch((error) => {          this.$message.error('资源获取错误!')        })    },    submitForm (formName) { // 提交表单      this.$refs[formName].validate((valid) => {        if (valid) {          this.postEvaluate()        } else {          this.$message({            message: "请按要求填写",            type: "warning"          })          console.log('error in submit form')          return false        }      })      document.getElementById("queryButton").blur()    },  }}</script><style lang="scss" scoped>.recommend {  width: 100%;  height: 100%;  text-align: center;  display: flex;  text-align: center;  flex-direction: column;  align-items: center;  justify-content: center;  overflow: hidden;  background: #00416a 0 / cover fixed;   background: -webkit-linear-gradient(    to right,    #00416a,    #e4e5e6  );   background: linear-gradient(    to right,    #00416a,    #e4e5e6  );   .box {    width: 48%;    height: 60%;    position: relative;    background: hsla(0, 0%, 100%, 0.3);    z-index: 5;    padding: 10px 20px;    // display: flex;    // flex-direction: column;    // justify-content: center;    box-sizing: border-box;    &::before {      content: '';      position: absolute;      top: 0;      right: 0;      bottom: 0;      left: 0;      filter: blur(20px);    }    .title {      color: #143b54;    }    .btn-container {      margin: 10px auto;      .btn {        width: 100%;        border-radius: 20px;      }    }  }}::v-deep .el-card {  border: 0;  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);}::v-deep .el-progress-bar__outer {  border: 0;  background-color: transparent;  // background-color: #abcbe0;}::v-deep .el-bg-inner-running .el-progress-bar__inner {  background: #9cecfb;   background: -webkit-linear-gradient(    to left,    #0052d4,    #65c7f7,    #9cecfb  );   background: linear-gradient(    to left,    #0052d4,    #65c7f7,    #9cecfb  ); }</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

在这里插入图片描述

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

在这里插入图片描述

最后

代码已上传至sbert中文文本相似度预测,欢迎star!

来源地址:https://blog.csdn.net/weixin_54218079/article/details/128687878

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯