NingG +

AI 系列:argmax 函数

torch.argmax() 是 PyTorch 中的一个函数,它的含义是返回张量中最大值的索引位置

1.基本概念

2.简单示例

import torch

# 创建一个张量
tensor = torch.tensor([0.1, 0.8, 0.3, 0.9, 0.2])

# 找到最大值的索引
max_index = torch.argmax(tensor)
print(f"张量: {tensor}")
print(f"最大值索引: {max_index}")  # 输出: 3
print(f"最大值: {tensor[max_index]}")  # 输出: 0.9

3.在你的代码中的应用

next_token_id = torch.argmax(probas).item()

这行代码的含义是:

  1. probas 是一个概率分布张量(通常形状为 [vocab_size]
  2. torch.argmax(probas) 找到概率最高的那个token的索引
  3. .item() 将张量转换为Python标量
  4. next_token_id 就是下一个要生成的token的ID

4.实际应用场景

在语言模型中,这通常用于贪婪解码

5.多维张量的情况

# 2D张量
tensor_2d = torch.tensor([[1, 5, 3], [4, 2, 6]])
print(torch.argmax(tensor_2d))  # 输出: 5 (第1行第2列,值为6)

# 指定维度
print(torch.argmax(tensor_2d, dim=0))  # 每列的最大值索引: [1, 0, 1]
print(torch.argmax(tensor_2d, dim=1))  # 每行的最大值索引: [1, 2]

简单来说,argmax 就是”告诉我最大值在哪里”,而不是”告诉我最大值是多少”。

原文地址:https://ningg.top/ai-series-argmax-intro/
微信公众号 ningg, 联系我

同类文章:

微信搜索: 公众号 ningg, 联系我, 交个朋友.

Top