AI 系列:argmax 函数
2025-04-12
torch.argmax()
是 PyTorch 中的一个函数,它的含义是返回张量中最大值的索引位置。
1.基本概念
- argmax = argument of maximum(最大值的参数/索引)
- 它不返回最大值本身,而是返回
最大值
在张量中的位置索引
2.简单示例
import torch
# 创建一个张量
tensor = torch.tensor([0.1, 0.8, 0.3, 0.9, 0.2])
# 找到最大值的索引
max_index = torch.argmax(tensor)
print(f"张量: {tensor}")
print(f"最大值索引: {max_index}") # 输出: 3
print(f"最大值: {tensor[max_index]}") # 输出: 0.9
3.在你的代码中的应用
next_token_id = torch.argmax(probas).item()
这行代码的含义是:
probas
是一个概率分布张量(通常形状为[vocab_size]
)torch.argmax(probas)
找到概率最高的那个token的索引.item()
将张量转换为Python标量next_token_id
就是下一个要生成的token的ID
4.实际应用场景
在语言模型中,这通常用于贪婪解码:
- 模型输出每个可能token的概率分布
argmax
选择概率最高的token作为下一个预测- 这是最简单的文本生成策略(还有其他策略如采样、束搜索等)
5.多维张量的情况
# 2D张量
tensor_2d = torch.tensor([[1, 5, 3], [4, 2, 6]])
print(torch.argmax(tensor_2d)) # 输出: 5 (第1行第2列,值为6)
# 指定维度
print(torch.argmax(tensor_2d, dim=0)) # 每列的最大值索引: [1, 0, 1]
print(torch.argmax(tensor_2d, dim=1)) # 每行的最大值索引: [1, 2]
简单来说,argmax
就是”告诉我最大值在哪里”,而不是”告诉我最大值是多少”。
原文地址:https://ningg.top/ai-series-argmax-intro/