Advent of Code 2022 in pure TensorFlow - Day 10


Solving problem 10 of the AoC 2022 in pure TensorFlow is an interesting challenge. This problem involves simulating a clock signal with varying frequencies and tracking the state of a signal-strength variable. TensorFlow’s ability to handle complex data manipulations, control structures, and its @tf.function decorator for efficient execution makes it a fitting choice for tackling this problem. By utilizing TensorFlow’s features such as Dataset transformations, efficient filtering, and tensor operations, we can create a clean and efficient solution to this intriguing puzzle.

Day 10: Clock Signal

You can click on the title above to read the full text of the puzzle. The TLDR version is: the puzzle involves a series of instructions to update a clock signal’s strength. Each cycle, the clock signal’s strength X is updated based on a given list of instructions. The goal is to calculate the sum of the signal strength at specific cycles and visualize the clock signal’s behavior over a fixed number of cycles.

Parsing the input

First, let’s use tf.data.TextLineDataset to read the input file line by line:

dataset = tf.data.TextLineDataset(input_path.as_posix())

Now, split each line into a list of strings (the operation and the value):

dataset = dataset.map(lambda line: tf.strings.split(line, " "))

Then, we need to define a function opval to convert the string values into a tuple of (op, val). If the operation is "noop", the value will be set to 0.

@tf.function
def opval(pair):
    if tf.equal(tf.shape(pair)[0], 1):
        return pair[0], tf.constant(0, tf.int32)

    return pair[0], tf.strings.to_number(pair[1], tf.int32)

dataset = dataset.map(opval)

As usual, when working with a tf.data.Dataset the eager mode is disabled and everything runs in graph mode. That’s why we explicitly added the tf.function decorator on top of the opval function (although not required - but it helps to remember that we need to think in graph mode).

We’ll use a lookup table (tf.lookup.StaticHashTable) to map instruction strings to integer values. This allows us to work with numerical values, which is more convenient for processing in TensorFlow.

lut = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        tf.constant(["noop", "addx"]), tf.constant([0, 1])
    ),
    default_value=-1,
)

Next, we need to process the dataset so that each element represents a clock cycle. To do this, we define a prepend_noop function that inserts a "noop" instruction before each "addx" instruction. This will ensure that the dataset correctly represents the clock signal’s behavior.

@tf.function
def prepend_noop(op, val):
    if tf.equal(op, "noop"):
        return tf.stack([noop, invalid], axis=0)

    return tf.stack(
        [
            noop,
            tf.stack((lut.lookup(tf.expand_dims(op, axis=0))[0], val), axis=0),
        ],
        axis=0,
    )

dataset = (
    dataset.map(prepend_noop)
    .unbatch()
    .filter(lambda op_val: tf.not_equal(op_val[0], -1))  # remove invalid
    .map(lambda op_val: (op_val[0], op_val[1]))
)

Now that we have the dataset correctly formatted, we can proceed with simulating the clock signal’s behavior.

Simulating the clock signal

We’ll use a TensorFlow tf.Variable to keep track of the current cycle and the current signal strength X. Initialize these variables as follows:

cycle = tf.Variable(0, dtype=tf.int32)
X = tf.Variable(1, dtype=tf.int32)

To simulate the clock signal’s behavior, we define a clock function that processes each instruction in the dataset. This function updates the cycle and signal strength X variables accordingly.

prev_x = tf.Variable(X)

def clock(op, val):
    prev_x.assign(X)
    if tf.equal(op, noop_id):
        pass
    else:  # addx
        X.assign_add(val)

    cycle.assign_add(1)

    if tf.reduce_any([tf.equal(cycle, value) for value in range(20, 221, 40)]):
        return [cycle, prev_x, prev_x * cycle]
    return [cycle, prev_x, -1]

Next, we’ll create a dataset of signal strength values at the specific cycles requested in the problem (i.e., every 40 cycles between 20 and 220 inclusive). We do this by mapping the clock function to the dataset and filtering out the elements with a signal strength value of -1.

strenghts_dataset = dataset.map(clock).filter(
    lambda c, x, strenght: tf.not_equal(strenght, -1)
)

