Advent of Code 2022 in pure TensorFlow - Day 9


In this article, we’ll show two different solutions to the Advent of Code 2022 day 9 problem. Both of them are purely TensorFlow solutions. The first one, more traditional, just implement a solution algorithm using only TensorFlow’s primitive operations - of course, due to some TensorFlow limitations this solution will contain some details worth reading (e.g. using a pairing function for being able to use n-dimensional tf.Tensor as keys for a mutable hashmap). The second one, instead, demonstrates how a different interpretation of the problem paves the way to completely different solutions. In particular, this solution is Keras based and uses a multi-layer convolutional model for modeling the rope movements.

Day 9: Rope Bridge

You can click on the title above to read the full text of the puzzle. The TLDR version is: we need to model a rope movement. In part 1, there are only 2 nodes in the rope: Head (H) and Tail (T). Part 2, makes the problem a little bit more complicated, asking us to model the movement of a rope made of 10 knots.

The rope moves following a series of motions (the puzzle input) and it always respects a simple rule: the head moves and the tail follows it. In the first part of the problem, where the rope is short and it’s made of only 2 knots, the head and tail must always be touching (diagonally adjacent and even overlapping both counts as touching).

How to model the movement in the 2 knots scenario is perfectly explained in the puzzle. We just report the relevant parts below:

....
.TH.
....

....
.H..
..T.
....

...
.H. (H covers T)
...

If the head is ever two steps directly up, down, left, or right from the tail, the tail must also move one step in that direction so it remains close enough:

.....    .....    .....
.TH.. -> .T.H. -> ..TH.
.....    .....    .....
...    ...    ...
.T.    .T.    ...
.H. -> ... -> .T.
...    .H.    .H.
...    ...    ...

Otherwise, if the head and tail aren’t touching and aren’t in the same row or column, the tail always moves one step diagonally to keep up:

.....    .....    .....
.....    ..H..    ..H..
..H.. -> ..... -> ..T..
.T...    .T...    .....
.....    .....    .....
.....    .....    .....
.....    .....    .....
..H.. -> ...H. -> ..TH.
.T...    .T...    .....
.....    .....    .....

The puzzle asks us to work out where the tail goes as the head follows a series of motions, assuming the head and the tail both start at the same position (the origin), overlapping.

Thus, given an input like

R 4
U 4
L 3
D 1
R 4
D 1
L 5
R 2

we need to simulate the rope movement and can count up all of the positions the tail visited at least once. In this diagram, s marks the starting position (which the tail also visited), and # marks other positions the tail visited:

..##..
...##.
.####.
....#.
s###..

So, there are 13 positions the tail visited at least once (the puzzle answer).

Imperative solution: prerequisites

We can solve both parts of the puzzle by the simple observation that a rope is made of a Head, a Tail, and a variable number of knots in between (part 1: 0, part 2: 8). In part 1, the tail always moves together with the head, but in part 2 we first need to move by 10 times before the tail starts moving.

This assumption allows us to define the problem structure in a more generic way. We’ll define a function (get_play(nodes)) that will model the rope movement depending on the number of nodes of the rope.

Before doing it, of course, we need a way to understand if two elements are close in the 2D grid. As usual, TensorFlow comes with a function ready to use for us. : tf.norm with its ord parameter set to numpy inf correctly implements the Chebyshev distance: tf.norm, that’s the perfect choice for measuring distances on a 2D grid.

def are_neigh(a, b):
    return tf.math.less_equal(tf.norm(a - b, ord=tf.experimental.numpy.inf), 1)

This function returns a boolean tensor if the two tensors a and b are neighbors according to the L∞ metric.

Since we are interested in keeping track of the position of the tail, we should find a way for saving all the visited positions. In pure Python, we can use a tuple with the (x,y) coordinates of the visited point as index of a map, or as elements of a set (since tuples are hashable). In pure TensorFlow this is not possible, since tf.Tensors are not hashable.

tf.lookup.experimental.MutableHashTable can be used to store only rank-0 elements as keys (e.g. the scalar value 1 is a valid key, but the tuple (1,2) is not). Thus, to workaround this issue, we need a way to map a pair of numbers into a scalar value.

In this way, we can use a MutableHashMap to store the visited coordinates. The tool that perfectly solves this problem is a Pairing function.

