Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions trumania/core/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

# There are a lot of somewhat ugly optimizations here like in-place mutations,
# caching, or usage of numpy instead of a more readable pandas alternative. The
# reason is the methods of this filetend to be called a large amount of time
# reason is the methods of this file tend to be called a large amount of time
# in inner loop of the simulation, optimizing them make the whole simulation
# faster.


class Relations(object):
class OutgoingRelations(object):
"""
This entity contains all the "to" sides of the relationships of a given
"from", together with the related weights.
For a given "from", this entity contains all the "to" sides of the
relationship with the related weights.

This data structure seems to be the most optimal since it corresponds to a cached
group-by result, and those group-by are expensive in the select_one
Expand All @@ -45,8 +45,10 @@ def from_tuples(from_ids, to_ids, weights):
a relationship is built here for each "line" read across those 3
arrays.

This methods builds one instance of Relations for each unique from_id
This methods builds one instance of OutgoingRelations for each unique from_id
value, containing all the to_id's it is related to.

:returns Dictionary { id1 -> OutgoingRelations1, id2 -> OutgoingRelations2, ... }
"""

from_ids = np.array(from_ids)
Expand All @@ -60,19 +62,20 @@ def from_tuples(from_ids, to_ids, weights):
order = from_ids.argsort()
ordered = zip(from_ids[order], to_ids[order], weights[order])

# Find for every unique id in from_ids their matching "to" relations
def _relations():
# itertools.groupby is much faster than pandas
for from_id, tuples in itertools.groupby(ordered, lambda t: t[0]):
to_ids, weights = list(zip(*tuples))[1: 3]
yield from_id, Relations(list(to_ids), list(weights))
to_ids, weights = list(zip(*tuples))[1:3]
yield from_id, OutgoingRelations(list(to_ids), list(weights))

return {from_id: relz for from_id, relz in _relations()}

def plus(self, other):
"""
Merge function for 2 sets of relations all starting from the same "from"
"""
return Relations(
return OutgoingRelations(
np.hstack([self.to_ids, other.to_ids]),
np.hstack([self.weights, other.weights]))

Expand All @@ -83,7 +86,7 @@ def minus(self, other):
"""
removed_indices = np.argwhere(
[idx in other.to_ids for idx in self.to_ids])
return Relations(
return OutgoingRelations(
np.delete(self.to_ids, removed_indices),
np.delete(self.weights, removed_indices))

Expand Down Expand Up @@ -157,7 +160,7 @@ def add_relations(self, from_ids, to_ids, weights=1):

self.grouped = utils.merge_2_dicts(
self.grouped,
Relations.from_tuples(from_ids, to_ids, weights),
OutgoingRelations.from_tuples(from_ids, to_ids, weights),
lambda r1, r2: r1.plus(r2))

def add_grouped_relations(self, from_ids, grouped_ids):
Expand Down Expand Up @@ -185,7 +188,7 @@ def remove_relations(self, from_ids, to_ids):

self.grouped = utils.merge_2_dicts(
self.grouped,
Relations.from_tuples(from_ids, to_ids, weights=0),
OutgoingRelations.from_tuples(from_ids, to_ids, weights=0),
lambda r1, r2: r1.minus(r2))

def get_relations(self, from_ids=None):
Expand Down
44 changes: 27 additions & 17 deletions trumania/core/util_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,33 @@ def merge_2_dicts(dict1, dict2, value_merge_func=None):
if dict1 is None:
return dict2

def merged_value(key):
if key not in dict1:
return dict2[key]
elif key not in dict2:
return dict1[key]
else:
if value_merge_func is None:
raise ValueError(
"Conflict in merged dictionaries: merge function not "
"provided but key {} exists in both dictionaries".format(
key))

return value_merge_func(dict1[key], dict2[key])

keys = set(dict1.keys()) | set(dict2.keys())

return {key: merged_value(key) for key in keys}
if dict1 == dict2:
for k, v in dict1.items():
dict1[k] = value_merge_func(v, v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to call value_merge_func() ? why not just returning dict1 ? I just did a quick check, python's == seems to honor __eq__() on the dictionnary values => if == returns true, that means those dicts are equal and I think do not need to be merged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the merge function is for example add 1, you still need to apply the merge function.

return dict1

dict1_set = set(dict1)
dict2_set = set(dict2)

keys_to_merge = dict1_set.intersection(dict2_set)

if len(keys_to_merge) != 0 and value_merge_func is None:
raise ValueError(
"Conflict in merged dictionaries: merge function not "
"provided but keys {} exists in both dictionaries".format(
keys_to_merge))

values_merged = dict()

for key_to_merge in keys_to_merge:
old_value1 = dict1[key_to_merge]
old_value2 = dict2[key_to_merge]

new_value = value_merge_func(old_value1, old_value2)

values_merged[key_to_merge] = new_value

return {**dict1, **dict2, **values_merged}


def df_concat(d1, d2):
Expand Down