Skip to content

Commit f42e614

Browse files
committed
Rep4, SPDZ-wise, MNIST training.
1 parent 53f9b02 commit f42e614

File tree

184 files changed

+5833
-816
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

184 files changed

+5833
-816
lines changed

BMR/RealProgramParty.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
104104
}
105105
else
106106
{
107-
T::read_or_generate_mac_key(prep_dir, N, mac_key);
107+
T::read_or_generate_mac_key(prep_dir, *P, mac_key);
108108
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
109109
}
110110

BMR/Register.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
#include <unistd.h>
2626

27-
ostream& EvalRegister::out = cout;
28-
2927
int Register::counter = 0;
3028

3129
void Register::init(int n_parties)

BMR/Register.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using namespace std;
2222
#include "Tools/FlexBuffer.h"
2323
#include "Tools/PointerVector.h"
2424
#include "Tools/Bundle.h"
25+
#include "Tools/SwitchableOutput.h"
2526

2627
//#define PAD_TO_8(n) (n+8-n%8)
2728
#define PAD_TO_8(n) (n)
@@ -199,6 +200,7 @@ class BlackHole
199200
BlackHole& operator<<(T) { return *this; }
200201
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
201202
void activate(bool) {}
203+
void redirect_to_file(ostream&) {}
202204
};
203205
inline BlackHole& endl(BlackHole& b) { return b; }
204206
inline BlackHole& flush(BlackHole& b) { return b; }
@@ -211,7 +213,6 @@ class Phase
211213
typedef NoMemory DynamicMemory;
212214

213215
typedef BlackHole out_type;
214-
static BlackHole out;
215216

216217
static const bool actual_inputs = true;
217218

@@ -353,8 +354,7 @@ class EvalRegister : public ProgramRegister
353354

354355
typedef EvalInputter Input;
355356

356-
typedef ostream& out_type;
357-
static ostream& out;
357+
typedef SwitchableOutput out_type;
358358

359359
static const bool actual_inputs = true;
360360

CHANGELOG.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.2.0 (Oct 28, 2020)
4+
5+
- Rep4: honest-majority four-party computation with malicious security
6+
- SY/SPDZ-wise: honest-majority computation with malicious security based on replicated or Shamir secret sharing
7+
- Training with a sequence of dense layers
8+
- Training and inference for multi-class classification
9+
- Local share conversion for semi-honest protocols based on additive secret sharing modulo a power of two
10+
- edaBit generation based on local share conversion
11+
- Optimize exponentation with local share conversion
12+
- Optimize Shamir pseudo-random secret sharing using a hyper-invertible matrix
13+
- Mathematical functions (exponentation, logarithm, square root, and trigonometric functions) with binary circuits
14+
- Direct construction of fixed-point values from any type, breaking `sfix(x)` where `x` is the integer representation of a fixed-point number. Use `sfix._new(x)` instead.
15+
- Optimized dot product for `sfix`
16+
- Matrix multiplication via operator overloading uses VM-optimized multiplication.
17+
- Fake preprocessing for daBits and edaBits
18+
- Fixed security bug: insufficient randomness in SemiBin random bit generation.
19+
- Fixed security bug: insufficient randomization of FKOS15 inputs.
20+
- Fixed security bug in binary computation with SPDZ(2k).
21+
322
## 0.1.9 (Aug 24, 2020)
423

524
- Streamline inputs to binary circuits
625
- Improved private output
726
- Emulator for arithmetic circuits
827
- Efficient dot product with Shamir's secret sharing
928
- Lower memory usage for TensorFlow inference
10-
- This version breaks bytecode compatibilty.
29+
- This version breaks bytecode compatibility.
1130

1231
## 0.1.8 (June 15, 2020)
1332

CONFIG

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ USE_GF2N_LONG = 1
2424
# AVX/AVX2 is required for replicated binary secret sharing
2525
# BMI2 is used to optimize multiplication modulo a prime
2626
# ADX is used to optimize big integer additions
27+
# delete the second line to compile for a platform that supports everything
2728
ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
29+
ARCH = -march=native
2830

2931
# allow to set compiler in CONFIG.mine
3032
CXX = g++
@@ -60,7 +62,7 @@ else
6062
BOOST = -lboost_thread $(MY_BOOST)
6163
endif
6264

63-
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror
65+
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
6466
CPPFLAGS = $(CFLAGS)
6567
LD = $(CXX)
6668