In particular, we implement the Cantor pairing function. This function assigns one natural number to each pair of natural numbers, however, for storing coordinates it’s not perfect. In fact, this function maps (a,b) to c, but it also maps (b,a) to c! Thus, we need to manually take care of this (we’ll do it in the get_play function).

Moreover, this bijection works on natural numbers, but coordinates can be negative, thus we need also a way to map integers to naturals.

# integers to naturals
def to_natural(z):
    if tf.greater_equal(z, 0):
        return tf.cast(2 * z, tf.int64)
    return tf.cast(-2 * z - 1, tf.int64)

This function maps an integer value z to a natural value. Thus, the container pairing function can be easily implemented by applying the definition found on Wikipedia.

def pairing_fn(i, j):
    i, j = to_natural(i), to_natural(j)
    return (i + j) * (i + j + 1) // 2 + j

Alright, we now have something to start developing our solution.

Imperative solution: input & play

Reading the input data is trivial. As usual, tf.data.Dataset simplifies this data transformation and allows us to obtain a tf.RaggedTensor containing a tf.string tensor with the direction, and a tf.int64 tensor with the amount.

dataset = (
    tf.data.TextLineDataset(input_path.as_posix())
    .map(lambda line: tf.strings.split(line, " "))
    .map(lambda pair: (pair[0], tf.strings.to_number(pair[1], tf.int64)))
)

In the previous section we prepared all we need to use a MutableHashTable using coordinates as key values, thus we can declare it

pos = tf.lookup.experimental.MutableHashTable(tf.int64, tf.int64, (-1, 0, 0))

The pos MutableHashTable is used in this way:

  1. Map the coordinates to a natural number
  2. Check if this number is present as a key in the map.
  3. If it’s not present, insert the tuple (1, x, y) as value. 1 means the point with coordinate (x,y) have been visited once.
  4. If it’s present, check if the first value of the tuple is 1. If not, set it to 2. In this case, we are handling the scenario in which we visited x,y and now we are visiting y,x.
  5. The x,y coordinates are stored only for debugging purposes, but they are de facto unused in this solution.

Alright, we now know how to use the pos hashtable correctly. However, as we know from the 3 years ago article Analyzing tf.function to discover AutoGraph strengths and subtleties we can’t declare a tf.Variable whenever we want while working in graph-mode because variables are special nodes in the graph (we want to define the variable once and use it). But we need a tf.Variable for modeling the rope. In fact, depending on the number of knots we should declare a different variable with a different shape, and this is not possible when working in graph mode.

For this reason, we define the get_play function as a configurator for the play function defined in its body (thus, we are defining and returning a closure). The get_play function scope defines a separate lexical environment we can use for creating a new tf.Variable whose lifetime is bounded with the lifetime of the closure returned. In short, every time we’ll call get_play(x) a new tf.Variable is created because a new tf.function-decorated function is created (automatically by its usage in pure static-graph mode because of tf.data.Dataset).

def get_play(nodes):
    rope = tf.Variable(tf.zeros((nodes, 2), tf.int64))

We can now define the play closure that implements the head-tail movement as described in the requirement.

def play(direction, amount):

    sign = tf.constant(-1, tf.int64)
    if tf.logical_or(tf.equal(direction, "U"), tf.equal(direction, "R")):
        sign = tf.constant(1, tf.int64)

    axis = tf.constant((0, 1), tf.int64)
    if tf.logical_or(tf.equal(direction, "R"), tf.equal(direction, "L")):
        axis = tf.constant((1, 0), tf.int64)

    for _ in tf.range(amount):
        rope.assign(tf.tensor_scatter_nd_add(rope, [[0]], [sign * axis]))
        for i in tf.range(1, nodes):
            if tf.logical_not(are_neigh(rope[i - 1], rope[i])):
                distance = rope[i - 1] - rope[i]

                rope.assign(
                    tf.tensor_scatter_nd_add(
                        rope, [[i]], [tf.math.sign(distance)]
                    )
                )

                if tf.equal(i, nodes - 1):
                    mapped = pairing_fn(rope[i][0], rope[i][1])
                    info = pos.lookup([mapped])[0]
                    visited, first_coord, second_coord = (
                        info[0],
                        info[1],
                        info[2],
                    )
                    if tf.equal(visited, -1):
                        # first time visited
                        pos.insert(
                            [mapped],
                            [
                                tf.stack(
                                    [
                                        tf.constant(1, tf.int64),
                                        rope[i][0],
                                        rope[i][1],
                                    ]
                                )
                            ],
                        )

    return 0

