diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp index 360dec6eba..56a4724829 100644 --- a/source/slang-wasm/slang-wasm-bindings.cpp +++ b/source/slang-wasm/slang-wasm-bindings.cpp @@ -47,8 +47,14 @@ EMSCRIPTEN_BINDINGS(slang) "getEntryPointCode", &slang::wgsl::ComponentType::getEntryPointCode) .function( - "getEntryPointCodeSpirv", - &slang::wgsl::ComponentType::getEntryPointCodeSpirv); + "getEntryPointCodeBlob", + &slang::wgsl::ComponentType::getEntryPointCodeBlob) + .function( + "getTargetCodeBlob", + &slang::wgsl::ComponentType::getTargetCodeBlob) + .function( + "getTargetCode", + &slang::wgsl::ComponentType::getTargetCode); class_>("Module") .function( @@ -58,14 +64,25 @@ EMSCRIPTEN_BINDINGS(slang) .function( "findAndCheckEntryPoint", &slang::wgsl::Module::findAndCheckEntryPoint, - return_value_policy::take_ownership()); + return_value_policy::take_ownership()) + .function( + "getDefinedEntryPoint", + &slang::wgsl::Module::getDefinedEntryPoint, + return_value_policy::take_ownership()) + .function( + "getDefinedEntryPointCount", + &slang::wgsl::Module::getDefinedEntryPointCount); value_object("Error") .field("type", &slang::wgsl::Error::type) .field("result", &slang::wgsl::Error::result) .field("message", &slang::wgsl::Error::message); - class_>("EntryPoint"); + class_>("EntryPoint") + .function( + "getName", + &slang::wgsl::EntryPoint::getName, + allow_raw_pointers()); class_("CompileTargets") .function( diff --git a/source/slang-wasm/slang-wasm.cpp b/source/slang-wasm/slang-wasm.cpp index 886efc5f7b..8948c1075b 100644 --- a/source/slang-wasm/slang-wasm.cpp +++ b/source/slang-wasm/slang-wasm.cpp @@ -94,17 +94,15 @@ Session* GlobalSession::createSession(int compileTarget) return new Session(session); } -Module* Session::loadModuleFromSource(const std::string& slangCode) +Module* Session::loadModuleFromSource(const std::string& slangCode, const std::string& name, const std::string& path) { Slang::ComPtr module; { - const char * name = ""; - const char * path = ""; Slang::ComPtr diagnosticsBlob; Slang::ComPtr slangCodeBlob = Slang::RawBlob::create( slangCode.c_str(), slangCode.size()); module = m_interface->loadModuleFromSource( - name, path, slangCodeBlob, diagnosticsBlob.writeRef()); + name.c_str(), path.c_str(), slangCodeBlob, diagnosticsBlob.writeRef()); if (!module) { g_error.type = std::string("USER"); @@ -161,6 +159,38 @@ EntryPoint* Module::findAndCheckEntryPoint(const std::string& name, int stage) return new EntryPoint(entryPoint); } +int Module::getDefinedEntryPointCount() +{ + return moduleInterface()->getDefinedEntryPointCount(); +} + +EntryPoint* Module::getDefinedEntryPoint(int index) +{ + if (moduleInterface()->getDefinedEntryPointCount() <= index) + return nullptr; + + Slang::ComPtr entryPoint; + { + Slang::ComPtr diagnosticsBlob; + SlangResult result = moduleInterface()->getDefinedEntryPoint(index, entryPoint.writeRef()); + if (!SLANG_SUCCEEDED(result)) + { + g_error.type = std::string("USER"); + g_error.result = result; + + if (diagnosticsBlob->getBufferSize()) + { + char* diagnostics = (char*)diagnosticsBlob->getBufferPointer(); + g_error.message = std::string(diagnostics); + } + return nullptr; + } + } + + return new EntryPoint(entryPoint); +} + + ComponentType* Session::createCompositeComponentType( const std::vector& components) { @@ -235,9 +265,9 @@ std::string ComponentType::getEntryPointCode(int entryPointIndex, int targetInde return {}; } -// Since spirv code is binary, we can't return it as a string, we will need to use emscripten::val +// Since result code is binary, we can't return it as a string, we will need to use emscripten::val // to wrap it and return it to the javascript side. -emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int targetIndex) +emscripten::val ComponentType::getEntryPointCodeBlob(int entryPointIndex, int targetIndex) { Slang::ComPtr kernelBlob; Slang::ComPtr diagnosticBlob; @@ -262,6 +292,60 @@ emscripten::val ComponentType::getEntryPointCodeSpirv(int entryPointIndex, int t ptr)); } +std::string ComponentType::getTargetCode(int targetIndex) +{ + { + Slang::ComPtr kernelBlob; + Slang::ComPtr diagnosticBlob; + SlangResult result = interface()->getTargetCode( + targetIndex, + kernelBlob.writeRef(), + diagnosticBlob.writeRef()); + if (result != SLANG_OK) + { + g_error.type = std::string("USER"); + g_error.result = result; + g_error.message = std::string( + (char*)diagnosticBlob->getBufferPointer(), + (char*)diagnosticBlob->getBufferPointer() + + diagnosticBlob->getBufferSize()); + return ""; + } + std::string targetCode = std::string( + (char*)kernelBlob->getBufferPointer(), + (char*)kernelBlob->getBufferPointer() + kernelBlob->getBufferSize()); + return targetCode; + } + + return {}; +} + +// Since result code is binary, we can't return it as a string, we will need to use emscripten::val +// to wrap it and return it to the javascript side. +emscripten::val ComponentType::getTargetCodeBlob(int targetIndex) +{ + Slang::ComPtr kernelBlob; + Slang::ComPtr diagnosticBlob; + SlangResult result = interface()->getTargetCode( + targetIndex, + kernelBlob.writeRef(), + diagnosticBlob.writeRef()); + if (result != SLANG_OK) + { + g_error.type = std::string("USER"); + g_error.result = result; + g_error.message = std::string( + (char*)diagnosticBlob->getBufferPointer(), + (char*)diagnosticBlob->getBufferPointer() + + diagnosticBlob->getBufferSize()); + return {}; + } + + const uint8_t* ptr = (uint8_t*)kernelBlob->getBufferPointer(); + return emscripten::val(emscripten::typed_memory_view(kernelBlob->getBufferSize(), + ptr)); +} + namespace lsp { Position translate(Slang::LanguageServerProtocol::Position p) diff --git a/source/slang-wasm/slang-wasm.h b/source/slang-wasm/slang-wasm.h index 5a299453cf..eb302119b7 100644 --- a/source/slang-wasm/slang-wasm.h +++ b/source/slang-wasm/slang-wasm.h @@ -48,7 +48,9 @@ class ComponentType ComponentType* link(); std::string getEntryPointCode(int entryPointIndex, int targetIndex); - emscripten::val getEntryPointCodeSpirv(int entryPointIndex, int targetIndex); + emscripten::val getEntryPointCodeBlob(int entryPointIndex, int targetIndex); + std::string getTargetCode(int targetIndex); + emscripten::val getTargetCodeBlob(int targetIndex); slang::IComponentType* interface() const {return m_interface;} @@ -62,9 +64,11 @@ class ComponentType class EntryPoint : public ComponentType { public: - EntryPoint(slang::IEntryPoint* interface) : ComponentType(interface) {} - + std::string getName() const + { + return entryPointInterface()->getFunctionReflection()->getName(); + } private: slang::IEntryPoint* entryPointInterface() const { @@ -80,6 +84,8 @@ class Module : public ComponentType EntryPoint* findEntryPointByName(const std::string& name); EntryPoint* findAndCheckEntryPoint(const std::string& name, int stage); + EntryPoint* getDefinedEntryPoint(int index); + int getDefinedEntryPointCount(); slang::IModule* moduleInterface() const { return static_cast(interface()); @@ -93,7 +99,8 @@ class Session Session(slang::ISession* interface) : m_interface(interface) {} - Module* loadModuleFromSource(const std::string& slangCode); + Module* loadModuleFromSource( + const std::string& slangCode, const std::string& name, const std::string& path); ComponentType* createCompositeComponentType( const std::vector& components); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b77a3efc8f..b3ac7f73d5 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5040,13 +5040,21 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD }); List> components; components.add(this); + bool entryPointsDiscovered = false; for (auto module : modules) { for (auto entryPoint : module->getEntryPoints()) { components.add(entryPoint); + entryPointsDiscovered = true; } } + // If no entry points were discovered, then we should return nullptr. + if (!entryPointsDiscovered) + { + return nullptr; + } + RefPtr composite = new CompositeComponentType(linkage, components); ComPtr linkedComponentType; SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));