From 7c79acf98e1be8d5d29363790610d61ca0bf419c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Sat, 7 Jun 2025 10:26:12 -0700 Subject: [PATCH] Updated spirv-cross. --- 3rdparty/spirv-cross/spirv_common.hpp | 3 + 3rdparty/spirv-cross/spirv_cross.cpp | 43 ++- 3rdparty/spirv-cross/spirv_cross.hpp | 17 ++ 3rdparty/spirv-cross/spirv_cross_c.cpp | 1 + 3rdparty/spirv-cross/spirv_glsl.cpp | 93 +++--- 3rdparty/spirv-cross/spirv_glsl.hpp | 4 +- 3rdparty/spirv-cross/spirv_hlsl.cpp | 222 ++++++++++++++- 3rdparty/spirv-cross/spirv_hlsl.hpp | 4 + 3rdparty/spirv-cross/spirv_msl.cpp | 373 +++++++++++++++++++++++-- 3rdparty/spirv-cross/spirv_msl.hpp | 33 ++- 10 files changed, 707 insertions(+), 86 deletions(-) diff --git a/3rdparty/spirv-cross/spirv_common.hpp b/3rdparty/spirv-cross/spirv_common.hpp index dafd0c072..a4778c29b 100644 --- a/3rdparty/spirv-cross/spirv_common.hpp +++ b/3rdparty/spirv-cross/spirv_common.hpp @@ -1035,6 +1035,9 @@ struct SPIRFunction : IVariant // consider arrays value types. SmallVector constant_arrays_needed_on_stack; + // Does this function (or any function called by it), emit geometry? + bool emits_geometry = false; + bool active = false; bool flush_undeclared = true; bool do_combined_parameters = true; diff --git a/3rdparty/spirv-cross/spirv_cross.cpp b/3rdparty/spirv-cross/spirv_cross.cpp index 8c2608efd..8aa4e5e70 100644 --- a/3rdparty/spirv-cross/spirv_cross.cpp +++ b/3rdparty/spirv-cross/spirv_cross.cpp @@ -82,7 +82,7 @@ bool Compiler::variable_storage_is_aliased(const SPIRVariable &v) ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock); bool image = type.basetype == SPIRType::Image; bool counter = type.basetype == SPIRType::AtomicCounter; - bool buffer_reference = type.storage == StorageClassPhysicalStorageBufferEXT; + bool buffer_reference = type.storage == StorageClassPhysicalStorageBuffer; bool is_restrict; if (ssbo) @@ -484,7 +484,7 @@ void Compiler::register_write(uint32_t chain) } } - if (type.storage == StorageClassPhysicalStorageBufferEXT || variable_storage_is_aliased(*var)) + if (type.storage == StorageClassPhysicalStorageBuffer || variable_storage_is_aliased(*var)) flush_all_aliased_variables(); else if (var) flush_dependees(*var); @@ -4362,6 +4362,39 @@ bool Compiler::may_read_undefined_variable_in_block(const SPIRBlock &block, uint return true; } +bool Compiler::GeometryEmitDisocveryHandler::handle(spv::Op opcode, const uint32_t *, uint32_t) +{ + if (opcode == OpEmitVertex || opcode == OpEndPrimitive) + { + for (auto *func : function_stack) + func->emits_geometry = true; + } + + return true; +} + +bool Compiler::GeometryEmitDisocveryHandler::begin_function_scope(const uint32_t *stream, uint32_t) +{ + auto &callee = compiler.get(stream[2]); + function_stack.push_back(&callee); + return true; +} + +bool Compiler::GeometryEmitDisocveryHandler::end_function_scope([[maybe_unused]] const uint32_t *stream, uint32_t) +{ + assert(function_stack.back() == &compiler.get(stream[2])); + function_stack.pop_back(); + + return true; +} + +void Compiler::discover_geometry_emitters() +{ + GeometryEmitDisocveryHandler handler(*this); + + traverse_all_reachable_opcodes(get(ir.default_entry_point), handler); +} + Bitset Compiler::get_buffer_block_flags(VariableID id) const { return ir.get_buffer_block_flags(get(id)); @@ -5194,7 +5227,7 @@ bool Compiler::PhysicalStorageBufferPointerHandler::type_is_bda_block_entry(uint uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_minimum_scalar_alignment(const SPIRType &type) const { - if (type.storage == spv::StorageClassPhysicalStorageBufferEXT) + if (type.storage == spv::StorageClassPhysicalStorageBuffer) return 8; else if (type.basetype == SPIRType::Struct) { @@ -5298,6 +5331,10 @@ uint32_t Compiler::PhysicalStorageBufferPointerHandler::get_base_non_block_type_ void Compiler::PhysicalStorageBufferPointerHandler::analyze_non_block_types_from_block(const SPIRType &type) { + if (analyzed_type_ids.count(type.self)) + return; + analyzed_type_ids.insert(type.self); + for (auto &member : type.member_types) { auto &subtype = compiler.get(member); diff --git a/3rdparty/spirv-cross/spirv_cross.hpp b/3rdparty/spirv-cross/spirv_cross.hpp index e9062b485..b65b5ac77 100644 --- a/3rdparty/spirv-cross/spirv_cross.hpp +++ b/3rdparty/spirv-cross/spirv_cross.hpp @@ -1054,6 +1054,7 @@ protected: std::unordered_set non_block_types; std::unordered_map physical_block_type_meta; std::unordered_map access_chain_to_physical_block; + std::unordered_set analyzed_type_ids; void mark_aligned_access(uint32_t id, const uint32_t *args, uint32_t length); PhysicalBlockMeta *find_block_meta(uint32_t id) const; @@ -1072,6 +1073,22 @@ protected: bool single_function); bool may_read_undefined_variable_in_block(const SPIRBlock &block, uint32_t var); + struct GeometryEmitDisocveryHandler : OpcodeHandler + { + explicit GeometryEmitDisocveryHandler(Compiler &compiler_) + : compiler(compiler_) + { + } + Compiler &compiler; + + bool handle(spv::Op opcode, const uint32_t *args, uint32_t length) override; + bool begin_function_scope(const uint32_t *, uint32_t) override; + bool end_function_scope(const uint32_t *, uint32_t) override; + SmallVector function_stack; + }; + + void discover_geometry_emitters(); + // Finds all resources that are written to from inside the critical section, if present. // The critical section is delimited by OpBeginInvocationInterlockEXT and // OpEndInvocationInterlockEXT instructions. In MSL and HLSL, any resources written diff --git a/3rdparty/spirv-cross/spirv_cross_c.cpp b/3rdparty/spirv-cross/spirv_cross_c.cpp index 6f62dd479..6827f6135 100644 --- a/3rdparty/spirv-cross/spirv_cross_c.cpp +++ b/3rdparty/spirv-cross/spirv_cross_c.cpp @@ -55,6 +55,7 @@ #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4996) +#pragma warning(disable : 4065) // switch with 'default' but not 'case'. #endif #ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS diff --git a/3rdparty/spirv-cross/spirv_glsl.cpp b/3rdparty/spirv-cross/spirv_glsl.cpp index f39d715b2..ca9d0309d 100644 --- a/3rdparty/spirv-cross/spirv_glsl.cpp +++ b/3rdparty/spirv-cross/spirv_glsl.cpp @@ -545,7 +545,7 @@ void CompilerGLSL::find_static_extensions() if (options.separate_shader_objects && !options.es && options.version < 410) require_extension_internal("GL_ARB_separate_shader_objects"); - if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) { if (!options.vulkan_semantics) SPIRV_CROSS_THROW("GL_EXT_buffer_reference is only supported in Vulkan GLSL."); @@ -557,7 +557,7 @@ void CompilerGLSL::find_static_extensions() } else if (ir.addressing_model != AddressingModelLogical) { - SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64EXT addressing models are supported."); + SPIRV_CROSS_THROW("Only Logical and PhysicalStorageBuffer64 addressing models are supported."); } // Check for nonuniform qualifier and passthrough. @@ -708,7 +708,7 @@ string CompilerGLSL::compile() // Shaders might cast unrelated data to pointers of non-block types. // Find all such instances and make sure we can cast the pointers to a synthesized block type. - if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) analyze_non_block_pointer_types(); uint32_t pass_count = 0; @@ -1542,14 +1542,14 @@ uint32_t CompilerGLSL::type_to_packed_base_size(const SPIRType &type, BufferPack uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing) { - // If using PhysicalStorageBufferEXT storage class, this is a pointer, + // If using PhysicalStorageBuffer storage class, this is a pointer, // and is 64-bit. if (is_physical_pointer(type)) { if (!type.pointer) - SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers."); + SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers."); - if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) { if (packing_is_vec4_padded(packing) && type_is_array_of_pointers(type)) return 16; @@ -1557,7 +1557,7 @@ uint32_t CompilerGLSL::type_to_packed_alignment(const SPIRType &type, const Bits return 8; } else - SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT."); + SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer."); } else if (is_array(type)) { @@ -1665,17 +1665,17 @@ uint32_t CompilerGLSL::type_to_packed_array_stride(const SPIRType &type, const B uint32_t CompilerGLSL::type_to_packed_size(const SPIRType &type, const Bitset &flags, BufferPackingStandard packing) { - // If using PhysicalStorageBufferEXT storage class, this is a pointer, + // If using PhysicalStorageBuffer storage class, this is a pointer, // and is 64-bit. if (is_physical_pointer(type)) { if (!type.pointer) - SPIRV_CROSS_THROW("Types in PhysicalStorageBufferEXT must be pointers."); + SPIRV_CROSS_THROW("Types in PhysicalStorageBuffer must be pointers."); - if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) return 8; else - SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64EXT must be used for PhysicalStorageBufferEXT."); + SPIRV_CROSS_THROW("AddressingModelPhysicalStorageBuffer64 must be used for PhysicalStorageBuffer."); } else if (is_array(type)) { @@ -3638,6 +3638,36 @@ void CompilerGLSL::emit_resources() bool emitted = false; + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) + { + // Output buffer reference block forward declarations. + ir.for_each_typed_id([&](uint32_t id, SPIRType &type) + { + if (is_physical_pointer(type)) + { + bool emit_type = true; + if (!is_physical_pointer_to_buffer_block(type)) + { + // Only forward-declare if we intend to emit it in the non_block_pointer types. + // Otherwise, these are just "benign" pointer types that exist as a result of access chains. + emit_type = std::find(physical_storage_non_block_pointer_types.begin(), + physical_storage_non_block_pointer_types.end(), + id) != physical_storage_non_block_pointer_types.end(); + } + + if (emit_type) + { + emit_buffer_reference_block(id, true); + emitted = true; + } + } + }); + } + + if (emitted) + statement(""); + emitted = false; + // If emitted Vulkan GLSL, // emit specialization constants as actual floats, // spec op expressions will redirect to the constant name. @@ -3747,30 +3777,10 @@ void CompilerGLSL::emit_resources() emitted = false; - if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64EXT) + if (ir.addressing_model == AddressingModelPhysicalStorageBuffer64) { // Output buffer reference blocks. - // Do this in two stages, one with forward declaration, - // and one without. Buffer reference blocks can reference themselves - // to support things like linked lists. - ir.for_each_typed_id([&](uint32_t id, SPIRType &type) { - if (is_physical_pointer(type)) - { - bool emit_type = true; - if (!is_physical_pointer_to_buffer_block(type)) - { - // Only forward-declare if we intend to emit it in the non_block_pointer types. - // Otherwise, these are just "benign" pointer types that exist as a result of access chains. - emit_type = std::find(physical_storage_non_block_pointer_types.begin(), - physical_storage_non_block_pointer_types.end(), - id) != physical_storage_non_block_pointer_types.end(); - } - - if (emit_type) - emit_buffer_reference_block(id, true); - } - }); - + // Buffer reference blocks can reference themselves to support things like linked lists. for (auto type : physical_storage_non_block_pointer_types) emit_buffer_reference_block(type, false); @@ -10317,7 +10327,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice if (!is_ptr_chain) mod_flags &= ~ACCESS_CHAIN_PTR_CHAIN_BIT; access_chain_internal_append_index(expr, base, type, mod_flags, access_chain_is_arrayed, index); - check_physical_type_cast(expr, type, physical_type); + if (check_physical_type_cast(expr, type, physical_type)) + physical_type = 0; }; for (uint32_t i = 0; i < count; i++) @@ -10825,8 +10836,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice return expr; } -void CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t) +bool CompilerGLSL::check_physical_type_cast(std::string &, const SPIRType *, uint32_t) { + return false; } bool CompilerGLSL::prepare_access_chain_for_scalar_access(std::string &, const SPIRType &, spv::StorageClass, bool &) @@ -15337,8 +15349,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpConvertUToPtr: { auto &type = get(ops[0]); - if (type.storage != StorageClassPhysicalStorageBufferEXT) - SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertUToPtr."); + if (type.storage != StorageClassPhysicalStorageBuffer) + SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertUToPtr."); auto &in_type = expression_type(ops[2]); if (in_type.vecsize == 2) @@ -15353,8 +15365,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) { auto &type = get(ops[0]); auto &ptr_type = expression_type(ops[2]); - if (ptr_type.storage != StorageClassPhysicalStorageBufferEXT) - SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBufferEXT is supported by OpConvertPtrToU."); + if (ptr_type.storage != StorageClassPhysicalStorageBuffer) + SPIRV_CROSS_THROW("Only StorageClassPhysicalStorageBuffer is supported by OpConvertPtrToU."); if (type.vecsize == 2) require_extension_internal("GL_EXT_buffer_reference_uvec2"); @@ -16143,7 +16155,7 @@ string CompilerGLSL::to_array_size(const SPIRType &type, uint32_t index) string CompilerGLSL::type_to_array_glsl(const SPIRType &type, uint32_t) { - if (type.pointer && type.storage == StorageClassPhysicalStorageBufferEXT && type.basetype != SPIRType::Struct) + if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer && type.basetype != SPIRType::Struct) { // We are using a wrapped pointer type, and we should not emit any array declarations here. return ""; @@ -16856,6 +16868,7 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags) { // Recursively emit functions which are called. uint32_t id = ops[2]; + emit_function(get(id), ir.meta[ops[1]].decoration.decoration_flags); } } diff --git a/3rdparty/spirv-cross/spirv_glsl.hpp b/3rdparty/spirv-cross/spirv_glsl.hpp index 04cbf1b02..cea150f22 100644 --- a/3rdparty/spirv-cross/spirv_glsl.hpp +++ b/3rdparty/spirv-cross/spirv_glsl.hpp @@ -769,7 +769,7 @@ protected: spv::StorageClass get_expression_effective_storage_class(uint32_t ptr); virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base); - virtual void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type); + virtual bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type); virtual bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage, bool &is_packed); @@ -799,7 +799,7 @@ protected: std::string declare_temporary(uint32_t type, uint32_t id); void emit_uninitialized_temporary(uint32_t type, uint32_t id); SPIRExpression &emit_uninitialized_temporary_expression(uint32_t type, uint32_t id); - void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist); + virtual void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist); std::string to_non_uniform_aware_expression(uint32_t id); std::string to_atomic_ptr_expression(uint32_t id); std::string to_expression(uint32_t id, bool register_expression_read = true); diff --git a/3rdparty/spirv-cross/spirv_hlsl.cpp b/3rdparty/spirv-cross/spirv_hlsl.cpp index e9869e576..1ec4cb70f 100644 --- a/3rdparty/spirv-cross/spirv_hlsl.cpp +++ b/3rdparty/spirv-cross/spirv_hlsl.cpp @@ -1117,7 +1117,9 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord else { auto decl_type = type; - if (execution.model == ExecutionModelMeshEXT || has_decoration(var.self, DecorationPerVertexKHR)) + if (execution.model == ExecutionModelMeshEXT || + (execution.model == ExecutionModelGeometry && var.storage == StorageClassInput) || + has_decoration(var.self, DecorationPerVertexKHR)) { decl_type.array.erase(decl_type.array.begin()); decl_type.array_size_literal.erase(decl_type.array_size_literal.begin()); @@ -1834,7 +1836,7 @@ void CompilerHLSL::emit_resources() if (!output_variables.empty() || !active_output_builtins.empty()) { sort(output_variables.begin(), output_variables.end(), variable_compare); - require_output = !is_mesh_shader; + require_output = !(is_mesh_shader || execution.model == ExecutionModelGeometry); statement(is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output"); begin_scope(); @@ -2678,6 +2680,83 @@ void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block) } } +void CompilerHLSL::emit_geometry_stream_append() +{ + begin_scope(); + statement("SPIRV_Cross_Output stage_output;"); + + active_output_builtins.for_each_bit( + [&](uint32_t i) + { + if (i == BuiltInPointSize && hlsl_options.shader_model > 30) + return; + switch (static_cast(i)) + { + case BuiltInClipDistance: + for (uint32_t clip = 0; clip < clip_distance_count; clip++) + statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[", + clip, "];"); + break; + case BuiltInCullDistance: + for (uint32_t cull = 0; cull < cull_distance_count; cull++) + statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[", + cull, "];"); + break; + case BuiltInSampleMask: + statement("stage_output.gl_SampleMask = gl_SampleMask[0];"); + break; + default: + { + auto builtin_expr = builtin_to_glsl(static_cast(i), StorageClassOutput); + statement("stage_output.", builtin_expr, " = ", builtin_expr, ";"); + } + break; + } + }); + + ir.for_each_typed_id( + [&](uint32_t, SPIRVariable &var) + { + auto &type = this->get(var.basetype); + bool block = has_decoration(type.self, DecorationBlock); + + if (var.storage != StorageClassOutput) + return; + + if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) && + interface_variable_exists_in_entry_point(var.self)) + { + if (block) + { + auto type_name = to_name(type.self); + auto var_name = to_name(var.self); + for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++) + { + auto mbr_name = to_member_name(type, mbr_idx); + auto flat_name = join(type_name, "_", mbr_name); + statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";"); + } + } + else + { + auto name = to_name(var.self); + if (hlsl_options.shader_model <= 30 && get_entry_point().model == ExecutionModelFragment) + { + string output_filler; + for (uint32_t size = type.vecsize; size < 4; ++size) + output_filler += ", 0.0"; + statement("stage_output.", name, " = float4(", name, output_filler, ");"); + } + else + statement("stage_output.", name, " = ", name, ";"); + } + } + }); + + statement("geometry_stream.Append(stage_output);"); + end_scope(); +} + void CompilerHLSL::emit_buffer_block(const SPIRVariable &var) { auto &type = get(var.basetype); @@ -2940,6 +3019,8 @@ string CompilerHLSL::get_inner_entry_point_name() const return "frag_main"; else if (execution.model == ExecutionModelGLCompute) return "comp_main"; + else if (execution.model == ExecutionModelGeometry) + return "geom_main"; else if (execution.model == ExecutionModelMeshEXT) return "mesh_main"; else if (execution.model == ExecutionModelTaskEXT) @@ -2948,6 +3029,25 @@ string CompilerHLSL::get_inner_entry_point_name() const SPIRV_CROSS_THROW("Unsupported execution model."); } +uint32_t CompilerHLSL::input_vertices_from_execution_mode(spirv_cross::SPIREntryPoint &execution) const +{ + uint32_t input_vertices = 1; + + if (execution.flags.get(ExecutionModeInputLines)) + input_vertices = 2; + else if (execution.flags.get(ExecutionModeInputLinesAdjacency)) + input_vertices = 4; + else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency)) + input_vertices = 6; + else if (execution.flags.get(ExecutionModeTriangles)) + input_vertices = 3; + else if (execution.flags.get(ExecutionModeInputPoints)) + input_vertices = 1; + else + SPIRV_CROSS_THROW("Unsupported execution model."); + return input_vertices; +} + void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) { if (func.self != ir.default_entry_point) @@ -3041,6 +3141,38 @@ void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &ret var->parameter = &arg; } + if ((func.self == ir.default_entry_point || func.emits_geometry) && + get_entry_point().model == ExecutionModelGeometry) + { + auto &execution = get_entry_point(); + + uint32_t input_vertices = input_vertices_from_execution_mode(execution); + + const char *prim; + if (execution.flags.get(ExecutionModeInputLinesAdjacency)) + prim = "lineadj"; + else if (execution.flags.get(ExecutionModeInputLines)) + prim = "line"; + else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency)) + prim = "triangleadj"; + else if (execution.flags.get(ExecutionModeTriangles)) + prim = "triangle"; + else + prim = "point"; + + const char *stream_type; + if (execution.flags.get(ExecutionModeOutputPoints)) + stream_type = "PointStream"; + else if (execution.flags.get(ExecutionModeOutputLineStrip)) + stream_type = "LineStream"; + else + stream_type = "TriangleStream"; + + if (func.self == ir.default_entry_point) + arglist.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]")); + arglist.push_back(join("inout ", stream_type, " ", "geometry_stream")); + } + decl += merge(arglist); decl += ")"; statement(decl); @@ -3050,13 +3182,50 @@ void CompilerHLSL::emit_hlsl_entry_point() { SmallVector arguments; - if (require_input) + if (require_input && get_entry_point().model != ExecutionModelGeometry) arguments.push_back("SPIRV_Cross_Input stage_input"); auto &execution = get_entry_point(); + uint32_t input_vertices = 1; + switch (execution.model) { + case ExecutionModelGeometry: + { + input_vertices = input_vertices_from_execution_mode(execution); + + string prim; + if (execution.flags.get(ExecutionModeInputLinesAdjacency)) + prim = "lineadj"; + else if (execution.flags.get(ExecutionModeInputLines)) + prim = "line"; + else if (execution.flags.get(ExecutionModeInputTrianglesAdjacency)) + prim = "triangleadj"; + else if (execution.flags.get(ExecutionModeTriangles)) + prim = "triangle"; + else + prim = "point"; + + string stream_type; + if (execution.flags.get(ExecutionModeOutputPoints)) + { + stream_type = "PointStream"; + } + else if (execution.flags.get(ExecutionModeOutputLineStrip)) + { + stream_type = "LineStream"; + } + else + { + stream_type = "TriangleStream"; + } + + statement("[maxvertexcount(", execution.output_vertices, ")]"); + arguments.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]")); + arguments.push_back(join("inout ", stream_type, " ", "geometry_stream")); + break; + } case ExecutionModelTaskEXT: case ExecutionModelMeshEXT: case ExecutionModelGLCompute: @@ -3359,18 +3528,24 @@ void CompilerHLSL::emit_hlsl_entry_point() } else { - statement(name, " = stage_input.", name, ";"); + if (execution.model == ExecutionModelGeometry) + { + statement("for (int i = 0; i < ", input_vertices, "; i++)"); + begin_scope(); + statement(name, "[i] = stage_input[i].", name, ";"); + end_scope(); + } + else + statement(name, " = stage_input.", name, ";"); } } } }); // Run the shader. - if (execution.model == ExecutionModelVertex || - execution.model == ExecutionModelFragment || - execution.model == ExecutionModelGLCompute || - execution.model == ExecutionModelMeshEXT || - execution.model == ExecutionModelTaskEXT) + if (execution.model == ExecutionModelVertex || execution.model == ExecutionModelFragment || + execution.model == ExecutionModelGLCompute || execution.model == ExecutionModelMeshEXT || + execution.model == ExecutionModelGeometry || execution.model == ExecutionModelTaskEXT) { // For mesh shaders, we receive special arguments that we must pass down as function arguments. // HLSL does not support proper reference types for passing these IO blocks, @@ -3378,8 +3553,16 @@ void CompilerHLSL::emit_hlsl_entry_point() SmallVector arglist; auto &func = get(ir.default_entry_point); // The arguments are marked out, avoid detecting reads and emitting inout. + for (auto &arg : func.arguments) arglist.push_back(to_expression(arg.id, false)); + + if (execution.model == ExecutionModelGeometry) + { + arglist.push_back("stage_input"); + arglist.push_back("geometry_stream"); + } + statement(get_inner_entry_point_name(), "(", merge(arglist), ");"); } else @@ -4206,6 +4389,14 @@ bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t) return false; } +void CompilerHLSL::append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist) +{ + CompilerGLSL::append_global_func_args(func, index, arglist); + + if (func.emits_geometry) + arglist.push_back("geometry_stream"); +} + string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) { if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int) @@ -6594,6 +6785,16 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) statement("SetMeshOutputCounts(", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");"); break; } + case OpEmitVertex: + { + emit_geometry_stream_append(); + break; + } + case OpEndPrimitive: + { + statement("geometry_stream.RestartStrip();"); + break; + } default: CompilerGLSL::emit_instruction(instruction); break; @@ -6812,6 +7013,9 @@ string CompilerHLSL::compile() if (get_execution_model() == ExecutionModelMeshEXT) analyze_meshlet_writes(); + if (get_execution_model() == ExecutionModelGeometry) + discover_geometry_emitters(); + // Subpass input needs SV_Position. if (need_subpass_input) active_input_builtins.set(BuiltInFragCoord); diff --git a/3rdparty/spirv-cross/spirv_hlsl.hpp b/3rdparty/spirv-cross/spirv_hlsl.hpp index 3dc89cc68..4303bb7d5 100644 --- a/3rdparty/spirv-cross/spirv_hlsl.hpp +++ b/3rdparty/spirv-cross/spirv_hlsl.hpp @@ -231,6 +231,7 @@ private: std::string image_type_hlsl(const SPIRType &type, uint32_t id); std::string image_type_hlsl_modern(const SPIRType &type, uint32_t id); std::string image_type_hlsl_legacy(const SPIRType &type, uint32_t id); + uint32_t input_vertices_from_execution_mode(SPIREntryPoint &execution) const; void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override; void emit_hlsl_entry_point(); void emit_header() override; @@ -259,6 +260,8 @@ private: std::string to_interpolation_qualifiers(const Bitset &flags) override; std::string bitcast_glsl_op(const SPIRType &result_type, const SPIRType &argument_type) override; bool emit_complex_bitcast(uint32_t result_type, uint32_t id, uint32_t op0) override; + void append_global_func_args(const SPIRFunction &func, uint32_t index, SmallVector &arglist) override; + std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id) override; std::string to_sampler_expression(uint32_t id); std::string to_resource_binding(const SPIRVariable &var); @@ -286,6 +289,7 @@ private: uint32_t base_offset = 0) override; void emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops); void emit_mesh_tasks(SPIRBlock &block) override; + void emit_geometry_stream_append(); const char *to_storage_qualifiers_glsl(const SPIRVariable &var) override; void replace_illegal_names() override; diff --git a/3rdparty/spirv-cross/spirv_msl.cpp b/3rdparty/spirv-cross/spirv_msl.cpp index 16024b281..53f74f177 100644 --- a/3rdparty/spirv-cross/spirv_msl.cpp +++ b/3rdparty/spirv-cross/spirv_msl.cpp @@ -2222,6 +2222,27 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: break; } + case OpGroupNonUniformFAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformIAdd: + case OpGroupNonUniformIMul: + case OpGroupNonUniformSMin: + case OpGroupNonUniformSMax: + case OpGroupNonUniformUMin: + case OpGroupNonUniformUMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformLogicalAnd: + case OpGroupNonUniformLogicalOr: + case OpGroupNonUniformLogicalXor: + if ((get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) && + ops[3] == GroupOperationClusteredReduce) + added_arg_ids.insert(builtin_subgroup_invocation_id_id); + break; + case OpDemoteToHelperInvocation: if (needs_manual_helper_invocation_updates() && needs_helper_invocation) added_arg_ids.insert(builtin_helper_invocation_id); @@ -7026,6 +7047,105 @@ void CompilerMSL::emit_custom_functions() statement(""); break; + // C++ disallows partial specializations of function templates, + // hence the use of a struct. + // clang-format off +#define FUNC_SUBGROUP_CLUSTERED(spv, msl, combine, op, ident) \ + case SPVFuncImplSubgroupClustered##spv: \ + statement("template"); \ + statement("struct spvClustered" #spv "Detail;"); \ + statement(""); \ + statement("// Base cases"); \ + statement("template<>"); \ + statement("struct spvClustered" #spv "Detail<1, 0>"); \ + begin_scope(); \ + statement("template"); \ + statement("static T op(T value, uint)"); \ + begin_scope(); \ + statement("return value;"); \ + end_scope(); \ + end_scope_decl(); \ + statement(""); \ + statement("template"); \ + statement("struct spvClustered" #spv "Detail<1, offset>"); \ + begin_scope(); \ + statement("template"); \ + statement("static T op(T value, uint lid)"); \ + begin_scope(); \ + statement("// If the target lane is inactive, then return identity."); \ + if (msl_options.use_quadgroup_operation()) \ + statement("if (!extract_bits((quad_vote::vote_t)quad_active_threads_mask(), (lid ^ offset), 1))"); \ + else \ + statement("if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1))"); \ + statement(" return " #ident ";"); \ + if (msl_options.use_quadgroup_operation()) \ + statement("return quad_shuffle_xor(value, offset);"); \ + else \ + statement("return simd_shuffle_xor(value, offset);"); \ + end_scope(); \ + end_scope_decl(); \ + statement(""); \ + statement("template<>"); \ + statement("struct spvClustered" #spv "Detail<4, 0>"); \ + begin_scope(); \ + statement("template"); \ + statement("static T op(T value, uint)"); \ + begin_scope(); \ + statement("return quad_" #msl "(value);"); \ + end_scope(); \ + end_scope_decl(); \ + statement(""); \ + statement("template"); \ + statement("struct spvClustered" #spv "Detail<4, offset>"); \ + begin_scope(); \ + statement("template"); \ + statement("static T op(T value, uint lid)"); \ + begin_scope(); \ + statement("// Here, we care if any of the lanes in the quad are active."); \ + statement("uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4);"); \ + statement("if (!quad_mask)"); \ + statement(" return " #ident ";"); \ + statement("// But we need to make sure we shuffle from an active lane."); \ + if (msl_options.use_quadgroup_operation()) \ + SPIRV_CROSS_THROW("Subgroup size with quadgroup operation cannot exceed 4."); \ + else \ + statement("return simd_shuffle(quad_" #msl "(value), ((lid ^ offset) & ~3) | ctz(quad_mask));"); \ + end_scope(); \ + end_scope_decl(); \ + statement(""); \ + statement("// General case"); \ + statement("template"); \ + statement("struct spvClustered" #spv "Detail"); \ + begin_scope(); \ + statement("template"); \ + statement("static T op(T value, uint lid)"); \ + begin_scope(); \ + statement("return " combine(msl, op, "spvClustered" #spv "Detail::op(value, lid)", "spvClustered" #spv "Detail::op(value, lid)") ";"); \ + end_scope(); \ + end_scope_decl(); \ + statement(""); \ + statement("template"); \ + statement("T spvClustered_" #msl "(T value, uint lid)"); \ + begin_scope(); \ + statement("return spvClustered" #spv "Detail::op(value, lid);"); \ + end_scope(); \ + statement(""); \ + break +#define BINOP(msl, op, l, r) l " " #op " " r +#define BINFUNC(msl, op, l, r) #msl "(" l ", " r ")" + + FUNC_SUBGROUP_CLUSTERED(Add, sum, BINOP, +, 0); + FUNC_SUBGROUP_CLUSTERED(Mul, product, BINOP, *, 1); + FUNC_SUBGROUP_CLUSTERED(Min, min, BINFUNC, , numeric_limits::max()); + FUNC_SUBGROUP_CLUSTERED(Max, max, BINFUNC, , numeric_limits::min()); + FUNC_SUBGROUP_CLUSTERED(And, and, BINOP, &, ~T(0)); + FUNC_SUBGROUP_CLUSTERED(Or, or, BINOP, |, 0); + FUNC_SUBGROUP_CLUSTERED(Xor, xor, BINOP, ^, 0); + // clang-format on +#undef FUNC_SUBGROUP_CLUSTERED +#undef BINOP +#undef BINFUNC + case SPVFuncImplQuadBroadcast: statement("template"); statement("inline T spvQuadBroadcast(T value, uint lane)"); @@ -9126,7 +9246,7 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t // If the physical type of a physical buffer pointer has been changed // to a ulong or ulongn vector, add a cast back to the pointer type. -void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) +bool CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) { auto *p_physical_type = maybe_get(physical_type); if (p_physical_type && @@ -9137,7 +9257,10 @@ void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *ty expr += ".x"; expr = join("((", type_to_glsl(*type), ")", expr, ")"); + return true; } + + return false; } // Override for MSL-specific syntax instructions @@ -9840,9 +9963,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) case OpControlBarrier: // In GLSL a memory barrier is often followed by a control barrier. - // But in MSL, memory barriers are also control barriers, so don't + // But in MSL, memory barriers are also control barriers (before MSL 3.2), so don't // emit a simple control barrier if a memory barrier has just been emitted. - if (previous_instruction_opcode != OpMemoryBarrier) + if (previous_instruction_opcode != OpMemoryBarrier || msl_options.supports_msl_version(3, 2)) emit_barrier(ops[0], ops[1], ops[2]); break; @@ -10441,10 +10564,20 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin return; string bar_stmt; - if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2)) - bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier"; + + if (!id_exe_scope && msl_options.supports_msl_version(3, 2)) + { + // Just took 10 years to get a proper barrier, but hey! + bar_stmt = "atomic_thread_fence"; + } else - bar_stmt = "threadgroup_barrier"; + { + if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2)) + bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier"; + else + bar_stmt = "threadgroup_barrier"; + } + bar_stmt += "("; uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone); @@ -10452,7 +10585,8 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin // Use the | operator to combine flags if we can. if (msl_options.supports_msl_version(1, 2)) { - string mem_flags = ""; + string mem_flags; + // For tesc shaders, this also affects objects in the Output storage class. // Since in Metal, these are placed in a device buffer, we have to sync device memory here. if (is_tesc_shader() || @@ -10493,6 +10627,55 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin bar_stmt += "mem_flags::mem_none"; } + if (!id_exe_scope && msl_options.supports_msl_version(3, 2)) + { + // If there's no device-related memory in the barrier, demote to workgroup scope. + // glslang seems to emit device scope even for memoryBarrierShared(). + if (mem_scope == ScopeDevice && + (mem_sem & (MemorySemanticsUniformMemoryMask | + MemorySemanticsImageMemoryMask | + MemorySemanticsCrossWorkgroupMemoryMask)) == 0) + { + mem_scope = ScopeWorkgroup; + } + + // MSL 3.2 only supports seq_cst or relaxed. + if (mem_sem & (MemorySemanticsAcquireReleaseMask | + MemorySemanticsAcquireMask | + MemorySemanticsReleaseMask | + MemorySemanticsSequentiallyConsistentMask)) + { + bar_stmt += ", memory_order_seq_cst"; + } + else + { + bar_stmt += ", memory_order_relaxed"; + } + + switch (mem_scope) + { + case ScopeDevice: + bar_stmt += ", thread_scope_device"; + break; + + case ScopeWorkgroup: + bar_stmt += ", thread_scope_threadgroup"; + break; + + case ScopeSubgroup: + bar_stmt += ", thread_scope_subgroup"; + break; + + case ScopeInvocation: + bar_stmt += ", thread_scope_thread"; + break; + + default: + // The default argument is device, which is conservative. + break; + } + } + bar_stmt += ");"; statement(bar_stmt); @@ -13663,9 +13846,17 @@ string CompilerMSL::get_argument_address_space(const SPIRVariable &argument) return get_type_address_space(type, argument.self, true); } -bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags) +bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags) const { - return flags.get(DecorationVolatile) || flags.get(DecorationCoherent); + // Using volatile for coherent pre-3.2 is definitely not correct, but it's something. + // MSL 3.2 adds actual coherent qualifiers. + return flags.get(DecorationVolatile) || + (flags.get(DecorationCoherent) && !msl_options.supports_msl_version(3, 2)); +} + +bool CompilerMSL::decoration_flags_signal_coherent(const Bitset &flags) const +{ + return flags.get(DecorationCoherent) && msl_options.supports_msl_version(3, 2); } string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument) @@ -13677,8 +13868,17 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock))) flags = get_buffer_block_flags(id); else + { flags = get_decoration_bitset(id); + if (type.basetype == SPIRType::Struct && + (has_decoration(type.self, DecorationBlock) || + has_decoration(type.self, DecorationBufferBlock))) + { + flags.merge_or(ir.get_buffer_block_type_flags(type)); + } + } + const char *addr_space = nullptr; switch (type.storage) { @@ -13687,7 +13887,6 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo break; case StorageClassStorageBuffer: - case StorageClassPhysicalStorageBuffer: { // For arguments from variable pointers, we use the write count deduction, so // we should not assume any constness here. Only for global SSBOs. @@ -13695,10 +13894,19 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo if (!var || has_decoration(type.self, DecorationBlock)) readonly = flags.get(DecorationNonWritable); + if (decoration_flags_signal_coherent(flags)) + readonly = false; + addr_space = readonly ? "const device" : "device"; break; } + case StorageClassPhysicalStorageBuffer: + // We cannot fully trust NonWritable coming from glslang due to a bug in buffer_reference handling. + // There isn't much gain in emitting const in C++ languages anyway. + addr_space = "device"; + break; + case StorageClassUniform: case StorageClassUniformConstant: case StorageClassPushConstant: @@ -13787,7 +13995,9 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : ""; } - if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread")) + if (decoration_flags_signal_coherent(flags) && strcmp(addr_space, "device") == 0) + return join("coherent device"); + else if (decoration_flags_signal_volatile(flags) && strcmp(addr_space, "thread") != 0) return join("volatile ", addr_space); else return addr_space; @@ -15411,7 +15621,8 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg) bool constref = !arg.alias_global_variable && !passed_by_value && is_pointer(var_type) && arg.write_count == 0; // Framebuffer fetch is plain value, const looks out of place, but it is not wrong. - if (type_is_msl_framebuffer_fetch(type)) + // readonly coming from glslang is not reliable in all cases. + if (type_is_msl_framebuffer_fetch(type) || type_storage == StorageClassPhysicalStorageBuffer) constref = false; else if (type_storage == StorageClassUniformConstant) constref = true; @@ -16639,6 +16850,10 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb // Otherwise it may be set based on whether the image is read from or written to within the shader. if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData) { + auto *p_var = maybe_get_backing_variable(id); + if (p_var && p_var->basevariable) + p_var = maybe_get(p_var->basevariable); + switch (img_type.access) { case AccessQualifierReadOnly: @@ -16655,9 +16870,6 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb default: { - auto *p_var = maybe_get_backing_variable(id); - if (p_var && p_var->basevariable) - p_var = maybe_get(p_var->basevariable); if (p_var && !has_decoration(p_var->self, DecorationNonWritable)) { img_type_name += ", access::"; @@ -16670,6 +16882,9 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool memb break; } } + + if (p_var && has_decoration(p_var->self, DecorationCoherent) && msl_options.supports_msl_version(3, 2)) + img_type_name += ", memory_coherence_device"; } img_type_name += ">"; @@ -16924,11 +17139,10 @@ case OpGroupNonUniform##op: \ emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \ else if (operation == GroupOperationClusteredReduce) \ { \ - /* Only cluster sizes of 4 are supported. */ \ uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ - if (cluster_size != 4) \ - SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ - emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \ + if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \ + add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \ + emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \ } \ else \ SPIRV_CROSS_THROW("Invalid group operation."); \ @@ -16953,11 +17167,10 @@ case OpGroupNonUniform##op: \ SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \ else if (operation == GroupOperationClusteredReduce) \ { \ - /* Only cluster sizes of 4 are supported. */ \ uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ - if (cluster_size != 4) \ - SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ - emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \ + if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \ + add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \ + emit_subgroup_cluster_op(result_type, id, cluster_size, ops[op_idx], #msl_op); \ } \ else \ SPIRV_CROSS_THROW("Invalid group operation."); \ @@ -16976,11 +17189,10 @@ case OpGroupNonUniform##op: \ SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \ else if (operation == GroupOperationClusteredReduce) \ { \ - /* Only cluster sizes of 4 are supported. */ \ uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \ - if (cluster_size != 4) \ - SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \ - emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \ + if (get_execution_model() != ExecutionModelFragment || msl_options.supports_msl_version(2, 2)) \ + add_spv_func_and_recompile(SPVFuncImplSubgroupClustered##op); \ + emit_subgroup_cluster_op_cast(result_type, id, cluster_size, ops[op_idx], #msl_op, type, type); \ } \ else \ SPIRV_CROSS_THROW("Invalid group operation."); \ @@ -16996,9 +17208,11 @@ case OpGroupNonUniform##op: \ MSL_GROUP_OP(BitwiseAnd, and) MSL_GROUP_OP(BitwiseOr, or) MSL_GROUP_OP(BitwiseXor, xor) - MSL_GROUP_OP(LogicalAnd, and) - MSL_GROUP_OP(LogicalOr, or) - MSL_GROUP_OP(LogicalXor, xor) + // Metal doesn't support boolean types in SIMD-group operations, so we + // have to emit some casts. + MSL_GROUP_OP_CAST(LogicalAnd, and, SPIRType::UShort) + MSL_GROUP_OP_CAST(LogicalOr, or, SPIRType::UShort) + MSL_GROUP_OP_CAST(LogicalXor, xor, SPIRType::UShort) // clang-format on #undef MSL_GROUP_OP #undef MSL_GROUP_OP_CAST @@ -17026,6 +17240,83 @@ case OpGroupNonUniform##op: \ register_control_dependent_expression(id); } +void CompilerMSL::emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, + uint32_t op0, const char *op) +{ + if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2)) + { + if (cluster_size == 4) + { + emit_unary_func_op(result_type, result_id, op0, join("quad_", op).c_str()); + return; + } + SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2."); + } + bool forward = should_forward(op0); + emit_op(result_type, result_id, + join("spvClustered_", op, "<", cluster_size, ">(", to_unpacked_expression(op0), ", ", + to_expression(builtin_subgroup_invocation_id_id), ")"), + forward); + inherit_expression_dependencies(result_id, op0); +} + +void CompilerMSL::emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, + uint32_t op0, const char *op, SPIRType::BaseType input_type, + SPIRType::BaseType expected_result_type) +{ + if (get_execution_model() == ExecutionModelFragment && !msl_options.supports_msl_version(2, 2)) + { + if (cluster_size == 4) + { + emit_unary_func_op_cast(result_type, result_id, op0, join("quad_", op).c_str(), input_type, + expected_result_type); + return; + } + SPIRV_CROSS_THROW("Cluster sizes other than 4 in fragment shaders require MSL 2.2."); + } + + auto &out_type = get(result_type); + auto &expr_type = expression_type(op0); + auto expected_type = out_type; + + // Bit-widths might be different in unary cases because we use it for SConvert/UConvert and friends. + expected_type.basetype = input_type; + expected_type.width = expr_type.width; + + string cast_op; + if (expr_type.basetype != input_type) + { + if (expr_type.basetype == SPIRType::Boolean) + cast_op = join(type_to_glsl(expected_type), "(", to_unpacked_expression(op0), ")"); + else + cast_op = bitcast_glsl(expected_type, op0); + } + else + cast_op = to_unpacked_expression(op0); + + string sg_op = join("spvClustered_", op, "<", cluster_size, ">"); + string expr; + if (out_type.basetype != expected_result_type) + { + expected_type.basetype = expected_result_type; + expected_type.width = out_type.width; + if (out_type.basetype == SPIRType::Boolean) + expr = type_to_glsl(out_type); + else + expr = bitcast_glsl_op(out_type, expected_type); + expr += '('; + expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")"); + expr += ')'; + } + else + { + expr += join(sg_op, "(", cast_op, ", ", to_expression(builtin_subgroup_invocation_id_id), ")"); + } + + emit_op(result_type, result_id, expr, should_forward(op0)); + inherit_expression_dependencies(result_id, op0); +} + string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) { if (out_type.basetype == in_type.basetype) @@ -18097,6 +18388,28 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui } break; + case OpGroupNonUniformFAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformIAdd: + case OpGroupNonUniformIMul: + case OpGroupNonUniformSMin: + case OpGroupNonUniformSMax: + case OpGroupNonUniformUMin: + case OpGroupNonUniformUMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformLogicalAnd: + case OpGroupNonUniformLogicalOr: + case OpGroupNonUniformLogicalXor: + if ((compiler.get_execution_model() != ExecutionModelFragment || + compiler.msl_options.supports_msl_version(2, 2)) && + args[3] == GroupOperationClusteredReduce) + needs_subgroup_invocation_id = true; + break; + case OpArrayLength: { auto *var = compiler.maybe_get_backing_variable(args[2]); diff --git a/3rdparty/spirv-cross/spirv_msl.hpp b/3rdparty/spirv-cross/spirv_msl.hpp index 7c7a364f8..a3d08bfcc 100644 --- a/3rdparty/spirv-cross/spirv_msl.hpp +++ b/3rdparty/spirv-cross/spirv_msl.hpp @@ -818,6 +818,29 @@ protected: SPVFuncImplSubgroupShuffleUp, SPVFuncImplSubgroupShuffleDown, SPVFuncImplSubgroupRotate, + SPVFuncImplSubgroupClusteredAdd, + SPVFuncImplSubgroupClusteredFAdd = SPVFuncImplSubgroupClusteredAdd, + SPVFuncImplSubgroupClusteredIAdd = SPVFuncImplSubgroupClusteredAdd, + SPVFuncImplSubgroupClusteredMul, + SPVFuncImplSubgroupClusteredFMul = SPVFuncImplSubgroupClusteredMul, + SPVFuncImplSubgroupClusteredIMul = SPVFuncImplSubgroupClusteredMul, + SPVFuncImplSubgroupClusteredMin, + SPVFuncImplSubgroupClusteredFMin = SPVFuncImplSubgroupClusteredMin, + SPVFuncImplSubgroupClusteredSMin = SPVFuncImplSubgroupClusteredMin, + SPVFuncImplSubgroupClusteredUMin = SPVFuncImplSubgroupClusteredMin, + SPVFuncImplSubgroupClusteredMax, + SPVFuncImplSubgroupClusteredFMax = SPVFuncImplSubgroupClusteredMax, + SPVFuncImplSubgroupClusteredSMax = SPVFuncImplSubgroupClusteredMax, + SPVFuncImplSubgroupClusteredUMax = SPVFuncImplSubgroupClusteredMax, + SPVFuncImplSubgroupClusteredAnd, + SPVFuncImplSubgroupClusteredBitwiseAnd = SPVFuncImplSubgroupClusteredAnd, + SPVFuncImplSubgroupClusteredLogicalAnd = SPVFuncImplSubgroupClusteredAnd, + SPVFuncImplSubgroupClusteredOr, + SPVFuncImplSubgroupClusteredBitwiseOr = SPVFuncImplSubgroupClusteredOr, + SPVFuncImplSubgroupClusteredLogicalOr = SPVFuncImplSubgroupClusteredOr, + SPVFuncImplSubgroupClusteredXor, + SPVFuncImplSubgroupClusteredBitwiseXor = SPVFuncImplSubgroupClusteredXor, + SPVFuncImplSubgroupClusteredLogicalXor = SPVFuncImplSubgroupClusteredXor, SPVFuncImplQuadBroadcast, SPVFuncImplQuadSwap, SPVFuncImplReflectScalar, @@ -871,6 +894,11 @@ protected: void emit_function_prototype(SPIRFunction &func, const Bitset &return_flags) override; void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id) override; void emit_subgroup_op(const Instruction &i) override; + void emit_subgroup_cluster_op(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0, + const char *op); + void emit_subgroup_cluster_op_cast(uint32_t result_type, uint32_t result_id, uint32_t cluster_size, uint32_t op0, + const char *op, SPIRType::BaseType input_type, + SPIRType::BaseType expected_result_type); std::string to_texture_op(const Instruction &i, bool sparse, bool *forward, SmallVector &inherited_expressions) override; void emit_fixup() override; @@ -1084,7 +1112,8 @@ protected: bool validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const; std::string get_argument_address_space(const SPIRVariable &argument); std::string get_type_address_space(const SPIRType &type, uint32_t id, bool argument = false); - static bool decoration_flags_signal_volatile(const Bitset &flags); + bool decoration_flags_signal_volatile(const Bitset &flags) const; + bool decoration_flags_signal_coherent(const Bitset &flags) const; const char *to_restrict(uint32_t id, bool space); SPIRType &get_stage_in_struct_type(); SPIRType &get_stage_out_struct_type(); @@ -1154,7 +1183,7 @@ protected: bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage, bool &is_packed) override; void fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length); - void check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override; + bool check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type) override; bool emit_tessellation_access_chain(const uint32_t *ops, uint32_t length); bool emit_tessellation_io_load(uint32_t result_type, uint32_t id, uint32_t ptr);