I suggest to the readers to take their time to go through this snippet. Of course the get_play function should return the closure, so the body of our configuration function ends with

return play

Imperative solution: conclusion

Both parts can be solved by instantiating two different graphs (through get_play(n)) and looping over the dataset, executing step by step the rope movements.

tf.print("Part 1: ")
pos.insert([pairing_fn(0, 0)], [(1, 0, 0)])
list(dataset.map(get_play(2)))
tail_positions = pos.export()[1]
visited_count = tf.reduce_sum(tail_positions[:, 0])
tf.print(visited_count)

tf.print("Part 2: ")
pos.remove(pos.export()[0])
pos.insert([pairing_fn(0, 0)], [(1, 0, 0)])
list(dataset.map(get_play(10)))
tail_positions = pos.export()[1]
visited_count = tf.reduce_sum(tail_positions[:, 0])
tf.print(visited_count)

So far so good, we’ve demonstrated how TensorFlow can be used (knowing how to use the paring functions to workaround, thus, some limitations) to solve this code challenge.

However, as anticipated, this article contains also a different solution implemented in a completely different way. In fact, the funny part about solving coding challenges is to see how a problem can be modeled differently, and how a different model can lead to a completely different (but still correct) solution!

Keras & convolutional solution

An alternative way to solve this puzzle is to use a convolutional neural network to simulate the rope movement. Observe that every non-head knot follows the knot in front of it and they will end up touching each other (diagonally adjacent and even overlapping both counts as touching, i.e., maximum metric is at most 1), and when a knot moves, it turns out that the new position would be touching the original position as well. Thus, we have the following two observations:

  • The movement of a non-head knot depends only on its position and the position of the knot.
  • T non-head knot will never be too far away from the knot in front of it. the maximum metric will be at most 2. In particular, we only need finitely many local patterns to describe the movement. For example, suppose we have two knots H and T. The following are a few patterns:
.H.
...
T..

.H.
...
.T.

.H.
.T.
...

..H
.T.
...

..H
...
T..

For each of these patterns, the tail will then be pulled to the center of the 3x3 grids. In fact, these are all of the patterns up to rotation and flipping. All these patterns can be encoded into a few 3x3 convolution kernels. Consider we are simulating on a 5x5 grid. The position of the head knot can be represented as a one-hot 5x5 array, e.g.

0 0 0 0 0
0 0 0 1 0
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

and the position of the tail may look something like the following:

0 0 0 0 0
0 0 0 0 0
0 1 0 0 0
0 0 0 0 0
0 0 0 0 0

These two arrays can be stacked into a 1x5x5x2 array (in channel last convention, the heading 1 is the batch size), which is the input of our convolution layer. Suppose we want to match the pattern:

...
..H
T..

Then we design a 3x3 convolution kernel with input channels=2 and output channels = 1. The first part of the kernel looks like

0 0 0
0 0 1
0 0 0

which apples to the first channel, and the second part of the kernel looks like

0 0 0
0 0 0
1 0 0

which applies to the second channel. The result looks like

0 0 0 0 0
0 0 2 0 0
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

which is the representation of the desired position of the tail in the next step. If an incompatible pattern is applied in a similar manner, say

...
T.H
...

The result would look like

0 0 0 0 0
0 0 1 0 0
0 0 1 0 0
0 0 0 0 0
0 0 0 0 0

That is, only the correct pattern would output value \(2\) at the desired position, all other positions would have value 0 or 1. By applying a bias value \(-1\) (subtract 1 from every output value), and applying the non-linear function ReLU, the output value \(2\) becomes \(1\) and all other values become \(0\). In principle, we can construct a convolution kernel for each pattern, and only one of them will match and all other patterns will output \(0\). Then we simply sum up all the outputs, then we will get a one-hot representation of the new tail position. The sum-up operation can again be considered as a 2d convolution with 1x1 kernel. We name this layer collect layer in contrast to the move layers that match different movement patterns. However, considering all rotations and flipping, there will be 25 patterns. We can further simplify the patterns. Consider the following 4 patterns:

.H.
...
T..

.H.
...
.T.

.H.
...
..T

.H.
.T.
...

Since there will be only one tail, we can combine the above 4 patterns into one pattern:

.H.
.T.
TTT

Therefore, we need only need 9 patterns. The original implementation groups 25 patterns into 9 patterns in a slightly different way. Finally, we need to track which grid has been visited by the tail knot, we use another layer to keep the information and update the information in a similar manner.

