文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

LangGraph实战:从零分阶打造人工智能航空客服助手

2024-11-29 22:38

关注

完成本教程后,你不仅会拥有一个功能完备的机器人,还将深入理解LangGraph的核心理念和架构设计。这些知识将帮助你在其他人工智能项目中运用相似的设计模式。

由于内容较多,本文将由浅入深,分四个阶段进行讲解,每个阶段都将打造出一个具备以上描述所有能力的机器人。但受限于LLM的能力,初期阶段的机器人的运行可能存在各类问题,但都将在后续阶段得到解决。

你最终完成的聊天机器人将类似于以下示意图:

最终示意图

现在,让我们开启这第一阶段段的学习之旅吧!

准备工作

在开始之前,我们需要搭建好环境。本教程将安装一些必要的先决条件,包括下载测试用的数据库,并定义一些在后续各部分中会用到的工具。

我们会使用 Claude 作为语言模型(LLM),并创建一些定制化的工具。这些工具大多数会连接到本地的 SQLite 数据库,无需额外依赖。此外,我们还会通过 Tavily 为代理提供网络搜索功能。

%%capture --no-stderr
% pip install -U langgraph langchain-community langchain-anthropic tavily-python pandas

数据库初始化

接下来,执行下面的脚本来获取我们为这个教程准备的 SQLite 数据库,并更新它以反映当前的数据状态。具体细节不是重点。

import os
import requests
import sqlite3
import pandas as pd
import shutil

# 下载数据库文件
db_url = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
local_file = "travel2.sqlite"
backup_file = "travel2.backup.sqlite"
overwrite = False
if not os.path.exists(local_file) or overwrite:
    response = requests.get(db_url)
    response.raise_for_status()  # 确保请求成功
    with open(local_file, "wb") as file:
        file.write(response.content)

# 创建数据库备份,以便在每个教程部分开始时重置数据库状态
shutil.copy(local_file, backup_file)

# 将航班数据更新为当前时间,以适应我们的教程
conn = sqlite3.connect(local_file)
cursor = conn.cursor()

# 读取数据库中的所有表
tables = pd.read_sql(
    "SELECT name FROM sqlite_master WHERE type='table';", conn
).name.tolist()
tdf = {}
for table_name in tables:
    tdf[table_name] = pd.read_sql(f"SELECT * from {table_name}", conn)

# 找到最早的出发时间,并计算时间差
example_time = pd.to_datetime(
    tdf["flights"]["actual_departure"].replace("\\N", pd.NaT)
).max()
current_time = pd.to_datetime("now").tz_localize(example_time.tz)
time_diff = current_time - example_time

# 更新预订日期和航班时间
for column in ["book_date", "scheduled_departure", "scheduled_arrival", "actual_departure", "actual_arrival"]:
    tdf["flights"][column] = pd.to_datetime(
        tdf["flights"][column].replace("\\N", pd.NaT)
    ) + time_diff

# 将更新后的数据写回数据库
for table_name, df in tdf.items():
    df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.commit()
conn.close()

# 在本教程中,我们将使用这个本地文件作为数据库
db = local_file

工具定义

现在,我们来定义一些工具,以便助手可以搜索航空公司的政策手册,以及搜索和管理航班、酒店、租车和远足活动的预订。这些工具将在教程的各个部分中重复使用,具体的实现细节不是关键。

查询公司政策

助手需要检索政策信息来回答用户的问题。请注意,这些政策的实施还需要在工具或 API 中进行,因为语言模型可能会忽略这些信息。以下工具受限于篇幅将仅提供定义及描述,详细代码[1]可在github上获取。

import re
import numpy as np
import openai
from langchain_core.tools import tool


@tool
def lookup_policy(query):
    """查询公司政策,以确定某些选项是否允许。"""

航班管理

定义一个工具来获取用户的航班信息,然后定义一些工具来搜索航班和管理用户的预订信息,这些信息存储在 SQL 数据库中。

