prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1) # [1, hidden_dim]
prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1) 主要作用是通过将 prompt_embedding 与 attention_weights 相乘后再按指定维度求和,得到一个新的张量 prompt_vector。
代码解释
prompt_embedding:这是一个包含提示词嵌入向量的张量,通常形状为[batch_size, seq_len, hidden_dim],表示批次大小、序列长度和隐藏层维度。attention_weights:这是一个注意力权重张量,形状通常为[batch_size, seq_len],表示每个位置的注意力权重。