Skip to content

Commit 1b371b6

Browse files
authored
BFVVector: sum_vector support (microsoft#283)
* add sum_vector for BFVVector * add python tests
1 parent f57e286 commit 1b371b6

File tree

6 files changed

+294
-6
lines changed

6 files changed

+294
-6
lines changed

tenseal/binding.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ void bind_bfv_vector(py::module &m) {
220220
})
221221
.def("mul_plain_",
222222
py::overload_cast<const int64_t &>(&BFVVector::mul_plain_inplace))
223+
.def("dot", &BFVVector::dot)
224+
.def("dot", &BFVVector::dot_plain)
225+
.def("dot_", &BFVVector::dot_inplace)
226+
.def("dot_", &BFVVector::dot_plain_inplace)
227+
.def("sum", &BFVVector::sum, py::arg("axis") = 0)
228+
.def("sum_", &BFVVector::sum_inplace, py::arg("axis") = 0)
223229
// python arithmetic
224230
.def("__add__", &BFVVector::add)
225231
.def("__add__", py::overload_cast<const int64_t &>(

tenseal/cpp/tensors/utils/utils.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,43 @@ inline size_t below_power2(size_t n) {
2121
return 1 << count;
2222
}
2323

24-
Ciphertext& sum_vector(shared_ptr<TenSEALContext> tenseal_context,
25-
Ciphertext& vector, size_t size) {
24+
Ciphertext &sum_vector(shared_ptr<TenSEALContext> tenseal_context,
25+
Ciphertext &vector, size_t size) {
2626
// Nothing to do
2727
if (size == 1) return vector;
2828

29+
auto rotate = [&](const Ciphertext &encrypted, int steps,
30+
const GaloisKeys &galois_keys, Ciphertext &destination) {
31+
switch (tenseal_context->seal_context()
32+
->key_context_data()
33+
->parms()
34+
.scheme()) {
35+
case scheme_type::ckks: {
36+
tenseal_context->evaluator->rotate_vector(
37+
encrypted, steps, galois_keys, destination);
38+
break;
39+
}
40+
case scheme_type::bfv: {
41+
tenseal_context->evaluator->rotate_rows(
42+
encrypted, steps, galois_keys, destination);
43+
break;
44+
}
45+
default:
46+
throw invalid_argument("unsupported scheme for sum_vector");
47+
}
48+
};
2949
auto galois_keys = tenseal_context->galois_keys();
50+
3051
Ciphertext rest, tmp;
3152
size_t bp2 = below_power2(size);
3253

3354
if (bp2 != size) {
34-
tenseal_context->evaluator->rotate_vector(vector, bp2, *galois_keys,
35-
rest);
55+
rotate(vector, bp2, *galois_keys, rest);
3656
sum_vector(tenseal_context, rest, size - bp2);
3757
}
3858

3959
for (size_t i = bp2 / 2; i > 0; i /= 2) {
40-
tenseal_context->evaluator->rotate_vector(vector, i, *galois_keys, tmp);
60+
rotate(vector, i, *galois_keys, tmp);
4161
tenseal_context->evaluator->add_inplace(vector, tmp);
4262
tmp = vector;
4363
}

tenseal/cpp/tensors/utils/utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ void replicate_vector(vector<T>& vec, size_t final_size) {
2929

3030
/*
3131
Sum the values in the vector.
32-
IMPORTANT: Tested only with CKKS.
3332
*/
3433
Ciphertext& sum_vector(shared_ptr<TenSEALContext> tenseal_context,
3534
Ciphertext& vector, size_t size);

tenseal/tensors/bfvvector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,25 @@ def sub_(self, other) -> "BFVVector":
101101
other = self._get_operand(other, dtype="int")
102102
self.data -= other
103103
return self
104+
105+
@classmethod
106+
def _dot(cls, other):
107+
if isinstance(other, (cls)):
108+
return other.data
109+
if not isinstance(other, ts.PlainTensor):
110+
try:
111+
other = ts.plain_tensor(other, dtype="int")
112+
except TypeError:
113+
raise TypeError(f"can't operate with object of type {type(other)}")
114+
if len(other.shape) != 1:
115+
raise ValueError("can only operate with a vector")
116+
return other.data
117+
118+
def dot(self, other) -> "BFVVector":
119+
other = self._dot(other)
120+
return self._wrap(self.data.dot(other))
121+
122+
def dot_(self, other) -> "BFVVector":
123+
other = self._dot(other)
124+
self.data.dot_(other)
125+
return self

tests/cpp/tensors/bfvvector_test.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,67 @@ TEST_P(BFVVectorTest, TestBFVAddBigVector) {
328328
EXPECT_THAT(decr.data(), ElementsAreArray(expected));
329329
}
330330

331+
TEST_P(BFVVectorTest, TestSum) {
332+
auto enc_type = get<1>(GetParam());
333+
334+
int poly_mod = 8192;
335+
int input_size = 100000;
336+
337+
auto ctx = TenSEALContext::Create(scheme_type::bfv, poly_mod, 1032193, {},
338+
enc_type);
339+
ASSERT_TRUE(ctx != nullptr);
340+
ctx->generate_galois_keys();
341+
342+
auto ldata = PlainTensor(vector<int64_t>({1, 2, 3, 4, 5, 6, 7, 8, 9}));
343+
344+
auto l = BFVVector::Create(ctx, ldata);
345+
346+
l->sum_inplace();
347+
auto decr = l->decrypt();
348+
EXPECT_THAT(decr.data(), ElementsAreArray({45}));
349+
}
350+
351+
TEST_P(BFVVectorTest, TestDot) {
352+
auto enc_type = get<1>(GetParam());
353+
354+
int poly_mod = 8192;
355+
int input_size = 100000;
356+
357+
auto ctx = TenSEALContext::Create(scheme_type::bfv, poly_mod, 1032193, {},
358+
enc_type);
359+
ASSERT_TRUE(ctx != nullptr);
360+
ctx->generate_galois_keys();
361+
362+
vector<int64_t> lraw({1, 2, 3, 4, 5, 6, 7, 8, 9});
363+
vector<int64_t> rraw({11, 22, 33, 11, 22, 33, 11, 22, 33});
364+
365+
int64_t expected = 0;
366+
for (auto idx = 0; idx < lraw.size(); ++idx)
367+
expected += lraw[idx] * rraw[idx];
368+
369+
auto ldata = PlainTensor(lraw);
370+
auto rdata = PlainTensor(rraw);
371+
372+
auto l = BFVVector::Create(ctx, ldata);
373+
auto r = BFVVector::Create(ctx, rdata);
374+
375+
auto res = l->dot(r);
376+
auto decr = res->decrypt();
377+
EXPECT_THAT(decr.data(), ElementsAreArray({expected}));
378+
379+
res = l->dot_plain(rdata);
380+
decr = res->decrypt();
381+
EXPECT_THAT(decr.data(), ElementsAreArray({expected}));
382+
383+
res = r->dot(l);
384+
decr = res->decrypt();
385+
EXPECT_THAT(decr.data(), ElementsAreArray({expected}));
386+
387+
res = r->dot_plain(ldata);
388+
decr = res->decrypt();
389+
EXPECT_THAT(decr.data(), ElementsAreArray({expected}));
390+
}
391+
331392
INSTANTIATE_TEST_CASE_P(
332393
TestBFVVector, BFVVectorTest,
333394
::testing::Values(make_tuple(false, encryption_type::asymmetric),

tests/python/tenseal/tensors/test_bfv_vector.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,183 @@ def test_size(context):
416416
for size in range(1, 10):
417417
vec = ts.bfv_vector(context, [1] * size)
418418
assert vec.size() == size, "Size of encrypted vector is incorrect."
419+
420+
421+
@pytest.mark.parametrize(
422+
"vec1, vec2",
423+
[
424+
([0], [0]),
425+
([1], [0]),
426+
([-1], [0]),
427+
([-1], [-1]),
428+
([1], [1]),
429+
([-1], [1]),
430+
([-1, -2], [-73, -10]),
431+
([1, 2], [-73, -10]),
432+
([1, 2, 3], [4, 3, 2]),
433+
([1, 2, 3, 4], [4, 3, 2, 1]),
434+
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
435+
([1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]),
436+
([1, 2, 3, 4, 5, 6, 7], [7, 6, 5, 4, 3, 2, 1]),
437+
([1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]),
438+
],
439+
)
440+
def test_dot_product(context, vec1, vec2):
441+
context.generate_galois_keys()
442+
first_vec = ts.bfv_vector(context, vec1)
443+
second_vec = ts.bfv_vector(context, vec2)
444+
445+
result = first_vec.dot(second_vec)
446+
447+
expected = [sum([v1 * v2 for v1, v2 in zip(vec1, vec2)])]
448+
449+
# Decryption
450+
assert result.decrypt() == expected, "Multiplication of vectors is incorrect."
451+
assert first_vec.decrypt() == vec1, "Something went wrong in memory."
452+
assert second_vec.decrypt() == vec2, "Something went wrong in memory."
453+
454+
455+
@pytest.mark.parametrize(
456+
"vec1, vec2",
457+
[
458+
([0], [0]),
459+
([1], [0]),
460+
([-1], [0]),
461+
([-1], [-1]),
462+
([1], [1]),
463+
([-1], [1]),
464+
([-1, -2], [-73, -10]),
465+
([1, 2], [-73, -10]),
466+
([1, 2, 3], [4, 3, 2]),
467+
([1, 2, 3, 4], [4, 3, 2, 1]),
468+
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
469+
([1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]),
470+
([1, 2, 3, 4, 5, 6, 7], [7, 6, 5, 4, 3, 2, 1]),
471+
([1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]),
472+
],
473+
)
474+
def test_dot_product_inplace(context, vec1, vec2):
475+
context.generate_galois_keys()
476+
first_vec = ts.bfv_vector(context, vec1)
477+
second_vec = ts.bfv_vector(context, vec2)
478+
first_vec.dot_(second_vec)
479+
expected = [sum([v1 * v2 for v1, v2 in zip(vec1, vec2)])]
480+
481+
# Decryption
482+
assert first_vec.decrypt() == expected, "Dot product of vectors is incorrect."
483+
assert second_vec.decrypt() == vec2, "Something went wrong in memory."
484+
485+
486+
@pytest.mark.parametrize(
487+
"vec1, vec2",
488+
[
489+
([0], [0]),
490+
([1], [0]),
491+
([-1], [0]),
492+
([-1], [-1]),
493+
([1], [1]),
494+
([-1], [1]),
495+
([-1, -2], [-73, -10]),
496+
([1, 2], [-73, -10]),
497+
([1, 2, 3], [4, 3, 2]),
498+
([1, 2, 3, 4], [4, 3, 2, 1]),
499+
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
500+
([1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]),
501+
([1, 2, 3, 4, 5, 6, 7], [7, 6, 5, 4, 3, 2, 1]),
502+
([1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]),
503+
],
504+
)
505+
def test_dot_product_plain(context, vec1, vec2):
506+
context.generate_galois_keys()
507+
first_vec = ts.bfv_vector(context, vec1)
508+
second_vec = ts.plain_tensor(vec2, dtype="int")
509+
result = first_vec.dot(second_vec)
510+
expected = [sum([v1 * v2 for v1, v2 in zip(vec1, vec2)])]
511+
512+
# Decryption
513+
assert result.decrypt() == expected, "Dot product of vectors is incorrect."
514+
assert first_vec.decrypt() == vec1, "Something went wrong in memory."
515+
516+
517+
@pytest.mark.parametrize(
518+
"vec1, vec2",
519+
[
520+
([0], [0]),
521+
([1], [0]),
522+
([-1], [0]),
523+
([-1], [-1]),
524+
([1], [1]),
525+
([-1], [1]),
526+
([-1, -2], [-73, -10]),
527+
([1, 2], [-73, -10]),
528+
([1, 2, 3], [4, 3, 2]),
529+
([1, 2, 3, 4], [4, 3, 2, 1]),
530+
([1, 2, 3, 4, 5], [5, 4, 3, 2, 1]),
531+
([1, 2, 3, 4, 5, 6], [6, 5, 4, 3, 2, 1]),
532+
([1, 2, 3, 4, 5, 6, 7], [7, 6, 5, 4, 3, 2, 1]),
533+
([1, 2, 3, 4, 5, 6, 7, 8], [8, 7, 6, 5, 4, 3, 2, 1]),
534+
],
535+
)
536+
def test_dot_product_plain_inplace(context, vec1, vec2):
537+
context.generate_galois_keys()
538+
first_vec = ts.bfv_vector(context, vec1)
539+
second_vec = ts.plain_tensor(vec2, dtype="int")
540+
first_vec.dot_(second_vec)
541+
expected = [sum([v1 * v2 for v1, v2 in zip(vec1, vec2)])]
542+
543+
# Decryption
544+
assert first_vec.decrypt() == expected, "Dot product of vectors is incorrect."
545+
546+
547+
@pytest.mark.parametrize(
548+
"vec1",
549+
[
550+
([0]),
551+
([1]),
552+
([-1]),
553+
([-1, -2]),
554+
([1, 2]),
555+
([1, 2, 3]),
556+
([1, 2, 3, 4]),
557+
([1, 2, 3, 4, 5]),
558+
([1, 2, 3, 4, 5, 6]),
559+
([1, 2, 3, 4, 5, 6, 7]),
560+
([1, 2, 3, 4, 5, 6, 7, 8]),
561+
],
562+
)
563+
def test_sum(context, vec1):
564+
context.generate_galois_keys()
565+
first_vec = ts.bfv_vector(context, vec1)
566+
result = first_vec.sum()
567+
expected = [sum(vec1)]
568+
569+
# Decryption
570+
assert result.decrypt() == expected, "Sum of vector is incorrect."
571+
assert first_vec.decrypt() == vec1, "Something went wrong in memory."
572+
573+
574+
@pytest.mark.parametrize(
575+
"vec1",
576+
[
577+
([0]),
578+
([1]),
579+
([-1]),
580+
([-1, -2]),
581+
([1, 2]),
582+
([1, 2, 3]),
583+
([1, 2, 3, 4]),
584+
([1, 2, 3, 4, 5]),
585+
([1, 2, 3, 4, 5, 6]),
586+
([1, 2, 3, 4, 5, 6, 7]),
587+
([1, 2, 3, 4, 5, 6, 7, 8]),
588+
],
589+
)
590+
def test_sum_inplace(context, vec1):
591+
context.generate_galois_keys()
592+
first_vec = ts.bfv_vector(context, vec1)
593+
result = first_vec.sum()
594+
expected = [sum(vec1)]
595+
596+
# Decryption
597+
decrypted_result = result.decrypt()
598+
assert decrypted_result == expected, "Sum of vector is incorrect."

0 commit comments

Comments
 (0)