strenghts = tf.convert_to_tensor((list(strenghts_dataset.as_numpy_iterator())))

Now, we can calculate the sum of the six signal strength values:

sumsix = tf.reduce_sum(strenghts[:, -1])
tf.print("Sum of six signal strenght: ", sumsix)

In the provided solution, we used the @tf.function decorator in the opval and prepend_noop methods. This powerful feature of TensorFlow enables automatic conversion of a Python function into a TensorFlow graph. The main benefits of this conversion are performance improvements and better compatibility with TensorFlow operations.

By converting a function into a TensorFlow graph, we allow TensorFlow to optimize the computation by fusing operations and running them more efficiently. This can lead to significant speed improvements, especially for functions that are called repeatedly, as in our case when processing the input dataset.

Part one solved!

Visualizing the clock signal

The second part of the puzzle asks us to visualize the clock signal’s behavior over a fixed number of cycles. For doing it we’ll create a tf.Variable to store the clock signal visualization in a 2D grid. Initialize this variable as follows:

crt = tf.Variable(tf.zeros((6, 40, 1), tf.string))

Next, we’ll define a clock2 function to update the visualization grid based on the clock signal’s behavior. This function modifies the grid at each cycle according to the current signal strength.

row = tf.Variable(0, dtype=tf.int32)

def clock2(op, val):
    prev_x.assign(X)
    if tf.equal(op, noop_id):
        pass
    else:  # addx
        X.assign_add(val)

    modcycle = tf.math.mod(cycle, 40)
    if tf.reduce_any(
        [
            tf.equal(modcycle, prev_x),
            tf.equal(modcycle, prev_x - 1),
            tf.equal(modcycle, prev_x + 1),
        ]
    ):
        crt.assign(
            tf.tensor_scatter_nd_update(
                crt, [[row, tf.math.mod(cycle, 40)]], [["#"]]
            )
        )
    else:
        crt.assign(
            tf.tensor_scatter_nd_update(
                crt, [[row, tf.math.mod(cycle, 40)]], [["."]]
            )
        )

    cycle.assign_add(1)

    if tf.equal(tf.math.mod(cycle, 40), 0):
        row.assign_add(1)
    return ""

Finally, we map the clock2 function to the dataset, and then we can print the resulting visualization:

list(dataset.map(clock2).as_numpy_iterator())

tf.print(tf.squeeze(crt), summarize=-1)

Sqeezing the unary dimensions is necessary to correctly display the 2D grid without square brackets everywhere. The summarize=-1 paramete of tf.print disables the standard behavior of printing only some part of the content and adding ... in between. In this way, we can see directly in the terminal the letters.

Part 2 solved!

Conclusion

You can the solution in folder 10 in the dedicated GitHub repository (in the 2022 folder): https://github.com/galeone/tf-aoc.

This TensorFlow-based solution demonstrates the power and flexibility of the TensorFlow library, allowing us to efficiently solve the AoC 2022 problem 10. We used various TensorFlow operations, data structures, and functions to parse the input, simulate the clock signal’s behavior, and visualize the clock signal’s behavior over a fixed number of cycles. This approach showcases how TensorFlow can be utilized beyond its primary use case of deep learning, and can be employed to solve a wide range of computational problems.

By employing TensorFlow’s built-in operations and data structures, we were able to efficiently process the input data, handle branching logic, and maintain state throughout the simulation. The final visualization provides a clear representation of the clock signal’s behavior, and the sum of the six signal strengths is the solution to the problem.

As the AoC 2022 puzzles continue to challenge participants with new and diverse problems, this solution demonstrates that TensorFlow can be a powerful tool in the problem solver’s toolkit. It serves as an example of how TensorFlow’s flexibility extends beyond deep learning applications and can be effectively used to tackle complex problems in a variety of domains.

If you missed the article about the previous days’ solutions, here’s a handy list

For any feedback or comment, please use the Disqus form below - thanks!

Don't you want to miss the next article? Do you want to be kept updated?
Subscribe to the newsletter!

Related Posts

Building a RAG for tabular data in Go with PostgreSQL & Gemini