我们使用 ensure_config 来通过配置参数传递 passenger_id。语言模型不需要显式提供这些信息,它们会在图的每次调用中提供,以确保每个用户无法访问其他乘客的预订信息。

from langchain_core.runnables import ensure_config
from typing import Optional
import sqlite3
import pytz
from datetime import datetime, timedelta, date

@tool
def fetch_user_flight_information():
    """获取用户的所有机票信息,包括航班详情和座位分配。"""

@tool
def search_flights(
    departure_airport=None,
    arrival_airport=None,
    start_time=None,
    end_time=None,
    limit=20,
):
    """根据出发机场、到达机场和出发时间范围来搜索航班。"""

@tool
def update_ticket_to_new_flight(ticket_no, new_flight_id):
    """将用户的机票更新到一个新的有效航班上。"""

@tool
def cancel_ticket(ticket_no):
    """取消用户的机票,并从数据库中移除。"""

租车服务

用户预订了航班后,可能需要租车服务。定义一些工具,让用户能够在目的地搜索和预订汽车。

from typing import Optional, Union
from datetime import datetime, date

@tool
def search_car_rentals(
    locatinotallow=None,
    name=None,
    price_tier=None,
    start_date=None,
    end_date=None,
):
    """
    根据位置、公司名称、价格等级、开始日期和结束日期来搜索租车服务。

    参数:
        location (Optional[str]): 租车服务的位置。
        name (Optional[str]): 租车公司的名称。
        price_tier (Optional[str]): 租车的价格等级。
        start_date (Optional[Union[datetime, date]]): 租车的开始日期。
        end_date (Optional[Union[datetime, date]]): 租车的结束日期。

    返回:
        list[dict]: 匹配搜索条件的租车服务列表。
    """

@tool
def book_car_rental(rental_id):
    """
    通过租车ID来预订租车服务。

    参数:
        rental_id (int): 要预订的租车服务的ID。

    返回:
        str: 预订成功与否的消息。
    """

@tool
def update_car_rental(
    rental_id,
    start_date=None,
    end_date=None,
):
    """
    通过租车ID来更新租车服务的开始和结束日期。

    参数:
        rental_id (int): 要更新的租车服务的ID。
        start_date (Optional[Union[datetime, date]]): 新的租车开始日期。
        end_date (Optional[Union[datetime, date]]): 新的租车结束日期。

    返回:
        str: 更新成功与否的消息。
    """

@tool
def cancel_car_rental(rental_id):
    """
    通过租车ID来取消租车服务。

    参数:
        rental_id (int): 要取消的租车服务的ID。

    返回:
        str: 取消成功与否的消息。
    """

酒店预订

用户需要住宿,因此定义一些工具来搜索和管理酒店预订。

@tool
def search_hotels(
    locatinotallow=None,
    name=None,
    price_tier=None,
    checkin_date=None,
    checkout_date=None,
):
    """
    根据位置、名称、价格等级、入住日期和退房日期来搜索酒店。

    参数:
        location (Optional[str]): 酒店的位置。
        name (Optional[str]): 酒店的名称。
        price_tier (Optional[str]): 酒店的价格等级。
        checkin_date
        
        # 入住日期和退房日期,用于搜索酒店
        checkin_date (Optional[Union[datetime, date]]): 酒店的入住日期。
        checkout_date (Optional[Union[datetime, date]]): 酒店的退房日期。

    返回:
        list[dict]: 符合搜索条件的酒店列表。
    """
    
@tool
def book_hotel(hotel_id):
    """
    通过酒店ID进行预订。

    参数:
        hotel_id (int): 要预订的酒店的ID。

    返回:
        str: 预订成功与否的消息。
    """
    
@tool
def update_hotel(
    hotel_id,
    checkin_date=None,
    checkout_date=None,
):
    """
    通过酒店ID更新酒店预订的入住和退房日期。

    参数:
        hotel_id (int): 要更新预订的酒店的ID。
        checkin_date (Optional[Union[datetime, date]]): 新的入住日期。
        checkout_date (Optional[Union[datetime, date]]): 新的退房日期。

    返回:
        str: 更新成功与否的消息。
    """
    
