Advent of Code 2022 in pure TensorFlow - Day 12


Solving problem 12 of the AoC 2022 in pure TensorFlow is a great exercise in graph theory and more specifically in using the Breadth-First Search (BFS) algorithm. This problem requires working with a grid of characters representing a graph, and the BFS algorithm allows us to traverse the graph in the most efficient way to solve the problem.

Day 12: Hill Climbing Algorithm

You can click on the title above to read the full text of the puzzle. The TLDR version is: you are given a grid representing a maze, where each cell contains a letter from the English alphabet (lowercase), an ‘S’ to indicate the starting point, or an ‘E’ to indicate the ending point. The goal is to find the shortest path from the starting point to the ending point, following specific rules for navigating the maze.

Here’s an example of the input grid:

Sabc
defg
hEij

To move from the starting point to the ending point, you can only move to cells with the next letter in alphabetical order. In this case, the shortest path would be “S -> a -> b -> c -> d -> e -> f -> g -> h -> E”, with a total of 9 steps.

NOTE: the goal is not to reach precisely the endpoint, you need to reach a point at the same elevation of E (in the input data, z, for the example above h).

Part 2 of this problem can be designed as the inverse problem: you start from the E point and you need to reach at point at the same elevation of S (thus, any possible a value in the grid) via the shortest path.

Design Phase

The problem can be tackled using a Breadth-First Search (BFS) algorithm to traverse the graph represented by the input grid. The BFS algorithm is ideal for this task as it allows us to explore all possible paths in the most efficient way, ensuring that we find the shortest path.

We’ll implement the BFS algorithm using TensorFlow’s tf.queue.FIFOQueue to maintain the order of the nodes to visit. In addition, we’ll use a visited tensor to keep track of the cells we’ve already visited, which will help us avoid visiting the same cell multiple times and prevent infinite loops.

The provided Python code supports solving both part 1 and part 2 of the problem, with slight differences in the BFS traversal. The main difference between the two parts is the condition for moving from one cell to another. In part 1, you can only move to cells with the next letter in alphabetical order, while in part 2, you can move to cells with the previous letter in alphabetical order.

Part 1 and Part 2 Solution

The code below contains the main function main that reads the input file and sets up the input. We first preprocess the input data by converting the characters to integers for easier processing. We create a lookup table to map characters to integers and apply this mapping to the dataset.

dataset = tf.data.TextLineDataset(input_path.as_posix())
dataset = dataset.map(tf.strings.bytes_split)

keys_tensor = tf.concat(
    [tf.strings.bytes_split(string.ascii_lowercase), tf.constant(["S", "E"])],
    axis=0,
)
values_tensor = tf.concat([tf.range(0, 26), tf.constant([-1, 26])], axis=0)
lut = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
    default_value=-1,
)

dataset = dataset.map(lut.lookup)

grid = tf.convert_to_tensor(list(dataset))

The grid tensor now contains our 2D world. We can now go straight to the BFS implementation. Implementing the BFS algorithm requires just a simple data structure (a queue), and a support variable (visited) that we use to keep track of the already visited neighbors and, thus, avoid useless recomputions.

queue = tf.queue.FIFOQueue(
    tf.cast(tf.reduce_prod(tf.shape(grid)), tf.int32),
    tf.int32,
    (3,),  # x,y,distance
)

visited = tf.Variable(tf.zeros_like(grid))

The bfs function is the core of our solution. This function takes an optional argument part2 , which is set to False by default for solving part 1. To solve part 2, we simply call the function with part2=True.

The BFS algorithm starts by enqueuing the starting point (or the ending point for part 2) into the queue, along with an initial distance of 0. Then, while the queue is not empty, we dequeue the next cell to visit, along with its distance from the starting point. We then check if this cell has been visited before. If it has not been visited, we update the visited tensor and check if the dequeued cell is the destination (either ‘E’ for part 1 or ‘S’ for part 2). If it is the destination, we return the distance as the shortest path length. Otherwise, we continue exploring the neighboring cells that satisfy the condition for traversal, depending on the part we are solving.

Of course, working on a 2D world we need to be able to move and “look around”. We can thus define a _neighs function that given a point on the 2D grid, gives us the 4-neighbors.

