Skip to content
This repository was archived by the owner on Jun 17, 2023. It is now read-only.

Commit 0a4cbba

Browse files
mdavis332wesyoung
andauthored
adds support for provider lists and negations (#495)
adds support for negation (via prepended exclamation mark) on both providers and tags (e.g.: 'tags: !botnet' to exclude botnet tags or 'provider: !csirtg.io, !otherprovid.xyz' to exclude listed providers) attempt at ES search tuning by favoring filters over queries Co-authored-by: wes <[email protected]>
1 parent 6a17e55 commit 0a4cbba

File tree

2 files changed

+200
-30
lines changed

2 files changed

+200
-30
lines changed

cif/store/zelasticsearch/filters.py

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,36 @@ def filter_reporttime(s, filter):
6565
if not filter.get('reporttime'):
6666
return s
6767

68-
c = filter.pop('reporttime')
69-
if PYVERSION == 2:
70-
if type(c) == unicode:
71-
c = str(c)
68+
high = 'now/m'
69+
# if passed 'days' or 'hours', preferentially use that for ES filtering/caching
70+
if filter.get('days') or filter.get('hours'):
71+
if filter.get('hours'):
72+
lookback_amount = filter.pop('hours')
73+
lookback_unit = 'h'
74+
elif filter.get('days'):
75+
lookback_amount = filter.pop('days')
76+
lookback_unit = 'd'
77+
78+
try:
79+
lookback_amount = int(lookback_amount)
80+
except Exception as e:
81+
raise InvalidSearch('Lookback time filter {}{} is not a valid time'.format(lookback_amount, lookback_unit))
82+
83+
# don't put spaces in relative date math operator query to make it easier to read. ES hates that and will error.
84+
low = 'now/m-{}{}'.format(lookback_amount, lookback_unit)
85+
# no relative 'days' or 'hours' params, so fallback to 'reporttime'
86+
else:
87+
c = filter.pop('reporttime')
88+
if PYVERSION == 2:
89+
if type(c) == unicode:
90+
c = str(c)
7291

73-
low, high = c, arrow.utcnow()
74-
if isinstance(c, basestring) and ',' in c:
75-
low, high = c.split(',')
92+
if isinstance(c, basestring) and ',' in c:
93+
low, high = c.split(',')
94+
else:
95+
low = c
7696

77-
low = arrow.get(low).datetime
78-
high = arrow.get(high).datetime
97+
low = arrow.get(low).datetime
7998

8099
s = s.filter('range', reporttime={'gte': low, 'lte': high})
81100
return s
@@ -114,7 +133,7 @@ def filter_indicator(s, q_filters):
114133

115134
def filter_terms(s, q_filters):
116135
for f in q_filters:
117-
if f in ['nolog', 'days', 'hours', 'groups', 'limit', 'tags']:
136+
if f in ['nolog', 'days', 'hours', 'groups', 'limit', 'provider', 'reporttime', 'tags']:
118137
continue
119138

120139
kwargs = {f: q_filters[f]}
@@ -127,16 +146,70 @@ def filter_terms(s, q_filters):
127146

128147

129148
def filter_tags(s, q_filters):
130-
tags = q_filters['tags']
149+
if not q_filters.get('tags'):
150+
return s
151+
152+
tags = q_filters.pop('tags')
131153

132154
if isinstance(tags, basestring):
133-
tags = tags.split(',')
155+
tags = [x.strip() for x in tags.split(',')]
134156

157+
# each array element is implicitly ORed (aka, 'should') using a terms filter
158+
#s = s.filter('terms', tags=tags)
135159
tt = []
160+
not_tt = []
136161
for t in tags:
137-
tt.append(Q('term', tags=t))
162+
# used for tags exclusion/negation
163+
if t.startswith('!'):
164+
t = t[1:]
165+
not_tt.append(t)
166+
else:
167+
tt.append(t)
168+
169+
if len(not_tt) > 0:
170+
if len(not_tt) == 1:
171+
s = s.exclude('term', tags=not_tt[0])
172+
else:
173+
s = s.exclude('terms', tags=not_tt)
174+
175+
if len(tt) > 0:
176+
if len(tt) == 1:
177+
s = s.filter('term', tags=tt[0])
178+
else:
179+
s = s.filter('terms', tags=tt)
180+
181+
return s
182+
183+
def filter_provider(s, q_filters):
184+
if not q_filters.get('provider'):
185+
return s
138186

139-
s.query = Q('bool', must=s.query, should=tt, minimum_should_match=1)
187+
provider = q_filters.pop('provider')
188+
189+
if isinstance(provider, basestring):
190+
provider = [x.strip() for x in provider.split(',')]
191+
192+
pp = []
193+
not_pp = []
194+
for p in provider:
195+
# used for provider exclusion/negation
196+
if p.startswith('!'):
197+
p = p[1:]
198+
not_pp.append(p)
199+
else:
200+
pp.append(p)
201+
202+
if len(not_pp) > 0:
203+
if len(not_pp) == 1:
204+
s = s.exclude('term', provider=not_pp[0])
205+
else:
206+
s = s.exclude('terms', provider=not_pp)
207+
208+
if len(pp) > 0:
209+
if len(pp) == 1:
210+
s = s.filter('term', provider=pp[0])
211+
else:
212+
s = s.filter('terms', provider=pp)
140213

141214
return s
142215

@@ -145,16 +218,13 @@ def filter_groups(s, q_filters, token=None):
145218
if token:
146219
groups = token.get('groups', 'everyone')
147220
else:
148-
groups = q_filters['groups']
221+
groups = q_filters.pop('groups')
149222

150223
if isinstance(groups, basestring):
151224
groups = [groups]
152225

153-
gg = []
154-
for g in groups:
155-
gg.append(Q('term', group=g))
156-
157-
s.query = Q('bool', must=s.query, should=gg, minimum_should_match=1)
226+
# each array element is implicitly ORed (aka, 'should') using a terms filter
227+
s = s.filter('terms', group=groups)
158228

159229
return s
160230

@@ -171,32 +241,34 @@ def filter_id(s, q_filters):
171241
def filter_build(s, filters, token=None):
172242
limit = filters.get('limit')
173243
if limit and int(limit) > WINDOW_LIMIT:
174-
raise InvalidSearch('request limit should be <= server threshold of {} but was set to {}'.format(WINDOW_LIMIT, limit))
175-
244+
raise InvalidSearch('Request limit should be <= server threshold of {} but was set to {}'.format(WINDOW_LIMIT, limit))
245+
176246
q_filters = {}
177247
for f in VALID_FILTERS:
178248
if filters.get(f):
179249
q_filters[f] = filters[f]
180250

181-
# treat indicator as special, transform into Search
182-
s = filter_indicator(s, q_filters)
251+
s = filter_provider(s, q_filters)
252+
253+
s = filter_confidence(s, q_filters)
183254

184255
s = filter_id(s, q_filters)
185256

186-
s = filter_confidence(s, q_filters)
257+
# treat indicator as special, transform into Search
258+
s = filter_indicator(s, q_filters)
187259

188260
s = filter_reporttime(s, q_filters)
189261

190262
# transform all other filters into term=
191263
s = filter_terms(s, q_filters)
192264

193-
if filters.get('tags'):
194-
s = filter_tags(s, filters)
195-
196-
if filters.get('groups'):
197-
s = filter_groups(s, filters)
265+
if q_filters.get('groups'):
266+
s = filter_groups(s, q_filters)
198267
else:
199268
if token:
200269
s = filter_groups(s, {}, token=token)
201270

271+
if q_filters.get('tags'):
272+
s = filter_tags(s, q_filters)
273+
202274
return s

test/zelasticsearch/test_store_elasticsearch_indicators.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def indicator():
5858
)
5959

