const std = @import("std");
const tau = std.math.tau;

const core = @import("mach-core");
const gpu = core.gpu;

const Renderer = @import("./renderer.zig");
const createAndWriteBuffer = Renderer.createAndWriteBuffer;

/// Describes the layout of each vertex that a primitive is made of.
pub const VertexData = struct {
    position: [3]f32,
    uv: [2]f32,
};

/// Contains the data to render a primitive (3D shape or model).
pub const PrimitiveData = struct {
    /// Vertices describe the "points" that a primitive is made out of.
    /// This buffer is of type `[]VertexData`.
    vertex_buffer: *gpu.Buffer,
    vertex_count: u32,

    /// Indices describe what vertices make up the triangles in a primitive.
    /// This buffer is of type `[]u32`.
    index_buffer: *gpu.Buffer,
    index_count: u32,

    // For example, `vertex_buffer` may have 4 points defining a square, but
    // since it needs to be rendered using 2 triangles, `index_buffer` will
    // contain 6 entries, `0, 1, 2` and `3, 2, 1` making up one triangle each.
};

/// Creates a primitive from the provided vertices and indices,
/// and uploads the buffers necessary to render it to the GPU.
pub fn createPrimitive(
    vertices: []const VertexData,
    indices: []const u32,
) PrimitiveData {
    return .{
        .vertex_buffer = createAndWriteBuffer(VertexData, vertices, .{ .vertex = true, .copy_dst = true }),
        .vertex_count = @intCast(vertices.len),
        .index_buffer = createAndWriteBuffer(u32, indices, .{ .index = true, .copy_dst = true }),
        .index_count = @intCast(indices.len),
    };
}

fn vert(x: f32, y: f32, z: f32, tx: f32, ty: f32) VertexData {
    return .{ .position = .{ x, y, z }, .uv = .{ tx, ty } };
}

pub fn createTrianglePrimitive(length: f32) PrimitiveData {
    const radius = length / @sqrt(3.0);
    const a0 = 0.0;
    const a1 = tau / 3.0;
    const a2 = tau / 3.0 * 2.0;
    return createPrimitive(
        // A triangle is made up of 3 vertices.
        //
        //        0
        //       / \
        //      /   \
        //     1-----2
        &.{
            vert(@sin(a0) * radius, @cos(a0) * radius, 0.0, 0.5, 0.0),
            vert(@sin(a1) * radius, @cos(a1) * radius, 0.0, 0.0, 1.0),
            vert(@sin(a2) * radius, @cos(a2) * radius, 0.0, 1.0, 1.0),
        },
        // Vertices have to be specified in counter-clockwise,
        // so the "front" of the triangle is facing the right way.
        &.{
            0, 1, 2,
        },
    );
}

pub fn createSquarePrimitive(width: f32) PrimitiveData {
    const half_width = width / 2.0;
    return createPrimitive(
        // A square is made up of 4 vertices, ...
        //
        //     0---2
        //     |   |
        //     |   |
        //     1---3
        &.{
            // zig fmt: off
            vert(-half_width, -half_width, 0.0, 0.0, 0.0),
            vert(-half_width,  half_width, 0.0, 0.0, 1.0),
            vert( half_width, -half_width, 0.0, 1.0, 0.0),
            vert( half_width,  half_width, 0.0, 1.0, 1.0),
            // zig fmt: on
        },
        // ... but it has to be split up into 2 triangles.
        //
        //     0--2  4
        //     | /  /|
        //     |/  / |
        //     1  5--3
        &.{
            0, 1, 2,
            3, 2, 1,
        },
    );
}

pub fn createCirclePrimitive(radius: f32, comptime sides: usize) PrimitiveData {
    if (sides < 3) @compileError("sides must be at least 3");

    var vertices: [sides]VertexData = undefined;
    for (&vertices, 0..) |*vertex, i| {
        const angle = tau / @as(f32, @floatFromInt(sides)) * @as(f32, @floatFromInt(i));
        vertex.* = vert(@sin(angle) * radius, @cos(angle) * radius, 0.0);
    }

    var indices: [(sides - 2) * 3]u32 = undefined;
    for (0..(sides - 2)) |i| {
        indices[i * 3 + 0] = 0;
        indices[i * 3 + 1] = @as(u32, @intCast(i)) + 1;
        indices[i * 3 + 2] = @as(u32, @intCast(i)) + 2;
    }

    return createPrimitive(&vertices, &indices);
}

