Skip to content

Conversation

@Narsil
Copy link
Contributor

@Narsil Narsil commented Nov 7, 2022

What does this PR do?

This adds chunk_length_s to seq2seq algorithms.

Approach

Since we have no way of finding a matching between output and input with seq2seq
this is an alternative route.

This runs the pipeline on the various chunks and finds all generated output.
Then it tries to find the longest sequence of non special ids that could correspond
to the subsequences within the batch.

Pros

  • It should work on any seq2seq models
  • It should work decently when the stride is long enough to have good overlapping of tokens so that the stitching can work correctly
  • It should be slightly robust to few token errors
  • It should perform best on mostly continuous talk (so that there is model output that can overlap)

Cons

  • This method is unsound and will fail under some circumstances
  • It will fail when there is silence in the overlap. If there is silence then there is no overlapping tokens, and the stitching might get lost during the stitching process. By default it will concatenate, but it might be put off by boundaries in the stride.
  • It will fail spectacularly when something repeats a single word over and over. Then, we will have overlap that might be TOO large. This is impossible to distinguish without getting access to the timestamps (which only whisper can currently do, and it does come with caveats). The currently algorithm will favor long chain of matching tokens.
  • It will have issues with capitalization and out of domain areas. For instance "Yes, sir." , "Sir Thomas" might be 2 chunks, which have different capitalization. Since the current algorithm works at the token level, the 2 tokens "sir" and ¨Sir" are different and will fail to match leading to some `¨Yes, sir. Sir Thomas" stitching instead of the intended "Yes, Sir Thomas.".

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Not sure if the PR is ready for (at least core maintainer) review yet?

Comment on lines 284 to 287
# if self.type not in {"ctc", "ctc_with_lm"}:
# raise ValueError(
# "`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clean up?

Comment on lines 149 to 152
# self.assertEqual(
# str(v.exception),
# "`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clean up as well?

Comment on lines 272 to 280
# waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
# output = speech_recognizer(waveform)
# self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
# output = speech_recognizer(filename)
# self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
print(filename)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments and print statements to clean up.

@Narsil
Copy link
Contributor Author

Narsil commented Nov 7, 2022

Thanks for working on this. Not sure if the PR is ready for (at least core maintainer) review yet?

Yup sorry it was slightly early for you.
The core idea is still there.

We chunk with stride. and we make a hopeful stitch to find the longest sequence from all the subsequences.

PROs:

  • It's extremely generic.
  • It should work in a lot of scenarios including repeating tokens

CONs:

  • It's technically unsound. Meaning if the model infers widely varying tokens, there's no way to reconstruct what the model would actually predict on the whole file.
  • I expect it can fail spectacularly in well crafted examples where someone repeats the same word over and over, where the longest match will be MUCH longer than the original voices thing.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 8, 2022

As we discussed offline with @Narsil , will be implementing the find_conmmon_sequence in O(N) 😉 Will open a new PR!

@Narsil
Copy link
Contributor Author

Narsil commented Nov 8, 2022

As we discussed offline with @Narsil , will be implementing the find_conmmon_sequence in O(N) wink Will open a new PR!

Seems it's going to be complex because of fault tolerance which does seem to be important.

You can try doing something like

#!wget https://www.archive.org/download/around_world_80_days_mfs_librivox/around_world_in_80_days_01_verne.mp3
from transformers import pipeline

speech_recognizer = pipeline(
    task="automatic-speech-recognition",
    model="openai/whisper-small",
    framework="pt",
    batch_size=2,
    device=0,
    chunk_length_s=30,
    generate_kwargs={"max_new_tokens": 1024},
)

out = speech_recognizer(["around_world_in_80_days_01_verne.mp3"])
print(out)

This will required some suboptimal stitches to work.

@Narsil Narsil requested a review from sgugger November 8, 2022 17:02
@Narsil
Copy link
Contributor Author

Narsil commented Nov 8, 2022

@sgugger it's now ready for review.

The TODO is left intentionnally. It might really become relevant on hour+ long files where the current naive algorithm might become too slow. However the code is likely to be orders of magnitude more complex (if a O(n) solution exists, I'm pretty sure we could find an expected O(n) algorithm, but not sure about worst case).
The current code works correctly, has the fault tolerance we need to be useful.

I added a warning because the current code Will fail in some know circumstances. I updated the PR description to reflect those. If those tradeoffs are not good enough, I'm happy to not merge this PR in this state.

The only other option I see is whisper specific with timestamps and it would only alleviate some of the issues.

@ArthurZucker
Copy link
Collaborator

Before merging, would love to try a little bit, otherwise LGTM (looking for a solution to the faults)

@Narsil
Copy link
Contributor Author

Narsil commented Nov 14, 2022

@ArthurZucker What are your conclusions ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker
Copy link
Collaborator

I think that including timestamp tokens in the process could help with the error tolerance as they are consistently predicted at the end of pauses in the speech. If the stride is big enough not at least include pauses in speech, it boils down to matching these.
Moreover, given that we know approximately the time between each tokens, we can use this information as some kind of guiding information. I am working on something, but we can merge for now and have an improved PR later on 😉

@Narsil
Copy link
Contributor Author

Narsil commented Nov 14, 2022

@sgugger would like your opinion on this if possible.

The results are pretty decent imo on regular speech. I'm still mentionning the caveats because they are real.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks a lot for working on this

Comment on lines +273 to +274
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIce

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one comment on the warning, otherwise LGTM! Thanks!

Comment on lines 289 to 293
logger.warning(
"Using `chunk_length_s` is very experimental. The results will not necessarily be entirely"
" accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some logic to only throw this warning once? Users are complaining Transformers is too verbose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there already a created way to do that ?

Otherwise I can create some tool for it.
Any other location we could add this "single" warning ? (Will add in a different PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use a dict in the state like this one. No need to overengineer another solution IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Narsil Narsil merged commit 25c451e into huggingface:main Nov 14, 2022
@Narsil Narsil deleted the whisper_chunking branch November 14, 2022 22:57
@pearl-yu pearl-yu mentioned this pull request Mar 16, 2023
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants