Skip to content

Commit

Permalink
Merge pull request #189 from Adamits/cuda-numpy-ser-hotfix
Browse files Browse the repository at this point in the history
Cuda numpy ser hotfix
  • Loading branch information
kylebgorman authored May 8, 2024
2 parents 1425735 + 5ed2159 commit a71a46e
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions yoyodyne/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ class SEREvaluator(Evaluator):

def _compute_ser(
self,
preds: List[str],
target: List[str],
preds: torch.Tensor,
target: torch.Tensor,
) -> float:
errors = self._edit_distance(preds, target)
total = len(target)
return errors / total

@staticmethod
def _edit_distance(x: List[str], y: List[str]) -> int:
def _edit_distance(x: torch.Tensor, y: torch.Tensor) -> int:
idim = len(x) + 1
jdim = len(y) + 1
table = numpy.zeros((idim, jdim), dtype=numpy.uint16)
Expand All @@ -224,8 +224,8 @@ def _edit_distance(x: List[str], y: List[str]) -> int:

def get_eval_item(
self,
predictions: List[List[str]],
golds: List[List[str]],
predictions: torch.Tensor,
golds: torch.Tensor,
pad_idx: int,
) -> EvalItem:
sers = [self._compute_ser(p, g) for p, g in zip(predictions, golds)]
Expand All @@ -236,13 +236,10 @@ def _finalize_tensor(
tensor: torch.Tensor,
end_idx: int,
pad_idx: int,
) -> List[List[str]]:
) -> torch.Tensor:
# Not necessary if batch size is 1.
if tensor.size(0) == 1:
# Returns a list of a numpy char vector.
# This is allows evaluation over strings without converting
# integer indices back to symbols.
return [numpy.char.mod("%d", tensor.cpu().numpy())]
return [tensor]
out = []
for prediction in tensor:
# Gets first instance of EOS.
Expand All @@ -253,23 +250,22 @@ def _finalize_tensor(
eos = eos[0]
else:
# Leaves tensor[i] alone.
out.append(numpy.char.mod("%d", prediction))
out.append(prediction)
continue
# Hack in case the first prediction is EOS. In this case
# torch.split will result in an error, so we change these 0's to
# 1's, which will make the entire sequence EOS as intended.
eos[eos == 0] = 1
symbols, *_ = torch.split(prediction, eos)
# Accumulates a list of numpy char vectors.
out.append(numpy.char.mod("%d", symbols))
out.append(symbols)
return out

def finalize_predictions(
self,
predictions: torch.Tensor,
end_idx: int,
pad_idx: int,
) -> List[List[str]]:
) -> torch.Tensor:
"""Finalizes predictions.
Args:
Expand All @@ -287,7 +283,7 @@ def finalize_golds(
golds: torch.Tensor,
end_idx: int,
pad_idx: int,
):
) -> torch.Tensor:
return self._finalize_tensor(golds, end_idx, pad_idx)

@property
Expand Down

0 comments on commit a71a46e

Please sign in to comment.