Skip to content

Commit

Permalink
improve runnable entry point code
Browse files Browse the repository at this point in the history
  • Loading branch information
Devon7925 committed Jan 14, 2025
1 parent f82c134 commit 6dad1a8
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 310 deletions.
98 changes: 48 additions & 50 deletions compiler.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {SpirvTools, default as spirvTools} from "./spirv-tools.js";
import { SpirvTools, default as spirvTools } from "./spirv-tools.js";
import { ModuleType } from './try-slang.js';
import type { ComponentType, EmbindString, GlobalSession, Module, ProgramLayout, Session, ThreadGroupSize, VariableLayoutReflection } from './slang-wasm.js';
import { playgroundSource } from "./playgroundShader.js";
Expand All @@ -10,7 +10,12 @@ export function isWholeProgramTarget(compileTarget: string) {
return compileTarget == "METAL" || compileTarget == "SPIRV";
}

const imageMainSource = `
export const RUNNABLE_ENTRY_POINT_NAMES = ['imageMain', 'printMain'] as const;
export type RunnableShaderType = typeof RUNNABLE_ENTRY_POINT_NAMES[number]
export type ShaderType = RunnableShaderType | null

const RUNNABLE_ENTRY_POINT_SOURCE_MAP: { [key in RunnableShaderType]: string } = {
'imageMain': `
import user;
import playground;
Expand All @@ -34,10 +39,8 @@ void imageMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputTexture.Store(dispatchThreadID.xy, color);
}
`;


const printMainSource = `
`,
'printMain': `
import user;
import playground;
Expand All @@ -52,7 +55,9 @@ void printMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
printMain();
}
`;
`,
}

type BindingDescriptor = {
storageTexture: {
access: "write-only" | "read-write",
Expand Down Expand Up @@ -91,18 +96,14 @@ export class SlangCompiler {
static SLANG_STAGE_FRAGMENT = 5;
static SLANG_STAGE_COMPUTE = 6;

static RENDER_SHADER = 0;
static PRINT_SHADER = 1;
static NON_RUNNABLE_SHADER = 2;

globalSlangSession: GlobalSession | null = null;
// slangSession = null;

compileTargetMap: { name: string, value: number }[] | null = null;

slangWasmModule;
diagnosticsMsg;
shaderType;
shaderType: ShaderType;

spirvToolsModule: SpirvTools | null = null;

Expand All @@ -111,9 +112,10 @@ export class SlangCompiler {
constructor(module: ModuleType) {
this.slangWasmModule = module;
this.diagnosticsMsg = "";
this.shaderType = SlangCompiler.NON_RUNNABLE_SHADER;
this.mainModules.set('imageMain', { source: imageMainSource });
this.mainModules.set('printMain', { source: printMainSource });
this.shaderType = null;
for (let runnableEntryPoint of RUNNABLE_ENTRY_POINT_NAMES) {
this.mainModules.set(runnableEntryPoint, { source: RUNNABLE_ENTRY_POINT_SOURCE_MAP[runnableEntryPoint] });
}
FS.createDataFile("/", "user.slang", new DataView(new ArrayBuffer(0)), true, true, false);
FS.createDataFile("/", "playground.slang", new DataView(new ArrayBuffer(0)), true, true, false);
}
Expand Down Expand Up @@ -149,15 +151,10 @@ export class SlangCompiler {

// In our playground, we only allow to run shaders with two entry points: renderMain and printMain
findRunnableEntryPoint(module: Module) {
const runnableEntryPointNames = ['imageMain', 'printMain'];
for (let i = 0; i < runnableEntryPointNames.length; i++) {
const entryPointName = runnableEntryPointNames[i];
for (let entryPointName of RUNNABLE_ENTRY_POINT_NAMES) {
let entryPoint = module.findAndCheckEntryPoint(entryPointName, SlangCompiler.SLANG_STAGE_COMPUTE);
if (entryPoint) {
if (i == 0)
this.shaderType = SlangCompiler.RENDER_SHADER;
else
this.shaderType = SlangCompiler.PRINT_SHADER;
this.shaderType = entryPointName;
return entryPoint;
}
}
Expand Down Expand Up @@ -187,8 +184,7 @@ export class SlangCompiler {
}

async initSpirvTools() {
if (!this.spirvToolsModule)
{
if (!this.spirvToolsModule) {
this.spirvToolsModule = await spirvTools();
}
}
Expand Down Expand Up @@ -218,11 +214,10 @@ export class SlangCompiler {
// we will also add them to the dropdown list.
findDefinedEntryPoints(shaderSource: string): string[] {
let result: string[] = [];
if (shaderSource.match("imageMain")) {
result.push("imageMain");
}
if (shaderSource.match("printMain")) {
result.push("printMain");
for (let entryPointName of RUNNABLE_ENTRY_POINT_NAMES) {
if (shaderSource.match(entryPointName)) {
result.push(entryPointName);
}
}
let slangSession: Session | null | undefined;
try {
Expand All @@ -232,9 +227,13 @@ export class SlangCompiler {
return [];
}
let module: Module | null = null;

if (result.length > 0) {
slangSession.loadModuleFromSource(playgroundSource, "playground", "/playground.slang");
}
module = slangSession.loadModuleFromSource(shaderSource, "user", "/user.slang");
if (!module) {
const error = this.slangWasmModule.getLastError();
console.error(error.type + " error: " + error.message);
return result;
}

Expand All @@ -256,16 +255,16 @@ export class SlangCompiler {
// If user entrypoint name imageMain or printMain, we will load the pre-built main modules because they
// are defined in those modules. Otherwise, we will only need to load the user module and find the entry
// point in the user module.
shouldLoadMainModule(entryPointName: string) {
return entryPointName == "imageMain" || entryPointName == "printMain";
isRunnableEntryPoint(entryPointName: string): entryPointName is RunnableShaderType {
return RUNNABLE_ENTRY_POINT_NAMES.includes(entryPointName as any);
}

// Since we will not let user to change the entry point code, we can precompile the entry point module
// and reuse it for every compilation.

compileEntryPointModule(slangSession: Session, moduleName: string) {
let source = this.mainModules.get(moduleName)?.source;
if(source == undefined) {
if (source == undefined) {
throw new Error(`Could not get module ${moduleName}`)
}
let module: Module | null = slangSession.loadModuleFromSource(source, moduleName, '/' + moduleName + '.slang');
Expand All @@ -287,12 +286,12 @@ export class SlangCompiler {
}

getPrecompiledProgram(slangSession: Session, moduleName: string) {
if (moduleName != "printMain" && moduleName != "imageMain")
if (!this.isRunnableEntryPoint(moduleName))
return null;

let mainModule = this.compileEntryPointModule(slangSession, moduleName);

this.shaderType = SlangCompiler.RENDER_SHADER;
this.shaderType = moduleName;
return mainModule;
}

Expand All @@ -306,23 +305,22 @@ export class SlangCompiler {
const count = userModule.getDefinedEntryPointCount();
for (let i = 0; i < count; i++) {
const name = userModule.getDefinedEntryPoint(i).getName();
if (name == "imageMain" || name == "printMain") {
this.diagnosticsMsg += ("error: Entry point name 'imageMain' or 'printMain' is reserved");
if (this.isRunnableEntryPoint(name)) {
this.diagnosticsMsg += `error: Entry point name ${name} is reserved`;
return false;
}
}

// If entry point is provided, we know for sure this is not a whole program compilation,
// so we will just go to find the correct module to include in the compilation.
if (entryPointName != "") {
if (this.shouldLoadMainModule(entryPointName)) {
if (this.isRunnableEntryPoint(entryPointName)) {
// we use the same entry point name as module name
const mainProgram = this.getPrecompiledProgram(slangSession, entryPointName);
if (!mainProgram)
return false;

this.shaderType = entryPointName == "imageMain" ?
SlangCompiler.RENDER_SHADER : SlangCompiler.PRINT_SHADER;
this.shaderType = entryPointName;

componentList.push(mainProgram.module);
componentList.push(mainProgram.entryPoint);
Expand All @@ -341,7 +339,7 @@ export class SlangCompiler {
else {
const results = this.findDefinedEntryPoints(shaderSource);
for (let i = 0; i < results.length; i++) {
if (results[i] == "imageMain" || results[i] == "printMain") {
if (this.isRunnableEntryPoint(results[i])) {
const mainProgram = this.getPrecompiledProgram(slangSession, results[i]);
if (!mainProgram)
return false;
Expand All @@ -361,10 +359,10 @@ export class SlangCompiler {
return true;
}

getBindingDescriptor(index: number, programReflection: ProgramLayout, parameter: VariableLayoutReflection): BindingDescriptor|null {
getBindingDescriptor(index: number, programReflection: ProgramLayout, parameter: VariableLayoutReflection): BindingDescriptor | null {
const globalLayout = programReflection.getGlobalParamsTypeLayout();

if(globalLayout == null) {
if (globalLayout == null) {
throw new Error("Could not get layout")
}

Expand Down Expand Up @@ -396,7 +394,7 @@ export class SlangCompiler {
getResourceBindings(linkedProgram: ComponentType): Bindings {
const reflection: ProgramLayout | null = linkedProgram.getLayout(0); // assume target-index = 0

if(reflection == null) {
if (reflection == null) {
throw new Error("Could not get reflection!")
}

Expand All @@ -405,7 +403,7 @@ export class SlangCompiler {
let resourceDescriptors = new Map();
for (let i = 0; i < count; i++) {
const parameter = reflection.getParameterByIndex(i);
if(parameter == null) {
if (parameter == null) {
throw new Error("Invalid state!")
}
const name = parameter.getName();
Expand Down Expand Up @@ -438,10 +436,10 @@ export class SlangCompiler {
return true;
}

compile(shaderSource: string, entryPointName: string, compileTargetStr: string): null | [string, Bindings, any, ReflectionJSON, ThreadGroupSize | {x: number, y: number, z: number}] {
compile(shaderSource: string, entryPointName: string, compileTargetStr: string): null | [string, Bindings, any, ReflectionJSON, ThreadGroupSize | { x: number, y: number, z: number }] {
this.diagnosticsMsg = "";

let shouldLinkPlaygroundModule = (shaderSource.match(/printMain|imageMain/) != null);
let shouldLinkPlaygroundModule = RUNNABLE_ENTRY_POINT_NAMES.some((entry_point) => shaderSource.match(entry_point) != null);

const compileTarget = this.findCompileTarget(compileTargetStr);
let isWholeProgram = isWholeProgramTarget(compileTargetStr);
Expand All @@ -452,7 +450,7 @@ export class SlangCompiler {
}

try {
if(this.globalSlangSession == null) {
if (this.globalSlangSession == null) {
throw new Error("Slang session not available. Maybe the compiler hasn't been initialized yet?")
}
let slangSession = this.globalSlangSession.createSession(compileTarget);
Expand All @@ -476,7 +474,7 @@ export class SlangCompiler {
if (this.addActiveEntryPoints(slangSession, shaderSource, entryPointName, isWholeProgram, components[userModuleIndex], components) == false)
return null;
let program: ComponentType = slangSession.createCompositeComponentType(components);
let linkedProgram:ComponentType = program.link();
let linkedProgram: ComponentType = program.link();
let hashedStrings = linkedProgram.loadStrings();

let outCode: string;
Expand Down Expand Up @@ -519,7 +517,7 @@ export class SlangCompiler {
} catch (e) {
console.error(e);
// typescript is missing the type for WebAssembly.Exception
if(typeof e === 'object' && e !== null && e.constructor.name === 'Exception') {
if (typeof e === 'object' && e !== null && e.constructor.name === 'Exception') {
this.diagnosticsMsg += "Slang internal error occurred.\n";
} else if (e instanceof Error) {
this.diagnosticsMsg += e.message;
Expand Down
33 changes: 0 additions & 33 deletions image_demo.ts

This file was deleted.

77 changes: 0 additions & 77 deletions test.cpp

This file was deleted.

Loading

0 comments on commit 6dad1a8

Please sign in to comment.