4
4
5
5
import pytest
6
6
7
- from vllm .model_executor .layers .pooler import CLSPool , PoolingType
7
+ from vllm .model_executor .layers .pooler import CLSPool , MeanPool , PoolingType
8
8
from vllm .model_executor .models .bert import BertEmbeddingModel
9
9
from vllm .model_executor .models .roberta import RobertaEmbeddingModel
10
10
from vllm .platforms import current_platform
14
14
REVISION = os .environ .get ("REVISION" , "main" )
15
15
16
16
MODEL_NAME_ROBERTA = os .environ .get ("MODEL_NAME" ,
17
- "intfloat/multilingual-e5-small " )
17
+ "intfloat/multilingual-e5-base " )
18
18
REVISION_ROBERTA = os .environ .get ("REVISION" , "main" )
19
19
20
20
@@ -40,17 +40,15 @@ def test_model_loading_with_params(vllm_runner):
40
40
41
41
# asserts on the pooling config files
42
42
assert model_config .pooler_config .pooling_type == PoolingType .CLS .name
43
- assert model_config .pooler_config .pooling_norm
43
+ assert model_config .pooler_config .normalize
44
44
45
45
# asserts on the tokenizer loaded
46
46
assert model_tokenizer .tokenizer_id == "BAAI/bge-base-en-v1.5"
47
- assert model_tokenizer .tokenizer_config ["do_lower_case" ]
48
47
assert model_tokenizer .tokenizer .model_max_length == 512
49
48
50
49
def check_model (model ):
51
50
assert isinstance (model , BertEmbeddingModel )
52
- assert model ._pooler .pooling_type == PoolingType .CLS
53
- assert model ._pooler .normalize
51
+ assert isinstance (model ._pooler , CLSPool )
54
52
55
53
vllm_model .apply_model (check_model )
56
54
@@ -80,16 +78,15 @@ def test_roberta_model_loading_with_params(vllm_runner):
80
78
81
79
# asserts on the pooling config files
82
80
assert model_config .pooler_config .pooling_type == PoolingType .MEAN .name
83
- assert model_config .pooler_config .pooling_norm
81
+ assert model_config .pooler_config .normalize
84
82
85
83
# asserts on the tokenizer loaded
86
- assert model_tokenizer .tokenizer_id == "intfloat/multilingual-e5-small "
87
- assert not model_tokenizer .tokenizer_config [ "do_lower_case" ]
84
+ assert model_tokenizer .tokenizer_id == "intfloat/multilingual-e5-base "
85
+ assert model_tokenizer .tokenizer . model_max_length == 512
88
86
89
87
def check_model (model ):
90
88
assert isinstance (model , RobertaEmbeddingModel )
91
- assert model ._pooler .pooling_type == PoolingType .MEAN
92
- assert model ._pooler .normalize
89
+ assert isinstance (model ._pooler , MeanPool )
93
90
94
91
vllm_model .apply_model (check_model )
95
92
0 commit comments