@tf.function
def _neighs(grid: tf.Tensor, center: tf.Tensor):
    y, x = center[0], center[1]

    shape = tf.shape(grid) - 1

    if tf.logical_and(tf.less(y, 1), tf.less(x, 1)):  # 0,0
        mask = tf.constant([(1, 0), (0, 1)])
    elif tf.logical_and(tf.equal(y, shape[0]), tf.equal(x, shape[1])):  # h,w
        mask = tf.constant([(-1, 0), (0, -1)])
    elif tf.logical_and(tf.less(y, 1), tf.equal(x, shape[1])):  # top right
        mask = tf.constant([(0, -1), (1, 0)])
    elif tf.logical_and(tf.less(x, 1), tf.equal(y, shape[0])):  # bottom left
        mask = tf.constant([(-1, 0), (0, 1)])
    elif tf.less(x, 1):  # left
        mask = tf.constant([(1, 0), (-1, 0), (0, 1)])
    elif tf.equal(x, shape[1]):  # right
        mask = tf.constant([(-1, 0), (1, 0), (0, -1)])
    elif tf.less(y, 1):  # top
        mask = tf.constant([(0, -1), (0, 1), (1, 0)])
    elif tf.equal(y, shape[0]):  # bottom
        mask = tf.constant([(0, -1), (0, 1), (-1, 0)])
    else:  # generic
        mask = tf.constant([(-1, 0), (0, -1), (1, 0), (0, 1)])

    coords = center + mask
    neighborhood = tf.gather_nd(grid, coords)
    return neighborhood, coords

The function is pretty borind to read: it handles all the cases in which the passed center parameter is a point along the border of the grid.

Breadth-First Search using tf.queue.FIFOQueue

The key to our BFS implementation is the use of the tf.queue.FIFOQueue for maintaining the order of the nodes to visit. The FIFO (first-in, first-out) property ensures that we visit the nodes in the correct order, always visiting the closest nodes to the starting point first. This guarantees that we find the shortest path to the destination.

We initialize the queue with the starting point (or the ending point for part 2) and its distance from the starting point. While the queue is not empty, we dequeue the next cell to visit, along with its distance. We then check if the cell has been visited before and update the visited tensor accordingly. If the dequeued cell is the destination, we return the distance as the shortest path length. Otherwise, we enqueue the neighboring cells that satisfy the condition for traversal, along with their distances from the starting point.

def bfs(part2=tf.constant(False)):
    if tf.logical_not(part2):
        start = tf.cast(tf.where(tf.equal(grid, -1))[0], tf.int32)
        queue.enqueue(tf.concat([start, tf.constant([0])], axis=0))
        dest_val = 25

        def condition(n_vals, me_val):
            return tf.where(tf.less_equal(n_vals, me_val + 1))

    else:
        end = tf.cast(tf.where(tf.equal(grid, 26)), tf.int32)[0]
        queue.enqueue(tf.concat([end, tf.constant([0])], axis=0))
        dest_val = 1

        def condition(n_vals, me_val):
            return tf.where(tf.greater_equal(n_vals, me_val - 1))

    while tf.greater(queue.size(), 0):
        v = queue.dequeue()
        me, distance = v[:2], v[2]
        me_val = tf.gather_nd(grid, [me])
        already_visited = tf.squeeze(tf.cast(tf.gather_nd(visited, [me]), tf.bool))
        if tf.logical_not(already_visited):
            if tf.reduce_all(tf.equal(me_val, dest_val)):
                return distance - 1
            visited.assign(tf.tensor_scatter_nd_add(visited, [me], [1]))

            n_vals, n_coords = _neighs(grid, me)
            potential_dests = tf.gather_nd(
                n_coords,
                condition(n_vals, me_val),
            )

            not_visited = tf.equal(tf.gather_nd(visited, potential_dests), 0)
            neigh_not_visited = tf.gather_nd(potential_dests, tf.where(not_visited))

            to_visit = tf.concat(
                [
                    neigh_not_visited,
                    tf.reshape(
                        tf.repeat(distance + 1, tf.shape(neigh_not_visited)[0]),
                        (-1, 1),
                    ),
                ],
                axis=1,
            )
            queue.enqueue_many(to_visit)

    return -1

The use of tf.queue.FIFOQueue in our BFS implementation allows us to efficiently explore the graph while maintaining the traversal order, enabling us to find the shortest path between the starting point and the destination.

Finally, we call the bfs function for both parts, reset the queue and visited tensor in between, and print the results.

tf.print("Steps: ", bfs())
queue.dequeue_many(queue.size())
visited.assign(tf.zeros_like(visited))

tf.print("Part 2: ", bfs(True))

Conclusion

You can the solution in folder 12 in the dedicated GitHub repository (in the 2022 folder): https://github.com/galeone/tf-aoc.

In this article, we have demonstrated how to solve problem 12 of the AoC 2022 using TensorFlow, focusing on the implementation of the Breadth-First Search algorithm with tf.queue.FIFOQueue. We have also shown how the provided Python code supports solving both part 1 and part 2 of the problem, highlighting the differences between the two parts. The BFS algorithm, along with the use of TensorFlow’s tf.queue.FIFOQueue, provides an efficient and elegant solution to this graph traversal problem.

If you missed the article about the previous days’ solutions, here’s a handy list