6060

61+
@pytest.fixture
62+
def indicator_alt_provider():
63+
return Indicator(
64+
indicator='example2.com',
65+
tags='phishing',
66+
provider='notcsirtg.io',
67+
group='everyone',
68+
lasttime=arrow.utcnow().datetime,
69+
reporttime=arrow.utcnow().datetime
70+
)
71+
72+
6173
@pytest.fixture
6274
def indicator_email():
6375
return Indicator(
@@ -232,3 +244,89 @@ def test_store_elasticsearch_indicators_malware(store, token, indicator_malware)
232244
assert (x[0]['reporttime'])
233245

234246
assert x[0]['indicator'] == indicator_malware.indicator
247+
248+
249+
## test returning a list of providers
250+
@pytest.mark.skipif(DISABLE_TESTS, reason='need to set CIF_ELASTICSEARCH_TEST=1 to run')
251+
def test_store_elasticsearch_indicators_provider_list(store, token, indicator, indicator_alt_provider):
252+
x = store.handle_indicators_create(token, indicator.__dict__(), flush=True)
253+
assert x == 1
254+
255+
y = store.handle_indicators_create(token, indicator_alt_provider.__dict__(), flush=True)
256+
assert y == 1
257+
258+
x = store.handle_indicators_search(token, {
259+
'provider': '{}, {}'.format(indicator.provider, indicator_alt_provider.provider)
260+
})
261+
262+
x = json.loads(x)
263+
pprint(x)
264+
265+
x = [i['_source'] for i in x['hits']['hits']]
266+
267+
assert len(x) == 2
268+
269+
270+
## test provider negation
271+
@pytest.mark.skipif(DISABLE_TESTS, reason='need to set CIF_ELASTICSEARCH_TEST=1 to run')
272+
def test_store_elasticsearch_indicators_provider_negation(store, token, indicator, indicator_alt_provider):
273+
x = store.handle_indicators_create(token, indicator.__dict__(), flush=True)
274+
assert x == 1
275+
276+
y = store.handle_indicators_create(token, indicator_alt_provider.__dict__(), flush=True)
277+
assert y == 1
278+
279+
x = store.handle_indicators_search(token, {
280+
'provider': '!{}'.format(indicator.provider)
281+
})
282+
283+
x = json.loads(x)
284+
pprint(x)
285+
286+
x = [i['_source'] for i in x['hits']['hits']]
287+
288+
assert len(x) == 1
289+
290+
assert x[0]['provider'] == indicator_alt_provider.provider
291+
292+
293+
## test tags negation
294+
@pytest.mark.skipif(DISABLE_TESTS, reason='need to set CIF_ELASTICSEARCH_TEST=1 to run')
295+
def test_store_elasticsearch_indicators_tags_negation(store, token, indicator, indicator_alt_provider):
296+
x = store.handle_indicators_create(token, indicator.__dict__(), flush=True)
297+
assert x == 1
298+
299+
y = store.handle_indicators_create(token, indicator_alt_provider.__dict__(), flush=True)
300+
assert y == 1
301+
302+
x = store.handle_indicators_search(token, {
303+
'tags': '!{}'.format(indicator_alt_provider.tags[0])
304+
})
305+
306+
x = json.loads(x)
307+
pprint(x)
308+
309+
x = [i['_source'] for i in x['hits']['hits']]
310+
311+
assert len(x) == 1
312+
313+
assert x[0]['tags'] == indicator.tags
314+
315+
316+
## test multi-tags negation
317+
@pytest.mark.skipif(DISABLE_TESTS, reason='need to set CIF_ELASTICSEARCH_TEST=1 to run')
318+
def test_store_elasticsearch_indicators_multi_tags_negation(store, token, indicator, indicator_alt_provider):
319+
x = store.handle_indicators_create(token, indicator.__dict__(), flush=True)
320+
assert x == 1
321+
322+
y = store.handle_indicators_create(token, indicator_alt_provider.__dict__(), flush=True)
323+
assert y == 1
324+
325+
x = store.handle_indicators_search(token, {
326+
'tags': '!{},!{}'.format(indicator.tags[0], indicator_alt_provider.tags[0])
327+
})
328+
329+
x = json.loads(x)
330+
pprint(x)
331+
332+
assert len(x) == 0

0 commit comments

Comments
 (0)