Compiler/GC/types.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def compose(cls, items, bit_length=1):
4545
return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], []))
4646
@classmethod
4747
def bit_compose(cls, bits):
48+
bits = list(bits)
4849
if len(bits) == 1:
4950
return bits[0]
5051
bits = list(bits)
@@ -72,7 +73,7 @@ def bit_decompose(self, bit_length=None):
7273
res = [self.bit_type() for i in range(n)]
7374
self.bitdec(self, *res)
7475
else:
75-
res = self.trans([self])
76+
res = self.bit_type.trans([self])
7677
self.decomposed = res
7778
return res + suffix
7879
else:
@@ -83,8 +84,8 @@ def bit_decompose_clear(a, n_bits):
8384
cbits.conv_cint_vec(a, *res)
8485
return res
8586
@classmethod
86-
def malloc(cls, size):
87-
return Program.prog.malloc(size, cls)
87+
def malloc(cls, size, creator_tape=None):
88+
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
8889
@staticmethod
8990
def n_elements():
9091
return 1
@@ -430,6 +431,8 @@ def reveal(self):
430431
def equal(self, other, n=None):
431432
bits = (~(self + other)).bit_decompose()
432433
return reduce(operator.mul, bits)
434+
def right_shift(self, m, k, security=None, signed=True):
435+
return self.TruncPr(k, m)
433436
def TruncPr(self, k, m, kappa=None):
434437
if k > self.n:
435438
raise Exception('TruncPr overflow: %d > %d' % (k, self.n))
@@ -481,8 +484,8 @@ class sbitvec(_vec):
481484
def get_type(cls, n):
482485
class sbitvecn(cls, _structure):
483486
@staticmethod
484-
def malloc(size):
485-
return sbit.malloc(size * n)
487+
def malloc(size, creator_tape=None):
488+
return sbit.malloc(size * n, creator_tape=creator_tape)
486489
@staticmethod
487490
def n_elements():
488491
return n
@@ -566,7 +569,8 @@ def __init__(self, elements=None, length=None):
566569
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
567570
v = x.v
568571
self.v = v[:length]
569-
elif elements is not None:
572+
elif elements is not None and not (util.is_constant(elements) and \
573+
elements == 0):
570574
self.v = sbits.trans(elements)
571575
def popcnt(self):
572576
res = sbitint.wallace_tree([[b] for b in self.v])
@@ -606,7 +610,10 @@ def conv(cls, other):
606610
return cls.from_vec(other.v)
607611
@property
608612
def size(self):
609-
return self.v[0].n
613+
if not self.v or util.is_constant(self.v[0]):
614+
return 1
615+
else:
616+
return self.v[0].n
610617
@property
611618
def n_bits(self):
612619
return len(self.v)
@@ -725,6 +732,8 @@ def cast(self, n):
725732
return self.get_type(n).bit_compose(bits)
726733
def round(self, k, m, kappa=None, nearest=None, signed=None):
727734
bits = self.bit_decompose()
735+
if signed:
736+
bits += [bits[-1]] * (k - len(bits))
728737
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
729738
return self.get_type(k - m).compose(res_bits)
730739
def int_div(self, other, bit_length=None):
@@ -781,7 +790,7 @@ def set_length(*args):
781790
@classmethod
782791
def bit_compose(cls, bits):
783792
# truncate and extend bits
784-
bits = bits[:cls.n]
793+
bits = list(bits)[:cls.n]
785794
bits += [0] * (cls.n - len(bits))
786795
return super(sbitint, cls).bit_compose(bits)
787796
def force_bit_decompose(self, n_bits=None):
@@ -801,6 +810,7 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False):
801810
b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits)))
802811
product = a * b
803812
res_bits = product.bit_decompose()[m:k]
813+
res_bits += [res_bits[-1]] * (self.n - len(res_bits))
804814
t = self.combo_type(other)
805815
return t.bit_compose(res_bits)
806816
def __mul__(self, other):
@@ -824,6 +834,15 @@ def get_bit_matrix(cls, self_bits, other):
824834
else:
825835
res.append([(x & bit) for x in other.bit_decompose(n - i)])
826836
return res
837+
@classmethod
838+
def popcnt_bits(cls, bits):
839+
res = sbitvec.from_vec(bits).popcnt().elements()[0]
840+
res = cls.conv(res)
841+
return res
842+
def pow2(self, k):
843+
l = int(math.ceil(math.log(k, 2)))
844+
bits = [self.equal(i, l) for i in range(k)]
845+
return self.bit_compose(bits)
827846

