[PYTHON] Perform path calculation on 2D grid with Networkx

It seemed easy to use the Python library Networkx to do a simple path calculation, so I tried using it.

First, consider the following two-dimensional grid environment. The grid size is 20x20.

import networkx as nx
import matplotlib.pyplot as plt

ngrid = 20
gp = nx.grid_graph(dim=[ngrid, ngrid])
nx.draw(gp,
        pos=dict((n, n) for n in gp.nodes()),
        node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()],
        node_size=200)
plt.axis('equal')
plt.show()

2dgrid_20x20.png

I also wanted the grid-like diagonal paths to be OK, so I'll add diagonal edges to each node.

import networkx as nx
import matplotlib.pyplot as plt

def add_cross_edge(gp, shape):
    """
Add diagonal edges to a 2D Grid graph
    """
    for node in gp.nodes_iter():
        nx_node = (node[0] + 1, node[1] + 1)
        if nx_node[0] < shape[0] and nx_node[1] < shape[1]:
            gp.add_edge(node, nx_node)
        nx_node = (node[0] + 1, node[1] - 1)
        if nx_node[0] < shape[0] and nx_node[1] >= 0:
            gp.add_edge(node, nx_node)

ngrid = 20
gp = nx.grid_graph(dim=[ngrid, ngrid])
add_cross_edge(gp, [ngrid, ngrid])
nx.draw(gp,
        pos=dict((n, n) for n in gp.nodes()),
        node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()],
        node_size=200)
plt.axis('equal')
plt.show()

2dgrid_20x20_with_cross.png

Finally, set obstacles and start goals, and you're done building the environment itself. Here, the obstacles are colored black, the start is colored green, and the goal is colored red. Using this environment, we will calculate the path connecting the start and the goal while avoiding obstacles.

import networkx as nx
import matplotlib.pyplot as plt

def add_cross_edge(gp, shape):
    """
Add diagonal edges to a 2D Grid graph
    """
    for node in gp.nodes_iter():
        nx_node = (node[0] + 1, node[1] + 1)
        if nx_node[0] < shape[0] and nx_node[1] < shape[1]:
            gp.add_edge(node, nx_node)
        nx_node = (node[0] + 1, node[1] - 1)
        if nx_node[0] < shape[0] and nx_node[1] >= 0:
            gp.add_edge(node, nx_node)

ngrid = 20
gp = nx.grid_graph(dim=[ngrid, ngrid])
add_cross_edge(gp, [ngrid, ngrid])
idcs = np.random.choice(len(gp.nodes()), int(ngrid * ngrid * 0.2), replace=False)
#Set start / goal / obstacle
st, gl, obs = gp.nodes()[idcs[0]], gp.nodes()[idcs[1]], [gp.nodes()[i] for i in idcs[2:]]
gp.node[st]['color'] = 'green'
gp.node[gl]['color'] = 'red'
for o in obs:
    gp.node[o]['color'] = 'black'
nx.draw(gp,
        pos=dict((n, n) for n in gp.nodes()),
        node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()],
        node_size=200)
plt.axis('equal')
plt.show()

2dgrid_20x20_full.png

Then set a Weight between each edge to avoid obstacles. I try to impose a penalty when the node is wearing an obstacle.

import numpy as np

def cost(a, b):
    """
Cost function
    """
    x1 = np.array(a, dtype=np.float32)
    x2 = np.array(b, dtype=np.float32)
    dist = np.linalg.norm(x1 - x2)
    if any([(a == o or b == o) for o in obs]):
        penalty = 1.0e6
    else:
        penalty = 0.0
    return dist + penalty

for u, v, d in gp.edges_iter(data=True):
    d['weight'] = cost(u, v)

There are several path search functions in Networkx, but this time we will use the A * algorithm. The A * algorithm requires something called a heuristic function, which is set as follows.

def dist(a, b):
    """
Heuristic function
    """
    x1 = np.array(a, dtype=np.float32)
    x2 = np.array(b, dtype=np.float32)
    return np.linalg.norm(x1 - x2)

path = nx.astar_path(gp, st, gl, dist)
length = nx.astar_path_length(gp, st, gl, dist)
print(path)
print(length)

Finally, let's draw the path search result. The nodes on the path are shown in blue.

for p in path[1:-1]:
    if gp.node[p].get('color', '') == 'black':
        print('Invalid path')
        continue
    gp.node[p]['color'] = 'blue'

nx.draw(gp,
        pos=dict((n, n) for n in gp.nodes()),
        node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()],
        node_size=200)
plt.axis('equal')
plt.show()

2dgrid_20x20_astar_path.png

The whole code looks like this:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

def add_cross_edge(gp, shape):
    """
Add diagonal edges to a 2D Grid graph
    """
    for node in gp.nodes_iter():
        nx_node = (node[0] + 1, node[1] + 1)
        if nx_node[0] < shape[0] and nx_node[1] < shape[1]:
            gp.add_edge(node, nx_node)
        nx_node = (node[0] + 1, node[1] - 1)
        if nx_node[0] < shape[0] and nx_node[1] >= 0:
            gp.add_edge(node, nx_node)

ngrid = 20
gp = nx.grid_graph(dim=[ngrid, ngrid])
add_cross_edge(gp, [ngrid, ngrid])
idcs = np.random.choice(len(gp.nodes()), int(ngrid * ngrid * 0.2), replace=False)
#Set start / goal / obstacle
st, gl, obs = gp.nodes()[idcs[0]], gp.nodes()[idcs[1]], [gp.nodes()[i] for i in idcs[2:]]
gp.node[st]['color'] = 'green'
gp.node[gl]['color'] = 'red'
for o in obs:
    gp.node[o]['color'] = 'black'

def dist(a, b):
    """
Heuristic function
    """
    x1 = np.array(a, dtype=np.float32)
    x2 = np.array(b, dtype=np.float32)
    return np.linalg.norm(x1 - x2)

def cost(a, b, k1=1.0, k2=10.0, kind='intsct'):
    """
Cost function
    """
    x1 = np.array(a, dtype=np.float32)
    x2 = np.array(b, dtype=np.float32)
    dist = np.linalg.norm(x1 - x2)
    if any([(a == o or b == o) for o in obs]):
        penalty = 1.0e6
    else:
        penalty = 0.0
    return dist + penalty

for u, v, d in gp.edges_iter(data=True):
    d['weight'] = cost(u, v)

path = nx.astar_path(gp, st, gl, dist)
length = nx.astar_path_length(gp, st, gl, dist)
print(path)
print(length)
for p in path[1:-1]:
    if gp.node[p].get('color', '') == 'black':
        print('Invalid path')
        continue
    gp.node[p]['color'] = 'blue'

nx.draw(gp,
        pos=dict((n, n) for n in gp.nodes()),
        node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()],
        node_size=200)
plt.axis('equal')
plt.show()

Networkx is convenient. The environment with obstacles in the 2D grid set this time seems to be usable for reinforcement learning.

Recommended Posts

Perform path calculation on 2D grid with Networkx
Perform DFT calculation with ASE and GPAW
Immediate page rank calculation (with commentary on all lines)