Compute Shader Tutorial
Using the compute shader, you can use the GPU to perform calculations thousands of times faster than just by using the CPU.
In this example, we will simulate a star field using an ‘N-Body simulation’. Each star is effected by each other star’s gravity. For 1,000 stars, this means we have 1,000 x 1,000 = 1,000,000 million calculations to perform for each frame. The video has 65,000 stars, requiring 4.2 billion gravity force calculations per frame. On high-end hardware it can still run at 60 fps!
How does this work? There are three major parts to this program:
The Python code, this glues everything together.
The visualization shaders, which let us see the data.
The compute shader, which moves everything.
Visualization Shaders
There are multiple visualization shaders, which operate in this order:
The Python program creates a shader storage buffer object (SSBO) of
floating point numbers. This buffer
has the x, y, z and radius of each star stored in in_vertex
. It also
stores the color in in_color
.
The vertex shader doesn’t do much more than separate out the radius variable from the group of floats used to store position.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | #version 330
in vec4 in_vertex;
in vec4 in_color;
out vec2 vertex_pos;
out float vertex_radius;
out vec4 vertex_color;
void main()
{
vertex_pos = in_vertex.xy;
vertex_radius = in_vertex.w;
vertex_color = in_color;
}
|
The geometry shader converts the single point (which we can’t render) to a square, which we can render. It changes the one point, to four points of a quad.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | #version 330
layout (points) in;
layout (triangle_strip, max_vertices = 4) out;
// Use arcade's global projection UBO
uniform Projection {
uniform mat4 matrix;
} proj;
in vec2 vertex_pos[];
in vec4 vertex_color[];
in float vertex_radius[];
out vec2 g_uv;
out vec3 g_color;
void main() {
vec2 center = vertex_pos[0];
vec2 hsize = vec2(vertex_radius[0]);
g_color = vertex_color[0].rgb;
gl_Position = proj.matrix * vec4(vec2(-hsize.x, hsize.y) + center, 0.0, 1.0);
g_uv = vec2(0, 1);
EmitVertex();
gl_Position = proj.matrix * vec4(vec2(-hsize.x, -hsize.y) + center, 0.0, 1.0);
g_uv = vec2(0, 0);
EmitVertex();
gl_Position = proj.matrix * vec4(vec2(hsize.x, hsize.y) + center, 0.0, 1.0);
g_uv = vec2(1, 1);
EmitVertex();
gl_Position = proj.matrix * vec4(vec2(hsize.x, -hsize.y) + center, 0.0, 1.0);
g_uv = vec2(1, 0);
EmitVertex();
EndPrimitive();
}
|
The fragment shader runs for each pixel. It produces the soft glow effect of the star, and rounds off the quad into a circle.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | #version 330
in vec2 g_uv;
in vec3 g_color;
out vec4 out_color;
void main()
{
float l = length(vec2(0.5, 0.5) - g_uv.xy);
if ( l > 0.5)
{
discard;
}
float alpha;
if (l == 0.0)
alpha = 1.0;
else
alpha = min(1.0, .60-l * 2);
vec3 c = g_color.rgb;
// c.xy += v_uv.xy * 0.05;
// c.xy += v_pos.xy * 0.75;
out_color = vec4(c, alpha);
}
|
Compute Shaders
This program runs two buffers. We have an input buffer, with all our current data. We perform calculations on that data and write to the output buffer. We then swap those buffers for the next frame, where we use the output of the previous frame as the input to the next frame.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #version 430
// Set up our compute groups
layout(local_size_x=COMPUTE_SIZE_X, local_size_y=COMPUTE_SIZE_Y) in;
// Input uniforms go here if you need them.
// Some examples:
//uniform vec2 screen_size;
//uniform vec2 force;
//uniform float frame_time;
// Structure of the ball data
struct Ball
{
vec4 pos;
vec4 vel;
vec4 color;
};
// Input buffer
layout(std430, binding=0) buffer balls_in
{
Ball balls[];
} In;
// Output buffer
layout(std430, binding=1) buffer balls_out
{
Ball balls[];
} Out;
void main()
{
int curBallIndex = int(gl_GlobalInvocationID);
Ball in_ball = In.balls[curBallIndex];
vec4 p = in_ball.pos.xyzw;
vec4 v = in_ball.vel.xyzw;
// Move the ball according to the current force
p.xy += v.xy;
// Calculate the new force based on all the other bodies
for (int i=0; i < In.balls.length(); i++) {
// If enabled, this will keep the star from calculating gravity on itself
// However, it does slow down the calcluations do do this check.
// if (i == x)
// continue;
// Calculate distance squared
float dist = distance(In.balls[i].pos.xyzw.xy, p.xy);
float distanceSquared = dist * dist;
// If stars get too close the fling into never-never land.
// So use a minimum distance
float minDistance = 0.02;
float gravityStrength = 0.3;
float simulationSpeed = 0.002;
float force = min(minDistance, gravityStrength / distanceSquared) * -simulationSpeed;
vec2 diff = p.xy - In.balls[i].pos.xyzw.xy;
// We should normalize this I think, but it doesn't work.
// diff = normalize(diff);
vec2 delta_v = diff * force;
v.xy += delta_v;
}
Ball out_ball;
out_ball.pos.xyzw = p.xyzw;
out_ball.vel.xyzw = v.xyzw;
vec4 c = in_ball.color.xyzw;
out_ball.color.xyzw = c.xyzw;
Out.balls[curBallIndex] = out_ball;
}
|
Python Program
Read through the code here, I’ve tried hard to explain all the parts in the comments.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
Compute shader with buffers
"""
import random
from array import array
import arcade
from arcade.gl import BufferDescription
# Window dimensions
WINDOW_WIDTH = 2300
WINDOW_HEIGHT = 1300
# Size of performance graphs
GRAPH_WIDTH = 200
GRAPH_HEIGHT = 120
GRAPH_MARGIN = 5
class MyWindow(arcade.Window):
def __init__(self):
# Call parent constructor
# Ask for OpenGL 4.3 context, as we need that for compute shader support.
super().__init__(WINDOW_WIDTH, WINDOW_HEIGHT,
"Compute Shader",
gl_version=(4, 3),
resizable=True)
self.center_window()
# --- Class instance variables
# Number of balls to move
self.num_balls = 40000
# This has something to do with how we break the calculations up
# and parallelize them.
self.group_x = 256
self.group_y = 1
# --- Create buffers
# Format of the buffer data.
# 4f = position and size -> x, y, z, radius
# 4x4 = Four floats used for calculating velocity. Not needed for visualization.
# 4f = color -> rgba
buffer_format = "4f 4x4 4f"
# Generate the initial data that we will put in buffer 1.
initial_data = self.gen_initial_data()
# Create data buffers for the compute shader
# We ping-pong render between these two buffers
# ssbo = shader storage buffer object
self.ssbo_1 = self.ctx.buffer(data=array('f', initial_data))
self.ssbo_2 = self.ctx.buffer(reserve=self.ssbo_1.size)
# Attribute variable names for the vertex shader
attributes = ["in_vertex", "in_color"]
self.vao_1 = self.ctx.geometry(
[BufferDescription(self.ssbo_1, buffer_format, attributes)],
mode=self.ctx.POINTS,
)
self.vao_2 = self.ctx.geometry(
[BufferDescription(self.ssbo_2, buffer_format, attributes)],
mode=self.ctx.POINTS,
)
# --- Create shaders
# Load in the shader source code
file = open("shaders/compute_shader.glsl")
compute_shader_source = file.read()
file = open("shaders/vertex_shader.glsl")
vertex_shader_source = file.read()
file = open("shaders/fragment_shader.glsl")
fragment_shader_source = file.read()
file = open("shaders/geometry_shader.glsl")
geometry_shader_source = file.read()
# Create our compute shader.
# Search/replace to set up our compute groups
compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_X",
str(self.group_x))
compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_Y",
str(self.group_y))
self.compute_shader = self.ctx.compute_shader(source=compute_shader_source)
# Program for visualizing the balls
self.program = self.ctx.program(
vertex_shader=vertex_shader_source,
geometry_shader=geometry_shader_source,
fragment_shader=fragment_shader_source,
)
# --- Create FPS graph
# Enable timings for the performance graph
arcade.enable_timings()
# Create a sprite list to put the performance graph into
self.perf_graph_list = arcade.SpriteList()
# Create the FPS performance graph
graph = arcade.PerfGraph(GRAPH_WIDTH, GRAPH_HEIGHT, graph_data="FPS")
graph.center_x = GRAPH_WIDTH / 2
graph.center_y = self.height - GRAPH_HEIGHT / 2
self.perf_graph_list.append(graph)
def on_draw(self):
# Clear the screen
self.clear()
# Enable blending so our alpha channel works
self.ctx.enable(self.ctx.BLEND)
# Bind buffers
self.ssbo_1.bind_to_storage_buffer(binding=0)
self.ssbo_2.bind_to_storage_buffer(binding=1)
# Set input variables for compute shader
# These are examples, although this example doesn't use them
# self.compute_shader["screen_size"] = self.get_size()
# self.compute_shader["force"] = force
# self.compute_shader["frame_time"] = self.run_time
# Run compute shader
self.compute_shader.run(group_x=self.group_x, group_y=self.group_y)
# Draw the balls
self.vao_2.render(self.program)
# Swap the buffers around (we are ping-ping rendering between two buffers)
self.ssbo_1, self.ssbo_2 = self.ssbo_2, self.ssbo_1
# Swap what geometry we draw
self.vao_1, self.vao_2 = self.vao_2, self.vao_1
# Draw the graphs
self.perf_graph_list.draw()
def gen_initial_data(self):
for i in range(self.num_balls):
# Position/radius
yield random.randrange(0, self.width)
yield random.randrange(0, self.height)
yield 0.0 # z (padding)
yield 6.0
# Velocity
yield 0.0
yield 0.0
yield 0.0 # vz (padding)
yield 0.0 # vw (padding)
# Color
yield 1.0 # r
yield 1.0 # g
yield 1.0 # b
yield 1.0 # a
app = MyWindow()
arcade.run()
|
An expanded version of this, with support for 3D, is available at: https://github.com/pvcraven/n-body