pub fn createCubePrimitive(width: f32) PrimitiveData {
    const half_width = width / 2.0;
    return createPrimitive(
        // zig fmt: off
        &.{
            // Right (+X)
            vert( half_width,  half_width, -half_width, 0.0, 0.0),
            vert( half_width, -half_width, -half_width, 0.0, 1.0),
            vert( half_width,  half_width,  half_width, 1.0, 0.0),
            vert( half_width, -half_width,  half_width, 1.0, 1.0),
            // Left (-X)
            vert(-half_width,  half_width,  half_width, 0.0, 0.0),
            vert(-half_width, -half_width,  half_width, 0.0, 1.0),
            vert(-half_width,  half_width, -half_width, 1.0, 0.0),
            vert(-half_width, -half_width, -half_width, 1.0, 1.0),
            // Top (+Y)
            vert( half_width,  half_width, -half_width, 0.0, 0.0),
            vert( half_width,  half_width,  half_width, 0.0, 1.0),
            vert(-half_width,  half_width, -half_width, 1.0, 0.0),
            vert(-half_width,  half_width,  half_width, 1.0, 1.0),
            // Bottom (-Y)
            vert(-half_width, -half_width, -half_width, 1.0, 0.0),
            vert(-half_width, -half_width,  half_width, 0.0, 0.0),
            vert( half_width, -half_width, -half_width, 1.0, 1.0),
            vert( half_width, -half_width,  half_width, 0.0, 1.0),
            // Front (+Z)
            vert( half_width,  half_width,  half_width, 0.0, 0.0),
            vert( half_width, -half_width,  half_width, 0.0, 1.0),
            vert(-half_width,  half_width,  half_width, 1.0, 0.0),
            vert(-half_width, -half_width,  half_width, 1.0, 1.0),
            // Back (-Z)
            vert(-half_width,  half_width, -half_width, 0.0, 0.0),
            vert(-half_width, -half_width, -half_width, 0.0, 1.0),
            vert( half_width,  half_width, -half_width, 1.0, 0.0),
            vert( half_width, -half_width, -half_width, 1.0, 1.0),
        },
        &.{
             0,  1,  2,    3,  2,  1, // Right
             4,  5,  6,    7,  6,  5, // Left
             8,  9, 10,   11, 10,  9, // Top
            12, 13, 14,   15, 14, 13, // Bottom
            16, 17, 18,   19, 18, 17, // Front
            20, 21, 22,   23, 22, 21, // Back
        },
        // zig fmt: on
    );
}

pub fn createPyramidPrimitive(width: f32) PrimitiveData {
    const half_width = width / 2.0;
    return createPrimitive(
        // zig fmt: off
        &.{
            // Right
            vert(        0.0,  half_width,         0.0, 0.5, 0.0),
            vert( half_width, -half_width, -half_width, 0.0, 1.0),
            vert( half_width, -half_width,  half_width, 1.0, 1.0),
            // Left
            vert(        0.0,  half_width,         0.0, 0.5, 0.0),
            vert(-half_width, -half_width,  half_width, 0.0, 1.0),
            vert(-half_width, -half_width, -half_width, 1.0, 1.0),
            // Front
            vert(        0.0,  half_width,         0.0, 0.5, 0.0),
            vert( half_width, -half_width,  half_width, 0.0, 1.0),
            vert(-half_width, -half_width,  half_width, 1.0, 1.0),
            // Back
            vert(        0.0,  half_width,         0.0, 0.5, 0.0),
            vert(-half_width, -half_width, -half_width, 0.0, 1.0),
            vert( half_width, -half_width, -half_width, 1.0, 1.0),
            // Bottom
            vert(-half_width, -half_width, -half_width, 0.0, 0.0),
            vert(-half_width, -half_width,  half_width, 0.0, 1.0),
            vert( half_width, -half_width, -half_width, 1.0, 0.0),
            vert( half_width, -half_width,  half_width, 1.0, 1.0),
        },
        &.{
             0,  1,  2, // Right
             3,  4,  5, // Left
             6,  7,  8, // Front
             9, 10, 11, // Back
            // Bottom
            12, 13, 14,
            15, 14, 13,
        },
        // zig fmt: on
    );
}