@tool
def cancel_hotel(hotel_id):
    """
    通过酒店ID取消酒店预订。

    参数:
        hotel_id (int): 要取消预订的酒店的ID。

    返回:
        str: 取消成功与否的消息。
    """

远足活动

最后,定义一些工具,让用户在到达目的地后搜索活动并进行预订。

@tool
def search_trip_recommendations(
    locatinotallow=None,
    name=None,
    keywords=None,
):
    """
    根据位置、名称和关键词搜索旅行推荐。

    参数:
        location (Optional[str]): 旅行推荐的地点。
        name (Optional[str]): 旅行推荐的名字。
        keywords (Optional[str]): 与旅行推荐相关的关键词。

    返回:
        list[dict]: 符合搜索条件的旅行推荐列表。
    """
    
@tool
def book_excursion(recommendation_id):
    """
    通过推荐ID预订远足活动。

    参数:
        recommendation_id (int): 要预订的旅行推荐的ID。

    返回:
        str: 预订成功与否的消息。
    """
    
@tool
def update_excursion(recommendation_id, details):
    """
    通过推荐ID更新旅行推荐的细节。

    参数:
        recommendation_id (int): 要更新的旅行推荐的ID。
        details (str): 旅行推荐的新细节。

    返回:
        str: 更新成功与否的消息。
    """
    
@tool
def cancel_excursion(recommendation_id):
    """
    通过推荐ID取消旅行推荐。

    参数:
        recommendation_id (int): 要取消的旅行推荐的ID。

    返回:
        str: 取消成功与否的消息。
    """

实用工具

定义一些辅助函数,以便在调试过程中美化图形中的消息显示,并为工具节点添加错误处理(通过将错误添加到聊天记录中)。

from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda

def handle_tool_error(state):
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                cnotallow=f"错误: {repr(error)}\n请修正你的错误。",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

def create_tool_node_with_fallback(tools):
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

def _print_event(event, _printed, max_length=1500):
    current_state = event.get("dialog_state")
    if current_state:
        print(f"当前状态: ", current_state[-1])
    message = event.get("messages")
    if message:
        if isinstance(message, list):
            message = message[-1]
        if message.id not in _printed:
            msg_repr = message.pretty_repr(html=True)
            if len(msg_repr) > max_length:
                msg_repr = msg_repr[:max_length] + " ... (内容已截断)"
            print(msg_repr)
            _printed.add(message.id)

第一部分:零样本代理

在构建任何系统时,最佳实践是从最简单的可行方案开始,并通过使用类似LangSmith这样的评估工具来测试其有效性。在条件相同的情况下,我们倾向于选择简单且可扩展的解决方案,而不是复杂的方案。然而,单一图谱方法存在一些限制,比如机器人可能在未经用户确认的情况下执行不希望的操作,处理复杂查询时可能遇到困难,或者在回答时缺乏针对性。这些问题我们会在后续进行改进。 在这部分,我们将定义一个简单的零样本代理作为用户的助手,并将所有工具赋予给它。我们的目标是引导它明智地使用这些工具来帮助用户。 我们的简单两节点图如下所示:

第一部分图解

首先,我们定义状态。

状态

我们将StateGraph的状态定义为一个包含消息列表的类型化字典。这些消息构成了聊天的记录,也就是我们简单助手所需要的全部状态信息。

from langgraph.graph.message import add_messages, AnyMessage
from typing_extensions import TypedDict
from typing import Annotated


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

代理

然后,我们定义助手函数。这个函数接收图的状态,将其格式化为提示,然后调用一个大型语言模型(LLM)来预测最佳的响应。

from langchain_core.runnables import Runnable, RunnableConfig
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate


class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            passenger_id = config.get("passenger_id", None)
            state = {**state, "user_info": passenger_id}
            result = self.runnable.invoke(state)
            # 如果大型语言模型返回了一个空响应,我们将重新提示它给出一个实际的响应。
            if (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "请给出一个真实的输出。")]
                state = {**state, "messages": messages}
            else:
                break
        return {"messages": result}


