Back to Writing

Building an AI That Masters Snake — A Deep Reinforcement Learning Project from Scratch

A neural network, a shaped reward signal, and a lot of virtual trial and error. How DeepSnake learned to average 44 points, peak at 75, and never read a rulebook.

Imagine dropping a toddler into the middle of a board game with no rulebook. No teacher. No examples. Just the raw experience of "that worked" and "that didn't." Now imagine that toddler learns to play better than most humans — in a matter of minutes. That's essentially what Deep Reinforcement Learning does, and it's one of the most fascinating corners of modern AI.

The challenge sounds deceptively simple: teach an AI to play Snake, the classic grid-based game where a serpent chases food and grows longer while trying not to crash into itself or the walls. Simple rules, sure — but the decision space explodes as the snake grows. Every single frame demands a choice: go straight, turn left, or turn right. One bad move, and it's game over.

DeepSnake is an open-source project that tackles exactly this problem, and the results speak for themselves: an agent that consistently scores 30+ points and averages nearly 44 points over 200 games, with a personal best of 75. No hand-coded strategy. No search algorithms. Just a neural network, a reward signal, and a lot of virtual trial and error.

LEGEND Head Body Food Empty 20 × 20 grid · 3 relative actions per step
A 20×20 grid world. The agent picks one of three relative actions every frame: straight, turn right, or turn left.

Let's break down how it works.

What Is DeepSnake, and Why Reinforcement Learning?

DeepSnake is a Python project that trains an AI agent to play Snake using a technique called Deep Q-Learning (DQN). The entire project is built on PyTorch for the neural network and Pygame for visualization, making it both a practical ML exercise and a satisfying thing to watch.

So why Reinforcement Learning instead of, say, supervised learning or a scripted pathfinding algorithm?

The answer comes down to the nature of the problem. Supervised learning requires labeled training data — thousands of examples of "given this board state, the correct move is X." For a game as dynamic as Snake, generating that dataset would be impractical and brittle. A hard-coded algorithm like A* pathfinding could work for small boards, but it doesn't learn or adapt, and it breaks down as the snake's body creates increasingly complex obstacles.

Reinforcement Learning is the natural fit because the Snake problem has exactly the structure RL was designed for: an agent (the snake) takes actions (turn left, turn right, go straight) in an environment (the game grid), and receives rewards (positive for eating food, negative for dying). The agent's only job is to figure out which actions maximize its cumulative reward over time. No labels required — just experience.

The Environment: A 20×20 Grid World

The Snake game environment is a 20×20 grid. The snake starts at the center of the board with a length of 3, moving to the right. Food spawns at a random unoccupied cell. The game ends when the snake collides with a wall or its own body.

What makes the environment design clever is how actions are defined. Instead of absolute directions (up, down, left, right), the agent chooses from relative actions: go straight, turn right, or turn left. This simplifies the decision space from four options to three and removes the problem of the agent choosing to reverse directly into itself — a bug you'd otherwise have to code around explicitly.

The Agent's Brain: A Deep Q-Network

At the heart of DeepSnake is a Deep Q-Network (DQN) — a neural network that learns to estimate the expected future reward (called the Q-value) for each possible action given the current state of the game.

The network architecture is compact and efficient:

# snake_dqn/model.py (simplified)
class DQN(nn.Module):
    def __init__(self, state_size=24, action_size=3):
        super().__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
INPUT 24 features HIDDEN 1 Linear(256) + ReLU HIDDEN 2 Linear(128) + ReLU OUTPUT 3 Q-values Straight Right Left argmax(Q-values) → chosen action
A compact feedforward network. Input features go in, a Q-value comes out for each of the three actions, and the agent picks the highest.

Those 24 input features are where a lot of the design work lives. Rather than feeding the entire raw grid to the network (which would be computationally expensive and slow to learn), the state is encoded as a carefully chosen set of features:

