22import random
33import numpy as np
44import tensorflow as tf
5+ import torch
56from pymilvus import (
67 connections ,
78 utility ,
@@ -20,6 +21,11 @@ def gen_bf16_vectors(num, dim):
2021 for _ in range (num ):
2122 raw_vector = [random .random () for _ in range (dim )]
2223 raw_vectors .append (raw_vector )
24+ # Numpy itself does not support bfloat16, use TensorFlow extension instead.
25+ # PyTorch does not support converting bfloat16 vector to numpy array.
26+ # See:
27+ # - https://github.com/numpy/numpy/issues/19808
28+ # - https://github.com/pytorch/pytorch/issues/90574
2329 bf16_vector = tf .cast (raw_vector , dtype = tf .bfloat16 ).numpy ()
2430 bf16_vectors .append (bf16_vector )
2531 return raw_vectors , bf16_vectors
@@ -57,8 +63,10 @@ def bf16_vector_search():
5763 index_params = {"index_type" : index_type , "params" : index_params , "metric_type" : "L2" })
5864 hello_milvus .load ()
5965 print ("index_type = " , index_type )
60- res = hello_milvus .search (vectors [0 :10 ], vector_field_name , {"metric_type" : "L2" }, limit = 1 )
61- print (res )
66+ res = hello_milvus .search (vectors [0 :10 ], vector_field_name , {"metric_type" : "L2" }, limit = 1 , output_fields = ["bfloat16_vector" ])
67+ print ("raw bytes: " , res [0 ][0 ].get ("bfloat16_vector" ))
68+ print ("tensorflow Tensor: " , tf .io .decode_raw (res [0 ][0 ].get ("bfloat16_vector" ), tf .bfloat16 , little_endian = True ))
69+ print ("pytorch Tensor: " , torch .frombuffer (res [0 ][0 ].get ("bfloat16_vector" ), dtype = torch .bfloat16 ))
6270 hello_milvus .release ()
6371 hello_milvus .drop_index ()
6472
0 commit comments