In this article we explore how to combine a large language model (LLM) with a relational database to allow users to ask questions about their data in a natural way. It demonstrates a Retrieval-Augmented Generation (RAG) system built with Go that utilizes PostgreSQL and pgvector for data storage and retrieval. The provided code showcases the core functionalities. This is an overview of how the "chat with your data" feature of fitsleepinsights.app is being developed.

Using Gemini in a Go application: limits and details

This article explores using Gemini within Go applications via Vertex AI. We'll delve into the limitations encountered, including the model's context window size and regional restrictions. We'll also explore various methods for feeding data to Gemini, highlighting the challenges faced due to these limitations. Finally, we'll briefly introduce RAG (Retrieval-Augmented Generation) as a potential solution, but leave its implementation details for future exploration.

Custom model training & deployment on Google Cloud using Vertex AI in Go

This article shows a different approach to solving the same problem presented in the article AutoML pipeline for tabular data on VertexAI in Go. This time, instead of relying on AutoML we will define the model and the training job ourselves. This is a more advanced usage that allows the experienced machine learning practitioner to have full control on the pipeline from the model definition to the hardware to use for training and deploying. At the end of the article, we will also see how to use the deployed model. All of this, in Go and with the help of Python and Docker for the custom training job definition.

Integrating third-party libraries as Unreal Engine plugins: solving the ABI compatibility issues on Linux when the source code is available

In this article, we will discuss the challenges and potential issues that may arise during the integration process of a third-party library when the source code is available. It will provide guidance on how to handle the compilation and linking of the third-party library, manage dependencies, and resolve compatibility issues. We'll realize a plugin for redis plus plus as a real use case scenario, and we'll see how tough can it be to correctly compile the library for Unreal Engine - we'll solve every problem step by step.

AutoML pipeline for tabular data on VertexAI in Go

In this article, we delve into the development and deployment of tabular models using VertexAI and AutoML with Go, showcasing the actual Go code and sharing insights gained through trial & error and extensive Google research to overcome documentation limitations.

Advent of Code 2022 in pure TensorFlow - Day 12

Solving problem 12 of the AoC 2022 in pure TensorFlow is a great exercise in graph theory and more specifically in using the Breadth-First Search (BFS) algorithm. This problem requires working with a grid of characters representing a graph, and the BFS algorithm allows us to traverse the graph in the most efficient way to solve the problem.

Advent of Code 2022 in pure TensorFlow - Day 11

In this article, we'll show how to solve problem 11 from the Advent of Code 2022 (AoC 2022) using TensorFlow. We'll first introduce the problem and then provide a detailed explanation of our TensorFlow solution. The problem at hand revolves around the interactions of multiple monkeys inspecting items, making decisions based on their worry levels, and following a set of rules.

Advent of Code 2022 in pure TensorFlow - Day 9

In this article, we'll show two different solutions to the Advent of Code 2022 day 9 problem. Both of them are purely TensorFlow solutions. The first one, more traditional, just implement a solution algorithm using only TensorFlow's primitive operations - of course, due to some TensorFlow limitations this solution will contain some details worth reading (e.g. using a pairing function for being able to use n-dimensional tf.Tensor as keys for a mutable hashmap). The second one, instead, demonstrates how a different interpretation of the problem paves the way to completely different solutions. In particular, this solution is Keras based and uses a multi-layer convolutional model for modeling the rope movements.

Advent of Code 2022 in pure TensorFlow - Day 8

Solving problem 8 of the AoC 2022 in pure TensorFlow is straightforward. After all, this problem requires working on a bi-dimensional grid and evaluating conditions by rows or columns. TensorFlow is perfectly suited for this kind of task thanks to its native support for reduction operators (tf.reduce) which are the natural choice for solving problems of this type.

Advent of Code 2022 in pure TensorFlow - Day 7

Solving problem 7 of the AoC 2022 in pure TensorFlow allows us to understand certain limitations of the framework. This problem requires a lot of string manipulation, and TensorFlow (especially in graph mode) is not only not easy to use when working with this data type, but also it has a set of limitations I'll present in the article. Additionally, the strings to work with in problem 7 are (Unix) paths. TensorFlow has zero support for working with paths, and thus for simplifying a part of the solution, I resorted to the pathlib Python module, thus not designing a completely pure TensorFlow solution.