1. 贪婪采样 (Greedy Sampling)
贪婪采样是一种直接选择最可能的下一个词的策略。
具体步骤为:从模型输出的logits中,找到概率最大的那个词,直接选择它作为输出。
实现代码:
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
return logits.argmax(dim=-1)
优点:
- 简单且计算效率高。
- 保证每一步选择最有可能的结果。
缺点:
- 可能会导致生成的文本非常重复和缺乏多样性。
- 贪婪采样只关注当前概率最大的词,忽略了其他潜在的好选择,容易陷入局部最优解。
2. 带温度的采样 (Temperature Sampling)
温度采样通过引入一个温度参数来调整输出概率的分布,以控制生成文本的多样性。温度 T 的作用是平滑或锐化概率分布:
- 当 T = 1 时,采样为标准随机采样。
- 当 T < 1 时,概率分布变得更尖锐,模型更倾向于选择最可能的词。
- 当 T > 1 时,概率分布变得更加平滑,模型会更多地探索低概率的词。
实现代码
class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
dist = Categorical(logits=logits / self.temperature)
return dist.sample()
优点:
- 提供了生成文本的多样性,尤其是在温度高时。
- 通过调整温度参数,可以控制探索(随机性)与利用(选择高概率词)之间的平衡。
缺点:
- 温度的选择需要仔细调节,不同任务或场景下对温度的需求可能不同。
- 温度过低时,生成的文本趋向于贪婪采样;温度过高时,生成的文本可能过于随机。
3. Top-k采样
Top-k采样限制了每次生成时候的候选词数量,模型只会从概率前k个最高的词中进行采样,而忽略其他可能性较小的词。
实现代码:
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
优点:
- 提供了对生成词汇的严格控制,避免生成概率非常低的词。
- 通过限制候选词的数量,避免了一些罕见或不合逻辑的词被选中。
缺点:
- 需要设定一个合适的 k 值,如果 k 值太小,生成的文本可能会缺乏多样性;如果 k 值太大,则效果与标准采样相似。
4. 核采样 (Nucleus Sampling)
核采样是一种自适应的采样方法,它选择的候选词集合 V(p) 是满足累计概率和大于或等于给定阈值 p 的最小词汇子集。与Top-k采样不同,核采样的候选词数量不是固定的,而是基于累计概率动态确定的。
示例
假设同样的语境:“今天的天气很”,但这次我们将会有不同的词汇及其概率分布,我们也会使用不同的阈值 ( p ) 来展示如何动态确定选词数量。
(1) 模型预测的词汇概率
- 好:0.4
- 冷:0.3
- 热:0.2
- 潮湿:0.05
- 多变:0.03
- 干燥:0.02
(2) 排序与累积概率
按概率从高到低排序并计算累积概率:
- 好:0.4
- 冷:0.7 (0.4 + 0.3)
- 热:0.9 (0.7 + 0.2)
- 潮湿:0.95 (0.9 + 0.05)
- 多变:0.98 (0.95 + 0.03)
- 干燥:1.00 (0.98 + 0.02)
(3) 确定核集合
这次,我们将选择不同的阈值 ( p ) 来观察核集合如何变化:
- **当 ( p = 0.7 )**:核集合包括:“好”和“冷”,因为它们的累积概率首次超过 0.7。
- **当 ( p = 0.9 )**:核集合扩展到:“好”,“冷”,和“热”,因为它们的累积概率首次超过 0.9。
- **当 ( p = 0.95 )**:核集合进一步扩展到:“好”,“冷”,“热”和“潮湿”,因为这是累积概率首次超过 0.95。
(4) 抽样
在每种情况下,我们从对应的核集合中随机选取一个词作为下一个词。选择的范围和多样性取决于 ( p ) 值的大小,而词的数量是根据这个阈值动态确定的,不是固定的。
实现代码
class NucleusSampler(Sampler):
"""
## Nucleus 采样器
Nucleus 采样器根据给定的概率 p 选择词汇的一个子集,并从中进行采样。
"""
def __init__(self, p: float, sampler: Sampler):
"""
### 初始化
:param p: 要选择的令牌概率之和,即 p 值。
:param sampler: 用于从选定令牌中进行采样的采样器。
"""
# 保存 p 值
self.p = p
# 保存采样器
self.sampler = sampler
# 初始化 softmax 层,用于将 logits 转换为概率
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
### 从 logits 中进行 Nucleus 采样
:param logits: 输入的 logits 张量,形状为 (batch_size, num_tokens)。
:return: 采样得到的令牌索引,形状为 (batch_size,)。
"""
# 获取概率 P(x_i | x_1:i-1)
probs = self.softmax(logits)
# 按降序对概率进行排序,并获取排序后的索引
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# 按排序顺序获取概率的累积总和
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找出累积总和小于 p 的令牌
nucleus = cum_sum_probs < self.p
# 在前面加一个 True,这样我们可以在累积概率小于 p 的最小令牌数量之后添加一个令牌
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# 获取对数概率并掩盖非核部分
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# 使用采样器从排序后的对数概率中进行采样
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# 获取实际的索引
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
# 返回采样得到的令牌索引
return res.squeeze(-1)
优点:
- 灵活性强,自动调整候选词集合,避免了固定的词数限制。
- 在生成文本时能够更好地平衡多样性与高概率词的利用,表现优于Top-k采样。
缺点:
- 参数 p 的选择需要调节,不同任务可能需要不同的 p 值。
- 计算复杂度较高,尤其是当处理较大的词汇表时。
总结
采样方法 | 优点 | 缺点 |
贪婪采样 | 简单、高效,始终选择最有可能的词 | 文本生成可能单一,缺乏多样性 |
温度采样 | 通过调整温度控制多样性,适应性强 | 温度的调节需要谨慎,过高或过低的温度可能产生不理想的结果 |
Top-k采样 | 控制候选词数量,避免选择低概率词 |
|
核采样 | 动态选择候选词集合,更灵活,生成文本质量较高 | 参数 |
每种采样方式都有其适用的场景,根据具体的应用和对生成文本的要求,可以选择不同的采样策略。