Skip to content

Commit

Permalink
Fix entrypoint auto discovery logic. (#5885)
Browse files Browse the repository at this point in the history
* Fix entrypoint auto discovery logic.

* format code

---------

Co-authored-by: slangbot <[email protected]>
  • Loading branch information
csyonghe and slangbot authored Dec 17, 2024
1 parent 7ffc69d commit 49e912a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
28 changes: 15 additions & 13 deletions source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5171,22 +5171,24 @@ IArtifact* ComponentType::getTargetArtifact(Int targetIndex, slang::IBlob** outD
entryPointsDiscovered = true;
}
}
// If no entry points were discovered, then we should return nullptr.
if (!entryPointsDiscovered)
{
return nullptr;
}

RefPtr<CompositeComponentType> composite = new CompositeComponentType(linkage, components);
ComPtr<IComponentType> linkedComponentType;
SLANG_RETURN_NULL_ON_FAIL(composite->link(linkedComponentType.writeRef(), outDiagnostics));
auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get())
->getTargetArtifact(targetIndex, outDiagnostics);
if (targetArtifact)
// If any entry points were discovered, then we should emit the program with entrypoints
// linked.
if (entryPointsDiscovered)
{
m_targetArtifacts[targetIndex] = targetArtifact;
RefPtr<CompositeComponentType> composite =
new CompositeComponentType(linkage, components);
ComPtr<IComponentType> linkedComponentType;
SLANG_RETURN_NULL_ON_FAIL(
composite->link(linkedComponentType.writeRef(), outDiagnostics));
auto targetArtifact = static_cast<ComponentType*>(linkedComponentType.get())
->getTargetArtifact(targetIndex, outDiagnostics);
if (targetArtifact)
{
m_targetArtifacts[targetIndex] = targetArtifact;
}
return targetArtifact;
}
return targetArtifact;
}

auto target = linkage->targets[targetIndex];
Expand Down
56 changes: 56 additions & 0 deletions tools/slang-unit-test/unit-test-cuda-compile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// unit-test-cuda-compile.cpp

#include "../../source/core/slang-io.h"
#include "../../source/core/slang-process.h"
#include "slang-com-ptr.h"
#include "slang.h"
#include "unit-test/slang-unit-test.h"

#include <stdio.h>
#include <stdlib.h>

using namespace Slang;

// Test that the compilation API can be used to produce CUDA source.

SLANG_UNIT_TEST(CudaCompile)
{
// Source for a module that contains an undecorated entrypoint.
const char* userSourceBody = R"(
[CudaDeviceExport]
float testExportedFunc(float3 particleRayOrigin)
{
return dot(particleRayOrigin,particleRayOrigin);
};
)";

auto moduleName = "moduleG" + String(Process::getId());
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_CUDA_SOURCE;
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IComponentType> linkedProgram;
module->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(linkedProgram != nullptr);

ComPtr<slang::IBlob> code;
linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(code != nullptr);
SLANG_CHECK(code->getBufferSize() != 0);
String text = String((char*)code->getBufferPointer());
SLANG_CHECK(text.indexOf("testExportedFunc") > 0);
}

0 comments on commit 49e912a

Please sign in to comment.