# Haiku模型更快、成本更低,但准确性稍差
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
llm = ChatAnthropic(model="claude-3-sonnet-20240229", temperature=1)

primary_assistant_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你是一个为瑞士航空提供帮助的客户支持助手。"
            "使用提供的工具来搜索航班、公司政策和其他信息以帮助回答用户的查询。"
            "在搜索时,要有毅力。如果第一次搜索没有结果,就扩大你的查询范围。"
            "如果搜索结果为空,不要放弃,先扩大搜索范围。"
            "\n\n当前用户:\n\n{user_info}\n"
            "\n当前时间:{time}。",
        ),
        ("placeholder", "{messages}"),
    ]
).partial(time=datetime.now())

part_1_tools = [
    TavilySearchResults(max_results=1),
    fetch_user_flight_information,
    search_flights,
    lookup_policy,
    update_ticket_to_new_flight,
    cancel_ticket,
    search_car_rentals,
    book_car_rental,
    update_car_rental,
    cancel_car_rental,
    search_hotels,
    book_hotel,
    update_hotel,
    cancel_hotel,
    search_trip_recommendations,
    book_excursion,
    update_excursion,
    cancel_excursion,
]
part_1_assistant_runnable = primary_assistant_prompt | llm.bind_tools(part_1_tools)

定义图

现在,我们来创建图。这张图是我们这部分的最终助手。

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import tools_condition, ToolNode

builder = StateGraph(State)


# 定义节点:这些节点执行具体的工作
builder.add_node("assistant", Assistant(part_1_assistant_runnable))
builder.add_node("action", create_tool_node_with_fallback(part_1_tools))
# 定义边:这些边决定了控制流程如何移动
builder.set_entry_point("assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
    # "action"调用我们的工具之一。END导致图终止(并向用户做出响应)
    {"action": "action", END: END},
)
builder.add_edge("action", "assistant")

# 检查点器允许图保存其状态
# 这是整个图的完整记忆。
memory = SqliteSaver.from_conn_string(":memory:")
part_1_graph = builder.compile(checkpointer=memory)

from IPython.display import Image, display

try:
    display(Image(part_1_graph.get_graph(xray=True).draw_mermaid_png()))
except:
    # 这需要一些额外的依赖项,是可选的
    pass

示例对话

现在,让我们通过一系列对话示例来测试我们的聊天机器人。

import uuid
import shutil

# 假设这是用户与助手之间可能发生的对话示例
tutorial_questions = [
    "你好,我的航班是什么时候?",
    "我可以把我的航班改签到更早的时间吗?我想今天晚些时候离开。",
    "那就把我的航班改签到下周某个时间吧",
    "下一个可用的选项很好",
    "住宿和交通方面有什么建议?",
    "我想在为期一周的住宿中选择一个经济实惠的酒店(7天),并且我还想租一辆车。",
    "好的,你能为你推荐的酒店预订吗?听起来不错。",
    "是的,去预订任何中等价位且有可用性的酒店。",
    "对于汽车,我有哪些选择?",
    "太棒了,我们只选择最便宜的选项。预订7天。",
    "那么,你对我的旅行有什么建议?",
    "在我在那里的时候,有哪些活动是可用的?",
    "有趣 - 我喜欢博物馆,有哪些选择?",
    "好的,那就为我在那里的第二天预订一个。",
]

# 使用备份文件以便我们可以从每个部分的原始位置重新启动
shutil.copy(backup_file, db)
thread_id = str(uuid.uuid4())

config = {
    "configurable": {
        # passenger_id 在我们的航班工具中使用
        # 以获取用户的航班信息
        "passenger_id": "3442 587242",
        # 检查点通过 thread_id 访问
        "thread_id": thread_id,
    }
}


_printed = set()
for question in tutorial_questions:
    events = part_1_graph.stream(
        {"messages": ("user", question)}, config, stream_mode="values"
    )
    for event in events:
        _print_event(event, _printed)

来源:AI小智内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