python comprehension with multiple ‘for’ clauses and single ‘if’

Imagine a discrete x,y,z space : I am trying to create an iterator which will return all points which lie within a sphere of some radial distance from a point.

My approach was to first look at all points within a larger cube which is guaranteed to contain all the points needed and then cull or skip points which are too far away.

My first attempt was:

#this doesn't work
it_0=((x+xp,y+yp,z+zp) for xp in range(-dist,dist+1) for yp in range(-dist,dist+1) for zp in range(-dist,dist+1) if ( ((x-xp)**2+(y-yp)**2+(z-zp)**2) <= dist**2+sys.float_info.epsilon ) )

a simple

for d,e,f in it_0:
    print( ((x-d)**2+(y-e)**2+(z-f)**2) <= dist**2+sys.float_info.epsilon,  d,e,f)

verifies that it_0 does not produce correct results. I believe it is applying the conditional only to the third (ie: z) ‘for’ clause

The following works:

it_1=((x+xp,y+yp,z+zp) for xp in range(-dist,dist+1) for yp in range(-dist,dist+1) for zp in range(-dist,dist+1))
it_2=filter( lambda p: ((x-p[0])**2+(y-p[1])**2+(z-p[2])**2) <= dist**2+sys.float_info.epsilon, it_1)

It collects all the points, then filter those which don’t fit the conditional.

I was hoping there might be a way to correct the first attempted implementation, or make these expressions more readable or compact.

Best answer

First of all, I suggest you replace the triply-nested for loop with itertools.product(), like so:

import itertools as it
it_1 = it.product(range(-dist, dist+1), repeat=3)

If you are using Python 2.x, you should use xrange() here instead of range().

Next, instead of using filter() you could just use a generator expression:

it_2=(x, y, z for x, y, z in it_1 if ((x-p[0])**2+(y-p[1])**2+(z-p[2])**2) <= dist**2+sys.float_info.epsilon)

This would avoid some overhead in Python 2.x (since filter() builds a list), but for Python 3.x would be about the same; and even in Python 2.x you could use itertools.ifilter().

But for readability, I would package the whole thing up into a generator, like so:

import itertools as it
import sys

def sphere_points(radius=0, origin=(0,0,0), epsilon=sys.float_info.epsilon):
    x0, y0, z0 = origin
    limit = radius**2 + epsilon
    for x, y, z in it.product(range(-radius, radius+1), repeat=3):
        if (x**2 + y**2 + z**2) <= limit:
            yield (x+x0, y+y0, z+z0)

I just changed the code from your original code. Each range for x, y, and z is adjusted to center around the origin point. When I test this code with a radius of 0, I correctly get back a single point, the origin point.

Note that I provided arguments to the function letting you specify radius, origin point, and even the value to use for epsilon, with defaults for each. I also unpacked the origin point tuple into explicit variables; I’m not sure if Python would optimize away the indexing operation or not, but this way we know there won’t be any indexing going on inside the loop. (I think the Python compiler would probably hoist the limit calculation out of the loop, but I actually prefer it on its own line as shown here, for readability.)

I think the above is about as fast as you can write it in native Python, and I think it is a big improvement in readability.

P.S. This code would probably run a lot faster if it was redone using Cython.

EDIT: Code simplified as suggested by @eryksun in comments.