828847
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
829848
def __add__(self, other):
@@ -867,8 +886,11 @@ class cbitfix(object):
867886
conv = staticmethod(lambda x: x)
868887
load_mem = classmethod(lambda cls, *args: cls(cbits.load_mem(*args)))
869888
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
870-
def __init__(self, value):
871-
self.v = value
889+
@classmethod
890+
def _new(cls, value):
891+
res = cls()
892+
res.v = value
893+
return res
872894
def output(self):
873895
v = self.v
874896
if self.k < v.unit:
@@ -897,10 +919,10 @@ def get_input_from(cls, player):
897919
inst.inputb(player, cls.k, cls.f, v)
898920
return cls._new(v)
899921
def __xor__(self, other):
900-
return type(self)(self.v ^ other.v)
922+
return type(self)._new(self.v ^ other.v)
901923
def __mul__(self, other):
902924
if isinstance(other, sbit):
903-
return type(self)(self.int_type(other * self.v))
925+
return type(self)._new(self.int_type(other * self.v))
904926
elif isinstance(other, sbitfixvec):
905927
return other * self
906928
else:
@@ -911,10 +933,11 @@ def __mul__(self, other):
911933
def multipliable(other, k, f, size):
912934
class cls(_fix):
913935
int_type = sbitint.get_type(k)
936+
clear_type = cbitfix
914937
cls.set_precision(f, k)
915938
return cls._new(cls.int_type(other), k, f)
916939

917-
sbitfix.set_precision(20, 41)
940+
sbitfix.set_precision(16, 31)
918941

919942
class sbitfixvec(_fix):
920943
int_type = sbitintvec

Compiler/allocator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(self, block, options, merge_classes):
220220
else:
221221
self.max_parallel_open = float('inf')
222222
self.counter = defaultdict(lambda: 0)
223+
self.rounds = defaultdict(lambda: 0)
223224
self.dependency_graph(merge_classes)
224225

225226
def do_merge(self, merges_iter):
@@ -271,6 +272,7 @@ def longest_paths_merge(self):
271272
merge = merges[i]
272273
t = type(self.instructions[merge[0]])
273274
self.counter[t] += len(merge)
275+
self.rounds[t] += 1
274276
if len(merge) > 10000:
275277
print('Merging %d %s in round %d/%d' % \
276278
(len(merge), t.__name__, i, len(merges)))

Compiler/comparison.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,27 +135,37 @@ def Trunc(d, a, k, m, kappa, signed):
135135
mulm(d, t, c[2])
136136

137137
def TruncRing(d, a, k, m, signed):
138-
if program.use_split() == 3:
138+
program.curr_tape.require_bit_length(1)
139+
if program.use_split() in (2, 3):
140+
if signed:
141+
a += (1 << (k - 1))
139142
from Compiler.types import sint
140143
from .GC.types import sbitint
141144
length = int(program.options.ring)
142-
summands = a.split_to_n_summands(length, 3)
145+
summands = a.split_to_n_summands(length, program.use_split())
143146
x = sbitint.wallace_tree_without_finish(summands, True)
144-
if m == 1:
145-
low = x[1][1]
146-
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
147-
sint.conv(x[0][-1])
147+
if program.use_split() == 2:
148+
carries = sbitint.get_carries(*x)
149+
low = carries[m]
150+
high = sint.conv(carries[length])
148151
else:
149-
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
150-
low = sint.conv(mid_carry) + sint.conv(x[0][m])
151-
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
152-
for xx, yy in zip(x[1][m:-1],
153-
x[0][m:-1])))
154-
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
155-
high = top_carry + sint.conv(x[0][-1])
152+
if m == 1:
153+
low = x[1][1]
154+
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
155+
sint.conv(x[0][-1])
156+
else:
157+
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
158+
low = sint.conv(mid_carry) + sint.conv(x[0][m])
159+
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
160+
for xx, yy in zip(x[1][m:-1],
161+
x[0][m:-1])))
162+
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
163+
high = top_carry + sint.conv(x[0][-1])
156164
shifted = sint()
157165
shrsi(shifted, a, m)
158166
res = shifted + sint.conv(low) - (high << (length - m))
167+
if signed:
168+
res -= (1 << (k - m - 1))
159169
else:
160170
a_prime = Mod2mRing(None, a, k, m, signed)
161171
a -= a_prime

0 commit comments

Comments
 (0)