Files
bgfx/3rdparty/spirv-tools/source/opt/trim_capabilities_pass.cpp
Бранимир Караџић efb7f66a0e Updated spirv-tools.
2023-07-30 16:29:11 -07:00

321 lines
11 KiB
C++

// Copyright (c) 2023 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/trim_capabilities_pass.h"
#include <algorithm>
#include <array>
#include <cassert>
#include <functional>
#include <optional>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "source/enum_set.h"
#include "source/enum_string_mapping.h"
#include "source/opt/ir_context.h"
#include "source/spirv_target_env.h"
#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
namespace {
constexpr uint32_t kVariableStorageClassIndex = 0;
constexpr uint32_t kTypeArrayTypeIndex = 0;
constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
constexpr uint32_t kTypePointerTypeIdInIdx = 1;
} // namespace
// ============== Begin opcode handler implementations. =======================
//
// Adding support for a new capability should only require adding a new handler,
// and updating the
// kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
//
// Handler names follow the following convention:
// Handler_<Opcode>_<Capability>()
static std::optional<spv::Capability> Handler_OpVariable_StorageInputOutput16(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpVariable &&
"This handler only support OpVariable opcodes.");
// This capability is only required if the variable as an Input/Output storage
// class.
spv::StorageClass storage_class = spv::StorageClass(
instruction->GetSingleWordInOperand(kVariableStorageClassIndex));
if (storage_class != spv::StorageClass::Input &&
storage_class != spv::StorageClass::Output) {
return std::nullopt;
}
// This capability is only required if the type involves a 16-bit component.
// Quick check: are 16-bit types allowed?
const CapabilitySet& capabilities =
instruction->context()->get_feature_mgr()->GetCapabilities();
if (!capabilities.contains(spv::Capability::Float16) &&
!capabilities.contains(spv::Capability::Int16)) {
return std::nullopt;
}
// We need to walk the type definition.
std::queue<uint32_t> instructions_to_visit;
instructions_to_visit.push(instruction->type_id());
const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
while (!instructions_to_visit.empty()) {
const Instruction* item =
def_use_mgr->GetDef(instructions_to_visit.front());
instructions_to_visit.pop();
if (item->opcode() == spv::Op::OpTypePointer) {
instructions_to_visit.push(
item->GetSingleWordInOperand(kTypePointerTypeIdInIdx));
continue;
}
if (item->opcode() == spv::Op::OpTypeMatrix ||
item->opcode() == spv::Op::OpTypeVector ||
item->opcode() == spv::Op::OpTypeArray ||
item->opcode() == spv::Op::OpTypeRuntimeArray) {
instructions_to_visit.push(
item->GetSingleWordInOperand(kTypeArrayTypeIndex));
continue;
}
if (item->opcode() == spv::Op::OpTypeStruct) {
item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
instructions_to_visit.push(*op_id);
});
continue;
}
if (item->opcode() != spv::Op::OpTypeInt &&
item->opcode() != spv::Op::OpTypeFloat) {
continue;
}
if (item->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16) {
return spv::Capability::StorageInputOutput16;
}
}
return std::nullopt;
}
// Opcode of interest to determine capabilities requirements.
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 1> kOpcodeHandlers{{
{spv::Op::OpVariable, Handler_OpVariable_StorageInputOutput16},
}};
// ============== End opcode handler implementations. =======================
namespace {
ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
const AssemblyGrammar& grammar) {
ExtensionSet output;
const spv_operand_desc_t* desc = nullptr;
for (auto capability : capabilities) {
if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
static_cast<uint32_t>(capability),
&desc)) {
continue;
}
for (uint32_t i = 0; i < desc->numExtensions; ++i) {
output.insert(desc->extensions[i]);
}
}
return output;
}
} // namespace
TrimCapabilitiesPass::TrimCapabilitiesPass()
: supportedCapabilities_(
TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
TrimCapabilitiesPass::kSupportedCapabilities.cend()),
forbiddenCapabilities_(
TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
untouchableCapabilities_(
TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
void TrimCapabilitiesPass::addInstructionRequirements(
Instruction* instruction, CapabilitySet* capabilities,
ExtensionSet* extensions) const {
// Ignoring OpCapability instructions.
if (instruction->opcode() == spv::Op::OpCapability) {
return;
}
// First case: the opcode is itself gated by a capability.
{
const spv_opcode_desc_t* desc = {};
auto result =
context()->grammar().lookupOpcode(instruction->opcode(), &desc);
if (result == SPV_SUCCESS) {
addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
capabilities);
if (desc->minVersion <=
spvVersionForTargetEnv(context()->GetTargetEnv())) {
extensions->insert(desc->extensions,
desc->extensions + desc->numExtensions);
}
}
}
// Second case: one of the opcode operand is gated by a capability.
const uint32_t operandCount = instruction->NumOperands();
for (uint32_t i = 0; i < operandCount; i++) {
const auto& operand = instruction->GetOperand(i);
// No supported capability relies on a 2+-word operand.
if (operand.words.size() != 1) {
continue;
}
// No supported capability relies on a literal string operand.
if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
continue;
}
const spv_operand_desc_t* desc = {};
auto result = context()->grammar().lookupOperand(operand.type,
operand.words[0], &desc);
if (result != SPV_SUCCESS) {
continue;
}
addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
capabilities);
if (desc->minVersion <= spvVersionForTargetEnv(context()->GetTargetEnv())) {
extensions->insert(desc->extensions,
desc->extensions + desc->numExtensions);
}
}
// Last case: some complex logic needs to be run to determine capabilities.
auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
for (auto it = begin; it != end; it++) {
const OpcodeHandler handler = it->second;
auto result = handler(instruction);
if (result.has_value()) {
capabilities->insert(*result);
}
}
}
std::pair<CapabilitySet, ExtensionSet>
TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
CapabilitySet required_capabilities;
ExtensionSet required_extensions;
get_module()->ForEachInst([&](Instruction* instruction) {
addInstructionRequirements(instruction, &required_capabilities,
&required_extensions);
});
#if !defined(NDEBUG)
// Debug only. We check the outputted required capabilities against the
// supported capabilities list. The supported capabilities list is useful for
// API users to quickly determine if they can use the pass or not. But this
// list has to remain up-to-date with the pass code. If we can detect a
// capability as required, but it's not listed, it means the list is
// out-of-sync. This method is not ideal, but should cover most cases.
{
for (auto capability : required_capabilities) {
assert(supportedCapabilities_.contains(capability) &&
"Module is using a capability that is not listed as supported.");
}
}
#endif
return std::make_pair(std::move(required_capabilities),
std::move(required_extensions));
}
Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
const CapabilitySet& required_capabilities) const {
const FeatureManager* feature_manager = context()->get_feature_mgr();
CapabilitySet capabilities_to_trim;
for (auto capability : feature_manager->GetCapabilities()) {
// Forbidden capability completely prevents trimming. Early exit.
if (forbiddenCapabilities_.contains(capability)) {
return Pass::Status::SuccessWithoutChange;
}
// Some capabilities cannot be safely removed. Leaving them untouched.
if (untouchableCapabilities_.contains(capability)) {
continue;
}
// If the capability is unsupported, don't trim it.
if (!supportedCapabilities_.contains(capability)) {
continue;
}
if (required_capabilities.contains(capability)) {
continue;
}
capabilities_to_trim.insert(capability);
}
for (auto capability : capabilities_to_trim) {
context()->RemoveCapability(capability);
}
return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
: Pass::Status::SuccessWithChange;
}
Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
const ExtensionSet& required_extensions) const {
const auto supported_extensions =
getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
bool modified_module = false;
for (auto extension : supported_extensions) {
if (!required_extensions.contains(extension)) {
modified_module = true;
context()->RemoveExtension(extension);
}
}
return modified_module ? Pass::Status::SuccessWithChange
: Pass::Status::SuccessWithoutChange;
}
Pass::Status TrimCapabilitiesPass::Process() {
auto[required_capabilities, required_extensions] =
DetermineRequiredCapabilitiesAndExtensions();
Pass::Status status = TrimUnrequiredCapabilities(required_capabilities);
// If no capabilities were removed, we have no extension to trim.
// Note: this is true because this pass only removes unused extensions caused
// by unused capabilities.
// This is not an extension trimming pass.
if (status == Pass::Status::SuccessWithoutChange) {
return status;
}
return TrimUnrequiredExtensions(required_extensions);
}
} // namespace opt
} // namespace spvtools