Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,17 @@ def guard(clusters):
# Separate out the indirect ConditionalDimensions, which only serve
# the purpose of protecting from OOB accesses
cds = [d for d in cds if not d.indirect]
modes = [cd.relation for cd in cds]
if modes.count('strict') > 1:
raise CompilationError("Only one `strict` condition"
"can be used in an equation")
elif 'strict' in modes:
mode = 'strict'
else:
mode = sympy.And if sympy.And in modes else sympy.Or

# Chain together all `cds` conditions from all expressions in `c`
guards = {}
mode = sympy.Or
for cd in cds:
# `BOTTOM` parent implies a guard that lives outside of
# any iteration space, which corresponds to the placeholder None
Expand All @@ -279,7 +286,6 @@ def guard(clusters):

# Pull `cd` from any expr
condition = guards.setdefault(k, [])
mode = mode and cd.relation
for e in exprs:
try:
condition.append(e.conditionals[cd])
Expand All @@ -296,7 +302,10 @@ def guard(clusters):

# Combination `mode` is And by default.
# If all conditions are Or then Or combination `mode` is used.
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
if mode == 'strict':
guards = {d: v[0] for d, v in guards.items()}
else:
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}

# Construct a guarded Cluster
processed.append(c.rebuild(exprs=exprs, guards=guards))
Expand Down
32 changes: 26 additions & 6 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
)
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
from devito.types import (
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
)

__all__ = [
'ClusterizedEq',
Expand Down Expand Up @@ -222,7 +224,7 @@ def __new__(cls, *args, **kwargs):
relations=ordering.relations, mode='partial')
ispace = IterationSpace(intervals, iterators)

# Construct the conditionals and replace the ConditionalDimensions in `expr`
# Construct the conditionals

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should place this whole block of code, which constructs/lowers the conditionals, into its own separate functions, and a docstring with some examples

conditionals = {}
for d in ordering:
if not d.is_Conditional:
Expand All @@ -234,13 +236,31 @@ def __new__(cls, *args, **kwargs):
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
conditionals[d] = cond

# Merge conditionals when possible. E.g if we have an implicit_dim

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this block imho deserves its own function

# and there is a dimension with the same parent, we ca merged

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension

"ca merged"

"their conditions"

you could also make the example a bit more practical

# its condition
for d in input_expr.implicit_dims:
if d not in conditionals:
continue
for cd in dict(conditionals):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(...) is fine

if cd.parent == d.parent and cd != d:
cond = conditionals.pop(d)
if d.relation == 'strict':
conditionals[cd] = conditionals[d] = cond
else:
mode = cd.relation and d.relation
conditionals[cd] = mode(cond, conditionals[cd])
break

# Replace the ConditionalDimensions in `expr`
for d, cond in conditionals.items():
# Replace dimension with index
index = d.index
if d.condition is not None and d in expr.free_symbols:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})

conditionals = frozendict(conditionals)
index = index - relational_min(cond, d.parent)
shift = relational_shift(cond, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)
Expand Down
84 changes: 58 additions & 26 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from sympy.logic.boolalg import BooleanFunction

from devito.ir.support.space import Forward, IterationDirection
from devito.symbolics import CondEq, CondNe, search
from devito.symbolics import CondEq, CondNe, IntDiv, search
from devito.symbolics.manipulation import _uxreplace_handle, _uxreplace_registry
from devito.tools import Pickable, as_tuple, frozendict, split
from devito.types import Dimension, LocalObject

Expand All @@ -31,6 +32,34 @@
]


@singledispatch
def bound_index(expr, dim, dir):
if dir == Forward:
return expr._subs(dim, dim + 1)
else:
return expr._subs(dim, dim - 1)


@bound_index.register(Expr)
def _(expr, dim, dir):
if not expr.args:
if dir == Forward:
return expr._subs(dim, dim + 1)
else:
return expr._subs(dim, dim - 1)
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])


@bound_index.register(IntDiv)
def _(expr, dim, dir):
v = dim.symbolic_factor
p0 = dim.root
if dir == Forward:
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
else:
return (p0 - 1) - abs(p0 - 1) % v


class AbstractGuard:
pass

Expand Down Expand Up @@ -138,37 +167,29 @@ class BaseGuardBoundNext(Guard, Pickable):
given `direction`.
"""

__rargs__ = ('d', 'direction')
__rargs__ = ('d', 'index', 'direction')
__rkwargs__ = ('d_min', 'd_max')

def __new__(cls, d, direction, **kwargs):
def __new__(cls, d, index, direction,
d_min=None, d_max=None, **kwargs):
assert isinstance(d, Dimension)
assert isinstance(direction, IterationDirection)

if direction == Forward:
p0 = d.root
p1 = d.root.symbolic_max
# Always take the next index in the iteration direction
next_index = bound_index(index, d, direction)

if d.is_Conditional:
v = d.symbolic_factor
# Round `p0 + 1` up to the nearest multiple of `v`
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
else:
p0 = p0 + 1
# The direction might be forward but accessing c - d
# making the access backward w.r.t
# Update direction according to access direction for valid guard
if index.has(-d):
direction = -direction

if direction == Forward:
p0 = next_index
p1 = d_max or d.root.symbolic_max
else:
p0 = d.root.symbolic_min
p1 = d.root

if d.is_Conditional:
v = d.symbolic_factor
# Round `p1 - 1` down to the nearest sub-multiple of `v`
# NOTE: we use ABS to make sure we handle negative values properly.
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
# as long as we get a negative number, rather than 0 and even if it's
# not `-v`, we're good
p1 = (p1 - 1) - abs(p1 - 1) % v
else:
p1 = p1 - 1
p0 = d_min if d_min is not None else d.root.symbolic_min
p1 = next_index

try:
if cls.__base__._eval_relation(p0, p1) is true:
Expand All @@ -180,12 +201,15 @@ def __new__(cls, d, direction, **kwargs):

obj.d = d
obj.direction = direction
obj.index = index
obj.d_min = d_min
obj.d_max = d_max

return obj

@property
def _args_rebuild(self):
return (self.d, self.direction)
return (self.d, self.index, self.direction)


class GuardBoundNextLe(BaseGuardBoundNext, Le):
Expand Down Expand Up @@ -541,3 +565,11 @@ def pairwise_or(*guards):
pass

return guard


_uxreplace_registry.register(BaseGuardBoundNext)


@_uxreplace_handle.register(BaseGuardBoundNext)
def _(expr, args, kwargs):
return expr.func(expr.d, expr.index, expr.direction, **kwargs)
8 changes: 8 additions & 0 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,14 @@ def __repr__(self):
def __hash__(self):
return hash(self._name)

def __neg__(self):
if self._name == '++':
return Backward
elif self._name == '--':
return Forward
else:
return Any


Forward = IterationDirection('++')
"""Forward iteration direction ('++')."""
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/support/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __lt__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -164,6 +165,7 @@ def __gt__(self, other):
return True
elif q_negative(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -203,6 +205,7 @@ def __le__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

# Note: unlike `__lt__`, if we end up here, then *it is* <=. For example,
Expand Down
Loading
Loading