Commit 24892520 authored by brkirch's avatar brkirch

`torch.empty` can create issues; use `torch.zeros`

For MPS, using a tensor created with `torch.empty()` can cause `torch.baddbmm()` to include NaNs in the tensor it returns, even though `beta=0`. However, with a tensor of shape [1,1,1], there should be a negligible performance difference between `torch.empty()` and `torch.zeros()` anyway, so it's better to just use `torch.zeros()` for this and avoid unnecessarily creating issues.
parent 87dd6852
......@@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
......@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment