AI 系列:argmax 函数
2025-04-12
torch.argmax() 是 PyTorch 中的一个函数,它的含义是返回张量中最大值的索引位置。
1.基本概念
- argmax = argument of maximum(最大值的参数/索引),argument 自变量
- 它不返回最大值本身,而是返回最大值在张量中的位置索引
argmax: arg + max,找到让max(·)取得最大值的自变量(的位置/取值)。
2.简单示例
import torch
# 创建一个张量
tensor = torch.tensor([0.1, 0.8, 0.3, 0.9, 0.2])
# 找到最大值的索引
max_index = torch.argmax(tensor)
print(f"张量: {tensor}")
print(f"最大值索引: {max_index}")  # 输出: 3
print(f"最大值: {tensor[max_index]}")  # 输出: 0.9
3.在你的代码中的应用
next_token_id = torch.argmax(probas).item()
这行代码的含义是:
- probas是一个概率分布张量(通常形状为- [vocab_size])
- torch.argmax(probas)找到概率最高的那个token的索引
- .item()将张量转换为Python标量
- next_token_id就是下一个要生成的token的ID
4.实际应用场景
在语言模型中,这通常用于贪婪解码:
- 模型输出每个可能token的概率分布
- argmax选择概率最高的token作为下一个预测
- 这是最简单的文本生成策略(还有其他策略如采样、束搜索等)
5.多维张量的情况
# 2D张量
tensor_2d = torch.tensor([[1, 5, 3], [4, 2, 6]])
print(torch.argmax(tensor_2d))  # 输出: 5 (第1行第2列,值为6)
# 指定维度
print(torch.argmax(tensor_2d, dim=0))  # 每列的最大值索引: [1, 0, 1]
print(torch.argmax(tensor_2d, dim=1))  # 每行的最大值索引: [1, 2]
简单来说,argmax 就是”告诉我最大值在哪里”,而不是”告诉我最大值是多少”。
原文地址:https://ningg.top/ai-series-argmax-intro/

