文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

LLM模型贪婪、温度、Top-k、核采样方式的区别(附代码与示例)

2024-11-29 18:21

关注

1. 贪婪采样 (Greedy Sampling)

贪婪采样是一种直接选择最可能的下一个词的策略。

具体步骤为:从模型输出的logits中,找到概率最大的那个词,直接选择它作为输出。

实现代码:

class GreedySampler(Sampler):
    def __call__(self, logits: torch.Tensor):
        return logits.argmax(dim=-1)

优点:

缺点:

2. 带温度的采样 (Temperature Sampling)

温度采样通过引入一个温度参数来调整输出概率的分布,以控制生成文本的多样性。温度 T 的作用是平滑或锐化概率分布:

实现代码

class TemperatureSampler(Sampler):
    def __init__(self, temperature: float = 1.0):
        self.temperature = temperature

    def __call__(self, logits: torch.Tensor):
        dist = Categorical(logits=logits / self.temperature)
        return dist.sample()

优点:

缺点:

3. Top-k采样

Top-k采样限制了每次生成时候的候选词数量,模型只会从概率前k个最高的词中进行采样,而忽略其他可能性较小的词。

实现代码:

class TopKSampler(Sampler):
    def __init__(self, k: int, sampler: Sampler):
        self.k = k
        self.sampler = sampler

    def __call__(self, logits: torch.Tensor):
        zeros = logits.new_ones(logits.shape) * float('-inf')
        values, indices = torch.topk(logits, self.k, dim=-1)
        zeros.scatter_(-1, indices, values)
        return self.sampler(zeros)

优点:

缺点:

4. 核采样 (Nucleus Sampling)

核采样是一种自适应的采样方法,它选择的候选词集合 V(p) 是满足累计概率和大于或等于给定阈值 p 的最小词汇子集。与Top-k采样不同,核采样的候选词数量不是固定的,而是基于累计概率动态确定的。

示例

假设同样的语境:“今天的天气很”,但这次我们将会有不同的词汇及其概率分布,我们也会使用不同的阈值 ( p ) 来展示如何动态确定选词数量。

(1) 模型预测的词汇概率

(2) 排序与累积概率

按概率从高到低排序并计算累积概率:

(3) 确定核集合

这次,我们将选择不同的阈值 ( p ) 来观察核集合如何变化:

(4) 抽样

在每种情况下,我们从对应的核集合中随机选取一个词作为下一个词。选择的范围和多样性取决于 ( p ) 值的大小,而词的数量是根据这个阈值动态确定的,不是固定的。

实现代码

class NucleusSampler(Sampler):
    """
    ## Nucleus 采样器

    Nucleus 采样器根据给定的概率 p 选择词汇的一个子集,并从中进行采样。
    """

    def __init__(self, p: float, sampler: Sampler):
        """
        ### 初始化

        :param p: 要选择的令牌概率之和,即 p 值。
        :param sampler: 用于从选定令牌中进行采样的采样器。
        """
        # 保存 p 值
        self.p = p
        # 保存采样器
        self.sampler = sampler
        # 初始化 softmax 层,用于将 logits 转换为概率
        self.softmax = nn.Softmax(dim=-1)

    def __call__(self, logits: torch.Tensor):
        """
        ### 从 logits 中进行 Nucleus 采样

        :param logits: 输入的 logits 张量,形状为 (batch_size, num_tokens)。
        :return: 采样得到的令牌索引,形状为 (batch_size,)。
        """
        # 获取概率 P(x_i | x_1:i-1)
        probs = self.softmax(logits)

        # 按降序对概率进行排序,并获取排序后的索引
        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)

        # 按排序顺序获取概率的累积总和
        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)

        # 找出累积总和小于 p 的令牌
        nucleus = cum_sum_probs < self.p

        # 在前面加一个 True,这样我们可以在累积概率小于 p 的最小令牌数量之后添加一个令牌
        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

        # 获取对数概率并掩盖非核部分
        sorted_log_probs = torch.log(sorted_probs)
        sorted_log_probs[~nucleus] = float('-inf')

        # 使用采样器从排序后的对数概率中进行采样
        sampled_sorted_indexes = self.sampler(sorted_log_probs)

        # 获取实际的索引
        res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))

        # 返回采样得到的令牌索引
        return res.squeeze(-1)

优点:

缺点:

总结

采样方法

优点

缺点

贪婪采样

简单、高效,始终选择最有可能的词

文本生成可能单一,缺乏多样性

温度采样

通过调整温度控制多样性,适应性强

温度的调节需要谨慎,过高或过低的温度可能产生不理想的结果

Top-k采样

控制候选词数量,避免选择低概率词

k 值选择需要调节,k 太小可能导致文本单一

核采样

动态选择候选词集合,更灵活,生成文本质量较高

参数 p 需要调节,计算复杂度较高

每种采样方式都有其适用的场景,根据具体的应用和对生成文本的要求,可以选择不同的采样策略。

来源:coding日记内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