For any feedback or comment, please use the Disqus form below - thanks!

Don't you want to miss the next article? Do you want to be kept updated?
Subscribe to the newsletter!

Related Posts

Building a RAG for tabular data in Go with PostgreSQL & Gemini

In this article we explore how to combine a large language model (LLM) with a relational database to allow users to ask questions about their data in a natural way. It demonstrates a Retrieval-Augmented Generation (RAG) system built with Go that utilizes PostgreSQL and pgvector for data storage and retrieval. The provided code showcases the core functionalities. This is an overview of how the "chat with your data" feature of fitsleepinsights.app is being developed.

Using Gemini in a Go application: limits and details

This article explores using Gemini within Go applications via Vertex AI. We'll delve into the limitations encountered, including the model's context window size and regional restrictions. We'll also explore various methods for feeding data to Gemini, highlighting the challenges faced due to these limitations. Finally, we'll briefly introduce RAG (Retrieval-Augmented Generation) as a potential solution, but leave its implementation details for future exploration.

Custom model training & deployment on Google Cloud using Vertex AI in Go

This article shows a different approach to solving the same problem presented in the article AutoML pipeline for tabular data on VertexAI in Go. This time, instead of relying on AutoML we will define the model and the training job ourselves. This is a more advanced usage that allows the experienced machine learning practitioner to have full control on the pipeline from the model definition to the hardware to use for training and deploying. At the end of the article, we will also see how to use the deployed model. All of this, in Go and with the help of Python and Docker for the custom training job definition.

Integrating third-party libraries as Unreal Engine plugins: solving the ABI compatibility issues on Linux when the source code is available

In this article, we will discuss the challenges and potential issues that may arise during the integration process of a third-party library when the source code is available. It will provide guidance on how to handle the compilation and linking of the third-party library, manage dependencies, and resolve compatibility issues. We'll realize a plugin for redis plus plus as a real use case scenario, and we'll see how tough can it be to correctly compile the library for Unreal Engine - we'll solve every problem step by step.

AutoML pipeline for tabular data on VertexAI in Go

In this article, we delve into the development and deployment of tabular models using VertexAI and AutoML with Go, showcasing the actual Go code and sharing insights gained through trial & error and extensive Google research to overcome documentation limitations.

Advent of Code 2022 in pure TensorFlow - Day 11

In this article, we'll show how to solve problem 11 from the Advent of Code 2022 (AoC 2022) using TensorFlow. We'll first introduce the problem and then provide a detailed explanation of our TensorFlow solution. The problem at hand revolves around the interactions of multiple monkeys inspecting items, making decisions based on their worry levels, and following a set of rules.

Advent of Code 2022 in pure TensorFlow - Day 10

Solving problem 10 of the AoC 2022 in pure TensorFlow is an interesting challenge. This problem involves simulating a clock signal with varying frequencies and tracking the state of a signal-strength variable. TensorFlow's ability to handle complex data manipulations, control structures, and its @tf.function decorator for efficient execution makes it a fitting choice for tackling this problem. By utilizing TensorFlow's features such as Dataset transformations, efficient filtering, and tensor operations, we can create a clean and efficient solution to this intriguing puzzle.

Advent of Code 2022 in pure TensorFlow - Day 9

In this article, we'll show two different solutions to the Advent of Code 2022 day 9 problem. Both of them are purely TensorFlow solutions. The first one, more traditional, just implement a solution algorithm using only TensorFlow's primitive operations - of course, due to some TensorFlow limitations this solution will contain some details worth reading (e.g. using a pairing function for being able to use n-dimensional tf.Tensor as keys for a mutable hashmap). The second one, instead, demonstrates how a different interpretation of the problem paves the way to completely different solutions. In particular, this solution is Keras based and uses a multi-layer convolutional model for modeling the rope movements.

Advent of Code 2022 in pure TensorFlow - Day 8

Solving problem 8 of the AoC 2022 in pure TensorFlow is straightforward. After all, this problem requires working on a bi-dimensional grid and evaluating conditions by rows or columns. TensorFlow is perfectly suited for this kind of task thanks to its native support for reduction operators (tf.reduce) which are the natural choice for solving problems of this type.

Advent of Code 2022 in pure TensorFlow - Day 7

Solving problem 7 of the AoC 2022 in pure TensorFlow allows us to understand certain limitations of the framework. This problem requires a lot of string manipulation, and TensorFlow (especially in graph mode) is not only not easy to use when working with this data type, but also it has a set of limitations I'll present in the article. Additionally, the strings to work with in problem 7 are (Unix) paths. TensorFlow has zero support for working with paths, and thus for simplifying a part of the solution, I resorted to the pathlib Python module, thus not designing a completely pure TensorFlow solution.