The code can be illustrated as followings, where L is the length of the rope:

# The following layers are for non-head knots movement
for i in range(1, L):
    # there are a 'move layer' and a 'collect layer' for each knot.
    move, collect = model.layers[i * 2 - 1:i * 2 + 1]
    W, b = move.get_weights()
    # First 1+L channels are unmodified.
    for t in range(1 + L):
        W[1, 1, t, t] = 2  # copy all, note that b=-1, so 1=2*1-1 unchanged.
    # The new position of knot j=i+1 depends on the current position of knot i, and knot j.
    j = i + 1  # knot j follows knot i
    # If knot i is adjacent to knot j(maximum distance<=1), then knot j stays the same position.
    W[:, :, i, 1 + L] = W[1, 1, j, 1 + L] = 1
    # the following kernels will match patterns like
    # X X X
    # _ _ _
    # _ i _
    # where knot j is at one of the X position and knot j is expected to moved to the center position.
    for n, k in enumerate([0, 2]):
        W[:, k, j, 1 + L + 1 + n] = W[1, 2 - k, i, 1 + L + 1 + n] = 1
        W[k, :, j, 1 + L + 3 + n] = W[2 - k, 1, i, 1 + L + 3 + n] = 1
    # the following kernels match the patterns like
    # j _ _
    # _ _ _
    # _ _ i
    # knot j is expected to moved to the center position.
    for n, (y, x) in enumerate(zip([0, 0, 2, 2], [0, 2, 0, 2])):
        W[y, x, j, 1 + L + 5 + n] = W[2 - y, 2 - x, i, 1 + L + 5 + n] = 1
    move.set_weights([W, b])
    # The collect layer collect the results matched by above patterns
    W, = collect.get_weights()
    # Copy the first 1+L channels, except channel j for knot j.
    for t in range(1 + L):
        W[..., t, t] = 1  # copy
    W[..., j, j] = 0
    # For channel j, sum up the last 9 channels. There will be exactly one position has value 1, and rest of the position are all 0.
    W[..., 1 + L:, j] = 1  # collect moves
    collect.set_weights([W])
# For the last layer, also collect the position of the tail. 0 represents and 1 represent unvisited.
# Because the non-linear function is relu, it will clip the negative values into 0.
W[..., 1 + L:, 0] = -1  # collect unvisited
collect.set_weights([W])

The head knot is moved in one of the up, down, left, right directions, according to the input data. We use a 3x3 2d convolution similar to the above move layer for each direction and rotate the kernel dynamically according to the input data. We simulate the rope movement on an NxN grid using the following code:

