Skip to content

Multi-gpu inference returns AttributeError  #177

@arturnn

Description

@arturnn

🐛 Bug

When trying to run scoring with COMET models via comet-score (although the issue is the same with comet-mbr when pre-scoring with qe models using multiple gpus) I get the following error. It works fine with --gpus 0 (running on cpu) or --gpus 1. I tested it with wmt22-comet-da and wmt21-comet-mqm models.

Traceback (most recent call last):
  File "/opt/conda/envs/test-3.11/bin/comet-score", line 8, in <module>
    sys.exit(score_command())
             ^^^^^^^^^^^^^^^
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/cli/score.py", line 192, in score_command
    outputs = model.predict(
              ^^^^^^^^^^^^^^
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/models/base.py", line 643, in predict
    predictions = pred_writer.gather_all_predictions()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/models/predict_writer.py", line 99, in gather_all_predictions
    [
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/models/predict_writer.py", line 100, in <listcomp>
    flatten_predictions(torch.load(os.path.join(self.output_dir, f))[0])
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/models/predict_writer.py", line 89, in flatten_predictions
    scores=torch.cat([pred.scores for pred in predictions], dim=0)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/test-3.11/lib/python3.11/site-packages/comet/models/predict_writer.py", line 89, in <listcomp>
    scores=torch.cat([pred.scores for pred in predictions], dim=0)
                      ^^^^^^^^^^^
AttributeError: 'str' object has no attribute 'scores'

To Reproduce

Run comet-score command with --gpus > 1 using the latest version of the package.

Expected behaviour

Segment and system scores should be returned, as is the case with single gpu inference.

Environment

OS: Debian 11 (bullseye)
Packaging: tried both conda (with python 3.11.6) and standard virtual environment with installation via pip (python 3.9.2)
Version: latest master & current PyPI (2.2.0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions