Skip to content

Commit

Permalink
Validate Resource commands and properly account for element size (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
Devon7925 authored Jan 18, 2025
1 parent 41327a4 commit 7e75e52
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 26 deletions.
51 changes: 42 additions & 9 deletions compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,54 @@ type BindingDescriptor = {

export type Bindings = Map<string, GPUBindGroupLayoutEntry>;

export type ReflectionBinding = {
"kind": "uniform",
"offset": number,
"size": number,
} | {
"kind": "descriptorTableSlot",
"index": number,
};

export type ReflectionType = {
"kind": "struct",
"name": string,
"fields": ReflectionParameter[]
} | {
"kind": "vector",
"elementCount": number,
"elementType": ReflectionType,
} | {
"kind": "scalar",
"scalarType": `${"uint" | "int"}${8 | 16 | 32 | 64}` | `${"float"}${16 | 32 | 64}`,
} | {
"kind": "resource",
"baseShape": "structuredBuffer",
"access"?: "readWrite",
"resultType": ReflectionType
} | {
"kind": "resource",
"baseShape": "texture2D",
"access"?: "readWrite"
};

export type ReflectionParameter = {
"binding": ReflectionBinding,
"name": string,
"type": ReflectionType,
"userAttribs"?: {
"arguments": any[],
"name": string,
}[],
}

export type ReflectionJSON = {
"entryPoints": {
"name": string,
"semanticName": string,
"type": unknown
}[],
"parameters": {
"binding": unknown,
"name": string,
"type": unknown,
"userAttribs"?: {
"arguments": any[],
"name": string,
}[],
}[],
"parameters": ReflectionParameter[],
};


Expand Down
22 changes: 11 additions & 11 deletions try-slang.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ async function execFrame(timeMS: number) {

let resource = allocatedResources.get(command.resourceName);
if (resource instanceof GPUBuffer) {
size = [resource.size / 4, 1, 1];
let elementSize = command.elementSize || 4;
size = [resource.size / elementSize, 1, 1];
}
else if (resource instanceof GPUTexture) {
size = [resource.width, resource.height, 1];
Expand Down Expand Up @@ -496,7 +497,11 @@ function checkShaderType(userSource: string) {
}

export type ParsedCommand = {
"type": "ZEROS" | "RAND",
"type": "ZEROS",
"count": number,
"elementSize": number,
} | {
"type": "RAND",
"count": number,
} | {
"type": "BLACK",
Expand All @@ -521,7 +526,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel

for (const { resourceName, parsedCommand } of resourceCommands) {
if (parsedCommand.type === "ZEROS") {
const elementSize = 4; // Assuming 4 bytes per element (e.g., float) TODO: infer from type.
const elementSize = parsedCommand.elementSize;
const bindingInfo = resourceBindings.get(resourceName);
if (!bindingInfo) {
throw new Error(`Resource ${resourceName} is not defined in the bindings.`);
Expand All @@ -539,12 +544,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
safeSet(allocatedResources, resourceName, buffer);

// Initialize the buffer with zeros.
let zeros: BufferSource;
if (elementSize == 4) {
zeros = new Float32Array(parsedCommand.count);
} else {
throw new Error("Element size isn't handled");
}
let zeros: BufferSource = new Uint8Array(parsedCommand.count * elementSize);
pipeline.device.queue.writeBuffer(buffer, 0, zeros);
} else if (parsedCommand.type === "BLACK") {
const size = parsedCommand.width * parsedCommand.height;
Expand Down Expand Up @@ -703,7 +703,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
}
} else {
// exhaustiveness check
let x: never = parsedCommand.type;
let x: never = parsedCommand;
throw new Error("Invalid resource command type");
}
}
Expand Down Expand Up @@ -785,7 +785,7 @@ export let onRun = () => {
resourceCommands = getCommandsFromAttributes(ret.reflection);

try {
callCommands = parseCallCommands(userSource);
callCommands = parseCallCommands(userSource, ret.reflection);
}
catch (error: any) {
throw new Error("Error while parsing '//! CALL' commands: " + error.message);
Expand Down
79 changes: 73 additions & 6 deletions util.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ReflectionJSON } from './compiler.js';
import { ReflectionJSON, ReflectionType } from './compiler.js';
import { ParsedCommand } from './try-slang.js';

export function configContext(device: GPUDevice, canvas: HTMLCanvasElement) {
Expand Down Expand Up @@ -42,6 +42,42 @@ function reinterpretUint32AsFloat(uint32: number) {
return float32View[0];
}

function roundUpToNearest(x: number, nearest: number){
return Math.ceil(x / nearest) * nearest;
}

function getSize(reflectionType: ReflectionType): number {
if(reflectionType.kind == "resource") {
throw new Error("unimplemented");
} else if(reflectionType.kind == "scalar") {
const bitsMatch = reflectionType.scalarType.match(/\d+$/);
if(bitsMatch == null) {
throw new Error("Could not get bit count out of scalar type");
}
return parseInt(bitsMatch[0]) / 8;
} else if(reflectionType.kind == "struct") {
const alignment = reflectionType.fields.map((f) => {
if(f.binding.kind == "uniform") return f.binding.size;
else throw new Error("Invalid state")
}).reduce((a, b) => Math.max(a, b));

const unalignedSize = reflectionType.fields.map((f) => {
if(f.binding.kind == "uniform") return f.binding.offset + f.binding.size;
else throw new Error("Invalid state")
}).reduce((a, b) => Math.max(a, b));

return roundUpToNearest(unalignedSize, alignment);
} else if(reflectionType.kind == "vector") {
if(reflectionType.elementCount == 3) {
return 4 * getSize(reflectionType.elementType);
}
return reflectionType.elementCount * getSize(reflectionType.elementType);
} else {
let x:never = reflectionType;
throw new Error("Cannot get size of unrecognized reflection type");
}
}

/**
* Here are some patterns we support:
*
Expand All @@ -63,20 +99,41 @@ export function getCommandsFromAttributes(reflection: ReflectionJSON): { resourc
if (!attribute.name.startsWith("playground_")) continue;

let playground_attribute_name = attribute.name.slice(11);
if (playground_attribute_name == "ZEROS" || playground_attribute_name == "RAND") {
if (playground_attribute_name == "ZEROS") {
if(parameter.type.kind != "resource" || parameter.type.baseShape != "structuredBuffer") {
throw new Error(`ZEROS attribute cannot be applied to ${parameter.name}, it only supports buffers`)
}
command = {
type: playground_attribute_name,
count: attribute.arguments[0] as number,
elementSize: getSize(parameter.type.resultType),
};
} else if (playground_attribute_name == "RAND") {
if(parameter.type.kind != "resource" || parameter.type.baseShape != "structuredBuffer") {
throw new Error(`RAND attribute cannot be applied to ${parameter.name}, it only supports buffers`)
}
if(parameter.type.resultType.kind != "scalar" || parameter.type.resultType.scalarType != "float32") {
throw new Error(`RAND attribute cannot be applied to ${parameter.name}, it only supports float buffers`)
}
command = {
type: playground_attribute_name,
count: attribute.arguments[0] as number
};
} else if (playground_attribute_name == "BLACK") {
if(parameter.type.kind != "resource" || parameter.type.baseShape != "texture2D") {
throw new Error(`BLACK attribute cannot be applied to ${parameter.name}, it only supports 2D textures`)
}
command = {
type: "BLACK",
type: playground_attribute_name,
width: attribute.arguments[0] as number,
height: attribute.arguments[1] as number,
};
} else if (playground_attribute_name == "URL") {
if(parameter.type.kind != "resource" || parameter.type.baseShape != "texture2D") {
throw new Error(`URL attribute cannot be applied to ${parameter.name}, it only supports 2D textures`)
}
command = {
type: "URL",
type: playground_attribute_name,
url: attribute.arguments[0] as string,
};
}
Expand All @@ -97,13 +154,14 @@ export type CallCommand = {
type: "RESOURCE_BASED",
fnName: string,
resourceName: string,
elementSize?: number,
} | {
type: "FIXED_SIZE",
fnName: string,
size: number[],
};

export function parseCallCommands(userSource: string): CallCommand[] {
export function parseCallCommands(userSource: string, reflection: ReflectionJSON): CallCommand[] {
// Look for commands of the form:
//
// 1. //! CALL(fn-name, SIZE_OF(<resource-name>)) ==> Dispatch a compute pass with the given
Expand All @@ -122,7 +180,16 @@ export function parseCallCommands(userSource: string): CallCommand[] {
const args = match[2].split(',').map(arg => arg.trim());

if (args[0].startsWith("SIZE_OF")) {
callCommands.push({ type: "RESOURCE_BASED", fnName, resourceName: args[0].slice(8, -1) });
let resourceName = args[0].slice(8, -1);
let resourceReflection = reflection.parameters.find((param) => param.name == resourceName);
if(resourceReflection == undefined) {
throw new Error(`Cannot find resource ${resourceName} for ${fnName} CALL command`)
}
let elementSize: number | undefined = undefined;
if(resourceReflection.type.kind == "resource" && resourceReflection.type.baseShape == "structuredBuffer") {
elementSize = getSize(resourceReflection.type.resultType);
}
callCommands.push({ type: "RESOURCE_BASED", fnName, resourceName, elementSize });
}
else {
callCommands.push({ type: "FIXED_SIZE", fnName, size: args.map(Number) });
Expand Down

0 comments on commit 7e75e52

Please sign in to comment.