Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 26 additions & 77 deletions junctiontree/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,25 @@ def get_message(sepset_ix, tree, beliefs, clique_vars):
for var in clique_vars[ss_ix]
]

neighbor_vars = list(set(neighbor_vars))
neighbor_vars = np.unique(neighbor_vars)

# multiply neighbor messages
messages = messages if len(messages) else [1]

msg_prod = dl.einsum(
*messages,
neighbor_vars
msg_prod = 1 if len(messages) == 0 else dl.einsum(
*messages,
neighbor_vars
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this special case 1 related only to the normal sum-product? If some other distributive law was used, would the value be different then? If so, should this if-else be inside the specific "sum-product distributive law" so it wouldn't affect all distributive laws? Or, alternatively, should each distributive law define an identity element ("empty value") that is used in this kind of cases? My gut feeling is that this latter option would make perfect sense.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, now that this if-else is only here, it won't fix all other places where dl.einsum is used. I suppose this special-case handling should be used always when dl.einsum is used, right? Therefore, even more so, I would suggest defining an identity element for the distributive laws.

Does this make sense or am I misunderstanding something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been trying to come up with a simple unit test for the inconsistent results that were identified in the original issue. However, the issue derives from the lack of a guaranteed order when creating Python sets. There is an assumption about the ordering of the indices within the remove_messages function implementation with regard to the ordering of the indices. The use of the set function prior to calling remove_messages produces inconsistent behavior. So, the incorrect results are only produced occasionally.

There are 2 possible fixes:

  1. Use NumPy's set functions in place of the standard Python set function so that the indices are always ordered prior to being used as an argument to remove_message.

  2. Rely on numpy.einsum to compute the product of the messages with the message from the current neighbor being excluded. I originally wrote the remove_message function in order to avoid repeated calculations of the product of the messages. But the complexity of the code for performing the calculation to divide out the neighbor's message may not be worth any small efficiency gains from avoiding the repeated calculations.

My commit included both replacing the standard set function calls with NumPy set implementations and removing the remove_message function to simply the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an additional commit in this branch to support an identity element for SumProduct distributive law

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! If it's not easy to add a unit test, then that can be left out. 👍

Your most recent commit added new files computation.py and sum_product.py that are in the root of the repository, not inside junctiontree package. I suppose this was a mistake, right? Being outside the package, they aren't used anywhere. Perhaps you meant to replace junctiontree/computation.py and junctiontree/sum_product.py?

)



args = [msg_prod, neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[sepset_ix]]

# compute message as marginalization over non-sepset values
# multiplied by product of messages with output being vars in input sepset

message = dl.einsum(*args)


try:
# attempt to update belief
beliefs[sepset_ix] = message
Expand All @@ -96,47 +99,6 @@ def get_message(sepset_ix, tree, beliefs, clique_vars):
return None


def remove_message(msg_prod, prod_ixs, msg, msg_ixs, out_ixs):
'''Removes (divides out) sepset message from
product of all neighbor sepset messages for a clique

:param msg_prod: product of all messages for clique
:param prod_ixs: variable indices in clique
:param msg: sepset message to be removed from product
:param msg_ixs: variable indices in sepset
:param out_ixs: variables indices expected in result
:return: the product of messages with sepset msg removed (divided out)
'''

exp_mask = np.in1d(prod_ixs, msg_ixs)

# use mask to specify expanded dimensions in message
exp_ixs = np.full(msg_prod.ndim, None)
exp_ixs[exp_mask] = slice(None)

# use mask to select slice dimensions
slice_mask = np.in1d(prod_ixs, out_ixs)
slice_ixs = np.full(msg_prod.ndim, slice(None))
slice_ixs[~slice_mask] = 0

if all(exp_mask) and msg_ixs != prod_ixs:
# axis must be labeled starting at 0
var_map = {var:i for i, var in enumerate(set(msg_ixs + prod_ixs))}

# axis must be re-ordered if all variables shared but order is different
msg = np.moveaxis(msg, [var_map[var] for var in prod_ixs], [var_map[var] for var in msg_ixs])

# create dummy dimensions for performing division (with exp_ix)
# slice out dimensions of sepset variables from division result (with slice_ixs)
return np.divide(
msg_prod,
msg[ tuple(exp_ixs) ],
out = np.zeros_like(msg_prod),
where = msg[ tuple(exp_ixs) ] != 0
)[ tuple(slice_ixs) ]



def send_message(message, sepset_ix, tree, beliefs, clique_vars):
'''Sends message from clique at root of tree

Expand All @@ -157,13 +119,14 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars):
# adding message sent
] + [message, clique_vars[sepset_ix]]


all_neighbor_vars = [
var
for vars in messages[1::2]
for var in vars
]

neighbor_vars = list(set(all_neighbor_vars))
neighbor_vars = np.unique(all_neighbor_vars)

# multiply neighbor messages
msg_prod = dl.einsum(
Expand All @@ -175,34 +138,20 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars):
ss_num = 0
for ss_ix, subtree in tree[1:]:

# divide product of messages by current sepset message for this neighbor
output_vars = list(
set(
[
var
for vars in messages[1::2][0:ss_num] + messages[1::2][ss_num+1:]
for var in vars
]
)
)

mask = np.in1d(
neighbor_vars,
output_vars

)

mod_neighbor_vars = np.array(neighbor_vars)[mask].tolist()

mod_msg_prod = remove_message(
msg_prod,
neighbor_vars,
beliefs[ss_ix],
clique_vars[ss_ix],
mod_neighbor_vars
)

args = [mod_msg_prod, mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]]
# remove sepset ix vars from neighbor vars
mod_neighbor_vars = np.setdiff1d(neighbor_vars, clique_vars[ss_ix])




# create product of messages that excludes the message from this sepset
mod_messages = [
comp
for i in range(1,len(messages), 2)
for comp in messages[i-1:i+1] if messages[i] != clique_vars[ss_ix]
]
args = [dl.einsum(*mod_messages, mod_neighbor_vars), mod_neighbor_vars] + [beliefs[tree[0]], clique_vars[tree[0]], clique_vars[ss_ix]]

# calculate message to be sent
message = dl.einsum( *args )

Expand All @@ -221,6 +170,7 @@ def send_message(message, sepset_ix, tree, beliefs, clique_vars):
clique_vars[tree[0]]
]


beliefs[tree[0]] = dl.einsum(*args)


Expand All @@ -242,7 +192,6 @@ def __run(tree, beliefs, clique_vars):

return beliefs


beliefs = [np.copy(p) for p in potentials]
return __run(tree, beliefs, clique_vars)