Building an AI That Masters Snake — A Deep Reinforcement Learning Project from Scratch
A neural network, a shaped reward signal, and a lot of virtual trial and error. How DeepSnake learned to average 44 points, peak at 75, and never read a rulebook.
Imagine dropping a toddler into the middle of a board game with no rulebook. No teacher. No examples. Just the raw experience of "that worked" and "that didn't." Now imagine that toddler learns to play better than most humans — in a matter of minutes. That's essentially what Deep Reinforcement Learning does, and it's one of the most fascinating corners of modern AI.
The challenge sounds deceptively simple: teach an AI to play Snake, the classic grid-based game where a serpent chases food and grows longer while trying not to crash into itself or the walls. Simple rules, sure — but the decision space explodes as the snake grows. Every single frame demands a choice: go straight, turn left, or turn right. One bad move, and it's game over.
DeepSnake is an open-source project that tackles exactly this problem, and the results speak for themselves: an agent that consistently scores 30+ points and averages nearly 44 points over 200 games, with a personal best of 75. No hand-coded strategy. No search algorithms. Just a neural network, a reward signal, and a lot of virtual trial and error.
Before this. Classical reinforcement learning kept its values in a table — one row per state, one column per action — and updated cells with the Bellman equation. That works for tic-tac-toe; it falls apart the moment the state is anything richer, because the table grows combinatorially and the agent has no way to generalize between similar situations. The 2013 DeepMind paper Playing Atari with Deep Reinforcement Learning replaced the table with a neural network and trained it on raw pixels — same update rule, learned weights instead of cell assignments. That single substitution gave RL the ability to handle any state you can encode as a vector. DeepSnake is a small, readable instance of that recipe.
Let's break down how it works.
What Is DeepSnake, and Why Reinforcement Learning?
DeepSnake is a Python project that trains an AI agent to play Snake using a technique called Deep Q-Learning (DQN). The entire project is built on PyTorch for the neural network and Pygame for visualization, making it both a practical ML exercise and a satisfying thing to watch.
So why Reinforcement Learning instead of, say, supervised learning or a scripted pathfinding algorithm?
The answer comes down to the nature of the problem. Supervised learning requires labeled training data — thousands of examples of "given this board state, the correct move is X." For a game as dynamic as Snake, generating that dataset would be impractical and brittle. A hard-coded algorithm like A* pathfinding could work for small boards, but it doesn't learn or adapt, and it breaks down as the snake's body creates increasingly complex obstacles.
Reinforcement Learning is the natural fit because the Snake problem has exactly the structure RL was designed for: an agent (the snake) takes actions (turn left, turn right, go straight) in an environment (the game grid), and receives rewards (positive for eating food, negative for dying). The agent's only job is to figure out which actions maximize its cumulative reward over time. No labels required — just experience.
- State (s) — what the agent sees right now. For DeepSnake, a 24-number summary of the board.
- Action (a) — what the agent does. Here: straight, turn right, or turn left.
- Reward (r) — a single number the environment hands back after each step. Positive = good, negative = bad.
- Q-value Q(s, a) — the expected sum of all future rewards if the agent takes action a from state s and plays well after. Bigger Q = better move.
- DQN — a neural network that approximates Q(s, a). Input: state vector. Output: one Q-value per action.
- ε-greedy — pick the highest-Q action most of the time, but with probability ε pick a random action. The random fraction is how the agent discovers things it hasn't tried.
- Replay buffer — a rolling log of past (state, action, reward, next state) transitions. Training samples random minibatches from this log so the network doesn't overfit to whatever just happened.
The Environment: A 20×20 Grid World
The Snake game environment is a 20×20 grid. The snake starts at the center of the board with a length of 3, moving to the right. Food spawns at a random unoccupied cell. The game ends when the snake collides with a wall or its own body.
What makes the environment design clever is how actions are defined. Instead of absolute directions (up, down, left, right), the agent chooses from relative actions: go straight, turn right, or turn left. This simplifies the decision space from four options to three and removes the problem of the agent choosing to reverse directly into itself — a bug you'd otherwise have to code around explicitly.
The Agent's Brain: A Deep Q-Network
At the heart of DeepSnake is a Deep Q-Network (DQN) — a neural network that learns to estimate the expected future reward (called the Q-value) for each possible action given the current state of the game.
The network architecture is compact and efficient:
# snake_dqn/model.py (simplified)
class DQN(nn.Module):
def __init__(self, state_size=24, action_size=3):
super().__init__()
self.fc1 = nn.Linear(state_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, action_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
Those 24 input features are where a lot of the design work lives. Rather than feeding the entire raw grid to the network (which would be computationally expensive and slow to learn), the state is encoded as a carefully chosen set of features:
- Immediate danger in three directions (straight, right, left) — is there a wall or body segment one step away?
- Current direction as a one-hot vector (4 values).
- Food direction — four binary values indicating whether the food is to the north, south, east, or west.
- Normalized head position (x, y) and food offset (dx, dy).
- Snake length, normalized.
- Distance to the nearest obstacle in 8 directions, calculated via raycasting.
This compact representation gives the network everything it needs to make good decisions without drowning it in irrelevant data. The three outputs correspond to Q-values for each action (straight, right, left). During gameplay, the agent picks the action with the highest Q-value.
In tabular Q-learning this update overwrites a cell. In a DQN the same right-hand side becomes a target; the network is trained — by gradient descent — to make its prediction Q(s, a) match it. That target is sometimes called the Bellman target, and the gap between prediction and target is the TD error — the signal that drives every weight update.
The Reward System: How the AI Learns Right from Wrong
The reward signal is the only feedback the agent gets, so its design is critical. DeepSnake uses a shaped reward system:
| Event | Reward |
|---|---|
| Eats food | +10 |
| Dies (wall or self-collision) | −10 |
| Moves closer to food | +1 |
| Moves farther from food | −1 |
The +10 / −10 for eating and dying are straightforward. The interesting addition is the +1 / −1 shaping reward for moving closer to or farther from the food.
Why bother? Imagine the reward signal were only +10 for food and −10 for death. The agent starts as a random policy. On a 20×20 grid that's 400 cells, with one food cell. The probability of stumbling into food in any given step is roughly 1/400; meanwhile the snake usually crashes within a few dozen steps. That means the agent dies hundreds of times — collecting nothing but −10 signals — before it ever tastes a positive reward. With nothing to imitate, the network just learns "everything is bad" and never finds the gradient pointing toward food. This is the sparse-reward problem, and it's the single most common reason RL projects stall.
The +1 / −1 distance shaping fixes this by handing the agent a small, dense signal on every single step. Every move toward the food is worth a tenth of a meal; every move away costs the same. The agent can now learn "go toward food" long before it has ever eaten one — and once it does eat, the +10 reinforces what was already a sensible policy. The proximity reward isn't strictly necessary in theory; in practice, it cuts training time from "maybe never" to "a couple thousand episodes."
DQN Improvements: Beyond Vanilla Q-Learning
DeepSnake doesn't just implement a basic DQN — it incorporates several important improvements that stabilize and accelerate learning:
- Double DQN reduces Q-value overestimation, a common problem where the network becomes overly optimistic about certain actions (van Hasselt et al., 2016).
- Huber Loss (SmoothL1) replaces standard Mean Squared Error for more stable training gradients when the temporal-difference error is large.
- Soft target updates (τ = 0.005) gradually blend the target network's weights rather than copying them wholesale, producing smoother learning curves.
- Gradient clipping (max norm = 1.0) prevents exploding gradients from destabilizing the network.
- Epsilon-greedy exploration decays from 1.0 to 0.01, starting with nearly random actions and gradually shifting to exploitation as the agent becomes more confident.
One step of training, end to end
It's worth tracing what actually happens between the moment the agent observes a state and the moment its weights are updated. Every game step is also a training step:
The two pieces that do the actual work are the replay buffer sample and the loss computation. They're short:
# snake_dqn/agent.py — one training step (annotated)
def learn(self):
# 1. Pull a random minibatch of past experiences.
# Random sampling breaks the temporal correlation between consecutive frames,
# which would otherwise destabilize gradient descent.
batch = self.replay.sample(64)
s, a, r, s_next, done = batch
# 2. What the network currently thinks Q(s, a) is, for the action we took.
q_pred = self.online(s).gather(1, a.unsqueeze(1)).squeeze(1)
# 3. The Bellman target — what Q(s, a) *should* be, given what happened next.
# Double-DQN trick: pick the next action with the online net, evaluate it
# with the target net. Reduces the overestimation bias of vanilla DQN.
with torch.no_grad():
a_next = self.online(s_next).argmax(dim=1)
q_next = self.target(s_next).gather(1, a_next.unsqueeze(1)).squeeze(1)
y = r + self.gamma * q_next * (1 - done) # zero out future term on terminal steps
# 4. Huber loss is MSE-like for small errors, MAE-like for big ones — robust to outliers.
loss = F.smooth_l1_loss(q_pred, y)
# 5. Standard PyTorch update, with gradient clipping to keep things stable.
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.online.parameters(), max_norm=1.0)
self.optimizer.step()
# 6. Soft-update the target network: a slow exponential moving average of the online weights.
for p_t, p_o in zip(self.target.parameters(), self.online.parameters()):
p_t.data.copy_(self.tau * p_o.data + (1 - self.tau) * p_t.data)
That's the entire learning algorithm. Six numbered steps, repeated tens of thousands of times.
These are all well-established techniques from the RL literature, and seeing them applied together in a single, readable codebase is one of the most valuable aspects of this project for anyone learning the field.
The Tech Stack
DeepSnake is built entirely in Python with a focused set of dependencies:
- Python 3.10+ — the runtime.
- PyTorch — powers the DQN neural network and the entire training loop.
- Pygame — provides the visual rendering so you can watch the trained agent play in real time.
- NumPy — handles numerical operations and array manipulation.
- Matplotlib — generates the training performance plots.
No bloated frameworks or unnecessary dependencies. Everything you need fits in a single pip install command.
How to Get Started
Getting DeepSnake running on your machine takes about two minutes.
Step 1 — Clone the repository:
git clone https://github.com/josephgec/deepsnake.git
cd deepsnake
Step 2 — Install dependencies:
pip install torch numpy pygame matplotlib
Step 3 — Watch the pre-trained agent play:
cd snake_dqn
python play.py
Use SPACE to pause/unpause, R to reset, and Q to quit.
Step 4 (optional) — Train from scratch:
python train.py
# 2,000 episodes with automatic checkpointing
python plot_training.py # render the score curve
The project structure is clean and modular — snake_env.py handles the game logic, model.py defines the neural network, agent.py manages the DQN training logic, and train.py ties everything together. Each file has a clear, single responsibility, making it easy to read, modify, and experiment with.
Why This Project Is Worth Your Time
DeepSnake sits in a sweet spot that's rare in ML education projects. It's complex enough to demonstrate real RL concepts — experience replay, target networks, epsilon decay, reward shaping — but simple enough that you can read the entire codebase in an afternoon and understand every line. The Snake game provides an intuitive, visual feedback loop that makes the learning process tangible in a way that abstract environments don't.
If you're an intermediate developer looking to move from "I've read about Reinforcement Learning" to "I've actually built something with it," this is a solid starting point.
What to Try Next
A few experiments to deepen your understanding once you've got the project running:
- Tweak the reward system. What happens if you remove the proximity reward (+1 / −1) and only keep the +10 / −10? How much slower does the agent learn?
- Change the network architecture. Add a layer, double the hidden units, or try a different activation function. Does performance improve, or does the extra capacity just memorize noise?
- Modify the grid size. A 10×10 grid is a very different challenge than a 30×30 grid. How does the agent adapt?
- Swap relative actions for absolute actions. Does the agent still learn effectively with four action choices instead of three, or does the self-reversal problem come back?
Get Involved
If you find DeepSnake useful or interesting, head over to the GitHub repository, give it a star, and start experimenting. Fork it, break it, improve it — that's how the best learning happens.
Deep Reinforcement Learning is one of the most exciting subfields in AI, and there's no better way to understand it than by watching a neural network teach itself to chase pixels around a grid. Happy training.
References
The core ideas behind DeepSnake, and where to read more.
Foundational DQN papers
- Mnih et al. Playing Atari with Deep Reinforcement Learning. NeurIPS Workshop 2013. arXiv:1312.5602
- Mnih et al. Human-level control through deep reinforcement learning. Nature 2015. nature.com
- van Hasselt, Guez & Silver. Deep Reinforcement Learning with Double Q-learning. AAAI 2016. arXiv:1509.06461
- Schaul et al. Prioritized Experience Replay. ICLR 2016. arXiv:1511.05952
Learning materials
- Sutton & Barto. Reinforcement Learning: An Introduction (2nd ed.). MIT Press, 2018. incompleteideas.net
- OpenAI. Spinning Up in Deep RL. spinningup.openai.com
- PyTorch. Reinforcement Learning (DQN) Tutorial. pytorch.org
Project source
- DeepSnake on GitHub — github.com/josephgec/deepsnake