This compact representation gives the network everything it needs to make good decisions without drowning it in irrelevant data. The three outputs correspond to Q-values for each action (straight, right, left). During gameplay, the agent picks the action with the highest Q-value.

The Reward System: How the AI Learns Right from Wrong

The reward signal is the only feedback the agent gets, so its design is critical. DeepSnake uses a shaped reward system:

EventReward
Eats food+10
Dies (wall or self-collision)−10
Moves closer to food+1
Moves farther from food−1

The +10 / −10 for eating and dying are straightforward. The interesting addition is the +1 / −1 shaping reward for moving closer to or farther from the food. Without this, the agent would receive very sparse feedback — it might wander aimlessly for hundreds of steps before stumbling onto food or dying, making learning painfully slow. The proximity reward gives the agent a continuous gradient to follow, dramatically accelerating training.

DQN Improvements: Beyond Vanilla Q-Learning

DeepSnake doesn't just implement a basic DQN — it incorporates several important improvements that stabilize and accelerate learning:

80 60 40 20 0 score 0 500 1000 1500 2000 episode peak: 75 ~44 avg Score per episode — 2000 episodes
A rough sketch of a training run. Score climbs fast through the first few hundred episodes as epsilon decays, then the rolling average settles in the mid-forties.

These are all well-established techniques from the RL literature, and seeing them applied together in a single, readable codebase is one of the most valuable aspects of this project for anyone learning the field.

The Tech Stack

DeepSnake is built entirely in Python with a focused set of dependencies:

No bloated frameworks or unnecessary dependencies. Everything you need fits in a single pip install command.

How to Get Started

Getting DeepSnake running on your machine takes about two minutes.

Step 1 — Clone the repository:

git clone https://github.com/josephgec/deepsnake.git
cd deepsnake

Step 2 — Install dependencies:

pip install torch numpy pygame matplotlib

Step 3 — Watch the pre-trained agent play:

cd snake_dqn
python play.py

Use SPACE to pause/unpause, R to reset, and Q to quit.

Step 4 (optional) — Train from scratch:

python train.py
# 2,000 episodes with automatic checkpointing
python plot_training.py  # render the score curve

The project structure is clean and modular — snake_env.py handles the game logic, model.py defines the neural network, agent.py manages the DQN training logic, and train.py ties everything together. Each file has a clear, single responsibility, making it easy to read, modify, and experiment with.

Why This Project Is Worth Your Time

DeepSnake sits in a sweet spot that's rare in ML education projects. It's complex enough to demonstrate real RL concepts — experience replay, target networks, epsilon decay, reward shaping — but simple enough that you can read the entire codebase in an afternoon and understand every line. The Snake game provides an intuitive, visual feedback loop that makes the learning process tangible in a way that abstract environments don't.

If you're an intermediate developer looking to move from "I've read about Reinforcement Learning" to "I've actually built something with it," this is a solid starting point.

What to Try Next

A few experiments to deepen your understanding once you've got the project running:

  1. Tweak the reward system. What happens if you remove the proximity reward (+1 / −1) and only keep the +10 / −10? How much slower does the agent learn?
  2. Change the network architecture. Add a layer, double the hidden units, or try a different activation function. Does performance improve, or does the extra capacity just memorize noise?
  3. Modify the grid size. A 10×10 grid is a very different challenge than a 30×30 grid. How does the agent adapt?
  4. Swap relative actions for absolute actions. Does the agent still learn effectively with four action choices instead of three, or does the self-reversal problem come back?

Get Involved

If you find DeepSnake useful or interesting, head over to the GitHub repository, give it a star, and start experimenting. Fork it, break it, improve it — that's how the best learning happens.

Deep Reinforcement Learning is one of the most exciting subfields in AI, and there's no better way to understand it than by watching a neural network teach itself to chase pixels around a grid. Happy training.

References

The core ideas behind DeepSnake, and where to read more.

Foundational DQN papers

Learning materials

Project source