def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled