diff --git a/Tools/WinMLRunner/src/CommandLineArgs.cpp b/Tools/WinMLRunner/src/CommandLineArgs.cpp index 3a38dd9c..f3991ae9 100644 --- a/Tools/WinMLRunner/src/CommandLineArgs.cpp +++ b/Tools/WinMLRunner/src/CommandLineArgs.cpp @@ -21,10 +21,7 @@ void CommandLineArgs::PrintUsage() std::cout << " -GPU : run model on default GPU" << std::endl; std::cout << " -GPUHighPerformance : run model on GPU with highest performance" << std::endl; std::cout << " -GPUMinPower : run model on GPU with the least power" << std::endl; -#ifdef DXCORE_SUPPORTED_BUILD - std::cout << " -GPUAdapterName : run model on GPU specified by its name. NOTE: Please only use this flag on DXCore supported machines." - << std::endl; -#endif + std::cout << " -GPUAdapterName : run model on GPU specified by its name." << std::endl; std::cout << " -CreateDeviceOnClient : create the D3D device on the client and pass it to WinML to create session" << std::endl; std::cout << " -CreateDeviceInWinML : create the device inside WinML" << std::endl; std::cout << " -CPUBoundInput : bind the input to the CPU" << std::endl; @@ -101,10 +98,10 @@ CommandLineArgs::CommandLineArgs(const std::vector& args) { m_useGPUMinPower = true; } -#ifdef DXCORE_SUPPORTED_BUILD else if (_wcsicmp(args[i].c_str(), L"-GPUAdapterName") == 0) { CheckNextArgument(args, i); +#ifdef DXCORE_SUPPORTED_BUILD HMODULE library = nullptr; library = LoadLibrary(L"dxcore.dll"); if (!library) @@ -113,10 +110,10 @@ CommandLineArgs::CommandLineArgs(const std::vector& args) L"ERROR: DXCORE isn't supported on this machine. " L"GpuAdapterName flag should only be used with DXCore supported machines."); } +#endif m_adapterName = args[++i]; m_useGPU = true; } -#endif else if ((_wcsicmp(args[i].c_str(), L"-CreateDeviceOnClient") == 0)) { m_createDeviceOnClient = true; diff --git a/Tools/WinMLRunner/src/CommandLineArgs.h b/Tools/WinMLRunner/src/CommandLineArgs.h index 51a6ca00..cac1eeb8 100644 --- a/Tools/WinMLRunner/src/CommandLineArgs.h +++ b/Tools/WinMLRunner/src/CommandLineArgs.h @@ -31,9 +31,8 @@ class CommandLineArgs const std::wstring& ModelPath() const { return m_modelPath; } const std::wstring& PerIterationDataPath() const { return m_perIterationDataPath; } std::vector>& GetPerformanceFileMetadata() { return m_perfFileMetadata; } -#ifdef DXCORE_SUPPORTED_BUILD const std::wstring& GetGPUAdapterName() const { return m_adapterName; } -#endif + bool UseRGB() const { @@ -159,9 +158,7 @@ class CommandLineArgs std::wstring m_inputImageFolderPath; std::wstring m_csvData; std::wstring m_inputData; -#ifdef DXCORE_SUPPORTED_BUILD std::wstring m_adapterName; -#endif std::wstring m_perfOutputPath; std::wstring m_perIterationDataPath; uint32_t m_numIterations = 1; diff --git a/Tools/WinMLRunner/src/Run.cpp b/Tools/WinMLRunner/src/Run.cpp index caa27bdd..93644f44 100644 --- a/Tools/WinMLRunner/src/Run.cpp +++ b/Tools/WinMLRunner/src/Run.cpp @@ -144,9 +144,9 @@ HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevic { return hresult_invalid_argument().code(); } -#ifdef DXCORE_SUPPORTED_BUILD + const std::wstring& adapterName = args.GetGPUAdapterName(); -#endif + try { if (deviceCreationLocation == DeviceCreationLocation::UserD3DDevice && deviceType != DeviceType::CPU) @@ -160,8 +160,29 @@ HRESULT CreateSession(LearningModelSession& session, IDirect3DDevice& winrtDevic switch (deviceType) { case DeviceType::DefaultGPU: - hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_UNSPECIFIED, __uuidof(IDXGIAdapter), - adapter.put_void()); + if (adapterName.empty()) + { + hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_UNSPECIFIED, + __uuidof(IDXGIAdapter), adapter.put_void()); + } + else + { + DXGI_ADAPTER_DESC desc; + + int adapterIndex = 0; + do + { + hr = factory->EnumAdapters(adapterIndex, adapter.put()); + + if (adapter) + { + adapter->GetDesc(&desc); + } + + adapterIndex++; + } while ((hr != DXGI_ERROR_NOT_FOUND) && + (wcsstr(desc.Description, adapterName.c_str()) == NULL)); + } break; case DeviceType::MinPowerGPU: hr = factory->EnumAdapterByGpuPreference(0, DXGI_GPU_PREFERENCE_MINIMUM_POWER,