Skip to content

Commit

Permalink
Support entrypoints defined in a namespace.
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe committed Sep 4, 2024
1 parent 599dae5 commit cd67908
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 71 deletions.
72 changes: 7 additions & 65 deletions source/slang/slang-check-shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,75 +235,17 @@ namespace Slang
Name* name,
DiagnosticSink* sink)
{
auto translationUnitSyntax = translationUnit->getModuleDecl();
FuncDecl* entryPointFuncDecl = nullptr;
auto declRef = translationUnit->findDeclFromString(getText(name), sink);
FuncDecl* entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();

for (auto globalScope = translationUnit->getModuleDecl()->ownedScope; globalScope; globalScope = globalScope->nextSibling)
{
if (globalScope->containerDecl != translationUnitSyntax && globalScope->containerDecl->parentDecl != translationUnitSyntax)
continue; // Skip scopes that aren't part of the current module.

// We will look up any global-scope declarations in the translation
// unit that match the name of our entry point.
Decl* firstDeclWithName = nullptr;
if (!globalScope->containerDecl->getMemberDictionary().tryGetValue(name, firstDeclWithName))
{
// If there doesn't appear to be any such declaration, then we are done with this scope.
continue;
}

// We found at least one global-scope declaration with the right name,
// but (1) it might not be a function, and (2) there might be
// more than one function.
//
// We'll walk the linked list of declarations with the same name,
// to see what we find. Along the way we'll keep track of the
// first function declaration we find, if any:
for (auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName)
{
// Is this declaration a function?
if (auto funcDecl = as<FuncDecl>(ee))
{
// Skip non-primary declarations, so that
// we don't give an error when an entry
// point is forward-declared.
if (!isPrimaryDecl(funcDecl))
continue;

// is this the first one we've seen?
if (!entryPointFuncDecl)
{
// If so, this is a candidate to be
// the entry point function.
entryPointFuncDecl = funcDecl;
}
else
{
// Uh-oh! We've already seen a function declaration with this
// name before, so the whole thing is ambiguous. We need
// to diagnose and bail out.

sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, name);

// List all of the declarations that the user *might* mean
for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName)
{
if (auto candidate = as<FuncDecl>(ff))
{
sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName());
}
}

// Bail out.
return nullptr;
}
}
}
}
if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
entryPointFuncDecl = nullptr;

if (!entryPointFuncDecl)
{
auto translationUnitSyntax = translationUnit->getModuleDecl();
sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, name);

}
return entryPointFuncDecl;
}

Expand Down
18 changes: 13 additions & 5 deletions source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2535,22 +2535,31 @@ namespace Slang
{
if (m_entryPoints.getCount() > 0)
return;

for (auto globalDecl : m_moduleDecl->members)
_discoverEntryPointsImpl(m_moduleDecl, sink, targets);
}
void Module::_discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets)
{
for (auto globalDecl : containerDecl->members)
{
auto maybeFuncDecl = globalDecl;
if (auto genericDecl = as<GenericDecl>(maybeFuncDecl))
{
maybeFuncDecl = genericDecl->inner;
}

if (as<NamespaceDeclBase>(globalDecl) || as<FileDecl>(globalDecl) || as<StructDecl>(globalDecl))
{
_discoverEntryPointsImpl(as<ContainerDecl>(globalDecl), sink, targets);
continue;
}

auto funcDecl = as<FuncDecl>(maybeFuncDecl);
if (!funcDecl)
continue;

Profile profile;
bool resolvedStageOfProfileWithEntryPoint = resolveStageOfProfileWithEntryPoint(profile, getLinkage()->m_optionSet, targets, funcDecl, sink);
if(!resolvedStageOfProfileWithEntryPoint)
if (!resolvedStageOfProfileWithEntryPoint)
{
// If there isn't a [shader] attribute, look for a [numthreads] attribute
// since that implicitly means a compute shader. We'll not do this when compiling for
Expand All @@ -2560,7 +2569,7 @@ namespace Slang
bool allTargetsCUDARelated = true;
for (auto target : targets)
{
if (!isCUDATarget(target) &&
if (!isCUDATarget(target) &&
target->getTarget() != CodeGenTarget::PyTorchCppBinding)
{
allTargetsCUDARelated = false;
Expand Down Expand Up @@ -2614,6 +2623,5 @@ namespace Slang
_addEntryPoint(entryPoint);
}
}

}

2 changes: 2 additions & 0 deletions source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,8 @@ namespace Slang
void _collectShaderParams();

void _discoverEntryPoints(DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets);
void _discoverEntryPointsImpl(ContainerDecl* containerDecl, DiagnosticSink* sink, const List<RefPtr<TargetRequest>>& targets);


class ModuleSpecializationInfo : public SpecializationInfo
{
Expand Down
9 changes: 8 additions & 1 deletion source/slang/slang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,14 @@ DeclRef<Decl> ComponentType::findDeclFromString(
{
result = declRefExpr->declRef;
}

else if (auto overloadedExpr = as<OverloadedExpr>(checkedExpr))
{
sink->diagnose(SourceLoc(), Diagnostics::ambiguousReference, name);
for (auto candidate : overloadedExpr->lookupResult2)
{
sink->diagnose(candidate.declRef.getDecl(), Diagnostics::overloadCandidate, candidate.declRef);
}
}
m_decls[name] = result;
return result;
}
Expand Down
20 changes: 20 additions & 0 deletions tests/language-feature/namespaces/entrypoint-in-namespace.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry Pixel.MyType.Main -stage fragment

// Test that we can compile an entrypoint defined in a namespace.

// CHECK: OpEntryPoint
struct PSInput
{
float4 color : COLOR;
};

namespace Pixel
{
struct MyType
{
static float4 Main(PSInput input) : SV_TARGET
{
return input.color;
}
}
}

0 comments on commit cd67908

Please sign in to comment.