# %% run the simulation
state = tf.zeros((1, N, N, 1 + L), dtype=tf.float32).numpy()
# Starts with  every knot at the center position
state[0, N // 2, N // 2, :] = 1
# Every position is marked as unvisited.
state[..., 0] = 1 - state[..., 0]
for n, line in enumerate(open('input.txt').read().splitlines()):
    tf.print(n, line)
    direction, num = line.split(' ')
    # Rotate the kernel of the first layer according to the direction.
    angle = {'R': 0, 'U': 1, 'L': 2, 'D': 3}[direction]
    head_move.set_weights([tf.transpose(tf.image.rot90(head_WT, angle), (1, 2, 3, 0))])
    # Simulate the movement num times
    for i in range(int(num)):
        state = model(state)
# Count visited positions.
print('Ans:', int(tf.reduce_sum(1 - state[..., 0])))

Rotating the kernel feels a bit like cheating. The network can be modified to take an additional conditional input on the direction. Furthermore, the whole network to modified into an RNN cell layer and the simulation loop can be replaced by RNN inference on the sequence of directions. We keep the code in its current form for explaining the process of transforming a local pattern-matching problem into a convolutional neural network. The above-mentioned modification, though fairly standard, is still very interesting as it is applied to solve a non-standard problem.

Conclusion

You can see both solutions in folder 9 in the dedicated GitHub repository (in the 2022 folder): https://github.com/galeone/tf-aoc.

This article demonstrated how to solve this coding challenge in pure TensorFlow in two completely different ways: the first one, that models the problem as a standard programming problem resolved using only TensorFlow primitives, and the second one instead models the problem completely differently and uses the properties of the convolutions for solving the problem in a very cool way!

If you missed the article about the previous days’ solutions, here’s a handy list

For any feedback or comment, please use the Disqus form below - thanks!

Don't you want to miss the next article? Do you want to be kept updated?
Subscribe to the newsletter!

Related Posts

Building a RAG for tabular data in Go with PostgreSQL & Gemini

In this article we explore how to combine a large language model (LLM) with a relational database to allow users to ask questions about their data in a natural way. It demonstrates a Retrieval-Augmented Generation (RAG) system built with Go that utilizes PostgreSQL and pgvector for data storage and retrieval. The provided code showcases the core functionalities. This is an overview of how the "chat with your data" feature of fitsleepinsights.app is being developed.

Using Gemini in a Go application: limits and details

This article explores using Gemini within Go applications via Vertex AI. We'll delve into the limitations encountered, including the model's context window size and regional restrictions. We'll also explore various methods for feeding data to Gemini, highlighting the challenges faced due to these limitations. Finally, we'll briefly introduce RAG (Retrieval-Augmented Generation) as a potential solution, but leave its implementation details for future exploration.

Custom model training & deployment on Google Cloud using Vertex AI in Go

This article shows a different approach to solving the same problem presented in the article AutoML pipeline for tabular data on VertexAI in Go. This time, instead of relying on AutoML we will define the model and the training job ourselves. This is a more advanced usage that allows the experienced machine learning practitioner to have full control on the pipeline from the model definition to the hardware to use for training and deploying. At the end of the article, we will also see how to use the deployed model. All of this, in Go and with the help of Python and Docker for the custom training job definition.

Integrating third-party libraries as Unreal Engine plugins: solving the ABI compatibility issues on Linux when the source code is available

In this article, we will discuss the challenges and potential issues that may arise during the integration process of a third-party library when the source code is available. It will provide guidance on how to handle the compilation and linking of the third-party library, manage dependencies, and resolve compatibility issues. We'll realize a plugin for redis plus plus as a real use case scenario, and we'll see how tough can it be to correctly compile the library for Unreal Engine - we'll solve every problem step by step.

AutoML pipeline for tabular data on VertexAI in Go

In this article, we delve into the development and deployment of tabular models using VertexAI and AutoML with Go, showcasing the actual Go code and sharing insights gained through trial & error and extensive Google research to overcome documentation limitations.

Advent of Code 2022 in pure TensorFlow - Day 12

Solving problem 12 of the AoC 2022 in pure TensorFlow is a great exercise in graph theory and more specifically in using the Breadth-First Search (BFS) algorithm. This problem requires working with a grid of characters representing a graph, and the BFS algorithm allows us to traverse the graph in the most efficient way to solve the problem.

Advent of Code 2022 in pure TensorFlow - Day 11

In this article, we'll show how to solve problem 11 from the Advent of Code 2022 (AoC 2022) using TensorFlow. We'll first introduce the problem and then provide a detailed explanation of our TensorFlow solution. The problem at hand revolves around the interactions of multiple monkeys inspecting items, making decisions based on their worry levels, and following a set of rules.

Advent of Code 2022 in pure TensorFlow - Day 10

Solving problem 10 of the AoC 2022 in pure TensorFlow is an interesting challenge. This problem involves simulating a clock signal with varying frequencies and tracking the state of a signal-strength variable. TensorFlow's ability to handle complex data manipulations, control structures, and its @tf.function decorator for efficient execution makes it a fitting choice for tackling this problem. By utilizing TensorFlow's features such as Dataset transformations, efficient filtering, and tensor operations, we can create a clean and efficient solution to this intriguing puzzle.

Advent of Code 2022 in pure TensorFlow - Day 8

Solving problem 8 of the AoC 2022 in pure TensorFlow is straightforward. After all, this problem requires working on a bi-dimensional grid and evaluating conditions by rows or columns. TensorFlow is perfectly suited for this kind of task thanks to its native support for reduction operators (tf.reduce) which are the natural choice for solving problems of this type.

Advent of Code 2022 in pure TensorFlow - Day 7

Solving problem 7 of the AoC 2022 in pure TensorFlow allows us to understand certain limitations of the framework. This problem requires a lot of string manipulation, and TensorFlow (especially in graph mode) is not only not easy to use when working with this data type, but also it has a set of limitations I'll present in the article. Additionally, the strings to work with in problem 7 are (Unix) paths. TensorFlow has zero support for working with paths, and thus for simplifying a part of the solution, I resorted to the pathlib Python module, thus not designing a completely pure TensorFlow solution.