Skip to content

Commit a1a8e06

Browse files
committed
More places to support quantize bf16.
1 parent a71a306 commit a1a8e06

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

lib/ccv_numeric.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,14 @@ void ccv_kmeans1d(const ccv_dense_matrix_t* const a, const int k, int* const clu
14541454
sorted_undos[i].value = f[i];
14551455
sorted_undos[i].index = i;
14561456
}
1457+
} else if (CCV_GET_DATA_TYPE(a->type) == CCV_16BF) {
1458+
float* f = (float*)sorted_undos;
1459+
ccv_bfloat_to_float((uint16_t*)a->data.f16, (float*)f, n);
1460+
for (i = n - 1; i >= 0; i--)
1461+
{
1462+
sorted_undos[i].value = f[i];
1463+
sorted_undos[i].index = i;
1464+
}
14571465
} else if (CCV_GET_DATA_TYPE(a->type) == CCV_32F) {
14581466
for (i = 0; i < n; i++)
14591467
{

lib/nnc/mfa/ccv_nnc_mfa_depalettize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ std::size_t std::hash<mfa::depalettize::hash>::operator()(const mfa::depalettize
7979

8080
mfa::depalettize::pipeline::pipeline(mfa::context* context, mfa::depalettize::hash hash) {
8181
// FlashNorm not supported for group depalettize yet.
82-
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))
82+
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf) || (hash.data_type == MTL::DataTypeBFloat))
8383

8484
auto* pool = NS::AutoreleasePool::alloc()->init();
8585

0 commit comments

Comments
 (0)