Flutter+SpringBoot实现ChatGPT流式输出、上下文了连续对话
最终实现Flutter的流式输出+上下文连续对话。
这里就是提供一个简单版的工具类和使用案例,此处页面仅参考。
服务端
这里直接封装提供工具类,修改自己的apiKey即可使用,支持连续对话
工具类及使用
http依赖这里使用okHttp
<dependency> <groupId>com.squareup.okhttp3groupId> <artifactId>okhttpartifactId> <version>4.9.3version> dependency>
import com.alibaba.fastjson2.JSON;import com.squareup.okhttp.Call;import com.squareup.okhttp.MediaType;import com.squareup.okhttp.OkHttpClient;import com.squareup.okhttp.Request;import com.squareup.okhttp.RequestBody;import com.squareup.okhttp.Response;import com.squareup.okhttp.ResponseBody;import lombok.AllArgsConstructor;import lombok.Builder;import lombok.Data;import lombok.Getter;import lombok.NoArgsConstructor;import lombok.extern.slf4j.Slf4j;import org.springframework.stereotype.Component;import org.springframework.util.StringUtils;import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;import vip.ailtw.common.utils.StringUtil;import javax.annotation.PostConstruct;import java.io.BufferedReader;import java.io.IOException;import java.io.InputStream;import java.io.InputStreamReader;import java.io.Serializable;import java.util.List;import java.util.concurrent.TimeUnit;import java.util.function.Consumer;import java.util.regex.Matcher;import java.util.regex.Pattern;@Slf4j@Componentpublic class ChatGptStreamUtil { private final String apiKey = "xxxxxxxxxxxxxx"; public final String gptCompletionsUrl = "https://api.openai.com/v1/chat/completions"; private static final OkHttpClient client = new OkHttpClient(); private static MediaType mediaType; private static Request.Builder requestBuilder; public final static Pattern contentPattern = Pattern.compile("\"content\":\"(.*?)\"}"); public final static String EVENT_DATA = "d"; public final static String EVENT_ERROR = "e"; public final static String END = "<>" ; @PostConstruct public void init() { client.setConnectTimeout(60, TimeUnit.SECONDS); client.setReadTimeout(60, TimeUnit.SECONDS); mediaType = MediaType.parse("application/json; charset=utf-8"); requestBuilder = new Request.Builder() .url(gptCompletionsUrl) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + apiKey); } public GptChatResultDTO chatStream(List<ChatGptDTO> talkList, Consumer<String> callable) throws Exception { long start = System.currentTimeMillis(); StringBuilder resp = new StringBuilder(); Response response = chatStream(talkList); //解析对话内容 try (ResponseBody responseBody = response.body(); InputStream inputStream = responseBody.byteStream(); BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) { String line; while ((line = bufferedReader.readLine()) != null) { if (!StringUtils.hasLength(line)) { continue; } Matcher matcher = contentPattern.matcher(line); if (matcher.find()) { String content = matcher.group(1); resp.append(content); callable.accept(content); } } } int wordSize = 0; for (ChatGptDTO dto : talkList) { String content = dto.getContent(); wordSize += content.toCharArray().length; } wordSize += resp.toString().toCharArray().length; long end = System.currentTimeMillis(); return GptChatResultDTO.builder().resContent(resp.toString()).time(end - start).wordSize(wordSize).build(); } private Response chatStream(List<ChatGptDTO> talkList) throws Exception { ChatStreamDTO chatStreamDTO = new ChatStreamDTO(talkList); RequestBody bodyOk = RequestBody.create(mediaType, chatStreamDTO.toString()); Request requestOk = requestBuilder.post(bodyOk).build(); Call call = client.newCall(requestOk); Response response; try { response = call.execute(); } catch (IOException e) { throw new IOException("请求时IO异常: " + e.getMessage()); } if (response.isSuccessful()) { return response; } try (ResponseBody body = response.body()) { if (429 == response.code()) { String msg = "Open Api key 已过期,msg: " + body.string(); log.error(msg); } throw new RuntimeException("chat api 请求异常, code: " + response.code() + "body: " + body.string()); } } private boolean sendToClient(String event, String data, SseEmitter emitter) { try { emitter.send(SseEmitter.event().name(event).data("{" + data + "}")); return true; } catch (IOException e) { log.error("向客户端发送消息时出现异常", e); } return false; } public boolean sendData(String data, SseEmitter emitter) { if (StringUtil.isBlank(data)) { return true; } return sendToClient(EVENT_DATA, data, emitter); } public void sendEnd(SseEmitter emitter) { try { sendToClient(EVENT_DATA, END, emitter); } finally { emitter.complete(); } } public void sendError(SseEmitter emitter) { try { sendToClient(EVENT_ERROR, "我累垮了", emitter); } finally { emitter.complete(); } } @Data @NoArgsConstructor @AllArgsConstructor @Builder public static class GptChatResultDTO implements Serializable { private String resContent; private int wordSize; private long time; } @Data @Builder @NoArgsConstructor @AllArgsConstructor public static class ChatGptDTO implements Serializable { private String content; private String role; } @Getter public static enum GptRoleEnum { USER_ROLE("user", "用户"), GPT_ROLE("assistant", "ChatGPT本身"), SYSTEM_ROLE("system", "对话设定"), ; private final String value; private final String desc; GptRoleEnum(String value, String desc) { this.value = value; this.desc = desc; } } @Data public static class ChatStreamDTO { private static final String model = "gpt-3.5-turbo"; private static final boolean stream = true; private List<ChatGptDTO> messages; public ChatStreamDTO(List<ChatGptDTO> messages) { this.messages = messages; } @Override public String toString() { return "{\"model\":\"" + model + "\"," + "\"messages\":" + JSON.toJSONString(messages) + "," + "\"stream\":" + stream + "}"; } }}
使用案例:
public static void main(String[] args) throws Exception { ChatGptStreamUtil chatGptStreamUtil = new ChatGptStreamUtil(); chatGptStreamUtil.init(); //构建一个上下文对话情景 List<ChatGptDTO> talkList = new ArrayList<>(); //设定gpt talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build()); //开始提问 talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build()); chatGptStreamUtil.chatStream(talkList, (respContent) -> { //这里是gpt每次流式返回的内容 System.out.println("gpt返回:" + respContent); }); }
SpringBoot接口
基于SpringBoot工程,提供接口,供Flutter端使用。
通过上面的工具类的使用,可以知道gpt返回给我们的内容是一段一段的,因此如果我们服务端也要提供类似的效果,提供两个思路和实现:
- WebSocket,服务端接收gpt返回的内容时推送内容给flutter
- 使用Http长链接,也就是 SseEmitter,这里也是采用这种方式。
代码:
@RestController@RequestMapping("/chat")@Slf4jpublic class ChatController { @Autowired private ChatGptStreamUtil chatGptStreamUtil; @PostMapping(value = "/chatStream") @ApiOperation("流式对话") public SseEmitter chatStream() { SseEmitter emitter = new SseEmitter(80000L); //构建一个上下文对话情景 List<ChatGptDTO> talkList = new ArrayList<>(); //设定gpt talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build()); //开始提问 talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build()); GptChatResultDTO gptChatResultDTO = chatGptStreamUtil.chatStream(talkList, (content) -> { //这里服务端接收到消息就发送给Flutter chatGptStreamUtil.sendData(content, emitter); }); return emitter; }}
Flutter端
这里使用dio作为网络请求的工具
依赖
dio: ^5.2.1+1
工具类
import 'dart:async';import 'dart:convert';import 'package:dio/dio.dart';import 'package:flutter/cupertino.dart';import 'package:flutter/foundation.dart';import 'package:get/get.dart' hide Response;///http工具类class HttpUtil { Dio? client; static HttpUtil of() { return HttpUtil.init(); } //初始化http工具 HttpUtil.init() { if (client == null) { var options = BaseOptions( baseUrl: Config.baseUrl, connectTimeout: const Duration(seconds: 100), receiveTimeout: const Duration(seconds: 100)); client = Dio(options); // 请求与响应拦截器/异常拦截器 client?.interceptors.add(OnReqResInterceptors()); } } Future<Stream<String>?> postStream(String path, [Map<String, dynamic>? params]) async { Response<ResponseBody> rs = await Dio().post<ResponseBody>(Config.baseUrl + path, options: Options(headers: { "Accept": "text/event-stream", "Cache-Control": "no-cache" }, responseType: ResponseType.stream), data: params ); StreamTransformer<Uint8List, List<int>> unit8Transformer = StreamTransformer.fromHandlers( handleData: (data, sink) { sink.add(List<int>.from(data)); }, ); var resp = rs.data?.stream .transform(unit8Transformer) .transform(const Utf8Decoder()) .transform(const LineSplitter()); return resp; }/// Dio 请求与响应拦截器class OnReqResInterceptors extends InterceptorsWrapper { Future<void> onRequest( RequestOptions options, RequestInterceptorHandler handler) async { //统一添加token var headers = options.headers; headers['Authorization'] = '请求头token'; return super.onRequest(options, handler); } void onError(DioError err, ErrorInterceptorHandler handler) { if (err.type == DioErrorType.unknown) { // 网络不可用,请稍后再试 } return super.onError(err, handler); } void onResponse( Response<dynamic> response, ResponseInterceptorHandler handler) { Response res = response; return super.onResponse(res, handler); }}
使用
//构建文章、流式对话 chatStream() async { final stream = await HttpUtil.of().postStream("/api/chat/chatStream"); String respContent = ""; stream?.listen((content) { debugPrint(content); if (content != '' && content.contains("data:")) { //解析数据 var start = content.indexOf("{") + 1; var end = content.indexOf("}"); var substring = content.substring(start, end); content = substring; respContent += content; print("返回的内容:$content"); } }); }