Updated spirv-cross.

This commit is contained in:
Бранимир Караџић
2025-06-07 10:26:12 -07:00
parent 88e74f02bf
commit 7c79acf98e
10 changed files with 707 additions and 86 deletions

View File

@@ -1035,6 +1035,9 @@ struct SPIRFunction : IVariant
// consider arrays value types.
SmallVector<ID> constant_arrays_needed_on_stack;
// Does this function (or any function called by it), emit geometry?
bool emits_geometry = false;
bool active = false;
bool flush_undeclared = true;
bool do_combined_parameters = true;

View File

@@ -82,7 +82,7 @@ bool Compiler::variable_storage_is_aliased(const SPIRVariable &v)
ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
bool image = type.basetype == SPIRType::Image;
bool counter = type.basetype == SPIRType::AtomicCounter;
bool buffer_reference = type.storage == StorageClassPhysicalStorageBufferEXT;
bool buffer_reference = type.storage == StorageClassPhysicalStorageBuffer;
bool is_restrict;
if (ssbo)
@@ -484,7 +484,7 @@ void Compiler::register_write(uint32_t chain)
}
}
if (type.storage == StorageClassPhysicalStorageBufferEXT || variable_storage_is_aliased(*var))
if (type.storage == StorageClassPhysicalStorageBuffer || variable_storage_is_aliased(*var))
flush_all_aliased_variables();
else if (var)
flush_dependees(*var);
@@ -4362,6 +4362,39 @@ bool Compiler::may_read_undefined_variable_in_block(const SPIRBlock &block, uint
return true;
}
bool Compiler::GeometryEmitDisocveryHandler::handle(spv::Op opcode, const uint32_t *, uint32_t)
{
if (opcode == OpEmitVertex || opcode == OpEndPrimitive)
{
for (auto *func : function_stack)
func->emits_geometry = true;
}
return true;
}
bool Compiler::GeometryEmitDisocveryHandler::begin_function_scope(const uint32_t *stream, uint32_t)
{
auto &callee = compiler.get<SPIRFunction>(stream[2]);
function_stack.push_back(&callee);
return true;
}
bool Compiler::GeometryEmitDisocveryHandler::end_function_scope([[maybe_unused]] const uint32_t *stream, uint32_t)
{
assert(function_stack.back() == &compiler.get<SPIRFunction>(stream[2]));
function_stack.pop_back();
return true;
}
void Compiler::discover_geometry_emitters()
{
GeometryEmitDisocveryHandler handler(*this);
traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), handler);
}
Bitset Compiler::get_buffer_block_flags(VariableID id) const
{
return ir.get_buffer_block_flags(get<SPIRVariable>(id));
@@ -5194,7 +5227,7 @@ bool Compiler::PhysicalStorageBufferPointerHandler::type_is_bda_block_entry(uint
uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_minimum_scalar_alignment(const SPIRType &type) const
{
if (type.storage == spv::StorageClassPhysicalStorageBufferEXT)
if (type.storage == spv::StorageClassPhysicalStorageBuffer)
return 8;
else if (type.basetype == SPIRType::Struct)
{
@@ -5298,6 +5331,10 @@ uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_base_non_block_type_
void Compiler::PhysicalStorageBufferPointerHandler::analyze_non_block_types_from_block(const SPIRType &type)
{
if (analyzed_type_ids.count(type.self))
return;
analyzed_type_ids.insert(type.self);
for (auto &member : type.member_types)
{
auto &subtype = compiler.get<SPIRType>(member);

View File

@@ -1054,6 +1054,7 @@ protected:
std::unordered_set<uint32_t> non_block_types;
std::unordered_map<uint32_t, PhysicalBlockMeta> physical_block_type_meta;
std::unordered_map<uint32_t, PhysicalBlockMeta *> access_chain_to_physical_block;
std::unordered_set<uint32_t> analyzed_type_ids;
void mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length);
PhysicalBlockMeta *find_block_meta(uint32_t id) const;
@@ -1072,6 +1073,22 @@ protected:
bool single_function);
bool may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var);
struct GeometryEmitDisocveryHandler : OpcodeHandler
{
explicit GeometryEmitDisocveryHandler(Compiler &compiler_)
: compiler(compiler_)
{
}
Compiler &compiler;
bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override;
bool begin_function_scope(const uint32_t *, uint32_t) override;
bool end_function_scope(const uint32_t *, uint32_t) override;
SmallVector<SPIRFunction *> function_stack;
};
void discover_geometry_emitters();
// Finds all resources that are written to from inside the critical section, if present.
// The critical section is delimited by OpBeginInvocationInterlockEXT and
// OpEndInvocationInterlockEXT instructions. In MSL and HLSL, any resources written

View File

@@ -55,6 +55,7 @@
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#pragma warning(disable : 4065) // switch with 'default' but not 'case'.
#endif
#ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS

View File

@@ -545,7 +545,7 @@ void CompilerGLSL::find_static_extensions()
if (options.separate_shader_objects && !options.es && options.version < 410)
require_extension_internal("GL_ARB_separate_shader_objects");
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
{
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("GL_EXT_buffer_reference is only supported in Vulkan GLSL.");
@@ -557,7 +557,7 @@ void CompilerGLSL::find_static_extensions()
}
else if (ir.addressing_model != AddressingModelLogical)
{
SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64EXT addressing models are supported.");
SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64 addressing models are supported.");
}
// Check for nonuniform qualifier and passthrough.
@@ -708,7 +708,7 @@ string CompilerGLSL::compile()
// Shaders might cast unrelated data to pointers of non-block types.
// Find all such instances and make sure we can cast the pointers to a synthesized block type.
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
analyze_non_block_pointer_types();
uint32_t pass_count = 0;
@@ -1542,14 +1542,14 @@ uint32_t CompilerGLSL::type_to_packed_base_size(const SPIRType &type, BufferPack
uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bitset &flags,
BufferPackingStandard packing)
{
// If using PhysicalStorageBufferEXT storage class, this is a pointer,
// If using PhysicalStorageBuffer storage class, this is a pointer,
// and is 64-bit.
if (is_physical_pointer(type))
{
if (!type.pointer)
SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers.");
SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers.");
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
{
if (packing_is_vec4_padded(packing) && type_is_array_of_pointers(type))
return 16;
@@ -1557,7 +1557,7 @@ uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bits
return 8;
}
else
SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT.");
SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer.");
}
else if (is_array(type))
{
@@ -1665,17 +1665,17 @@ uint32_t CompilerGLSL::type_to_packed_array_stride(const SPIRType &type, const B
uint32_t CompilerGLSL::type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing)
{
// If using PhysicalStorageBufferEXT storage class, this is a pointer,
// If using PhysicalStorageBuffer storage class, this is a pointer,
// and is 64-bit.
if (is_physical_pointer(type))
{
if (!type.pointer)
SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers.");
SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers.");
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
return 8;
else
SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT.");
SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer.");
}
else if (is_array(type))
{
@@ -3638,6 +3638,36 @@ void CompilerGLSL::emit_resources()
bool emitted = false;
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
{
// Output buffer reference block forward declarations.
ir.for_each_typed_id<SPIRType>([&](uint32_t id, SPIRType &type)
{
if (is_physical_pointer(type))
{
bool emit_type = true;
if (!is_physical_pointer_to_buffer_block(type))
{
// Only forward-declare if we intend to emit it in the non_block_pointer types.
// Otherwise, these are just "benign" pointer types that exist as a result of access chains.
emit_type = std::find(physical_storage_non_block_pointer_types.begin(),
physical_storage_non_block_pointer_types.end(),
id) != physical_storage_non_block_pointer_types.end();
}
if (emit_type)
{
emit_buffer_reference_block(id, true);
emitted = true;
}
}
});
}
if (emitted)
statement("");
emitted = false;
// If emitted Vulkan GLSL,
// emit specialization constants as actual floats,
// spec op expressions will redirect to the constant name.
@@ -3747,30 +3777,10 @@ void CompilerGLSL::emit_resources()
emitted = false;
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT)
if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64)
{
// Output buffer reference blocks.
// Do this in two stages, one with forward declaration,
// and one without. Buffer reference blocks can reference themselves
// to support things like linked lists.
ir.for_each_typed_id<SPIRType>([&](uint32_t id, SPIRType &type) {
if (is_physical_pointer(type))
{
bool emit_type = true;
if (!is_physical_pointer_to_buffer_block(type))
{
// Only forward-declare if we intend to emit it in the non_block_pointer types.
// Otherwise, these are just "benign" pointer types that exist as a result of access chains.
emit_type = std::find(physical_storage_non_block_pointer_types.begin(),
physical_storage_non_block_pointer_types.end(),
id) != physical_storage_non_block_pointer_types.end();
}
if (emit_type)
emit_buffer_reference_block(id, true);
}
});
// Buffer reference blocks can reference themselves to support things like linked lists.
for (auto type : physical_storage_non_block_pointer_types)
emit_buffer_reference_block(type, false);
@@ -10317,7 +10327,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
if (!is_ptr_chain)
mod_flags &= ~ACCESS_CHAIN_PTR_CHAIN_BIT;
access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index);
check_physical_type_cast(expr, type, physical_type);
if (check_physical_type_cast(expr, type, physical_type))
physical_type = 0;
};
for (uint32_t i = 0; i < count; i++)
@@ -10825,8 +10836,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
return expr;
}
void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t)
bool CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t)
{
return false;
}
bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &)
@@ -15337,8 +15349,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
case OpConvertUToPtr:
{
auto &type = get<SPIRType>(ops[0]);
if (type.storage != StorageClassPhysicalStorageBufferEXT)
SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertUToPtr.");
if (type.storage != StorageClassPhysicalStorageBuffer)
SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertUToPtr.");
auto &in_type = expression_type(ops[2]);
if (in_type.vecsize == 2)
@@ -15353,8 +15365,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
{
auto &type = get<SPIRType>(ops[0]);
auto &ptr_type = expression_type(ops[2]);
if (ptr_type.storage != StorageClassPhysicalStorageBufferEXT)
SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertPtrToU.");
if (ptr_type.storage != StorageClassPhysicalStorageBuffer)
SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertPtrToU.");
if (type.vecsize == 2)
require_extension_internal("GL_EXT_buffer_reference_uvec2");
@@ -16143,7 +16155,7 @@ string CompilerGLSL::to_array_size(const SPIRType &type, uint32_t index)
string CompilerGLSL::type_to_array_glsl(const SPIRType &type, uint32_t)
{
if (type.pointer && type.storage == StorageClassPhysicalStorageBufferEXT && type.basetype != SPIRType::Struct)
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer && type.basetype != SPIRType::Struct)
{
// We are using a wrapped pointer type, and we should not emit any array declarations here.
return "";
@@ -16856,6 +16868,7 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags)
{
// Recursively emit functions which are called.
uint32_t id = ops[2];
emit_function(get<SPIRFunction>(id), ir.meta[ops[1]].decoration.decoration_flags);
}
}

View File

@@ -769,7 +769,7 @@ protected:
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);
virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
virtual bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type);
virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
spv::StorageClass storage, bool &is_packed);
@@ -799,7 +799,7 @@ protected:
std::string declare_temporary(uint32_t type, uint32_t id);
void emit_uninitialized_temporary(uint32_t type, uint32_t id);
SPIRExpression &emit_uninitialized_temporary_expression(uint32_t type, uint32_t id);
void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist);
virtual void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist);
std::string to_non_uniform_aware_expression(uint32_t id);
std::string to_atomic_ptr_expression(uint32_t id);
std::string to_expression(uint32_t id, bool register_expression_read = true);

View File

@@ -1117,7 +1117,9 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord
else
{
auto decl_type = type;
if (execution.model == ExecutionModelMeshEXT || has_decoration(var.self, DecorationPerVertexKHR))
if (execution.model == ExecutionModelMeshEXT ||
(execution.model == ExecutionModelGeometry && var.storage == StorageClassInput) ||
has_decoration(var.self, DecorationPerVertexKHR))
{
decl_type.array.erase(decl_type.array.begin());
decl_type.array_size_literal.erase(decl_type.array_size_literal.begin());
@@ -1834,7 +1836,7 @@ void CompilerHLSL::emit_resources()
if (!output_variables.empty() || !active_output_builtins.empty())
{
sort(output_variables.begin(), output_variables.end(), variable_compare);
require_output = !is_mesh_shader;
require_output = !(is_mesh_shader || execution.model == ExecutionModelGeometry);
statement(is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
begin_scope();
@@ -2678,6 +2680,83 @@ void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block)
}
}
void CompilerHLSL::emit_geometry_stream_append()
{
begin_scope();
statement("SPIRV_Cross_Output stage_output;");
active_output_builtins.for_each_bit(
[&](uint32_t i)
{
if (i == BuiltInPointSize && hlsl_options.shader_model > 30)
return;
switch (static_cast<BuiltIn>(i))
{
case BuiltInClipDistance:
for (uint32_t clip = 0; clip < clip_distance_count; clip++)
statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
clip, "];");
break;
case BuiltInCullDistance:
for (uint32_t cull = 0; cull < cull_distance_count; cull++)
statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
cull, "];");
break;
case BuiltInSampleMask:
statement("stage_output.gl_SampleMask = gl_SampleMask[0];");
break;
default:
{
auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
}
break;
}
});
ir.for_each_typed_id<SPIRVariable>(
[&](uint32_t, SPIRVariable &var)
{
auto &type = this->get<SPIRType>(var.basetype);
bool block = has_decoration(type.self, DecorationBlock);
if (var.storage != StorageClassOutput)
return;
if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
interface_variable_exists_in_entry_point(var.self))
{
if (block)
{
auto type_name = to_name(type.self);
auto var_name = to_name(var.self);
for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
{
auto mbr_name = to_member_name(type, mbr_idx);
auto flat_name = join(type_name, "_", mbr_name);
statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";");
}
}
else
{
auto name = to_name(var.self);
if (hlsl_options.shader_model <= 30 && get_entry_point().model == ExecutionModelFragment)
{
string output_filler;
for (uint32_t size = type.vecsize; size < 4; ++size)
output_filler += ", 0.0";
statement("stage_output.", name, " = float4(", name, output_filler, ");");
}
else
statement("stage_output.", name, " = ", name, ";");
}
}
});
statement("geometry_stream.Append(stage_output);");
end_scope();
}
void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
{
auto &type = get<SPIRType>(var.basetype);
@@ -2940,6 +3019,8 @@ string CompilerHLSL::get_inner_entry_point_name() const
return "frag_main";
else if (execution.model == ExecutionModelGLCompute)
return "comp_main";
else if (execution.model == ExecutionModelGeometry)
return "geom_main";
else if (execution.model == ExecutionModelMeshEXT)
return "mesh_main";
else if (execution.model == ExecutionModelTaskEXT)
@@ -2948,6 +3029,25 @@ string CompilerHLSL::get_inner_entry_point_name() const
SPIRV_CROSS_THROW("Unsupported execution model.");
}
uint32_t CompilerHLSL::input_vertices_from_execution_mode(spirv_cross::SPIREntryPoint &execution) const
{
uint32_t input_vertices = 1;
if (execution.flags.get(ExecutionModeInputLines))
input_vertices = 2;
else if (execution.flags.get(ExecutionModeInputLinesAdjacency))
input_vertices = 4;
else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
input_vertices = 6;
else if (execution.flags.get(ExecutionModeTriangles))
input_vertices = 3;
else if (execution.flags.get(ExecutionModeInputPoints))
input_vertices = 1;
else
SPIRV_CROSS_THROW("Unsupported execution model.");
return input_vertices;
}
void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
{
if (func.self != ir.default_entry_point)
@@ -3041,6 +3141,38 @@ void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &ret
var->parameter = &arg;
}
if ((func.self == ir.default_entry_point || func.emits_geometry) &&
get_entry_point().model == ExecutionModelGeometry)
{
auto &execution = get_entry_point();
uint32_t input_vertices = input_vertices_from_execution_mode(execution);
const char *prim;
if (execution.flags.get(ExecutionModeInputLinesAdjacency))
prim = "lineadj";
else if (execution.flags.get(ExecutionModeInputLines))
prim = "line";
else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
prim = "triangleadj";
else if (execution.flags.get(ExecutionModeTriangles))
prim = "triangle";
else
prim = "point";
const char *stream_type;
if (execution.flags.get(ExecutionModeOutputPoints))
stream_type = "PointStream";
else if (execution.flags.get(ExecutionModeOutputLineStrip))
stream_type = "LineStream";
else
stream_type = "TriangleStream";
if (func.self == ir.default_entry_point)
arglist.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]"));
arglist.push_back(join("inout ", stream_type, "<SPIRV_Cross_Output> ", "geometry_stream"));
}
decl += merge(arglist);
decl += ")";
statement(decl);
@@ -3050,13 +3182,50 @@ void CompilerHLSL::emit_hlsl_entry_point()
{
SmallVector<string> arguments;
if (require_input)
if (require_input && get_entry_point().model != ExecutionModelGeometry)
arguments.push_back("SPIRV_Cross_Input stage_input");
auto &execution = get_entry_point();
uint32_t input_vertices = 1;
switch (execution.model)
{
case ExecutionModelGeometry:
{
input_vertices = input_vertices_from_execution_mode(execution);
string prim;
if (execution.flags.get(ExecutionModeInputLinesAdjacency))
prim = "lineadj";
else if (execution.flags.get(ExecutionModeInputLines))
prim = "line";
else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency))
prim = "triangleadj";
else if (execution.flags.get(ExecutionModeTriangles))
prim = "triangle";
else
prim = "point";
string stream_type;
if (execution.flags.get(ExecutionModeOutputPoints))
{
stream_type = "PointStream";
}
else if (execution.flags.get(ExecutionModeOutputLineStrip))
{
stream_type = "LineStream";
}
else
{
stream_type = "TriangleStream";
}
statement("[maxvertexcount(", execution.output_vertices, ")]");
arguments.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]"));
arguments.push_back(join("inout ", stream_type, "<SPIRV_Cross_Output> ", "geometry_stream"));
break;
}
case ExecutionModelTaskEXT:
case ExecutionModelMeshEXT:
case ExecutionModelGLCompute:
@@ -3359,18 +3528,24 @@ void CompilerHLSL::emit_hlsl_entry_point()
}
else
{
statement(name, " = stage_input.", name, ";");
if (execution.model == ExecutionModelGeometry)
{
statement("for (int i = 0; i < ", input_vertices, "; i++)");
begin_scope();
statement(name, "[i] = stage_input[i].", name, ";");
end_scope();
}
else
statement(name, " = stage_input.", name, ";");
}
}
}
});
// Run the shader.
if (execution.model == ExecutionModelVertex ||
execution.model == ExecutionModelFragment ||
execution.model == ExecutionModelGLCompute ||
execution.model == ExecutionModelMeshEXT ||
execution.model == ExecutionModelTaskEXT)
if (execution.model == ExecutionModelVertex || execution.model == ExecutionModelFragment ||
execution.model == ExecutionModelGLCompute || execution.model == ExecutionModelMeshEXT ||
execution.model == ExecutionModelGeometry || execution.model == ExecutionModelTaskEXT)
{
// For mesh shaders, we receive special arguments that we must pass down as function arguments.
// HLSL does not support proper reference types for passing these IO blocks,
@@ -3378,8 +3553,16 @@ void CompilerHLSL::emit_hlsl_entry_point()
SmallVector<string> arglist;
auto &func = get<SPIRFunction>(ir.default_entry_point);
// The arguments are marked out, avoid detecting reads and emitting inout.
for (auto &arg : func.arguments)
arglist.push_back(to_expression(arg.id, false));
if (execution.model == ExecutionModelGeometry)
{
arglist.push_back("stage_input");
arglist.push_back("geometry_stream");
}
statement(get_inner_entry_point_name(), "(", merge(arglist), ");");
}
else
@@ -4206,6 +4389,14 @@ bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
return false;
}
void CompilerHLSL::append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist)
{
CompilerGLSL::append_global_func_args(func, index, arglist);
if (func.emits_geometry)
arglist.push_back("geometry_stream");
}
string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
{
if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
@@ -6594,6 +6785,16 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");");
break;
}
case OpEmitVertex:
{
emit_geometry_stream_append();
break;
}
case OpEndPrimitive:
{
statement("geometry_stream.RestartStrip();");
break;
}
default:
CompilerGLSL::emit_instruction(instruction);
break;
@@ -6812,6 +7013,9 @@ string CompilerHLSL::compile()
if (get_execution_model() == ExecutionModelMeshEXT)
analyze_meshlet_writes();
if (get_execution_model() == ExecutionModelGeometry)
discover_geometry_emitters();
// Subpass input needs SV_Position.
if (need_subpass_input)
active_input_builtins.set(BuiltInFragCoord);

