mirror of
https://github.com/bkaradzic/bgfx.git
synced 2026-02-17 20:52:36 +01:00
1012 lines
36 KiB
C++
1012 lines
36 KiB
C++
// Copyright 2025 The Khronos Group Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "fnvar.h"
|
|
|
|
#include <initializer_list>
|
|
#include <memory>
|
|
#include <sstream>
|
|
|
|
#include "source/opt/instruction.h"
|
|
|
|
namespace spvtools {
|
|
|
|
using opt::Function;
|
|
using opt::Instruction;
|
|
using opt::analysis::Type;
|
|
|
|
namespace {
|
|
// Helper functions
|
|
|
|
// Parses a CSV source string for the purpose of this extension.
|
|
//
|
|
// Required columns must be known in advance and supplied as the required_cols
|
|
// argument -- this is used for error checking. Values are assumed to be
|
|
// separated by CSV_SEP. The input source string is assumed to be the output of
|
|
// io::ReadTextFile and no other validation, apart from the CSV parsing, is
|
|
// performed.
|
|
//
|
|
// Returns true on success, false on error (with error message stored in
|
|
// err_msg).
|
|
bool ParseCsv(const std::string& source,
|
|
const std::vector<std::string>& required_cols,
|
|
std::stringstream& err_msg,
|
|
std::vector<std::vector<std::string>>& result) {
|
|
std::stringstream fn_variants_csv_stream(source);
|
|
std::string line;
|
|
std::vector<std::string> columns;
|
|
constexpr char CSV_SEP = ',';
|
|
bool first_line = true;
|
|
|
|
while (std::getline(fn_variants_csv_stream, line, '\n')) {
|
|
if (line.empty()) {
|
|
continue;
|
|
}
|
|
|
|
std::vector<std::string> vals;
|
|
std::string val;
|
|
std::stringstream line_stream(line);
|
|
auto* vec = first_line ? &columns : &vals;
|
|
|
|
while (std::getline(line_stream, val, CSV_SEP)) {
|
|
vec->push_back(val);
|
|
}
|
|
|
|
if (!line_stream && val.empty()) {
|
|
vec->push_back("");
|
|
}
|
|
|
|
if (!first_line) {
|
|
if (vals.size() != columns.size()) {
|
|
err_msg << "Number of values does not match the number of columns. "
|
|
"Offending line:\n"
|
|
<< line;
|
|
return false;
|
|
}
|
|
result.push_back(vals);
|
|
}
|
|
|
|
first_line = false;
|
|
}
|
|
|
|
// check if required columns match actual columns (ordering matters)
|
|
|
|
if (columns.size() != required_cols.size()) {
|
|
err_msg << "Invalid number of CSV columns: " << columns.size()
|
|
<< ", expected " << required_cols.size() << ".";
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 0; i < columns.size(); ++i) {
|
|
if (columns[i] != required_cols[i]) {
|
|
err_msg << "Invalid name of column " << i + 1 << ". Expected '"
|
|
<< required_cols[i] << "', got '" << columns[i] << "'.";
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Annotate ID with ConditionalINTEL decoration
|
|
void DecorateConditional(IRContext* context, uint32_t id_to_decorate,
|
|
uint32_t spec_const_id) {
|
|
auto decor_instr =
|
|
std::make_unique<Instruction>(context, spv::Op::OpDecorate);
|
|
decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {id_to_decorate}});
|
|
decor_instr->AddOperand({SPV_OPERAND_TYPE_DECORATION,
|
|
{uint32_t(spv::Decoration::ConditionalINTEL)}});
|
|
decor_instr->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
|
|
context->module()->AddAnnotationInst(std::move(decor_instr));
|
|
}
|
|
|
|
// Finds entry point corresponding to a function
|
|
//
|
|
// Returns null if not found, otherwise returns pointer to the EP Instruction.
|
|
Instruction* FindEntryPoint(const Instruction& fn_inst) {
|
|
auto* mod = fn_inst.context()->module();
|
|
for (auto& entry_point : mod->entry_points()) {
|
|
const int ep_i =
|
|
entry_point.opcode() == spv::Op::OpConditionalEntryPointINTEL ? 2 : 1;
|
|
if (entry_point.GetOperand(ep_i).AsId() == fn_inst.result_id()) {
|
|
return &entry_point;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
// If the function has an entry point, converts it to a conditional one
|
|
void ConvertEPToConditional(Module* module, const Function& fn,
|
|
uint32_t spec_const_id) {
|
|
for (const auto& ep_inst : module->entry_points()) {
|
|
if (ep_inst.opcode() == spv::Op::OpEntryPoint) {
|
|
auto* entry_point = FindEntryPoint(fn.DefInst());
|
|
if (entry_point != nullptr) {
|
|
std::vector<opt::Operand> old_operands;
|
|
for (auto operand : *entry_point) {
|
|
old_operands.push_back(operand);
|
|
}
|
|
entry_point->ToNop();
|
|
entry_point->SetOpcode(spv::Op::OpConditionalEntryPointINTEL);
|
|
entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {spec_const_id}});
|
|
for (auto old_operand : old_operands) {
|
|
entry_point->AddOperand(old_operand);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Finds ID of a bool type (returns 0 if not found)
|
|
uint32_t FindIdOfBoolType(const Module* const mod) {
|
|
return mod->context()->get_type_mgr()->GetBoolTypeId();
|
|
}
|
|
|
|
// Combines IDs using OpSpecConstantOp with the operation defined by cmp_op.
|
|
//
|
|
// Returns the ID of the final result. If there are no IDs, returns 0. If there
|
|
// is one ID, does not generate any instructions and returns the ID.
|
|
uint32_t CombineIds(IRContext* const context, const std::vector<uint32_t>& ids,
|
|
spv::Op cmp_op) {
|
|
if (ids.empty()) {
|
|
return 0;
|
|
} else if (ids.size() == 1) {
|
|
return ids[0];
|
|
} else {
|
|
uint32_t bool_id = FindIdOfBoolType(context->module());
|
|
assert(bool_id != 0);
|
|
|
|
uint32_t prev_spec_const_id = ids[0];
|
|
|
|
for (size_t i = 1; i < ids.size(); ++i) {
|
|
const uint32_t id = ids[i];
|
|
const uint32_t spec_const_op_id = context->TakeNextId();
|
|
|
|
auto inst = std::make_unique<Instruction>(
|
|
context, spv::Op::OpSpecConstantOp, bool_id, spec_const_op_id,
|
|
std::initializer_list<opt::Operand>{
|
|
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {(uint32_t)(cmp_op)}},
|
|
{SPV_OPERAND_TYPE_ID, {prev_spec_const_id}},
|
|
{SPV_OPERAND_TYPE_ID, {id}}});
|
|
context->module()->AddType(std::move(inst));
|
|
|
|
prev_spec_const_id = spec_const_op_id;
|
|
}
|
|
|
|
return prev_spec_const_id;
|
|
}
|
|
}
|
|
|
|
// Returns whether instruction can be shared between variant modules and
|
|
// combined using spec constants (such as conditional capabilities).
|
|
bool CanBeFnVarCombined(const Instruction* inst) {
|
|
const spv::Op opcode = inst->opcode();
|
|
|
|
if ((opcode != spv::Op::OpExtInstImport) &&
|
|
(opcode != spv::Op::OpCapability) && (opcode != spv::Op::OpExtension) &&
|
|
!spvOpcodeGeneratesType(opcode)) {
|
|
return false;
|
|
}
|
|
|
|
if ((opcode == spv::Op::OpCapability) &&
|
|
((inst->GetSingleWordOperand(0) ==
|
|
static_cast<uint32_t>(spv::Capability::FunctionVariantsINTEL)) ||
|
|
(inst->GetSingleWordOperand(0) ==
|
|
static_cast<uint32_t>(spv::Capability::SpecConditionalINTEL)))) {
|
|
// Always enabled
|
|
return false;
|
|
}
|
|
|
|
if ((opcode == spv::Op::OpExtension) &&
|
|
(inst->GetOperand(0).AsString() == FNVAR_EXT_NAME)) {
|
|
// Always enabled
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Calculates hash of an instruction.
|
|
//
|
|
// Applicable only to instructions that can be combined (ie. with
|
|
// CanBeFnVarCombined being true) and from those, hash can be only computed for
|
|
// selected instructions. Computing hash from other instruction is unsupported.
|
|
size_t HashInst(const Instruction* inst) {
|
|
if (CanBeFnVarCombined(inst)) {
|
|
if (spvOpcodeGeneratesType(inst->opcode())) {
|
|
const Type* t =
|
|
inst->context()->get_type_mgr()->GetType(inst->result_id());
|
|
assert(t != nullptr);
|
|
return t->HashValue();
|
|
}
|
|
|
|
if (inst->opcode() == spv::Op::OpExtension) {
|
|
const auto name = inst->GetOperand(0).AsString();
|
|
return std::hash<std::string>()(name);
|
|
}
|
|
|
|
if (inst->opcode() == spv::Op::OpCapability) {
|
|
const auto cap = inst->GetSingleWordOperand(0);
|
|
return std::hash<uint32_t>()(cap);
|
|
}
|
|
|
|
if (inst->opcode() == spv::Op::OpExtInstImport) {
|
|
const auto name = inst->GetOperand(1).AsString();
|
|
return std::hash<std::string>()(name);
|
|
}
|
|
}
|
|
|
|
assert(false && "Unsupported instruction hash");
|
|
return std::hash<const Instruction*>()(inst);
|
|
}
|
|
|
|
std::string GetFnName(const Instruction& fn_inst) {
|
|
// Check entry point
|
|
const auto* ep_inst = FindEntryPoint(fn_inst);
|
|
if (ep_inst != nullptr) {
|
|
const int name_i =
|
|
ep_inst->opcode() == spv::Op::OpConditionalEntryPointINTEL ? 3 : 2;
|
|
return ep_inst->GetOperand(name_i).AsString();
|
|
}
|
|
|
|
// Check name of export linkage attribute decoration
|
|
const auto* decor_mgr = fn_inst.context()->get_decoration_mgr();
|
|
for (const auto* inst :
|
|
decor_mgr->GetDecorationsFor(fn_inst.result_id(), true)) {
|
|
const auto decoration = inst->GetOperand(1);
|
|
if ((decoration.type == SPV_OPERAND_TYPE_DECORATION) &&
|
|
(decoration.words.size() == 1) &&
|
|
(decoration.words[0] ==
|
|
static_cast<uint32_t>(spv::Decoration::LinkageAttributes))) {
|
|
const auto linkage = inst->GetOperand(3);
|
|
if ((linkage.type == SPV_OPERAND_TYPE_LINKAGE_TYPE) &&
|
|
(linkage.words.size() == 1) &&
|
|
(linkage.words[0] ==
|
|
static_cast<uint32_t>(spv::LinkageType::Export))) {
|
|
// decorates fn with LinkageAttribute and Export linkage type -> get the
|
|
// name
|
|
return inst->GetOperand(2).AsString();
|
|
}
|
|
}
|
|
}
|
|
|
|
return "";
|
|
}
|
|
|
|
uint32_t FindSpecConstByName(const Module* mod, std::string name) {
|
|
for (const auto* const_inst : mod->context()->GetConstants()) {
|
|
if (opt::IsSpecConstantInst(const_inst->opcode())) {
|
|
const auto id = const_inst->result_id();
|
|
for (const auto& name_inst : mod->debugs2()) {
|
|
if ((name_inst.opcode() == spv::Op::OpName) &&
|
|
(name_inst.GetOperand(0).AsId() == id) &&
|
|
(name_inst.GetOperand(1).AsString() == name)) {
|
|
return id;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
uint32_t CombineVariantDefs(const std::vector<VariantDef>& variant_defs,
|
|
const std::vector<size_t> var_ids,
|
|
IRContext* context,
|
|
std::map<std::vector<size_t>, uint32_t>& cache) {
|
|
assert(var_ids.size() <= variant_defs.size());
|
|
uint32_t spec_const_comb_id = 0;
|
|
if (var_ids.size() != variant_defs.size()) {
|
|
// if not used by all variants
|
|
if (cache.find(var_ids) == cache.end()) {
|
|
// cache variant combinations
|
|
std::vector<uint32_t> spec_const_ids;
|
|
for (const auto& var_id : var_ids) {
|
|
const auto var_name = variant_defs[var_id].GetName();
|
|
const auto var_spec_id =
|
|
FindSpecConstByName(context->module(), var_name);
|
|
spec_const_ids.push_back(var_spec_id);
|
|
}
|
|
spec_const_comb_id =
|
|
CombineIds(context, spec_const_ids, spv::Op::OpLogicalOr);
|
|
assert(spec_const_comb_id != 0);
|
|
cache.insert({var_ids, spec_const_comb_id});
|
|
} else {
|
|
spec_const_comb_id = cache[var_ids];
|
|
}
|
|
}
|
|
return spec_const_comb_id;
|
|
}
|
|
|
|
bool strToInt(std::string s, uint32_t* x) {
|
|
for (const char& c : s) {
|
|
if (c < '0' || c > '9') {
|
|
return false;
|
|
}
|
|
}
|
|
if (!(std::stringstream(s) >> *x)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
bool VariantDefs::ProcessFnVar(const LinkerOptions& options,
|
|
const std::vector<Module*>& modules) {
|
|
assert(variant_defs_.empty());
|
|
assert(modules.size() == options.GetInFiles().size());
|
|
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
const auto* feat_mgr = modules[i]->context()->get_feature_mgr();
|
|
if ((feat_mgr->HasCapability(spv::Capability::FunctionVariantsINTEL)) ||
|
|
(feat_mgr->HasCapability(spv::Capability::SpecConditionalINTEL)) ||
|
|
(feat_mgr->HasExtension(kSPV_INTEL_function_variants))) {
|
|
// In principle, it can be done but it's complicated due to having to
|
|
// combine the existing conditionals with the new ones. For example,
|
|
// conditional capabilities would need to become "doubly-conditional".
|
|
err_ << "Creating multitarget modules from multitarget modules is not "
|
|
"supported. Offending file: "
|
|
<< options.GetInFiles()[i];
|
|
return false;
|
|
}
|
|
}
|
|
|
|
std::vector<std::vector<std::string>> target_rows;
|
|
std::vector<std::vector<std::string>> architecture_rows;
|
|
|
|
if (!options.GetFnVarTargetsCsv().empty()) {
|
|
const std::vector<std::string> tgt_cols = {"module", "target", "features"};
|
|
if (!ParseCsv(options.GetFnVarTargetsCsv(), tgt_cols, err_, target_rows)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (!options.GetFnVarArchitecturesCsv().empty()) {
|
|
const std::vector<std::string> arch_cols = {"module", "category", "family",
|
|
"op", "architecture"};
|
|
if (!ParseCsv(options.GetFnVarArchitecturesCsv(), arch_cols, err_,
|
|
architecture_rows)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// check that all modules defined in the CSV exist
|
|
|
|
for (const auto& tgt_vals : target_rows) {
|
|
bool found = false;
|
|
for (const auto& in_file : options.GetInFiles()) {
|
|
if (tgt_vals[0] == in_file) {
|
|
found = true;
|
|
}
|
|
}
|
|
if (!found) {
|
|
err_ << "Module '" << tgt_vals[0]
|
|
<< "' found in targets CSV not passed to the CLI.";
|
|
return false;
|
|
}
|
|
}
|
|
|
|
for (const auto& arch_vals : architecture_rows) {
|
|
bool found = false;
|
|
for (const auto& in_file : options.GetInFiles()) {
|
|
if (arch_vals[0] == in_file) {
|
|
found = true;
|
|
}
|
|
}
|
|
if (!found) {
|
|
err_ << "Module '" << arch_vals[0]
|
|
<< "' found in architectures CSV not passed to the CLI.";
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// create per-module variant defs
|
|
|
|
for (size_t i = 0; i < modules.size(); ++i) {
|
|
// first module passed to the CLI is considered the base module
|
|
bool is_base = i == 0;
|
|
const auto name = options.GetInFiles()[i];
|
|
auto variant_def = VariantDef(is_base, name, modules[i]);
|
|
|
|
for (const auto& arch_row : architecture_rows) {
|
|
const auto row_name = arch_row[0];
|
|
if (row_name == name) {
|
|
uint32_t category, family, op, architecture;
|
|
|
|
if (!strToInt(arch_row[1], &category)) {
|
|
err_ << "Error converting " << arch_row[1]
|
|
<< " to architecture category.";
|
|
return false;
|
|
}
|
|
if (!strToInt(arch_row[2], &family)) {
|
|
err_ << "Error converting " << arch_row[2]
|
|
<< " to architecture family.";
|
|
return false;
|
|
}
|
|
if (!strToInt(arch_row[3], &op)) {
|
|
err_ << "Error converting " << arch_row[3] << " to architecture op.";
|
|
return false;
|
|
}
|
|
if (!strToInt(arch_row[4], &architecture)) {
|
|
err_ << "Error converting " << arch_row[4] << " to architecture.";
|
|
return false;
|
|
}
|
|
|
|
variant_def.AddArchDef(category, family, op, architecture);
|
|
}
|
|
}
|
|
|
|
for (const auto& tgt_row : target_rows) {
|
|
const auto row_name = tgt_row[0];
|
|
if (row_name == name) {
|
|
uint32_t target;
|
|
std::vector<uint32_t> features;
|
|
|
|
if (!strToInt(tgt_row[1], &target)) {
|
|
err_ << "Error converting " << tgt_row[1] << " to target.";
|
|
return false;
|
|
}
|
|
|
|
// get features as FEAT_SEP-delimited integers
|
|
|
|
std::stringstream feat_stream(tgt_row[2]);
|
|
std::string feat;
|
|
while (std::getline(feat_stream, feat, FEAT_SEP)) {
|
|
uint32_t ufeat;
|
|
// if (!(std::stringstream(feat) >> ufeat)) {
|
|
if (!strToInt(feat, &ufeat)) {
|
|
err_ << "Error converting " << feat << " in " << tgt_row[2]
|
|
<< " to target feature.";
|
|
return false;
|
|
}
|
|
features.push_back(ufeat);
|
|
}
|
|
|
|
variant_def.AddTgtDef(target, features);
|
|
}
|
|
}
|
|
|
|
if (options.GetHasFnVarCapabilities()) {
|
|
variant_def.InferCapabilities();
|
|
}
|
|
|
|
variant_defs_.push_back(variant_def);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool VariantDefs::ProcessVariantDefs() {
|
|
EnsureBoolType();
|
|
CollectVarInsts();
|
|
if (!GenerateFnVarConstants()) {
|
|
return false;
|
|
}
|
|
CollectBaseFnCalls();
|
|
return true;
|
|
}
|
|
|
|
void VariantDefs::GenerateHeader(IRContext* linked_context) {
|
|
linked_context->AddCapability(spv::Capability::SpecConditionalINTEL);
|
|
linked_context->AddCapability(spv::Capability::FunctionVariantsINTEL);
|
|
linked_context->AddExtension(std::string(FNVAR_EXT_NAME));
|
|
|
|
// Specifies used registry version
|
|
auto inst =
|
|
std::make_unique<Instruction>(linked_context, spv::Op::OpModuleProcessed);
|
|
std::stringstream line;
|
|
line << "SPV_INTEL_function_variants registry version "
|
|
<< FNVAR_REGISTRY_VERSION;
|
|
inst->AddOperand(
|
|
{SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(line.str())});
|
|
linked_context->AddDebug3Inst(std::move(inst));
|
|
}
|
|
|
|
void VariantDefs::CombineVariantInstructions(IRContext* linked_context) {
|
|
CombineBaseFnCalls(linked_context);
|
|
CombineInstructions(linked_context);
|
|
}
|
|
|
|
void VariantDefs::EnsureBoolType() {
|
|
for (auto& variant_def : variant_defs_) {
|
|
Module* module = variant_def.GetModule();
|
|
IRContext* context = module->context();
|
|
|
|
uint32_t bool_id = FindIdOfBoolType(module);
|
|
if (bool_id == 0) {
|
|
bool_id = context->TakeNextId();
|
|
auto variant_bool = std::make_unique<Instruction>(
|
|
context, spv::Op::OpTypeBool, 0, bool_id,
|
|
std::initializer_list<opt::Operand>{});
|
|
module->AddType(std::move(variant_bool));
|
|
}
|
|
}
|
|
}
|
|
|
|
void VariantDefs::CollectVarInsts() {
|
|
for (size_t i = 0; i < variant_defs_.size(); ++i) {
|
|
const auto variant_def = variant_defs_[i];
|
|
const auto* var_mod = variant_def.GetModule();
|
|
|
|
var_mod->ForEachInst([this, &i](const Instruction* inst) {
|
|
if (CanBeFnVarCombined(inst)) {
|
|
const size_t inst_hash = HashInst(inst);
|
|
if (fnvar_usage_.find(inst_hash) == fnvar_usage_.end()) {
|
|
fnvar_usage_.insert({inst_hash, {i}});
|
|
} else {
|
|
assert(fnvar_usage_[inst_hash].size() < variant_defs_.size());
|
|
fnvar_usage_[inst_hash].push_back(i);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
bool VariantDefs::GenerateFnVarConstants() {
|
|
assert(variant_defs_.size() > 0);
|
|
assert(variant_defs_[0].IsBase());
|
|
|
|
if (variant_defs_.size() == 1) {
|
|
return true;
|
|
}
|
|
|
|
for (auto& variant_def : variant_defs_) {
|
|
Module* module = variant_def.GetModule();
|
|
IRContext* context = module->context();
|
|
|
|
uint32_t bool_id = FindIdOfBoolType(module);
|
|
if (bool_id == 0) {
|
|
// add a bool type if not present already
|
|
bool_id = context->TakeNextId();
|
|
auto variant_bool = std::make_unique<Instruction>(
|
|
context, spv::Op::OpTypeBool, 0, bool_id,
|
|
std::initializer_list<opt::Operand>{});
|
|
module->AddType(std::move(variant_bool));
|
|
}
|
|
|
|
// Spec constant architecture and target
|
|
|
|
std::vector<uint32_t> spec_const_arch_ids;
|
|
for (const auto& arch_def : variant_def.GetArchDefs()) {
|
|
const uint32_t spec_const_arch_id = context->TakeNextId();
|
|
spec_const_arch_ids.push_back(spec_const_arch_id);
|
|
|
|
auto inst = std::make_unique<Instruction>(
|
|
context, spv::Op::OpSpecConstantArchitectureINTEL, bool_id,
|
|
spec_const_arch_id,
|
|
std::initializer_list<opt::Operand>{
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.category}},
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.family}},
|
|
// Using spec op opcode here expects then next operand to be
|
|
// a type:
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.op}},
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {arch_def.architecture}},
|
|
});
|
|
module->AddType(std::move(inst));
|
|
}
|
|
|
|
std::vector<uint32_t> spec_const_tgt_ids;
|
|
for (const auto& tgt_def : variant_def.GetTgtDefs()) {
|
|
const uint32_t spec_const_tgt_id = context->TakeNextId();
|
|
spec_const_tgt_ids.push_back(spec_const_tgt_id);
|
|
|
|
auto inst = std::make_unique<Instruction>(
|
|
context, spv::Op::OpSpecConstantTargetINTEL, bool_id,
|
|
spec_const_tgt_id,
|
|
std::initializer_list<opt::Operand>{
|
|
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {tgt_def.target}},
|
|
});
|
|
for (const auto& feat : tgt_def.features) {
|
|
inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {feat}});
|
|
}
|
|
module->AddType(std::move(inst));
|
|
}
|
|
|
|
std::vector<uint32_t> spec_const_ids;
|
|
|
|
// Spec constant capabilities
|
|
|
|
const auto variant_capabilities = variant_def.GetCapabilities();
|
|
if (!variant_capabilities.empty()) {
|
|
const uint32_t spec_const_cap_id = context->TakeNextId();
|
|
auto inst = std::make_unique<Instruction>(
|
|
context, spv::Op::OpSpecConstantCapabilitiesINTEL, bool_id,
|
|
spec_const_cap_id, std::initializer_list<opt::Operand>{});
|
|
for (const auto& cap : variant_capabilities) {
|
|
inst->AddOperand({SPV_OPERAND_TYPE_CAPABILITY, {uint32_t(cap)}});
|
|
}
|
|
module->AddType(std::move(inst));
|
|
spec_const_ids.push_back(spec_const_cap_id);
|
|
}
|
|
|
|
// Combine architectures such that, for the same module, those with the same
|
|
// category and family are combined with AND and different cat/fam are
|
|
// combined with OR.
|
|
// This lets you create combinations like "architecture between X and Y".
|
|
|
|
// map (category, family) -> IDs
|
|
std::map<std::pair<uint32_t, uint32_t>, std::vector<uint32_t>> arch_map_and;
|
|
|
|
for (size_t i = 0; i < spec_const_arch_ids.size(); ++i) {
|
|
const auto& arch_def = variant_def.GetArchDefs()[i];
|
|
const auto id = spec_const_arch_ids[i];
|
|
const auto key = std::make_pair(arch_def.category, arch_def.family);
|
|
if (arch_map_and.find(key) == arch_map_and.end()) {
|
|
arch_map_and[key] = {id};
|
|
} else {
|
|
arch_map_and[key].push_back(id);
|
|
}
|
|
}
|
|
|
|
std::vector<uint32_t> arch_ids_or;
|
|
for (const auto& it : arch_map_and) {
|
|
const auto id = CombineIds(context, it.second, spv::Op::OpLogicalAnd);
|
|
if (id > 0) {
|
|
arch_ids_or.push_back(id);
|
|
}
|
|
}
|
|
|
|
const uint32_t spec_const_arch_id =
|
|
CombineIds(context, arch_ids_or, spv::Op::OpLogicalOr);
|
|
if (spec_const_arch_id > 0) {
|
|
spec_const_ids.push_back(spec_const_arch_id);
|
|
}
|
|
|
|
const uint32_t spec_const_tgt_id =
|
|
CombineIds(context, spec_const_tgt_ids, spv::Op::OpLogicalOr);
|
|
if (spec_const_tgt_id > 0) {
|
|
spec_const_ids.push_back(spec_const_tgt_id);
|
|
}
|
|
|
|
uint32_t combined_spec_const_id =
|
|
CombineIds(context, spec_const_ids, spv::Op::OpLogicalAnd);
|
|
if (combined_spec_const_id == 0) {
|
|
// If the variant module has no constraints, use SpecConstantTrue
|
|
combined_spec_const_id = context->TakeNextId();
|
|
auto inst = std::make_unique<Instruction>(
|
|
context, spv::Op::OpSpecConstantTrue, bool_id, combined_spec_const_id,
|
|
std::initializer_list<opt::Operand>{});
|
|
context->module()->AddType(std::move(inst));
|
|
}
|
|
assert(combined_spec_const_id != 0);
|
|
|
|
// Add a name the combined boolean ID so we can look it up after the IDs are
|
|
// shifted
|
|
auto inst = std::make_unique<Instruction>(context, spv::Op::OpName);
|
|
inst->AddOperand({SPV_OPERAND_TYPE_ID, {combined_spec_const_id}});
|
|
std::vector<uint32_t> str_words;
|
|
utils::AppendToVector(variant_def.GetName(), &str_words);
|
|
inst->AddOperand({SPV_OPERAND_TYPE_LITERAL_STRING, {str_words}});
|
|
module->AddDebug2Inst(std::move(inst));
|
|
|
|
// Annotate all instructions in the types section (eg. constants) with
|
|
// ConditionalINTEL, unless they can be shared between variant_defs_ (eg.
|
|
// types). Spec constants are excluded because they might have been
|
|
// generated by this extension.
|
|
for (const auto& type_inst : module->types_values()) {
|
|
if (!CanBeFnVarCombined(&type_inst) &&
|
|
!spvOpcodeIsSpecConstant(type_inst.opcode())) {
|
|
DecorateConditional(context, type_inst.result_id(),
|
|
combined_spec_const_id);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Annotate functions with ConditionalINTEL
|
|
|
|
for (const auto& base_fn : *variant_defs_[0].GetModule()) {
|
|
// For each function of the base module, find matching variant functions in
|
|
// other modules
|
|
|
|
auto base_fn_name = GetFnName(base_fn.DefInst());
|
|
if (base_fn_name.empty()) {
|
|
err_ << "Could not find name of a function " << base_fn.result_id()
|
|
<< " in a base module " << variant_defs_[0].GetName()
|
|
<< ". To be usable by SPV_INTEL_function_variants, a function "
|
|
"must either have an entry point or an export "
|
|
"LinkAttribute decoration.";
|
|
return false;
|
|
}
|
|
|
|
bool base_fn_needs_conditional = false;
|
|
for (size_t i = 1; i < variant_defs_.size(); ++i) {
|
|
const auto& variant_def = variant_defs_[i];
|
|
auto* variant_module = variant_def.GetModule();
|
|
auto* variant_context = variant_module->context();
|
|
|
|
for (const auto& var_fn : *variant_module) {
|
|
auto var_fn_name = GetFnName(var_fn.DefInst());
|
|
if (var_fn_name.empty()) {
|
|
err_ << "Could not find name of a function " << var_fn.result_id()
|
|
<< " in a base module " << variant_def.GetName()
|
|
<< ". To be usable by SPV_INTEL_function_variants, a function "
|
|
"must either have an entry point or an export "
|
|
"LinkAttribute decoration.";
|
|
return false;
|
|
}
|
|
|
|
if (base_fn_name == var_fn_name) {
|
|
base_fn_needs_conditional = true;
|
|
}
|
|
|
|
// each function in a variant module gets a ConditionalINTEL decoration
|
|
|
|
uint32_t spec_const_id =
|
|
FindSpecConstByName(variant_module, variant_def.GetName());
|
|
assert(spec_const_id != 0);
|
|
DecorateConditional(variant_context, var_fn.result_id(), spec_const_id);
|
|
ConvertEPToConditional(variant_module, var_fn, spec_const_id);
|
|
}
|
|
}
|
|
|
|
if (base_fn_needs_conditional) {
|
|
// only a base function that has a variant in another module gets a
|
|
// ConditionalINTEL decoration, the others are common for all
|
|
// variant_defs_
|
|
auto* base_module = variant_defs_[0].GetModule();
|
|
auto* base_context = base_module->context();
|
|
uint32_t spec_const_id =
|
|
FindSpecConstByName(base_module, variant_defs_[0].GetName());
|
|
assert(spec_const_id != 0);
|
|
DecorateConditional(base_context, base_fn.result_id(), spec_const_id);
|
|
ConvertEPToConditional(base_module, base_fn, spec_const_id);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void VariantDefs::CollectBaseFnCalls() {
|
|
auto* base_mod = variant_defs_[0].GetModule();
|
|
assert(variant_defs_[0].IsBase());
|
|
const auto* base_def_use_mgr = base_mod->context()->get_def_use_mgr();
|
|
|
|
base_mod->ForEachInst([this, &base_def_use_mgr](const Instruction* inst) {
|
|
if (inst->opcode() == spv::Op::OpFunctionCall) {
|
|
// For each function call in base module, get the function name
|
|
const auto fn_id = inst->GetOperand(2).AsId();
|
|
const auto* called_fn_inst = base_def_use_mgr->GetDef(fn_id);
|
|
assert(called_fn_inst != nullptr);
|
|
const auto called_fn_name = GetFnName(*called_fn_inst);
|
|
assert(!called_fn_name.empty());
|
|
|
|
std::vector<std::pair<std::string, const opt::Function*>> called_fns;
|
|
for (size_t i = 1; i < variant_defs_.size(); ++i) {
|
|
// ... then see in which variant the called function was defined
|
|
const auto& variant_def = variant_defs_[i];
|
|
assert(!variant_def.IsBase());
|
|
|
|
for (const auto& fn : *variant_def.GetModule()) {
|
|
const auto fn_name = GetFnName(fn.DefInst());
|
|
if (fn_name == called_fn_name) {
|
|
called_fns.push_back(std::make_pair(variant_def.GetName(), &fn));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!called_fns.empty()) {
|
|
base_fn_calls_[inst->result_id()] = called_fns;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
void VariantDefs::CombineBaseFnCalls(IRContext* linked_context) {
|
|
for (auto kv : base_fn_calls_) {
|
|
const uint32_t call_id = kv.first;
|
|
const auto called_fns = kv.second;
|
|
|
|
if (called_fns.empty()) {
|
|
return;
|
|
}
|
|
|
|
opt::BasicBlock* fn_call_bb = linked_context->get_instr_block(call_id);
|
|
|
|
Instruction* found_call_inst = nullptr;
|
|
auto bb_iter = fn_call_bb->begin();
|
|
while (bb_iter != fn_call_bb->end() && found_call_inst == nullptr) {
|
|
if (bb_iter->HasResultId() && bb_iter->result_id() == call_id) {
|
|
found_call_inst = &*bb_iter;
|
|
}
|
|
++bb_iter;
|
|
}
|
|
|
|
if (found_call_inst == nullptr) {
|
|
return;
|
|
}
|
|
|
|
const auto base_spec_const_id = FindSpecConstByName(
|
|
variant_defs_[0].GetModule(), variant_defs_[0].GetName());
|
|
const auto base_type_op = found_call_inst->context()
|
|
->get_def_use_mgr()
|
|
->GetDef(found_call_inst->type_id())
|
|
->opcode();
|
|
const auto base_call_id = found_call_inst->result_id();
|
|
|
|
// decorate the base call with ConditionalINTEL
|
|
DecorateConditional(linked_context, base_call_id, base_spec_const_id);
|
|
|
|
// Add OpFunctionCall for each variant
|
|
Instruction* last_inst = found_call_inst;
|
|
std::vector<std::pair<uint32_t, uint32_t>> var_call_ids;
|
|
for (const auto& kv2 : called_fns) {
|
|
const std::string var_name = kv2.first;
|
|
const opt::Function* fn = kv2.second;
|
|
const uint32_t spec_const_id =
|
|
FindSpecConstByName(linked_context->module(), var_name);
|
|
assert(spec_const_id != 0);
|
|
const uint32_t var_call_id = linked_context->TakeNextId();
|
|
var_call_ids.push_back(std::make_pair(spec_const_id, var_call_id));
|
|
|
|
auto* var_call_inst = found_call_inst->Clone(linked_context);
|
|
var_call_inst->SetResultId(var_call_id);
|
|
var_call_inst->SetOperand(2, {fn->result_id()});
|
|
var_call_inst->InsertAfter(last_inst);
|
|
linked_context->set_instr_block(var_call_inst, fn_call_bb);
|
|
last_inst = var_call_inst;
|
|
|
|
// decorate the variant call with ConditionalINTEL
|
|
DecorateConditional(linked_context, var_call_id, spec_const_id);
|
|
}
|
|
|
|
if (base_type_op != spv::Op::OpTypeVoid) {
|
|
// Add OpConditionalCopyObjectINTEL combining the function calls
|
|
const uint32_t result_id = linked_context->TakeNextId();
|
|
auto conditional_copy_inst = new Instruction(
|
|
linked_context, spv::Op::OpConditionalCopyObjectINTEL,
|
|
found_call_inst->type_id(), result_id,
|
|
{{SPV_OPERAND_TYPE_ID, {base_spec_const_id}},
|
|
{SPV_OPERAND_TYPE_ID, {found_call_inst->result_id()}}});
|
|
|
|
for (const auto& kv3 : var_call_ids) {
|
|
const auto spec_const_id = kv3.first;
|
|
const auto var_call_id = kv3.second;
|
|
conditional_copy_inst->AddOperand(
|
|
{SPV_OPERAND_TYPE_ID, {spec_const_id}});
|
|
conditional_copy_inst->AddOperand({SPV_OPERAND_TYPE_ID, {var_call_id}});
|
|
}
|
|
conditional_copy_inst->InsertAfter(last_inst);
|
|
linked_context->set_instr_block(conditional_copy_inst, fn_call_bb);
|
|
last_inst = conditional_copy_inst;
|
|
|
|
// In all remaining instructions within the basic block, replace all
|
|
// usages of the base call ID with the result of
|
|
// OpConditionalCopyObjectINTEL
|
|
do {
|
|
last_inst = last_inst->NextNode();
|
|
last_inst->ForEachInId([base_call_id, result_id](uint32_t* id) {
|
|
if (*id == base_call_id) {
|
|
*id = result_id;
|
|
}
|
|
});
|
|
} while (last_inst != nullptr && *last_inst != *fn_call_bb->tail());
|
|
}
|
|
}
|
|
|
|
// Combine spec consts for the base module (base module is activated if all
|
|
// variant defs are inactive AND the base module constraints are satisfied)
|
|
|
|
std::vector<uint32_t> var_spec_const_ids;
|
|
for (const auto& variant_def : variant_defs_) {
|
|
if (variant_def.IsBase()) {
|
|
continue;
|
|
}
|
|
|
|
const auto id =
|
|
FindSpecConstByName(linked_context->module(), variant_def.GetName());
|
|
assert(id != 0);
|
|
var_spec_const_ids.push_back(id);
|
|
}
|
|
const uint32_t base_or_id =
|
|
CombineIds(linked_context, var_spec_const_ids, spv::Op::OpLogicalOr);
|
|
|
|
if (base_or_id != 0) {
|
|
const uint32_t bool_id = FindIdOfBoolType(linked_context->module());
|
|
assert(bool_id != 0);
|
|
|
|
const uint32_t base_not_id = linked_context->TakeNextId();
|
|
auto spec_const_op_inst = std::make_unique<Instruction>(
|
|
linked_context, spv::Op::OpSpecConstantOp, bool_id, base_not_id,
|
|
std::initializer_list<opt::Operand>{
|
|
{SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
|
|
{(uint32_t)(spv::Op::OpLogicalNot)}},
|
|
{SPV_OPERAND_TYPE_ID, {base_or_id}}});
|
|
linked_context->module()->AddType(std::move(spec_const_op_inst));
|
|
|
|
// Update any ConditionalINTEL annotations, names and entry points
|
|
// referencing the old spec const ID to use the new one
|
|
|
|
const uint32_t old_base_spec_const_id = FindSpecConstByName(
|
|
linked_context->module(), variant_defs_[0].GetName());
|
|
assert(old_base_spec_const_id != 0);
|
|
const uint32_t base_spec_const_id =
|
|
CombineIds(linked_context, {old_base_spec_const_id, base_not_id},
|
|
spv::Op::OpLogicalAnd);
|
|
|
|
for (auto& annot_inst : linked_context->module()->annotations()) {
|
|
if ((annot_inst.GetSingleWordOperand(1) ==
|
|
uint32_t(spv::Decoration::ConditionalINTEL)) &&
|
|
(annot_inst.GetOperand(2).AsId() == old_base_spec_const_id)) {
|
|
annot_inst.SetOperand(2, {base_spec_const_id});
|
|
}
|
|
}
|
|
|
|
for (auto& name_inst : linked_context->module()->debugs2()) {
|
|
if ((name_inst.opcode() == spv::Op::OpName) &&
|
|
(name_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
|
|
name_inst.SetOperand(0, {base_spec_const_id});
|
|
}
|
|
}
|
|
|
|
for (auto& ep_inst : linked_context->module()->entry_points()) {
|
|
if ((ep_inst.opcode() == spv::Op::OpConditionalEntryPointINTEL) &&
|
|
(ep_inst.GetOperand(0).AsId() == old_base_spec_const_id)) {
|
|
ep_inst.SetOperand(0, {base_spec_const_id});
|
|
}
|
|
}
|
|
|
|
linked_context->module()->ForEachInst(
|
|
[old_base_spec_const_id, base_spec_const_id](Instruction* inst) {
|
|
if (inst->opcode() == spv::Op::OpConditionalCopyObjectINTEL) {
|
|
inst->ForEachInId(
|
|
[old_base_spec_const_id, base_spec_const_id](uint32_t* id) {
|
|
if (*id == old_base_spec_const_id) {
|
|
*id = base_spec_const_id;
|
|
}
|
|
});
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
void VariantDefs::CombineInstructions(IRContext* linked_context) {
|
|
// cache for existing variant ID combinations
|
|
std::map<std::vector<size_t>, uint32_t> spec_const_comb_ids;
|
|
|
|
linked_context->module()->ForEachInst(
|
|
[this, &linked_context, &spec_const_comb_ids](Instruction* inst) {
|
|
if (!CanBeFnVarCombined(inst)) {
|
|
return;
|
|
}
|
|
|
|
const size_t inst_hash = HashInst(inst);
|
|
if (fnvar_usage_.find(inst_hash) != fnvar_usage_.end()) {
|
|
const std::vector<size_t> var_ids = fnvar_usage_[inst_hash];
|
|
const uint32_t spec_const_comb_id = CombineVariantDefs(
|
|
variant_defs_, var_ids, linked_context, spec_const_comb_ids);
|
|
if (spec_const_comb_id != 0) {
|
|
if (inst->HasResultId()) {
|
|
DecorateConditional(linked_context, inst->result_id(),
|
|
spec_const_comb_id);
|
|
} else if (inst->opcode() == spv::Op::OpCapability) {
|
|
const uint32_t cap = inst->GetSingleWordOperand(0);
|
|
inst->SetOpcode(spv::Op::OpConditionalCapabilityINTEL);
|
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
|
|
{SPV_OPERAND_TYPE_CAPABILITY, {cap}}});
|
|
} else if (inst->opcode() == spv::Op::OpExtension) {
|
|
const std::string ext_name = inst->GetOperand(0).AsString();
|
|
inst->SetOpcode(spv::Op::OpConditionalExtensionINTEL);
|
|
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {spec_const_comb_id}},
|
|
{SPV_OPERAND_TYPE_LITERAL_STRING,
|
|
{utils::MakeVector(ext_name)}}});
|
|
} else {
|
|
assert(false && "Unsupported");
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
} // namespace spvtools
|