View File

@@ -231,6 +231,7 @@ private:
std::string image_type_hlsl(const SPIRType &type, uint32_t id);
std::string image_type_hlsl_modern(const SPIRType &type, uint32_t id);
std::string image_type_hlsl_legacy(const SPIRType &type, uint32_t id);
uint32_t input_vertices_from_execution_mode(SPIREntryPoint &execution) const;
void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
void emit_hlsl_entry_point();
void emit_header() override;
@@ -259,6 +260,8 @@ private:
std::string to_interpolation_qualifiers(const Bitset &flags) override;
std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override;
bool emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0) override;
void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector<std::string> &arglist) override;
std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) override;
std::string to_sampler_expression(uint32_t id);
std::string to_resource_binding(const SPIRVariable &var);
@@ -286,6 +289,7 @@ private:
uint32_t base_offset = 0) override;
void emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops);
void emit_mesh_tasks(SPIRBlock &block) override;
void emit_geometry_stream_append();
const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override;
void replace_illegal_names() override;

View File

@@ -2222,6 +2222,27 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
break;
}
case OpGroupNonUniformFAdd:
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformIAdd:
case OpGroupNonUniformIMul:
case OpGroupNonUniformSMin:
case OpGroupNonUniformSMax:
case OpGroupNonUniformUMin:
case OpGroupNonUniformUMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
case OpGroupNonUniformLogicalAnd:
case OpGroupNonUniformLogicalOr:
case OpGroupNonUniformLogicalXor:
if ((get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) &&
ops[3] == GroupOperationClusteredReduce)
added_arg_ids.insert(builtin_subgroup_invocation_id_id);
break;
case OpDemoteToHelperInvocation:
if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
added_arg_ids.insert(builtin_helper_invocation_id);
@@ -7026,6 +7047,105 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;
// C++ disallows partial specializations of function templates,
// hence the use of a struct.
// clang-format off
#define FUNC_SUBGROUP_CLUSTERED(spv, msl, combine, op, ident) \
case SPVFuncImplSubgroupClustered##spv: \
statement("template<uint N, uint offset>"); \
statement("struct spvClustered" #spv "Detail;"); \
statement(""); \
statement("// Base cases"); \
statement("template<>"); \
statement("struct spvClustered" #spv "Detail<1, 0>"); \
begin_scope(); \
statement("template<typename T>"); \
statement("static T op(T value, uint)"); \
begin_scope(); \
statement("return value;"); \
end_scope(); \
end_scope_decl(); \
statement(""); \
statement("template<uint offset>"); \
statement("struct spvClustered" #spv "Detail<1, offset>"); \
begin_scope(); \
statement("template<typename T>"); \
statement("static T op(T value, uint lid)"); \
begin_scope(); \
statement("// If the target lane is inactive, then return identity."); \
if (msl_options.use_quadgroup_operation()) \
statement("if (!extract_bits((quad_vote::vote_t)quad_active_threads_mask(), (lid ^ offset), 1))"); \
else \
statement("if (!extract_bits(as_type<uint2>((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1))"); \
statement(" return " #ident ";"); \
if (msl_options.use_quadgroup_operation()) \
statement("return quad_shuffle_xor(value, offset);"); \
else \
statement("return simd_shuffle_xor(value, offset);"); \
end_scope(); \
end_scope_decl(); \
statement(""); \
statement("template<>"); \
statement("struct spvClustered" #spv "Detail<4, 0>"); \
begin_scope(); \
statement("template<typename T>"); \
statement("static T op(T value, uint)"); \
begin_scope(); \
statement("return quad_" #msl "(value);"); \
end_scope(); \
end_scope_decl(); \
statement(""); \
statement("template<uint offset>"); \
statement("struct spvClustered" #spv "Detail<4, offset>"); \
begin_scope(); \
statement("template<typename T>"); \
statement("static T op(T value, uint lid)"); \
begin_scope(); \
statement("// Here, we care if any of the lanes in the quad are active."); \
statement("uint quad_mask = extract_bits(as_type<uint2>((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4);"); \
statement("if (!quad_mask)"); \
statement(" return " #ident ";"); \
statement("// But we need to make sure we shuffle from an active lane."); \
if (msl_options.use_quadgroup_operation()) \
SPIRV_CROSS_THROW("Subgroup size with quadgroup operation cannot exceed 4."); \
else \
statement("return simd_shuffle(quad_" #msl "(value), ((lid ^ offset) & ~3) | ctz(quad_mask));"); \
end_scope(); \
end_scope_decl(); \
statement(""); \
statement("// General case"); \
statement("template<uint N, uint offset>"); \
statement("struct spvClustered" #spv "Detail"); \
begin_scope(); \
statement("template<typename T>"); \
statement("static T op(T value, uint lid)"); \
begin_scope(); \
statement("return " combine(msl, op, "spvClustered" #spv "Detail<N/2, offset>::op(value, lid)", "spvClustered" #spv "Detail<N/2, offset + N/2>::op(value, lid)") ";"); \
end_scope(); \
end_scope_decl(); \
statement(""); \
statement("template<uint N, typename T>"); \
statement("T spvClustered_" #msl "(T value, uint lid)"); \
begin_scope(); \
statement("return spvClustered" #spv "Detail<N, 0>::op(value, lid);"); \
end_scope(); \
statement(""); \
break
#define BINOP(msl, op, l, r) l " " #op " " r
#define BINFUNC(msl, op, l, r) #msl "(" l ", " r ")"
FUNC_SUBGROUP_CLUSTERED(Add, sum, BINOP, +, 0);
FUNC_SUBGROUP_CLUSTERED(Mul, product, BINOP, *, 1);
FUNC_SUBGROUP_CLUSTERED(Min, min, BINFUNC, , numeric_limits<T>::max());
FUNC_SUBGROUP_CLUSTERED(Max, max, BINFUNC, , numeric_limits<T>::min());
FUNC_SUBGROUP_CLUSTERED(And, and, BINOP, &, ~T(0));
FUNC_SUBGROUP_CLUSTERED(Or, or, BINOP, |, 0);
FUNC_SUBGROUP_CLUSTERED(Xor, xor, BINOP, ^, 0);
// clang-format on
#undef FUNC_SUBGROUP_CLUSTERED
#undef BINOP
#undef BINFUNC
case SPVFuncImplQuadBroadcast:
statement("template<typename T>");
statement("inline T spvQuadBroadcast(T value, uint lane)");
@@ -9126,7 +9246,7 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
// If the physical type of a physical buffer pointer has been changed
// to a ulong or ulongn vector, add a cast back to the pointer type.
void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
bool CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
{
auto *p_physical_type = maybe_get<SPIRType>(physical_type);
if (p_physical_type &&
@@ -9137,7 +9257,10 @@ void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *ty
expr += ".x";
expr = join("((", type_to_glsl(*type), ")", expr, ")");
return true;
}
return false;
}
// Override for MSL-specific syntax instructions
@@ -9840,9 +9963,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
case OpControlBarrier:
// In GLSL a memory barrier is often followed by a control barrier.
// But in MSL, memory barriers are also control barriers, so don't
// But in MSL, memory barriers are also control barriers (before MSL 3.2), so don't
// emit a simple control barrier if a memory barrier has just been emitted.
if (previous_instruction_opcode != OpMemoryBarrier)
if (previous_instruction_opcode != OpMemoryBarrier || msl_options.supports_msl_version(3, 2))
emit_barrier(ops[0], ops[1], ops[2]);
break;
@@ -10441,10 +10564,20 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
return;
string bar_stmt;
if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
if (!id_exe_scope && msl_options.supports_msl_version(3, 2))
{
// Just took 10 years to get a proper barrier, but hey!
bar_stmt = "atomic_thread_fence";
}
else
bar_stmt = "threadgroup_barrier";
{
if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
else
bar_stmt = "threadgroup_barrier";
}
bar_stmt += "(";
uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
@@ -10452,7 +10585,8 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
// Use the | operator to combine flags if we can.
if (msl_options.supports_msl_version(1, 2))
{
string mem_flags = "";
string mem_flags;
// For tesc shaders, this also affects objects in the Output storage class.
// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
if (is_tesc_shader() ||
@@ -10493,6 +10627,55 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin
bar_stmt += "mem_flags::mem_none";
}
if (!id_exe_scope && msl_options.supports_msl_version(3, 2))
{
// If there's no device-related memory in the barrier, demote to workgroup scope.
// glslang seems to emit device scope even for memoryBarrierShared().
if (mem_scope == ScopeDevice &&
(mem_sem & (MemorySemanticsUniformMemoryMask |
MemorySemanticsImageMemoryMask |
MemorySemanticsCrossWorkgroupMemoryMask)) == 0)
{
mem_scope = ScopeWorkgroup;
}
// MSL 3.2 only supports seq_cst or relaxed.
if (mem_sem & (MemorySemanticsAcquireReleaseMask |
MemorySemanticsAcquireMask |
MemorySemanticsReleaseMask |
MemorySemanticsSequentiallyConsistentMask))
{
bar_stmt += ", memory_order_seq_cst";
}
else
{
bar_stmt += ", memory_order_relaxed";
}
switch (mem_scope)
{
case ScopeDevice:
bar_stmt += ", thread_scope_device";
break;
case ScopeWorkgroup:
bar_stmt += ", thread_scope_threadgroup";
break;
case ScopeSubgroup:
bar_stmt += ", thread_scope_subgroup";
break;
case ScopeInvocation:
bar_stmt += ", thread_scope_thread";
break;
default:
// The default argument is device, which is conservative.
break;
}
}
bar_stmt += ");";
statement(bar_stmt);
@@ -13663,9 +13846,17 @@ string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
return get_type_address_space(type, argument.self, true);
}
bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags)
bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags) const
{
return flags.get(DecorationVolatile) || flags.get(DecorationCoherent);
// Using volatile for coherent pre-3.2 is definitely not correct, but it's something.
// MSL 3.2 adds actual coherent qualifiers.
return flags.get(DecorationVolatile) ||
(flags.get(DecorationCoherent) && !msl_options.supports_msl_version(3, 2));
}
bool CompilerMSL::decoration_flags_signal_coherent(const Bitset &flags) const
{
return flags.get(DecorationCoherent) && msl_options.supports_msl_version(3, 2);
}
string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
@@ -13677,8 +13868,17 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
(has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
flags = get_buffer_block_flags(id);
else
{
flags = get_decoration_bitset(id);
if (type.basetype == SPIRType::Struct &&
(has_decoration(type.self, DecorationBlock) ||
has_decoration(type.self, DecorationBufferBlock)))
{
flags.merge_or(ir.get_buffer_block_type_flags(type));
}
}
const char *addr_space = nullptr;
switch (type.storage)
{
@@ -13687,7 +13887,6 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
break;
case StorageClassStorageBuffer:
case StorageClassPhysicalStorageBuffer:
{
// For arguments from variable pointers, we use the write count deduction, so
// we should not assume any constness here. Only for global SSBOs.
@@ -13695,10 +13894,19 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
if (!var || has_decoration(type.self, DecorationBlock))
readonly = flags.get(DecorationNonWritable);
if (decoration_flags_signal_coherent(flags))
readonly = false;
addr_space = readonly ? "const device" : "device";
break;
}
case StorageClassPhysicalStorageBuffer:
// We cannot fully trust NonWritable coming from glslang due to a bug in buffer_reference handling.
// There isn't much gain in emitting const in C++ languages anyway.
addr_space = "device";
break;
case StorageClassUniform:
case StorageClassUniformConstant:
case StorageClassPushConstant:
@@ -13787,7 +13995,9 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
}
if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread"))
if (decoration_flags_signal_coherent(flags) && strcmp(addr_space, "device") == 0)
return join("coherent device");
else if (decoration_flags_signal_volatile(flags) && strcmp(addr_space, "thread") != 0)
return join("volatile ", addr_space);
else
return addr_space;
@@ -15411,7 +15621,8 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
bool constref = !arg.alias_global_variable && !passed_by_value && is_pointer(var_type) && arg.write_count == 0;
// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
if (type_is_msl_framebuffer_fetch(type))
// readonly coming from glslang is not reliable in all cases.
if (type_is_msl_framebuffer_fetch(type) || type_storage == StorageClassPhysicalStorageBuffer)
constref = false;
else if (type_storage == StorageClassUniformConstant)
constref = true;
@@ -16639,6 +16850,10 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
// Otherwise it may be set based on whether the image is read from or written to within the shader.
if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
{
auto *p_var = maybe_get_backing_variable(id);
if (p_var && p_var->basevariable)
p_var = maybe_get<SPIRVariable>(p_var->basevariable);
switch (img_type.access)
{
case AccessQualifierReadOnly:
@@ -16655,9 +16870,6 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
default:
{
auto *p_var = maybe_get_backing_variable(id);
if (p_var && p_var->basevariable)
p_var = maybe_get<SPIRVariable>(p_var->basevariable);
if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
{
img_type_name += ", access::";
@@ -16670,6 +16882,9 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb
break;
}
}
if (p_var && has_decoration(p_var->self, DecorationCoherent) && msl_options.supports_msl_version(3, 2))
img_type_name += ", memory_coherence_device";
}
img_type_name += ">";
@@ -16924,11 +17139,10 @@ case OpGroupNonUniform##op: \
emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16953,11 +17167,10 @@ case OpGroupNonUniform##op: \
SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16976,11 +17189,10 @@ case OpGroupNonUniform##op: \
SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \
if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \
add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \
emit_subgroup_cluster_op_cast(result_type, id, cluster_size, ops[op_idx], #msl_op, type, type); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@@ -16996,9 +17208,11 @@ case OpGroupNonUniform##op: \
MSL_GROUP_OP(BitwiseAnd, and)
MSL_GROUP_OP(BitwiseOr, or)
MSL_GROUP_OP(BitwiseXor, xor)
MSL_GROUP_OP(LogicalAnd, and)
MSL_GROUP_OP(LogicalOr, or)
MSL_GROUP_OP(LogicalXor, xor)
// Metal doesn't support boolean types in SIMD-group operations, so we
// have to emit some casts.
MSL_GROUP_OP_CAST(LogicalAnd, and, SPIRType::UShort)
MSL_GROUP_OP_CAST(LogicalOr, or, SPIRType::UShort)
MSL_GROUP_OP_CAST(LogicalXor, xor, SPIRType::UShort)
// clang-format on
#undef MSL_GROUP_OP
#undef MSL_GROUP_OP_CAST
@@ -17026,6 +17240,83 @@ case OpGroupNonUniform##op: \
register_control_dependent_expression(id);
}
void CompilerMSL::emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size,
uint32_t op0, const char *op)
{
if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2))
{
if (cluster_size == 4)
{
emit_unary_func_op(result_type, result_id, op0, join("quad_", op).c_str());
return;
}
SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2.");
}
bool forward = should_forward(op0);
emit_op(result_type, result_id,
join("spvClustered_", op, "<", cluster_size, ">(", to_unpacked_expression(op0), ", ",
to_expression(builtin_subgroup_invocation_id_id), ")"),
forward);
inherit_expression_dependencies(result_id, op0);
}
void CompilerMSL::emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size,
uint32_t op0, const char *op, SPIRType::BaseType input_type,
SPIRType::BaseType expected_result_type)
{
if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2))
{
if (cluster_size == 4)
{
emit_unary_func_op_cast(result_type, result_id, op0, join("quad_", op).c_str(), input_type,
expected_result_type);
return;
}
SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2.");
}
auto &out_type = get<SPIRType>(result_type);
auto &expr_type = expression_type(op0);
auto expected_type = out_type;
// Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends.
expected_type.basetype = input_type;
expected_type.width = expr_type.width;
string cast_op;
if (expr_type.basetype != input_type)
{
if (expr_type.basetype == SPIRType::Boolean)
cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")");
else
cast_op = bitcast_glsl(expected_type, op0);
}
else
cast_op = to_unpacked_expression(op0);
string sg_op = join("spvClustered_", op, "<", cluster_size, ">");
string expr;
if (out_type.basetype != expected_result_type)
{
expected_type.basetype = expected_result_type;
expected_type.width = out_type.width;
if (out_type.basetype == SPIRType::Boolean)
expr = type_to_glsl(out_type);
else
expr = bitcast_glsl_op(out_type, expected_type);
expr += '(';
expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")");
expr += ')';
}
else
{
expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")");
}
emit_op(result_type, result_id, expr, should_forward(op0));
inherit_expression_dependencies(result_id, op0);
}
string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
{
if (out_type.basetype == in_type.basetype)
@@ -18097,6 +18388,28 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
}
break;
case OpGroupNonUniformFAdd:
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformIAdd:
case OpGroupNonUniformIMul:
case OpGroupNonUniformSMin:
case OpGroupNonUniformSMax:
case OpGroupNonUniformUMin:
case OpGroupNonUniformUMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
case OpGroupNonUniformLogicalAnd:
case OpGroupNonUniformLogicalOr:
case OpGroupNonUniformLogicalXor:
if ((compiler.get_execution_model() != ExecutionModelFragment ||
compiler.msl_options.supports_msl_version(2, 2)) &&
args[3] == GroupOperationClusteredReduce)
needs_subgroup_invocation_id = true;
break;
case OpArrayLength:
{
auto *var = compiler.maybe_get_backing_variable(args[2]);

View File

@@ -818,6 +818,29 @@ protected:
SPVFuncImplSubgroupShuffleUp,
SPVFuncImplSubgroupShuffleDown,
SPVFuncImplSubgroupRotate,
SPVFuncImplSubgroupClusteredAdd,
SPVFuncImplSubgroupClusteredFAdd = SPVFuncImplSubgroupClusteredAdd,
SPVFuncImplSubgroupClusteredIAdd = SPVFuncImplSubgroupClusteredAdd,
SPVFuncImplSubgroupClusteredMul,
SPVFuncImplSubgroupClusteredFMul = SPVFuncImplSubgroupClusteredMul,
SPVFuncImplSubgroupClusteredIMul = SPVFuncImplSubgroupClusteredMul,
SPVFuncImplSubgroupClusteredMin,
SPVFuncImplSubgroupClusteredFMin = SPVFuncImplSubgroupClusteredMin,
SPVFuncImplSubgroupClusteredSMin = SPVFuncImplSubgroupClusteredMin,
SPVFuncImplSubgroupClusteredUMin = SPVFuncImplSubgroupClusteredMin,
SPVFuncImplSubgroupClusteredMax,
SPVFuncImplSubgroupClusteredFMax = SPVFuncImplSubgroupClusteredMax,
SPVFuncImplSubgroupClusteredSMax = SPVFuncImplSubgroupClusteredMax,
SPVFuncImplSubgroupClusteredUMax = SPVFuncImplSubgroupClusteredMax,
SPVFuncImplSubgroupClusteredAnd,
SPVFuncImplSubgroupClusteredBitwiseAnd = SPVFuncImplSubgroupClusteredAnd,
SPVFuncImplSubgroupClusteredLogicalAnd = SPVFuncImplSubgroupClusteredAnd,
SPVFuncImplSubgroupClusteredOr,
SPVFuncImplSubgroupClusteredBitwiseOr = SPVFuncImplSubgroupClusteredOr,
SPVFuncImplSubgroupClusteredLogicalOr = SPVFuncImplSubgroupClusteredOr,
SPVFuncImplSubgroupClusteredXor,
SPVFuncImplSubgroupClusteredBitwiseXor = SPVFuncImplSubgroupClusteredXor,
SPVFuncImplSubgroupClusteredLogicalXor = SPVFuncImplSubgroupClusteredXor,
SPVFuncImplQuadBroadcast,
SPVFuncImplQuadSwap,
SPVFuncImplReflectScalar,
@@ -871,6 +894,11 @@ protected:
void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override;
void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override;
void emit_subgroup_op(const Instruction &i) override;
void emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0,
const char *op);
void emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0,
const char *op, SPIRType::BaseType input_type,
SPIRType::BaseType expected_result_type);
std::string to_texture_op(const Instruction &i, bool sparse, bool *forward,
SmallVector<uint32_t> &inherited_expressions) override;
void emit_fixup() override;
@@ -1084,7 +1112,8 @@ protected:
bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const;
std::string get_argument_address_space(const SPIRVariable &argument);
std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false);
static bool decoration_flags_signal_volatile(const Bitset &flags);
bool decoration_flags_signal_volatile(const Bitset &flags) const;
bool decoration_flags_signal_coherent(const Bitset &flags) const;
const char *to_restrict(uint32_t id, bool space);
SPIRType &get_stage_in_struct_type();
SPIRType &get_stage_out_struct_type();
@@ -1154,7 +1183,7 @@ protected:
bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage,
bool &is_packed) override;
void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length);
void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;
bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override;
bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length);
bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);