Updated spirv-tools.

This commit is contained in:
Бранимир Караџић
2019-06-06 22:02:38 -07:00
parent 6eb0c7e224
commit 77823a3ff9
180 changed files with 16816 additions and 25402 deletions

View File

@@ -59,11 +59,11 @@ build:
build_script:
- mkdir build && cd build
- cmake -GNinja -DSPIRV_BUILD_COMPRESSION=ON -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
- cmake -GNinja -DCMAKE_BUILD_TYPE=%CONFIGURATION% -DCMAKE_INSTALL_PREFIX=install -DRE2_BUILD_TESTING=OFF ..
- ninja install
test_script:
- ctest -C %CONFIGURATION% --output-on-failure --timeout 300
- ctest -C %CONFIGURATION% --output-on-failure --timeout 310
after_test:
# Zip build artifacts for uploading and deploying

View File

@@ -9,6 +9,7 @@ compile_commands.json
/external/spirv-headers
/external/effcee
/external/re2
/external/protobuf
/out
/TAGS
/third_party/llvm-build/

View File

@@ -13,7 +13,6 @@ SPVTOOLS_SRC_FILES := \
source/ext_inst.cpp \
source/enum_string_mapping.cpp \
source/extensions.cpp \
source/id_descriptor.cpp \
source/libspirv.cpp \
source/name_mapper.cpp \
source/opcode.cpp \
@@ -63,6 +62,7 @@ SPVTOOLS_SRC_FILES := \
source/val/validate_instruction.cpp \
source/val/validate_memory.cpp \
source/val/validate_memory_semantics.cpp \
source/val/validate_misc.cpp \
source/val/validate_mode_setting.cpp \
source/val/validate_layout.cpp \
source/val/validate_literals.cpp \

View File

@@ -422,6 +422,7 @@ static_library("spvtools_val") {
"source/val/validate_logicals.cpp",
"source/val/validate_memory.cpp",
"source/val/validate_memory_semantics.cpp",
"source/val/validate_misc.cpp",
"source/val/validate_mode_setting.cpp",
"source/val/validate_non_uniform.cpp",
"source/val/validate_primitives.cpp",

View File

@@ -69,6 +69,10 @@ if(NOT ${SKIP_SPIRV_TOOLS_INSTALL})
endif()
option(SPIRV_BUILD_COMPRESSION "Build SPIR-V compressing codec" OFF)
if(SPIRV_BUILD_COMPRESSION)
message(FATAL_ERROR "SPIR-V compression codec has been removed from SPIR-V tools. "
"Please remove SPIRV_BUILD_COMPRESSION from your build options.")
endif(SPIRV_BUILD_COMPRESSION)
option(SPIRV_WERROR "Enable error on warning" ON)
if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR (("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") AND (NOT CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC")))
@@ -257,9 +261,6 @@ endif()
set(SPIRV_LIBRARIES "-lSPIRV-Tools -lSPIRV-Tools-link -lSPIRV-Tools-opt")
set(SPIRV_SHARED_LIBRARIES "-lSPIRV-Tools-shared")
if(SPIRV_BUILD_COMPRESSION)
set(SPIRV_LIBRARIES "${SPIRV_LIBRARIES} -lSPIRV-Tools-comp")
endif(SPIRV_BUILD_COMPRESSION)
# Build pkg-config file
# Use a first-class target so it's regenerated when relevant files are updated.

View File

@@ -11,7 +11,7 @@ vars = {
'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036',
'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59',
're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f',
'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab',
'spirv_headers_revision': '8b911bd2ba37677037b38c9bd286c7c05701bcda',
}
deps = {

View File

@@ -307,8 +307,6 @@ The following CMake options are supported:
the command line tools. This will prevent the tests from being built.
* `SPIRV_SKIP_EXECUTABLES={ON|OFF}`, default `OFF`- Build only the library, not
the command line tools and tests.
* `SPIRV_BUILD_COMPRESSION={ON|OFF}`, default `OFF`- Build SPIR-V compressing
codec.
* `SPIRV_USE_SANITIZER=<sanitizer>`, default is no sanitizing - On UNIX
platforms with an appropriate version of `clang` this option enables the use
of the sanitizers documented [here][clang-sanitizers].

View File

@@ -102,3 +102,15 @@ if (NOT ${SPIRV_SKIP_TESTS})
endif()
endif()
endif()
if(SPIRV_BUILD_FUZZER)
set(PROTOBUF_DIR ${CMAKE_CURRENT_SOURCE_DIR}/protobuf/cmake)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disable protobuf tests")
set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "Do not build protobuf static runtime")
if (IS_DIRECTORY ${PROTOBUF_DIR})
add_subdirectory(${PROTOBUF_DIR} EXCLUDE_FROM_ALL)
else()
message(FATAL_ERROR
"protobuf not found - please checkout a copy under external/.")
endif()
endif(SPIRV_BUILD_FUZZER)

View File

@@ -1 +1 @@
"v2019.4-dev", "SPIRV-Tools v2019.4-dev v2019.3-6-g47741f0"
"v2019.4-dev", "SPIRV-Tools v2019.4-dev v2019.3-41-ga8ae579f"

View File

@@ -6,6 +6,7 @@ static const SpvCapability pygen_variable_caps_CooperativeMatrixNV[] = {SpvCapab
static const SpvCapability pygen_variable_caps_DerivativeControl[] = {SpvCapabilityDerivativeControl};
static const SpvCapability pygen_variable_caps_DeviceEnqueue[] = {SpvCapabilityDeviceEnqueue};
static const SpvCapability pygen_variable_caps_FragmentMaskAMD[] = {SpvCapabilityFragmentMaskAMD};
static const SpvCapability pygen_variable_caps_FragmentShaderSampleInterlockEXTFragmentShaderPixelInterlockEXTFragmentShaderShadingRateInterlockEXT[] = {SpvCapabilityFragmentShaderSampleInterlockEXT, SpvCapabilityFragmentShaderPixelInterlockEXT, SpvCapabilityFragmentShaderShadingRateInterlockEXT};
static const SpvCapability pygen_variable_caps_Geometry[] = {SpvCapabilityGeometry};
static const SpvCapability pygen_variable_caps_GeometryStreams[] = {SpvCapabilityGeometryStreams};
static const SpvCapability pygen_variable_caps_GroupNonUniform[] = {SpvCapabilityGroupNonUniform};
@@ -44,6 +45,7 @@ static const SpvCapability pygen_variable_caps_SubgroupVoteKHR[] = {SpvCapabilit
static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_ballot[] = {spvtools::Extension::kSPV_AMD_shader_ballot};
static const spvtools::Extension pygen_variable_exts_SPV_AMD_shader_fragment_mask[] = {spvtools::Extension::kSPV_AMD_shader_fragment_mask};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_fragment_shader_interlock[] = {spvtools::Extension::kSPV_EXT_fragment_shader_interlock};
static const spvtools::Extension pygen_variable_exts_SPV_GOOGLE_decorate_stringSPV_GOOGLE_hlsl_functionality1[] = {spvtools::Extension::kSPV_GOOGLE_decorate_string, spvtools::Extension::kSPV_GOOGLE_hlsl_functionality1};
static const spvtools::Extension pygen_variable_exts_SPV_GOOGLE_hlsl_functionality1[] = {spvtools::Extension::kSPV_GOOGLE_hlsl_functionality1};
static const spvtools::Extension pygen_variable_exts_SPV_KHR_shader_ballot[] = {spvtools::Extension::kSPV_KHR_shader_ballot};
@@ -429,6 +431,8 @@ static const spv_opcode_desc_t kOpcodeTableEntries[] = {
{"CooperativeMatrixStoreNV", SpvOpCooperativeMatrixStoreNV, 1, pygen_variable_caps_CooperativeMatrixNV, 5, {SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS}, 0, 0, 1, pygen_variable_exts_SPV_NV_cooperative_matrix, 0xffffffffu, 0xffffffffu},
{"CooperativeMatrixMulAddNV", SpvOpCooperativeMatrixMulAddNV, 1, pygen_variable_caps_CooperativeMatrixNV, 5, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 1, pygen_variable_exts_SPV_NV_cooperative_matrix, 0xffffffffu, 0xffffffffu},
{"CooperativeMatrixLengthNV", SpvOpCooperativeMatrixLengthNV, 1, pygen_variable_caps_CooperativeMatrixNV, 3, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 1, pygen_variable_exts_SPV_NV_cooperative_matrix, 0xffffffffu, 0xffffffffu},
{"BeginInvocationInterlockEXT", SpvOpBeginInvocationInterlockEXT, 3, pygen_variable_caps_FragmentShaderSampleInterlockEXTFragmentShaderPixelInterlockEXTFragmentShaderShadingRateInterlockEXT, 0, {}, 0, 0, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, 0xffffffffu, 0xffffffffu},
{"EndInvocationInterlockEXT", SpvOpEndInvocationInterlockEXT, 3, pygen_variable_caps_FragmentShaderSampleInterlockEXTFragmentShaderPixelInterlockEXTFragmentShaderShadingRateInterlockEXT, 0, {}, 0, 0, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, 0xffffffffu, 0xffffffffu},
{"SubgroupShuffleINTEL", SpvOpSubgroupShuffleINTEL, 1, pygen_variable_caps_SubgroupShuffleINTEL, 4, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},
{"SubgroupShuffleDownINTEL", SpvOpSubgroupShuffleDownINTEL, 1, pygen_variable_caps_SubgroupShuffleINTEL, 5, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},
{"SubgroupShuffleUpINTEL", SpvOpSubgroupShuffleUpINTEL, 1, pygen_variable_caps_SubgroupShuffleINTEL, 5, {SPV_OPERAND_TYPE_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_ID}, 1, 1, 0, nullptr, 0xffffffffu, 0xffffffffu},

View File

@@ -26,6 +26,8 @@ const char* ExtensionToString(Extension extension) {
return "SPV_EXT_fragment_fully_covered";
case Extension::kSPV_EXT_fragment_invocation_density:
return "SPV_EXT_fragment_invocation_density";
case Extension::kSPV_EXT_fragment_shader_interlock:
return "SPV_EXT_fragment_shader_interlock";
case Extension::kSPV_EXT_physical_storage_buffer:
return "SPV_EXT_physical_storage_buffer";
case Extension::kSPV_EXT_shader_stencil_export:
@@ -90,6 +92,8 @@ const char* ExtensionToString(Extension extension) {
return "SPV_NV_sample_mask_override_coverage";
case Extension::kSPV_NV_shader_image_footprint:
return "SPV_NV_shader_image_footprint";
case Extension::kSPV_NV_shader_sm_builtins:
return "SPV_NV_shader_sm_builtins";
case Extension::kSPV_NV_shader_subgroup_partitioned:
return "SPV_NV_shader_subgroup_partitioned";
case Extension::kSPV_NV_shading_rate:
@@ -107,8 +111,8 @@ const char* ExtensionToString(Extension extension) {
bool GetExtensionFromString(const char* str, Extension* extension) {
static const char* known_ext_strs[] = { "SPV_AMD_gcn_shader", "SPV_AMD_gpu_shader_half_float", "SPV_AMD_gpu_shader_half_float_fetch", "SPV_AMD_gpu_shader_int16", "SPV_AMD_shader_ballot", "SPV_AMD_shader_explicit_vertex_parameter", "SPV_AMD_shader_fragment_mask", "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_trinary_minmax", "SPV_AMD_texture_gather_bias_lod", "SPV_EXT_descriptor_indexing", "SPV_EXT_fragment_fully_covered", "SPV_EXT_fragment_invocation_density", "SPV_EXT_physical_storage_buffer", "SPV_EXT_shader_stencil_export", "SPV_EXT_shader_viewport_index_layer", "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", "SPV_INTEL_device_side_avc_motion_estimation", "SPV_INTEL_media_block_io", "SPV_INTEL_shader_integer_functions2", "SPV_INTEL_subgroups", "SPV_KHR_16bit_storage", "SPV_KHR_8bit_storage", "SPV_KHR_device_group", "SPV_KHR_float_controls", "SPV_KHR_multiview", "SPV_KHR_no_integer_wrap_decoration", "SPV_KHR_post_depth_coverage", "SPV_KHR_shader_atomic_counter_ops", "SPV_KHR_shader_ballot", "SPV_KHR_shader_draw_parameters", "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_subgroup_vote", "SPV_KHR_variable_pointers", "SPV_KHR_vulkan_memory_model", "SPV_NVX_multiview_per_view_attributes", "SPV_NV_compute_shader_derivatives", "SPV_NV_cooperative_matrix", "SPV_NV_fragment_shader_barycentric", "SPV_NV_geometry_shader_passthrough", "SPV_NV_mesh_shader", "SPV_NV_ray_tracing", "SPV_NV_sample_mask_override_coverage", "SPV_NV_shader_image_footprint", "SPV_NV_shader_subgroup_partitioned", "SPV_NV_shading_rate", "SPV_NV_stereo_view_rendering", "SPV_NV_viewport_array2", "SPV_VALIDATOR_ignore_type_decl_unique" };
static const Extension known_ext_ids[] = { Extension::kSPV_AMD_gcn_shader, Extension::kSPV_AMD_gpu_shader_half_float, Extension::kSPV_AMD_gpu_shader_half_float_fetch, Extension::kSPV_AMD_gpu_shader_int16, Extension::kSPV_AMD_shader_ballot, Extension::kSPV_AMD_shader_explicit_vertex_parameter, Extension::kSPV_AMD_shader_fragment_mask, Extension::kSPV_AMD_shader_image_load_store_lod, Extension::kSPV_AMD_shader_trinary_minmax, Extension::kSPV_AMD_texture_gather_bias_lod, Extension::kSPV_EXT_descriptor_indexing, Extension::kSPV_EXT_fragment_fully_covered, Extension::kSPV_EXT_fragment_invocation_density, Extension::kSPV_EXT_physical_storage_buffer, Extension::kSPV_EXT_shader_stencil_export, Extension::kSPV_EXT_shader_viewport_index_layer, Extension::kSPV_GOOGLE_decorate_string, Extension::kSPV_GOOGLE_hlsl_functionality1, Extension::kSPV_INTEL_device_side_avc_motion_estimation, Extension::kSPV_INTEL_media_block_io, Extension::kSPV_INTEL_shader_integer_functions2, Extension::kSPV_INTEL_subgroups, Extension::kSPV_KHR_16bit_storage, Extension::kSPV_KHR_8bit_storage, Extension::kSPV_KHR_device_group, Extension::kSPV_KHR_float_controls, Extension::kSPV_KHR_multiview, Extension::kSPV_KHR_no_integer_wrap_decoration, Extension::kSPV_KHR_post_depth_coverage, Extension::kSPV_KHR_shader_atomic_counter_ops, Extension::kSPV_KHR_shader_ballot, Extension::kSPV_KHR_shader_draw_parameters, Extension::kSPV_KHR_storage_buffer_storage_class, Extension::kSPV_KHR_subgroup_vote, Extension::kSPV_KHR_variable_pointers, Extension::kSPV_KHR_vulkan_memory_model, Extension::kSPV_NVX_multiview_per_view_attributes, Extension::kSPV_NV_compute_shader_derivatives, Extension::kSPV_NV_cooperative_matrix, Extension::kSPV_NV_fragment_shader_barycentric, Extension::kSPV_NV_geometry_shader_passthrough, Extension::kSPV_NV_mesh_shader, Extension::kSPV_NV_ray_tracing, Extension::kSPV_NV_sample_mask_override_coverage, Extension::kSPV_NV_shader_image_footprint, Extension::kSPV_NV_shader_subgroup_partitioned, Extension::kSPV_NV_shading_rate, Extension::kSPV_NV_stereo_view_rendering, Extension::kSPV_NV_viewport_array2, Extension::kSPV_VALIDATOR_ignore_type_decl_unique };
static const char* known_ext_strs[] = { "SPV_AMD_gcn_shader", "SPV_AMD_gpu_shader_half_float", "SPV_AMD_gpu_shader_half_float_fetch", "SPV_AMD_gpu_shader_int16", "SPV_AMD_shader_ballot", "SPV_AMD_shader_explicit_vertex_parameter", "SPV_AMD_shader_fragment_mask", "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_trinary_minmax", "SPV_AMD_texture_gather_bias_lod", "SPV_EXT_descriptor_indexing", "SPV_EXT_fragment_fully_covered", "SPV_EXT_fragment_invocation_density", "SPV_EXT_fragment_shader_interlock", "SPV_EXT_physical_storage_buffer", "SPV_EXT_shader_stencil_export", "SPV_EXT_shader_viewport_index_layer", "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1", "SPV_INTEL_device_side_avc_motion_estimation", "SPV_INTEL_media_block_io", "SPV_INTEL_shader_integer_functions2", "SPV_INTEL_subgroups", "SPV_KHR_16bit_storage", "SPV_KHR_8bit_storage", "SPV_KHR_device_group", "SPV_KHR_float_controls", "SPV_KHR_multiview", "SPV_KHR_no_integer_wrap_decoration", "SPV_KHR_post_depth_coverage", "SPV_KHR_shader_atomic_counter_ops", "SPV_KHR_shader_ballot", "SPV_KHR_shader_draw_parameters", "SPV_KHR_storage_buffer_storage_class", "SPV_KHR_subgroup_vote", "SPV_KHR_variable_pointers", "SPV_KHR_vulkan_memory_model", "SPV_NVX_multiview_per_view_attributes", "SPV_NV_compute_shader_derivatives", "SPV_NV_cooperative_matrix", "SPV_NV_fragment_shader_barycentric", "SPV_NV_geometry_shader_passthrough", "SPV_NV_mesh_shader", "SPV_NV_ray_tracing", "SPV_NV_sample_mask_override_coverage", "SPV_NV_shader_image_footprint", "SPV_NV_shader_sm_builtins", "SPV_NV_shader_subgroup_partitioned", "SPV_NV_shading_rate", "SPV_NV_stereo_view_rendering", "SPV_NV_viewport_array2", "SPV_VALIDATOR_ignore_type_decl_unique" };
static const Extension known_ext_ids[] = { Extension::kSPV_AMD_gcn_shader, Extension::kSPV_AMD_gpu_shader_half_float, Extension::kSPV_AMD_gpu_shader_half_float_fetch, Extension::kSPV_AMD_gpu_shader_int16, Extension::kSPV_AMD_shader_ballot, Extension::kSPV_AMD_shader_explicit_vertex_parameter, Extension::kSPV_AMD_shader_fragment_mask, Extension::kSPV_AMD_shader_image_load_store_lod, Extension::kSPV_AMD_shader_trinary_minmax, Extension::kSPV_AMD_texture_gather_bias_lod, Extension::kSPV_EXT_descriptor_indexing, Extension::kSPV_EXT_fragment_fully_covered, Extension::kSPV_EXT_fragment_invocation_density, Extension::kSPV_EXT_fragment_shader_interlock, Extension::kSPV_EXT_physical_storage_buffer, Extension::kSPV_EXT_shader_stencil_export, Extension::kSPV_EXT_shader_viewport_index_layer, Extension::kSPV_GOOGLE_decorate_string, Extension::kSPV_GOOGLE_hlsl_functionality1, Extension::kSPV_INTEL_device_side_avc_motion_estimation, Extension::kSPV_INTEL_media_block_io, Extension::kSPV_INTEL_shader_integer_functions2, Extension::kSPV_INTEL_subgroups, Extension::kSPV_KHR_16bit_storage, Extension::kSPV_KHR_8bit_storage, Extension::kSPV_KHR_device_group, Extension::kSPV_KHR_float_controls, Extension::kSPV_KHR_multiview, Extension::kSPV_KHR_no_integer_wrap_decoration, Extension::kSPV_KHR_post_depth_coverage, Extension::kSPV_KHR_shader_atomic_counter_ops, Extension::kSPV_KHR_shader_ballot, Extension::kSPV_KHR_shader_draw_parameters, Extension::kSPV_KHR_storage_buffer_storage_class, Extension::kSPV_KHR_subgroup_vote, Extension::kSPV_KHR_variable_pointers, Extension::kSPV_KHR_vulkan_memory_model, Extension::kSPV_NVX_multiview_per_view_attributes, Extension::kSPV_NV_compute_shader_derivatives, Extension::kSPV_NV_cooperative_matrix, Extension::kSPV_NV_fragment_shader_barycentric, Extension::kSPV_NV_geometry_shader_passthrough, Extension::kSPV_NV_mesh_shader, Extension::kSPV_NV_ray_tracing, Extension::kSPV_NV_sample_mask_override_coverage, Extension::kSPV_NV_shader_image_footprint, Extension::kSPV_NV_shader_sm_builtins, Extension::kSPV_NV_shader_subgroup_partitioned, Extension::kSPV_NV_shading_rate, Extension::kSPV_NV_stereo_view_rendering, Extension::kSPV_NV_viewport_array2, Extension::kSPV_VALIDATOR_ignore_type_decl_unique };
const auto b = std::begin(known_ext_strs);
const auto e = std::end(known_ext_strs);
const auto found = std::equal_range(
@@ -388,6 +392,14 @@ const char* CapabilityToString(SpvCapability capability) {
return "PhysicalStorageBufferAddressesEXT";
case SpvCapabilityCooperativeMatrixNV:
return "CooperativeMatrixNV";
case SpvCapabilityFragmentShaderSampleInterlockEXT:
return "FragmentShaderSampleInterlockEXT";
case SpvCapabilityFragmentShaderShadingRateInterlockEXT:
return "FragmentShaderShadingRateInterlockEXT";
case SpvCapabilityFragmentShaderPixelInterlockEXT:
return "FragmentShaderPixelInterlockEXT";
case SpvCapabilityShaderSMBuiltinsNV:
return "ShaderSMBuiltinsNV";
case SpvCapabilityMax:
assert(0 && "Attempting to convert SpvCapabilityMax to string");
return "";

View File

@@ -11,6 +11,7 @@ kSPV_AMD_texture_gather_bias_lod,
kSPV_EXT_descriptor_indexing,
kSPV_EXT_fragment_fully_covered,
kSPV_EXT_fragment_invocation_density,
kSPV_EXT_fragment_shader_interlock,
kSPV_EXT_physical_storage_buffer,
kSPV_EXT_shader_stencil_export,
kSPV_EXT_shader_viewport_index_layer,
@@ -43,6 +44,7 @@ kSPV_NV_mesh_shader,
kSPV_NV_ray_tracing,
kSPV_NV_sample_mask_override_coverage,
kSPV_NV_shader_image_footprint,
kSPV_NV_shader_sm_builtins,
kSPV_NV_shader_subgroup_partitioned,
kSPV_NV_shading_rate,
kSPV_NV_stereo_view_rendering,

View File

@@ -13,6 +13,9 @@ static const SpvCapability pygen_variable_caps_DrawParametersMeshShadingNV[] = {
static const SpvCapability pygen_variable_caps_FragmentBarycentricNV[] = {SpvCapabilityFragmentBarycentricNV};
static const SpvCapability pygen_variable_caps_FragmentDensityEXTShadingRateNV[] = {SpvCapabilityFragmentDensityEXT, SpvCapabilityShadingRateNV};
static const SpvCapability pygen_variable_caps_FragmentFullyCoveredEXT[] = {SpvCapabilityFragmentFullyCoveredEXT};
static const SpvCapability pygen_variable_caps_FragmentShaderPixelInterlockEXT[] = {SpvCapabilityFragmentShaderPixelInterlockEXT};
static const SpvCapability pygen_variable_caps_FragmentShaderSampleInterlockEXT[] = {SpvCapabilityFragmentShaderSampleInterlockEXT};
static const SpvCapability pygen_variable_caps_FragmentShaderShadingRateInterlockEXT[] = {SpvCapabilityFragmentShaderShadingRateInterlockEXT};
static const SpvCapability pygen_variable_caps_GenericPointer[] = {SpvCapabilityGenericPointer};
static const SpvCapability pygen_variable_caps_Geometry[] = {SpvCapabilityGeometry};
static const SpvCapability pygen_variable_caps_GeometryMeshShadingNV[] = {SpvCapabilityGeometry, SpvCapabilityMeshShadingNV};
@@ -63,6 +66,7 @@ static const SpvCapability pygen_variable_caps_ShaderImageCubeArray[] = {SpvCapa
static const SpvCapability pygen_variable_caps_ShaderKernel[] = {SpvCapabilityShader, SpvCapabilityKernel};
static const SpvCapability pygen_variable_caps_ShaderKernelImageMSArray[] = {SpvCapabilityShader, SpvCapabilityKernel, SpvCapabilityImageMSArray};
static const SpvCapability pygen_variable_caps_ShaderNonUniformEXT[] = {SpvCapabilityShaderNonUniformEXT};
static const SpvCapability pygen_variable_caps_ShaderSMBuiltinsNV[] = {SpvCapabilityShaderSMBuiltinsNV};
static const SpvCapability pygen_variable_caps_ShaderStereoViewNV[] = {SpvCapabilityShaderStereoViewNV};
static const SpvCapability pygen_variable_caps_ShaderViewportIndexLayerNV[] = {SpvCapabilityShaderViewportIndexLayerNV};
static const SpvCapability pygen_variable_caps_ShaderViewportMaskNV[] = {SpvCapabilityShaderViewportMaskNV};
@@ -89,6 +93,7 @@ static const spvtools::Extension pygen_variable_exts_SPV_AMD_texture_gather_bias
static const spvtools::Extension pygen_variable_exts_SPV_EXT_descriptor_indexing[] = {spvtools::Extension::kSPV_EXT_descriptor_indexing};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_fragment_fully_covered[] = {spvtools::Extension::kSPV_EXT_fragment_fully_covered};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_fragment_invocation_densitySPV_NV_shading_rate[] = {spvtools::Extension::kSPV_EXT_fragment_invocation_density, spvtools::Extension::kSPV_NV_shading_rate};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_fragment_shader_interlock[] = {spvtools::Extension::kSPV_EXT_fragment_shader_interlock};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_physical_storage_buffer[] = {spvtools::Extension::kSPV_EXT_physical_storage_buffer};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_shader_stencil_export[] = {spvtools::Extension::kSPV_EXT_shader_stencil_export};
static const spvtools::Extension pygen_variable_exts_SPV_EXT_shader_viewport_index_layer[] = {spvtools::Extension::kSPV_EXT_shader_viewport_index_layer};
@@ -123,6 +128,7 @@ static const spvtools::Extension pygen_variable_exts_SPV_NV_mesh_shader[] = {spv
static const spvtools::Extension pygen_variable_exts_SPV_NV_ray_tracing[] = {spvtools::Extension::kSPV_NV_ray_tracing};
static const spvtools::Extension pygen_variable_exts_SPV_NV_sample_mask_override_coverage[] = {spvtools::Extension::kSPV_NV_sample_mask_override_coverage};
static const spvtools::Extension pygen_variable_exts_SPV_NV_shader_image_footprint[] = {spvtools::Extension::kSPV_NV_shader_image_footprint};
static const spvtools::Extension pygen_variable_exts_SPV_NV_shader_sm_builtins[] = {spvtools::Extension::kSPV_NV_shader_sm_builtins};
static const spvtools::Extension pygen_variable_exts_SPV_NV_shader_subgroup_partitioned[] = {spvtools::Extension::kSPV_NV_shader_subgroup_partitioned};
static const spvtools::Extension pygen_variable_exts_SPV_NV_shading_rateSPV_EXT_fragment_invocation_density[] = {spvtools::Extension::kSPV_NV_shading_rate, spvtools::Extension::kSPV_EXT_fragment_invocation_density};
static const spvtools::Extension pygen_variable_exts_SPV_NV_stereo_view_rendering[] = {spvtools::Extension::kSPV_NV_stereo_view_rendering};
@@ -307,7 +313,13 @@ static const spv_operand_desc_t pygen_variable_ExecutionModeEntries[] = {
{"OutputPrimitivesNV", 5270, 1, pygen_variable_caps_MeshShadingNV, 1, pygen_variable_exts_SPV_NV_mesh_shader, {SPV_OPERAND_TYPE_LITERAL_INTEGER}, 0xffffffffu, 0xffffffffu},
{"DerivativeGroupQuadsNV", 5289, 1, pygen_variable_caps_ComputeDerivativeGroupQuadsNV, 1, pygen_variable_exts_SPV_NV_compute_shader_derivatives, {}, 0xffffffffu, 0xffffffffu},
{"DerivativeGroupLinearNV", 5290, 1, pygen_variable_caps_ComputeDerivativeGroupLinearNV, 1, pygen_variable_exts_SPV_NV_compute_shader_derivatives, {}, 0xffffffffu, 0xffffffffu},
{"OutputTrianglesNV", 5298, 1, pygen_variable_caps_MeshShadingNV, 1, pygen_variable_exts_SPV_NV_mesh_shader, {}, 0xffffffffu, 0xffffffffu}
{"OutputTrianglesNV", 5298, 1, pygen_variable_caps_MeshShadingNV, 1, pygen_variable_exts_SPV_NV_mesh_shader, {}, 0xffffffffu, 0xffffffffu},
{"PixelInterlockOrderedEXT", 5366, 1, pygen_variable_caps_FragmentShaderPixelInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"PixelInterlockUnorderedEXT", 5367, 1, pygen_variable_caps_FragmentShaderPixelInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"SampleInterlockOrderedEXT", 5368, 1, pygen_variable_caps_FragmentShaderSampleInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"SampleInterlockUnorderedEXT", 5369, 1, pygen_variable_caps_FragmentShaderSampleInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"ShadingRateInterlockOrderedEXT", 5370, 1, pygen_variable_caps_FragmentShaderShadingRateInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"ShadingRateInterlockUnorderedEXT", 5371, 1, pygen_variable_caps_FragmentShaderShadingRateInterlockEXT, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu}
};
static const spv_operand_desc_t pygen_variable_StorageClassEntries[] = {
@@ -637,7 +649,11 @@ static const spv_operand_desc_t pygen_variable_BuiltInEntries[] = {
{"WorldToObjectNV", 5331, 1, pygen_variable_caps_RayTracingNV, 1, pygen_variable_exts_SPV_NV_ray_tracing, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"HitTNV", 5332, 1, pygen_variable_caps_RayTracingNV, 1, pygen_variable_exts_SPV_NV_ray_tracing, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"HitKindNV", 5333, 1, pygen_variable_caps_RayTracingNV, 1, pygen_variable_exts_SPV_NV_ray_tracing, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"IncomingRayFlagsNV", 5351, 1, pygen_variable_caps_RayTracingNV, 1, pygen_variable_exts_SPV_NV_ray_tracing, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
{"IncomingRayFlagsNV", 5351, 1, pygen_variable_caps_RayTracingNV, 1, pygen_variable_exts_SPV_NV_ray_tracing, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"WarpsPerSMNV", 5374, 1, pygen_variable_caps_ShaderSMBuiltinsNV, 1, pygen_variable_exts_SPV_NV_shader_sm_builtins, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"SMCountNV", 5375, 1, pygen_variable_caps_ShaderSMBuiltinsNV, 1, pygen_variable_exts_SPV_NV_shader_sm_builtins, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"WarpIDNV", 5376, 1, pygen_variable_caps_ShaderSMBuiltinsNV, 1, pygen_variable_exts_SPV_NV_shader_sm_builtins, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu},
{"SMIDNV", 5377, 1, pygen_variable_caps_ShaderSMBuiltinsNV, 1, pygen_variable_exts_SPV_NV_shader_sm_builtins, {}, SPV_SPIRV_VERSION_WORD(1, 0), 0xffffffffu}
};
static const spv_operand_desc_t pygen_variable_ScopeEntries[] = {
@@ -794,6 +810,10 @@ static const spv_operand_desc_t pygen_variable_CapabilityEntries[] = {
{"PhysicalStorageBufferAddressesEXT", 5347, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_EXT_physical_storage_buffer, {}, 0xffffffffu, 0xffffffffu},
{"ComputeDerivativeGroupLinearNV", 5350, 0, nullptr, 1, pygen_variable_exts_SPV_NV_compute_shader_derivatives, {}, 0xffffffffu, 0xffffffffu},
{"CooperativeMatrixNV", 5357, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_NV_cooperative_matrix, {}, 0xffffffffu, 0xffffffffu},
{"FragmentShaderSampleInterlockEXT", 5363, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"FragmentShaderShadingRateInterlockEXT", 5372, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"ShaderSMBuiltinsNV", 5373, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_NV_shader_sm_builtins, {}, 0xffffffffu, 0xffffffffu},
{"FragmentShaderPixelInterlockEXT", 5378, 1, pygen_variable_caps_Shader, 1, pygen_variable_exts_SPV_EXT_fragment_shader_interlock, {}, 0xffffffffu, 0xffffffffu},
{"SubgroupShuffleINTEL", 5568, 0, nullptr, 1, pygen_variable_exts_SPV_INTEL_subgroups, {}, 0xffffffffu, 0xffffffffu},
{"SubgroupBufferBlockIOINTEL", 5569, 0, nullptr, 1, pygen_variable_exts_SPV_INTEL_subgroups, {}, 0xffffffffu, 0xffffffffu},
{"SubgroupImageBlockIOINTEL", 5570, 0, nullptr, 1, pygen_variable_exts_SPV_INTEL_subgroups, {}, 0xffffffffu, 0xffffffffu},

View File

@@ -34,6 +34,11 @@ namespace spvtools {
// generated by InstrumentPass::GenDebugStreamWrite. This method is utilized
// by InstBindlessCheckPass.
//
// kInst2* values support version 2 of the output record format. These should
// be used if available and version 2 is enabled. Version 1 is DEPRECATED.
// Specifically, version 1 uses two words for the stage-specific section of
// the output record; version 2 uses three words.
//
// The first member of the debug output buffer contains the next available word
// in the data stream to be written. Shaders will atomically read and update
// this value so as not to overwrite each others records. This value must be
@@ -70,38 +75,58 @@ static const int kInstCommonOutCnt = 4;
// Stage-specific Stream Record Offsets
//
// Each stage will contain different values in the next two words of the record
// used to identify which instantiation of the shader generated the validation
// error.
// Each stage will contain different values in the next set of words of the
// record used to identify which instantiation of the shader generated the
// validation error.
//
// Vertex Shader Output Record Offsets
static const int kInstVertOutVertexIndex = kInstCommonOutCnt;
static const int kInstVertOutInstanceIndex = kInstCommonOutCnt + 1;
static const int kInstVertOutUnused = kInstCommonOutCnt + 2;
// Frag Shader Output Record Offsets
static const int kInstFragOutFragCoordX = kInstCommonOutCnt;
static const int kInstFragOutFragCoordY = kInstCommonOutCnt + 1;
static const int kInstFragOutUnused = kInstCommonOutCnt + 2;
// Compute Shader Output Record Offsets
static const int kInstCompOutGlobalInvocationIdX = kInstCommonOutCnt;
static const int kInstCompOutGlobalInvocationIdY = kInstCommonOutCnt + 1;
static const int kInstCompOutGlobalInvocationIdZ = kInstCommonOutCnt + 2;
// Compute Shader Output Record Offsets - Version 1 (DEPRECATED)
static const int kInstCompOutGlobalInvocationId = kInstCommonOutCnt;
static const int kInstCompOutUnused = kInstCommonOutCnt + 1;
// Tessellation Shader Output Record Offsets
// Tessellation Control Shader Output Record Offsets
static const int kInstTessCtlOutInvocationId = kInstCommonOutCnt;
static const int kInstTessCtlOutPrimitiveId = kInstCommonOutCnt + 1;
static const int kInstTessCtlOutUnused = kInstCommonOutCnt + 2;
// Tessellation Eval Shader Output Record Offsets
static const int kInstTessEvalOutPrimitiveId = kInstCommonOutCnt;
static const int kInstTessEvalOutTessCoordU = kInstCommonOutCnt + 1;
static const int kInstTessEvalOutTessCoordV = kInstCommonOutCnt + 2;
// Tessellation Shader Output Record Offsets - Version 1 (DEPRECATED)
static const int kInstTessOutInvocationId = kInstCommonOutCnt;
static const int kInstTessOutUnused = kInstCommonOutCnt + 1;
// Geometry Shader Output Record Offsets
static const int kInstGeomOutPrimitiveId = kInstCommonOutCnt;
static const int kInstGeomOutInvocationId = kInstCommonOutCnt + 1;
static const int kInstGeomOutUnused = kInstCommonOutCnt + 2;
// Size of Common and Stage-specific Members
static const int kInstStageOutCnt = kInstCommonOutCnt + 2;
static const int kInst2StageOutCnt = kInstCommonOutCnt + 3;
// Validation Error Code
// Validation Error Code Offset
//
// This identifies the validation error. It also helps to identify
// how many words follow in the record and their meaning.
static const int kInstValidationOutError = kInstStageOutCnt;
static const int kInst2ValidationOutError = kInst2StageOutCnt;
// Validation-specific Output Record Offsets
//
@@ -114,11 +139,19 @@ static const int kInstBindlessBoundsOutDescIndex = kInstStageOutCnt + 1;
static const int kInstBindlessBoundsOutDescBound = kInstStageOutCnt + 2;
static const int kInstBindlessBoundsOutCnt = kInstStageOutCnt + 3;
static const int kInst2BindlessBoundsOutDescIndex = kInst2StageOutCnt + 1;
static const int kInst2BindlessBoundsOutDescBound = kInst2StageOutCnt + 2;
static const int kInst2BindlessBoundsOutCnt = kInst2StageOutCnt + 3;
// A bindless uninitialized error will output the index.
static const int kInstBindlessUninitOutDescIndex = kInstStageOutCnt + 1;
static const int kInstBindlessUninitOutUnused = kInstStageOutCnt + 2;
static const int kInstBindlessUninitOutCnt = kInstStageOutCnt + 3;
static const int kInst2BindlessUninitOutDescIndex = kInst2StageOutCnt + 1;
static const int kInst2BindlessUninitOutUnused = kInst2StageOutCnt + 2;
static const int kInst2BindlessUninitOutCnt = kInst2StageOutCnt + 3;
// DEPRECATED
static const int kInstBindlessOutDescIndex = kInstStageOutCnt + 1;
static const int kInstBindlessOutDescBound = kInstStageOutCnt + 2;
@@ -126,6 +159,7 @@ static const int kInstBindlessOutCnt = kInstStageOutCnt + 3;
// Maximum Output Record Member Count
static const int kInstMaxOutCnt = kInstStageOutCnt + 3;
static const int kInst2MaxOutCnt = kInst2StageOutCnt + 3;
// Validation Error Codes
//

View File

@@ -370,6 +370,8 @@ typedef struct spv_optimizer_options_t spv_optimizer_options_t;
typedef struct spv_reducer_options_t spv_reducer_options_t;
typedef struct spv_fuzzer_options_t spv_fuzzer_options_t;
// Type Definitions
typedef spv_const_binary_t* spv_const_binary;
@@ -385,6 +387,8 @@ typedef spv_optimizer_options_t* spv_optimizer_options;
typedef const spv_optimizer_options_t* spv_const_optimizer_options;
typedef spv_reducer_options_t* spv_reducer_options;
typedef const spv_reducer_options_t* spv_const_reducer_options;
typedef spv_fuzzer_options_t* spv_fuzzer_options;
typedef const spv_fuzzer_options_t* spv_const_fuzzer_options;
// Platform API
@@ -590,6 +594,19 @@ SPIRV_TOOLS_EXPORT void spvReducerOptionsSetStepLimit(
SPIRV_TOOLS_EXPORT void spvReducerOptionsSetFailOnValidationError(
spv_reducer_options options, bool fail_on_validation_error);
// Creates a fuzzer options object with default options. Returns a valid
// options object. The object remains valid until it is passed into
// |spvFuzzerOptionsDestroy|.
SPIRV_TOOLS_EXPORT spv_fuzzer_options spvFuzzerOptionsCreate();
// Destroys the given fuzzer options object.
SPIRV_TOOLS_EXPORT void spvFuzzerOptionsDestroy(spv_fuzzer_options options);
// Sets the seed with which the random number generator used by the fuzzer
// should be initialized.
SPIRV_TOOLS_EXPORT void spvFuzzerOptionsSetRandomSeed(
spv_fuzzer_options options, uint32_t seed);
// Encodes the given SPIR-V assembly text to its binary representation. The
// length parameter specifies the number of bytes for text. Encoded binary will
// be stored into *binary. Any error will be written into *diagnostic if

View File

@@ -191,6 +191,26 @@ class ReducerOptions {
spv_reducer_options options_;
};
// A C++ wrapper around a fuzzer options object.
class FuzzerOptions {
public:
FuzzerOptions() : options_(spvFuzzerOptionsCreate()) {}
~FuzzerOptions() { spvFuzzerOptionsDestroy(options_); }
// Allow implicit conversion to the underlying object.
operator spv_fuzzer_options() const { // NOLINT(google-explicit-constructor)
return options_;
}
// See spvFuzzerOptionsSetRandomSeed.
void set_random_seed(uint32_t seed) {
spvFuzzerOptionsSetRandomSeed(options_, seed);
}
private:
spv_fuzzer_options options_;
};
// C++ interface for SPIRV-Tools functionalities. It wraps the context
// (including target environment and the corresponding SPIR-V grammar) and
// provides methods for assembling, disassembling, and validating.

View File

@@ -738,9 +738,10 @@ Optimizer::PassToken CreateCombineAccessChainsPass();
// |input_length_enable| controls instrumentation of runtime descriptor array
// references, and |input_init_enable| controls instrumentation of descriptor
// initialization checking, both of which require input buffer support.
// |version| specifies the buffer record format.
Optimizer::PassToken CreateInstBindlessCheckPass(
uint32_t desc_set, uint32_t shader_id, bool input_length_enable = false,
bool input_init_enable = false);
bool input_init_enable = false, uint32_t version = 1);
// Create a pass to upgrade to the VulkanKHR memory model.
// This pass upgrades the Logical GLSL450 memory model to Logical VulkanKHR.

View File

@@ -196,9 +196,9 @@ set_source_files_properties(
${CMAKE_CURRENT_SOURCE_DIR}/pch_source.cpp
PROPERTIES OBJECT_DEPENDS "${PCH_DEPENDS}")
add_subdirectory(comp)
add_subdirectory(opt)
add_subdirectory(reduce)
add_subdirectory(fuzz)
add_subdirectory(link)
set(SPIRV_SOURCES
@@ -221,7 +221,6 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.h
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.h
${CMAKE_CURRENT_SOURCE_DIR}/extensions.h
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.h
${CMAKE_CURRENT_SOURCE_DIR}/instruction.h
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_glsl_std_450_header.h
${CMAKE_CURRENT_SOURCE_DIR}/latest_version_opencl_std_header.h
@@ -235,6 +234,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/spirv_constant.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_definition.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_fuzzer_options.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_optimizer_options.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_reducer_options.h
${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.h
@@ -254,7 +254,6 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/enum_string_mapping.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ext_inst.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/id_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libspirv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/name_mapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/opcode.cpp
@@ -263,6 +262,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/print.cpp
${CMAKE_CURRENT_SOURCE_DIR}/software_version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_endian.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_fuzzer_options.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_optimizer_options.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_reducer_options.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_target_env.cpp
@@ -299,6 +299,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_logicals.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory_semantics.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_misc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp

View File

@@ -1,52 +0,0 @@
# Copyright (c) 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(SPIRV_BUILD_COMPRESSION)
add_library(SPIRV-Tools-comp
bit_stream.cpp
bit_stream.h
huffman_codec.h
markv_codec.cpp
markv_codec.h
markv.cpp
markv.h
markv_decoder.cpp
markv_decoder.h
markv_encoder.cpp
markv_encoder.h
markv_logger.h
move_to_front.h
move_to_front.cpp)
spvtools_default_compile_options(SPIRV-Tools-comp)
target_include_directories(SPIRV-Tools-comp
PUBLIC ${spirv-tools_SOURCE_DIR}/include
PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
PRIVATE ${spirv-tools_BINARY_DIR}
)
target_link_libraries(SPIRV-Tools-comp
PUBLIC ${SPIRV_TOOLS})
set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries")
spvtools_check_symbol_exports(SPIRV-Tools-comp)
if(ENABLE_SPIRV_TOOLS_INSTALL)
install(TARGETS SPIRV-Tools-comp
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif(ENABLE_SPIRV_TOOLS_INSTALL)
endif(SPIRV_BUILD_COMPRESSION)

View File

@@ -1,348 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstring>
#include <sstream>
#include <type_traits>
#include "source/comp/bit_stream.h"
namespace spvtools {
namespace comp {
namespace {
// Returns if the system is little-endian. Unfortunately only works during
// runtime.
bool IsLittleEndian() {
// This constant value allows the detection of the host machine's endianness.
// Accessing it as an array of bytes is valid due to C++11 section 3.10
// paragraph 10.
static const uint16_t kFF00 = 0xff00;
return reinterpret_cast<const unsigned char*>(&kFF00)[0] == 0;
}
// Copies bytes from the given buffer to a uint64_t buffer.
// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other
// direction is only advisable if uint8_t* is aligned to 64-bit word boundary.
std::vector<uint64_t> ToBuffer64(const void* buffer, size_t num_bytes) {
std::vector<uint64_t> out;
out.resize((num_bytes + 7) / 8, 0);
memcpy(out.data(), buffer, num_bytes);
return out;
}
// Copies uint8_t buffer to a uint64_t buffer.
std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
return ToBuffer64(in.data(), in.size());
}
// Returns uint64_t containing the same bits as |val|.
// Type size must be less than 8 bytes.
template <typename T>
uint64_t ToU64(T val) {
static_assert(sizeof(T) <= 8, "Type size too big");
uint64_t val64 = 0;
std::memcpy(&val64, &val, sizeof(T));
return val64;
}
// Returns value of type T containing the same bits as |val64|.
// Type size must be less than 8 bytes. Upper (unused) bits of |val64| must be
// zero (irrelevant, but is checked with assertion).
template <typename T>
T FromU64(uint64_t val64) {
assert(sizeof(T) == 8 || (val64 >> (sizeof(T) * 8)) == 0);
static_assert(sizeof(T) <= 8, "Type size too big");
T val = 0;
std::memcpy(&val, &val64, sizeof(T));
return val;
}
// Writes bits from |val| to |writer| in chunks of size |chunk_length|.
// Signal bit is used to signal if the reader should expect another chunk:
// 0 - no more chunks to follow
// 1 - more chunks to follow
// If number of written bits reaches |max_payload| last chunk is truncated.
void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val,
size_t chunk_length, size_t max_payload) {
assert(chunk_length > 0);
assert(chunk_length < max_payload);
assert(max_payload == 64 || (val >> max_payload) == 0);
if (val == 0) {
// Split in two writes for more readable logging.
writer->WriteBits(0, chunk_length);
writer->WriteBits(0, 1);
return;
}
size_t payload_written = 0;
while (val) {
if (payload_written + chunk_length >= max_payload) {
// This has to be the last chunk.
// There is no need for the signal bit and the chunk can be truncated.
const size_t left_to_write = max_payload - payload_written;
assert((val >> left_to_write) == 0);
writer->WriteBits(val, left_to_write);
break;
}
writer->WriteBits(val, chunk_length);
payload_written += chunk_length;
val = val >> chunk_length;
// Write a single bit to signal if there is more to come.
writer->WriteBits(val ? 1 : 0, 1);
}
}
// Reads data written with WriteVariableWidthInternal. |chunk_length| and
// |max_payload| should be identical to those used to write the data.
// Returns false if the stream ends prematurely.
bool ReadVariableWidthInternal(BitReaderInterface* reader, uint64_t* val,
size_t chunk_length, size_t max_payload) {
assert(chunk_length > 0);
assert(chunk_length <= max_payload);
size_t payload_read = 0;
while (payload_read + chunk_length < max_payload) {
uint64_t bits = 0;
if (reader->ReadBits(&bits, chunk_length) != chunk_length) return false;
*val |= bits << payload_read;
payload_read += chunk_length;
uint64_t more_to_come = 0;
if (reader->ReadBits(&more_to_come, 1) != 1) return false;
if (!more_to_come) {
return true;
}
}
// Need to read the last chunk which may be truncated. No signal bit follows.
uint64_t bits = 0;
const size_t left_to_read = max_payload - payload_read;
if (reader->ReadBits(&bits, left_to_read) != left_to_read) return false;
*val |= bits << payload_read;
return true;
}
// Calls WriteVariableWidthInternal with the right max_payload argument.
template <typename T>
void WriteVariableWidthUnsigned(BitWriterInterface* writer, T val,
size_t chunk_length) {
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
static_assert(std::is_integral<T>::value, "Type must be integral");
WriteVariableWidthInternal(writer, val, chunk_length, sizeof(T) * 8);
}
// Calls ReadVariableWidthInternal with the right max_payload argument.
template <typename T>
bool ReadVariableWidthUnsigned(BitReaderInterface* reader, T* val,
size_t chunk_length) {
static_assert(std::is_unsigned<T>::value, "Type must be unsigned");
static_assert(std::is_integral<T>::value, "Type must be integral");
uint64_t val64 = 0;
if (!ReadVariableWidthInternal(reader, &val64, chunk_length, sizeof(T) * 8))
return false;
*val = static_cast<T>(val64);
assert(*val == val64);
return true;
}
// Encodes signed |val| to an unsigned value and calls
// WriteVariableWidthInternal with the right max_payload argument.
template <typename T>
void WriteVariableWidthSigned(BitWriterInterface* writer, T val,
size_t chunk_length, size_t zigzag_exponent) {
static_assert(std::is_signed<T>::value, "Type must be signed");
static_assert(std::is_integral<T>::value, "Type must be integral");
WriteVariableWidthInternal(writer, EncodeZigZag(val, zigzag_exponent),
chunk_length, sizeof(T) * 8);
}
// Calls ReadVariableWidthInternal with the right max_payload argument
// and decodes the value.
template <typename T>
bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val,
size_t chunk_length, size_t zigzag_exponent) {
static_assert(std::is_signed<T>::value, "Type must be signed");
static_assert(std::is_integral<T>::value, "Type must be integral");
uint64_t encoded = 0;
if (!ReadVariableWidthInternal(reader, &encoded, chunk_length, sizeof(T) * 8))
return false;
const int64_t decoded = DecodeZigZag(encoded, zigzag_exponent);
*val = static_cast<T>(decoded);
assert(*val == decoded);
return true;
}
} // namespace
void BitWriterInterface::WriteVariableWidthU64(uint64_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthU32(uint32_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthU16(uint16_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
}
void BitWriterInterface::WriteVariableWidthS64(int64_t val, size_t chunk_length,
size_t zigzag_exponent) {
WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) {
buffer_.reserve(NumBitsToNumWords<64>(reserve_bits));
}
void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) {
// Check that |bits| and |num_bits| are valid and consistent.
assert(num_bits <= 64);
const bool is_little_endian = IsLittleEndian();
assert(is_little_endian && "Big-endian architecture support not implemented");
if (!is_little_endian) return;
if (num_bits == 0) return;
bits = GetLowerBits(bits, num_bits);
EmitSequence(bits, num_bits);
// Offset from the start of the current word.
const size_t offset = end_ % 64;
if (offset == 0) {
// If no offset, simply add |bits| as a new word to the buffer_.
buffer_.push_back(bits);
} else {
// Shift bits and add them to the current word after offset.
const uint64_t first_word = bits << offset;
buffer_.back() |= first_word;
// If we don't overflow to the next word, there is nothing more to do.
if (offset + num_bits > 64) {
// We overflow to the next word.
const uint64_t second_word = bits >> (64 - offset);
// Add remaining bits as a new word to buffer_.
buffer_.push_back(second_word);
}
}
// Move end_ into position for next write.
end_ += num_bits;
assert(buffer_.size() * 64 >= end_);
}
bool BitReaderInterface::ReadVariableWidthU64(uint64_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthU32(uint32_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthU16(uint16_t* val,
size_t chunk_length) {
return ReadVariableWidthUnsigned(this, val, chunk_length);
}
bool BitReaderInterface::ReadVariableWidthS64(int64_t* val, size_t chunk_length,
size_t zigzag_exponent) {
return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
BitReaderWord64::BitReaderWord64(std::vector<uint64_t>&& buffer)
: buffer_(std::move(buffer)), pos_(0) {}
BitReaderWord64::BitReaderWord64(const std::vector<uint8_t>& buffer)
: buffer_(ToBuffer64(buffer)), pos_(0) {}
BitReaderWord64::BitReaderWord64(const void* buffer, size_t num_bytes)
: buffer_(ToBuffer64(buffer, num_bytes)), pos_(0) {}
size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) {
assert(num_bits <= 64);
const bool is_little_endian = IsLittleEndian();
assert(is_little_endian && "Big-endian architecture support not implemented");
if (!is_little_endian) return 0;
if (ReachedEnd()) return 0;
// Index of the current word.
const size_t index = pos_ / 64;
// Bit position in the current word where we start reading.
const size_t offset = pos_ % 64;
// Read all bits from the current word (it might be too much, but
// excessive bits will be removed later).
*bits = buffer_[index] >> offset;
const size_t num_read_from_first_word = std::min(64 - offset, num_bits);
pos_ += num_read_from_first_word;
if (pos_ >= buffer_.size() * 64) {
// Reached end of buffer_.
EmitSequence(*bits, num_read_from_first_word);
return num_read_from_first_word;
}
if (offset + num_bits > 64) {
// Requested |num_bits| overflows to next word.
// Write all bits from the beginning of next word to *bits after offset.
*bits |= buffer_[index + 1] << (64 - offset);
pos_ += offset + num_bits - 64;
}
// We likely have written more bits than requested. Clear excessive bits.
*bits = GetLowerBits(*bits, num_bits);
EmitSequence(*bits, num_bits);
return num_bits;
}
bool BitReaderWord64::ReachedEnd() const { return pos_ >= buffer_.size() * 64; }
bool BitReaderWord64::OnlyZeroesLeft() const {
if (ReachedEnd()) return true;
const size_t index = pos_ / 64;
if (index < buffer_.size() - 1) return false;
assert(index == buffer_.size() - 1);
const size_t offset = pos_ % 64;
const uint64_t remaining_bits = buffer_[index] >> offset;
return !remaining_bits;
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,280 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains utils for reading, writing and debug printing bit streams.
#ifndef SOURCE_COMP_BIT_STREAM_H_
#define SOURCE_COMP_BIT_STREAM_H_
#include <algorithm>
#include <bitset>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <functional>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace spvtools {
namespace comp {
// Terminology:
// Bits - usually used for a uint64 word, first bit is the lowest.
// Stream - std::string of '0' and '1', read left-to-right,
// i.e. first bit is at the front and not at the end as in
// std::bitset::to_string().
// Bitset - std::bitset corresponding to uint64 bits and to reverse(stream).
// Converts number of bits to a respective number of chunks of size N.
// For example NumBitsToNumWords<8> returns how many bytes are needed to store
// |num_bits|.
template <size_t N>
inline size_t NumBitsToNumWords(size_t num_bits) {
return (num_bits + (N - 1)) / N;
}
// Returns value of the same type as |in|, where all but the first |num_bits|
// are set to zero.
template <typename T>
inline T GetLowerBits(T in, size_t num_bits) {
return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1));
}
// Encodes signed integer as unsigned. This is a generalized version of
// EncodeZigZag, designed to favor small positive numbers.
// Values are transformed in blocks of 2^|block_exponent|.
// If |block_exponent| is zero, then this degenerates into normal EncodeZigZag.
// Example when |block_exponent| is 1 (return value is the index):
// 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8
// Example when |block_exponent| is 2:
// 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8
inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) {
assert(block_exponent < 64);
const uint64_t uval = static_cast<uint64_t>(val >= 0 ? val : -val - 1);
const uint64_t block_num =
((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1);
const uint64_t pos = GetLowerBits(uval, block_exponent);
return (block_num << block_exponent) + pos;
}
// Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be
// the same.
inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) {
assert(block_exponent < 64);
const uint64_t block_num = val >> block_exponent;
const uint64_t pos = GetLowerBits(val, block_exponent);
if (block_num & 1) {
// Negative.
return -1LL - ((block_num >> 1) << block_exponent) - pos;
} else {
// Positive.
return ((block_num >> 1) << block_exponent) + pos;
}
}
// Converts first |num_bits| stored in uint64 to a left-to-right stream of bits.
inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) {
std::bitset<64> bitset(bits);
std::string str = bitset.to_string().substr(64 - num_bits);
std::reverse(str.begin(), str.end());
return str;
}
// Base class for writing sequences of bits.
class BitWriterInterface {
public:
BitWriterInterface() = default;
virtual ~BitWriterInterface() = default;
// Writes lower |num_bits| in |bits| to the stream.
// |num_bits| must be no greater than 64.
virtual void WriteBits(uint64_t bits, size_t num_bits) = 0;
// Writes bits from value of type |T| to the stream. No encoding is done.
// Always writes 8 * sizeof(T) bits.
template <typename T>
void WriteUnencoded(T val) {
static_assert(sizeof(T) <= 64, "Type size too large");
uint64_t bits = 0;
memcpy(&bits, &val, sizeof(T));
WriteBits(bits, sizeof(T) * 8);
}
// Writes |val| in chunks of size |chunk_length| followed by a signal bit:
// 0 - no more chunks to follow
// 1 - more chunks to follow
// for example 255 is encoded into 1111 1 1111 0 for chunk length 4.
// The last chunk can be truncated and signal bit omitted, if the entire
// payload (for example 16 bit for uint16_t has already been written).
void WriteVariableWidthU64(uint64_t val, size_t chunk_length);
void WriteVariableWidthU32(uint32_t val, size_t chunk_length);
void WriteVariableWidthU16(uint16_t val, size_t chunk_length);
void WriteVariableWidthS64(int64_t val, size_t chunk_length,
size_t zigzag_exponent);
// Returns number of bits written.
virtual size_t GetNumBits() const = 0;
// Provides direct access to the buffer data if implemented.
virtual const uint8_t* GetData() const { return nullptr; }
// Returns buffer size in bytes.
size_t GetDataSizeBytes() const { return NumBitsToNumWords<8>(GetNumBits()); }
// Generates and returns byte array containing written bits.
virtual std::vector<uint8_t> GetDataCopy() const = 0;
BitWriterInterface(const BitWriterInterface&) = delete;
BitWriterInterface& operator=(const BitWriterInterface&) = delete;
};
// This class is an implementation of BitWriterInterface, using
// std::vector<uint64_t> to store written bits.
class BitWriterWord64 : public BitWriterInterface {
public:
explicit BitWriterWord64(size_t reserve_bits = 64);
void WriteBits(uint64_t bits, size_t num_bits) override;
size_t GetNumBits() const override { return end_; }
const uint8_t* GetData() const override {
return reinterpret_cast<const uint8_t*>(buffer_.data());
}
std::vector<uint8_t> GetDataCopy() const override {
return std::vector<uint8_t>(GetData(), GetData() + GetDataSizeBytes());
}
// Sets callback to emit bit sequences after every write.
void SetCallback(std::function<void(const std::string&)> callback) {
callback_ = callback;
}
protected:
// Sends string generated from arguments to callback_ if defined.
void EmitSequence(uint64_t bits, size_t num_bits) const {
if (callback_) callback_(BitsToStream(bits, num_bits));
}
private:
std::vector<uint64_t> buffer_;
// Total number of bits written so far. Named 'end' as analogy to std::end().
size_t end_;
// If not null, the writer will use the callback to emit the written bit
// sequence as a string of '0' and '1'.
std::function<void(const std::string&)> callback_;
};
// Base class for reading sequences of bits.
class BitReaderInterface {
public:
BitReaderInterface() {}
virtual ~BitReaderInterface() {}
// Reads |num_bits| from the stream, stores them in |bits|.
// Returns number of read bits. |num_bits| must be no greater than 64.
virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0;
// Reads 8 * sizeof(T) bits and stores them in |val|.
template <typename T>
bool ReadUnencoded(T* val) {
static_assert(sizeof(T) <= 64, "Type size too large");
uint64_t bits = 0;
const size_t num_read = ReadBits(&bits, sizeof(T) * 8);
if (num_read != sizeof(T) * 8) return false;
memcpy(val, &bits, sizeof(T));
return true;
}
// Returns number of bits already read.
virtual size_t GetNumReadBits() const = 0;
// These two functions define 'hard' and 'soft' EOF.
//
// Returns true if the end of the buffer was reached.
virtual bool ReachedEnd() const = 0;
// Returns true if we reached the end of the buffer or are nearing it and only
// zero bits are left to read. Implementations of this function are allowed to
// commit a "false negative" error if the end of the buffer was not reached,
// i.e. it can return false even if indeed only zeroes are left.
// It is assumed that the consumer expects that
// the buffer stream ends with padding zeroes, and would accept this as a
// 'soft' EOF. Implementations of this class do not necessarily need to
// implement this, default behavior can simply delegate to ReachedEnd().
virtual bool OnlyZeroesLeft() const { return ReachedEnd(); }
// Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface).
// Reader and writer must use the same |chunk_length| and variable type.
// Returns true on success, false if the bit stream ends prematurely.
bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length);
bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length);
bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length);
bool ReadVariableWidthS64(int64_t* val, size_t chunk_length,
size_t zigzag_exponent);
BitReaderInterface(const BitReaderInterface&) = delete;
BitReaderInterface& operator=(const BitReaderInterface&) = delete;
};
// This class is an implementation of BitReaderInterface which accepts both
// uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and
// owned. uint8_t buffers are copied.
class BitReaderWord64 : public BitReaderInterface {
public:
// Consumes and owns the buffer.
explicit BitReaderWord64(std::vector<uint64_t>&& buffer);
// Copies the buffer and casts it to uint64.
// Consuming the original buffer and casting it to uint64 is difficult,
// as it would potentially cause data misalignment and poor performance.
explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
BitReaderWord64(const void* buffer, size_t num_bytes);
size_t ReadBits(uint64_t* bits, size_t num_bits) override;
size_t GetNumReadBits() const override { return pos_; }
bool ReachedEnd() const override;
bool OnlyZeroesLeft() const override;
BitReaderWord64() = delete;
// Sets callback to emit bit sequences after every read.
void SetCallback(std::function<void(const std::string&)> callback) {
callback_ = callback;
}
protected:
// Sends string generated from arguments to callback_ if defined.
void EmitSequence(uint64_t bits, size_t num_bits) const {
if (callback_) callback_(BitsToStream(bits, num_bits));
}
private:
const std::vector<uint64_t> buffer_;
size_t pos_;
// If not null, the reader will use the callback to emit the read bit
// sequence as a string of '0' and '1'.
std::function<void(const std::string&)> callback_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_BIT_STREAM_H_

View File

@@ -1,389 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Contains utils for reading, writing and debug printing bit streams.
#ifndef SOURCE_COMP_HUFFMAN_CODEC_H_
#define SOURCE_COMP_HUFFMAN_CODEC_H_
#include <algorithm>
#include <cassert>
#include <functional>
#include <iomanip>
#include <map>
#include <memory>
#include <ostream>
#include <queue>
#include <sstream>
#include <stack>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>
namespace spvtools {
namespace comp {
// Used to generate and apply a Huffman coding scheme.
// |Val| is the type of variable being encoded (for example a string or a
// literal).
template <class Val>
class HuffmanCodec {
public:
// Huffman tree node.
struct Node {
Node() {}
// Creates Node from serialization leaving weight and id undefined.
Node(const Val& in_value, uint32_t in_left, uint32_t in_right)
: value(in_value), left(in_left), right(in_right) {}
Val value = Val();
uint32_t weight = 0;
// Ids are issued sequentially starting from 1. Ids are used as an ordering
// tie-breaker, to make sure that the ordering (and resulting coding scheme)
// are consistent accross multiple platforms.
uint32_t id = 0;
// Handles of children.
uint32_t left = 0;
uint32_t right = 0;
};
// Creates Huffman codec from a histogramm.
// Histogramm counts must not be zero.
explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
if (hist.empty()) return;
// Heuristic estimate.
nodes_.reserve(3 * hist.size());
// Create NIL.
CreateNode();
// The queue is sorted in ascending order by weight (or by node id if
// weights are equal).
std::vector<uint32_t> queue_vector;
queue_vector.reserve(hist.size());
std::priority_queue<uint32_t, std::vector<uint32_t>,
std::function<bool(uint32_t, uint32_t)>>
queue(std::bind(&HuffmanCodec::LeftIsBigger, this,
std::placeholders::_1, std::placeholders::_2),
std::move(queue_vector));
// Put all leaves in the queue.
for (const auto& pair : hist) {
const uint32_t node = CreateNode();
MutableValueOf(node) = pair.first;
MutableWeightOf(node) = pair.second;
assert(WeightOf(node));
queue.push(node);
}
// Form the tree by combining two subtrees with the least weight,
// and pushing the root of the new tree in the queue.
while (true) {
// We push a node at the end of each iteration, so the queue is never
// supposed to be empty at this point, unless there are no leaves, but
// that case was already handled.
assert(!queue.empty());
const uint32_t right = queue.top();
queue.pop();
// If the queue is empty at this point, then the last node is
// the root of the complete Huffman tree.
if (queue.empty()) {
root_ = right;
break;
}
const uint32_t left = queue.top();
queue.pop();
// Combine left and right into a new tree and push it into the queue.
const uint32_t parent = CreateNode();
MutableWeightOf(parent) = WeightOf(right) + WeightOf(left);
MutableLeftOf(parent) = left;
MutableRightOf(parent) = right;
queue.push(parent);
}
// Traverse the tree and form encoding table.
CreateEncodingTable();
}
// Creates Huffman codec from saved tree structure.
// |nodes| is the list of nodes of the tree, nodes[0] being NIL.
// |root_handle| is the index of the root node.
HuffmanCodec(uint32_t root_handle, std::vector<Node>&& nodes) {
nodes_ = std::move(nodes);
assert(!nodes_.empty());
assert(root_handle > 0 && root_handle < nodes_.size());
assert(!LeftOf(0) && !RightOf(0));
root_ = root_handle;
// Traverse the tree and form encoding table.
CreateEncodingTable();
}
// Serializes the codec in the following text format:
// (<root_handle>, {
// {0, 0, 0},
// {val1, left1, right1},
// {val2, left2, right2},
// ...
// })
std::string SerializeToText(int indent_num_whitespaces) const {
const bool value_is_text = std::is_same<Val, std::string>::value;
const std::string indent1 = std::string(indent_num_whitespaces, ' ');
const std::string indent2 = std::string(indent_num_whitespaces + 2, ' ');
std::stringstream code;
code << "(" << root_ << ", {\n";
for (const Node& node : nodes_) {
code << indent2 << "{";
if (value_is_text) code << "\"";
code << node.value;
if (value_is_text) code << "\"";
code << ", " << node.left << ", " << node.right << "},\n";
}
code << indent1 << "})";
return code.str();
}
// Prints the Huffman tree in the following format:
// w------w------'x'
// w------'y'
// Where w stands for the weight of the node.
// Right tree branches appear above left branches. Taking the right path
// adds 1 to the code, taking the left adds 0.
void PrintTree(std::ostream& out) const { PrintTreeInternal(out, root_, 0); }
// Traverses the tree and prints the Huffman table: value, code
// and optionally node weight for every leaf.
void PrintTable(std::ostream& out, bool print_weights = true) {
std::queue<std::pair<uint32_t, std::string>> queue;
queue.emplace(root_, "");
while (!queue.empty()) {
const uint32_t node = queue.front().first;
const std::string code = queue.front().second;
queue.pop();
if (!RightOf(node) && !LeftOf(node)) {
out << ValueOf(node);
if (print_weights) out << " " << WeightOf(node);
out << " " << code << std::endl;
} else {
if (LeftOf(node)) queue.emplace(LeftOf(node), code + "0");
if (RightOf(node)) queue.emplace(RightOf(node), code + "1");
}
}
}
// Returns the Huffman table. The table was built at at construction time,
// this function just returns a const reference.
const std::unordered_map<Val, std::pair<uint64_t, size_t>>& GetEncodingTable()
const {
return encoding_table_;
}
// Encodes |val| and stores its Huffman code in the lower |num_bits| of
// |bits|. Returns false of |val| is not in the Huffman table.
bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) const {
auto it = encoding_table_.find(val);
if (it == encoding_table_.end()) return false;
*bits = it->second.first;
*num_bits = it->second.second;
return true;
}
// Reads bits one-by-one using callback |read_bit| until a match is found.
// Matching value is stored in |val|. Returns false if |read_bit| terminates
// before a code was mathced.
// |read_bit| has type bool func(bool* bit). When called, the next bit is
// stored in |bit|. |read_bit| returns false if the stream terminates
// prematurely.
bool DecodeFromStream(const std::function<bool(bool*)>& read_bit,
Val* val) const {
uint32_t node = root_;
while (true) {
assert(node);
if (!RightOf(node) && !LeftOf(node)) {
*val = ValueOf(node);
return true;
}
bool go_right;
if (!read_bit(&go_right)) return false;
if (go_right)
node = RightOf(node);
else
node = LeftOf(node);
}
assert(0);
return false;
}
private:
// Returns value of the node referenced by |handle|.
Val ValueOf(uint32_t node) const { return nodes_.at(node).value; }
// Returns left child of |node|.
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
// Returns right child of |node|.
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
// Returns weight of |node|.
uint32_t WeightOf(uint32_t node) const { return nodes_.at(node).weight; }
// Returns id of |node|.
uint32_t IdOf(uint32_t node) const { return nodes_.at(node).id; }
// Returns mutable reference to value of |node|.
Val& MutableValueOf(uint32_t node) {
assert(node);
return nodes_.at(node).value;
}
// Returns mutable reference to handle of left child of |node|.
uint32_t& MutableLeftOf(uint32_t node) {
assert(node);
return nodes_.at(node).left;
}
// Returns mutable reference to handle of right child of |node|.
uint32_t& MutableRightOf(uint32_t node) {
assert(node);
return nodes_.at(node).right;
}
// Returns mutable reference to weight of |node|.
uint32_t& MutableWeightOf(uint32_t node) { return nodes_.at(node).weight; }
// Returns mutable reference to id of |node|.
uint32_t& MutableIdOf(uint32_t node) { return nodes_.at(node).id; }
// Returns true if |left| has bigger weight than |right|. Node ids are
// used as tie-breaker.
bool LeftIsBigger(uint32_t left, uint32_t right) const {
if (WeightOf(left) == WeightOf(right)) {
assert(IdOf(left) != IdOf(right));
return IdOf(left) > IdOf(right);
}
return WeightOf(left) > WeightOf(right);
}
// Prints subtree (helper function used by PrintTree).
void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth) const {
if (!node) return;
const size_t kTextFieldWidth = 7;
if (!RightOf(node) && !LeftOf(node)) {
out << ValueOf(node) << std::endl;
} else {
if (RightOf(node)) {
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< WeightOf(RightOf(node));
out << label.str();
PrintTreeInternal(out, RightOf(node), depth + 1);
}
if (LeftOf(node)) {
out << std::string(depth * kTextFieldWidth, ' ');
std::stringstream label;
label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
<< WeightOf(LeftOf(node));
out << label.str();
PrintTreeInternal(out, LeftOf(node), depth + 1);
}
}
}
// Traverses the Huffman tree and saves paths to the leaves as bit
// sequences to encoding_table_.
void CreateEncodingTable() {
struct Context {
Context(uint32_t in_node, uint64_t in_bits, size_t in_depth)
: node(in_node), bits(in_bits), depth(in_depth) {}
uint32_t node;
// Huffman tree depth cannot exceed 64 as histogramm counts are expected
// to be positive and limited by numeric_limits<uint32_t>::max().
// For practical applications tree depth would be much smaller than 64.
uint64_t bits;
size_t depth;
};
std::queue<Context> queue;
queue.emplace(root_, 0, 0);
while (!queue.empty()) {
const Context& context = queue.front();
const uint32_t node = context.node;
const uint64_t bits = context.bits;
const size_t depth = context.depth;
queue.pop();
if (!RightOf(node) && !LeftOf(node)) {
auto insertion_result = encoding_table_.emplace(
ValueOf(node), std::pair<uint64_t, size_t>(bits, depth));
assert(insertion_result.second);
(void)insertion_result;
} else {
if (LeftOf(node)) queue.emplace(LeftOf(node), bits, depth + 1);
if (RightOf(node))
queue.emplace(RightOf(node), bits | (1ULL << depth), depth + 1);
}
}
}
// Creates new Huffman tree node and stores it in the deleter array.
uint32_t CreateNode() {
const uint32_t handle = static_cast<uint32_t>(nodes_.size());
nodes_.emplace_back(Node());
nodes_.back().id = next_node_id_++;
return handle;
}
// Huffman tree root handle.
uint32_t root_ = 0;
// Huffman tree deleter.
std::vector<Node> nodes_;
// Encoding table value -> {bits, num_bits}.
// Huffman codes are expected to never exceed 64 bit length (this is in fact
// impossible if frequencies are stored as uint32_t).
std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
// Next node id issued by CreateNode();
uint32_t next_node_id_ = 1;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_HUFFMAN_CODEC_H_

View File

@@ -1,112 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv.h"
#include "source/comp/markv_decoder.h"
#include "source/comp/markv_encoder.h"
namespace spvtools {
namespace comp {
namespace {
spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
uint32_t magic, uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t schema) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
schema);
}
spv_result_t EncodeInstruction(void* user_data,
const spv_parsed_instruction_t* inst) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeInstruction(*inst);
}
} // namespace
spv_result_t SpirvToMarkv(
spv_const_context context, const std::vector<uint32_t>& spirv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
spv_validator_options validator_options =
MarkvDecoder::GetValidatorOptions(options);
if (validator_options) {
spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
const spv_result_t result = spvValidateWithOptions(
&hijack_context, validator_options, &spirv_binary, nullptr);
if (result != SPV_SUCCESS) return result;
}
MarkvEncoder encoder(&hijack_context, options, &markv_model);
spv_position_t position = {};
if (log_consumer || debug_consumer) {
encoder.CreateLogger(log_consumer, debug_consumer);
spv_text text = nullptr;
if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Failed to disassemble SPIR-V binary.";
}
assert(text);
encoder.SetDisassembly(std::string(text->str, text->length));
spvTextDestroy(text);
}
if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to encode to MARK-V.";
}
*markv = encoder.GetMarkvBinary();
return SPV_SUCCESS;
}
spv_result_t MarkvToSpirv(
spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
spv_position_t position = {};
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
if (log_consumer || debug_consumer)
decoder.CreateLogger(log_consumer, debug_consumer);
if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to decode MARK-V.";
}
assert(!spirv->empty());
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,74 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result ids which can be regenerated) and
// uses various bit reduction techiniques to reduce the size of the binary and
// make it more similar to other compressed SPIR-V files to further improve
// compression of the dataset.
#ifndef SOURCE_COMP_MARKV_H_
#define SOURCE_COMP_MARKV_H_
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
class MarkvModel;
struct MarkvCodecOptions {
bool validate_spirv_binary = false;
};
// Debug callback. Called once per instruction.
// |words| is instruction SPIR-V words.
// |bits| is a textual representation of the MARK-V bit sequence used to encode
// the instruction (char '0' for 0, char '1' for 1).
// |comment| contains all logs generated while processing the instruction.
using MarkvDebugConsumer =
std::function<bool(const std::vector<uint32_t>& words,
const std::string& bits, const std::string& comment)>;
// Logging callback. Called often (if decoder reads a single bit, the log
// consumer will receive 1 character string with that bit).
// This callback is more suitable for continous output than MarkvDebugConsumer,
// for example if the codec crashes it would allow to pinpoint on which operand
// or bit the crash happened.
// |snippet| could be any atomic fragment of text logged by the codec. It can
// contain a paragraph of text with newlines, or can be just one character.
using MarkvLogConsumer = std::function<void(const std::string& snippet)>;
// Encodes the given SPIR-V binary to MARK-V binary.
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
spv_result_t SpirvToMarkv(
spv_const_context context, const std::vector<uint32_t>& spirv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv);
// Decodes a SPIR-V binary from the given MARK-V binary.
// |log_consumer| is optional (pass MarkvLogConsumer() to disable).
// |debug_consumer| is optional (pass MarkvDebugConsumer() to disable).
spv_result_t MarkvToSpirv(
spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv);
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_H_

View File

@@ -1,793 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// MARK-V is a compression format for SPIR-V binaries. It strips away
// non-essential information (such as result IDs which can be regenerated) and
// uses various bit reduction techniques to reduce the size of the binary.
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/opcode.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace comp {
namespace {
// Custom hash function used to produce short descriptors.
uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
// The hash function is a sum of hashes of each word seeded by word index.
// Knuth's multiplicative hash is used to hash the words.
const uint32_t kKnuthMulHash = 2654435761;
uint32_t val = 0;
for (uint32_t i = 0; i < words.size(); ++i) {
val += (words[i] + i + 123) * kKnuthMulHash;
}
return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1);
}
// Returns a set of mtf rank codecs based on a plausible hand-coded
// distribution.
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
GetMtfHuffmanCodecs() {
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
std::unique_ptr<HuffmanCodec<uint32_t>> codec;
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{0, 5},
{1, 40},
{2, 10},
{3, 5},
{4, 5},
{5, 5},
{6, 3},
{7, 3},
{8, 3},
{9, 3},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfAll, std::move(codec));
codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
{1, 50},
{2, 20},
{3, 5},
{4, 5},
{5, 2},
{6, 1},
{7, 1},
{8, 1},
{9, 1},
{MarkvCodec::kMtfRankEncodedByValueSignal, 10},
}));
codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
return codecs;
}
} // namespace
const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303;
const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10;
const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal =
std::numeric_limits<uint32_t>::max();
const uint32_t MarkvCodec::kShortDescriptorNumBits = 8;
const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8;
MarkvCodec::MarkvCodec(spv_const_context context,
spv_validator_options validator_options,
const MarkvModel* model)
: validator_options_(validator_options),
grammar_(context),
model_(model),
short_id_descriptors_(ShortHashU32Array),
mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
context_(context) {}
MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); }
MarkvCodec::MarkvHeader::MarkvHeader()
: magic_number(MarkvCodec::kMarkvMagicNumber),
markv_version(MarkvCodec::GetMarkvVersion()) {}
// Defines and returns current MARK-V version.
// static
uint32_t MarkvCodec::GetMarkvVersion() {
const uint32_t kVersionMajor = 1;
const uint32_t kVersionMinor = 4;
return kVersionMinor | (kVersionMajor << 16);
}
size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const {
return (8 - (bit_pos % 8)) % 8;
}
// Returns true if the opcode has a fixed number of operands. May return a
// false negative.
bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const {
switch (opcode) {
// TODO(atgoo@github.com) This is not a complete list.
case SpvOpNop:
case SpvOpName:
case SpvOpUndef:
case SpvOpSizeOf:
case SpvOpLine:
case SpvOpNoLine:
case SpvOpDecorationGroup:
case SpvOpExtension:
case SpvOpExtInstImport:
case SpvOpMemoryModel:
case SpvOpCapability:
case SpvOpTypeVoid:
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypeVector:
case SpvOpTypeMatrix:
case SpvOpTypeSampler:
case SpvOpTypeSampledImage:
case SpvOpTypeArray:
case SpvOpTypePointer:
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpLabel:
case SpvOpBranch:
case SpvOpFunction:
case SpvOpFunctionParameter:
case SpvOpFunctionEnd:
case SpvOpBitcast:
case SpvOpCopyObject:
case SpvOpTranspose:
case SpvOpSNegate:
case SpvOpFNegate:
case SpvOpIAdd:
case SpvOpFAdd:
case SpvOpISub:
case SpvOpFSub:
case SpvOpIMul:
case SpvOpFMul:
case SpvOpUDiv:
case SpvOpSDiv:
case SpvOpFDiv:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSMod:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
case SpvOpOuterProduct:
case SpvOpDot:
return true;
default:
break;
}
return false;
}
void MarkvCodec::ProcessCurInstruction() {
instructions_.emplace_back(new val::Instruction(&inst_));
const SpvOp opcode = SpvOp(inst_.opcode);
if (inst_.result_id) {
id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
// Collect ids local to the current function.
if (cur_function_id_) {
ids_local_to_cur_function_.push_back(inst_.result_id);
}
// Starting new function.
if (opcode == SpvOpFunction) {
cur_function_id_ = inst_.result_id;
cur_function_return_type_ = inst_.type_id;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
inst_.result_id);
}
// Store function parameter types in a queue, so that we know which types
// to expect in the following OpFunctionParameter instructions.
const val::Instruction* def_inst = FindDef(inst_.words[4]);
assert(def_inst);
assert(def_inst->opcode() == SpvOpTypeFunction);
for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
remaining_function_parameter_types_.push_back(def_inst->word(i));
}
}
}
// Remove local ids from MTFs if function end.
if (opcode == SpvOpFunctionEnd) {
cur_function_id_ = 0;
for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
ids_local_to_cur_function_.clear();
assert(remaining_function_parameter_types_.empty());
}
if (!inst_.result_id) return;
{
// Save the result ID to type ID mapping.
// In the grammar, type ID always appears before result ID.
// A regular value maps to its type. Some instructions (e.g. OpLabel)
// have no type Id, and will map to 0. The result Id for a
// type-generating instruction (e.g. OpTypeInt) maps to itself.
auto insertion_result = id_to_type_id_.emplace(
inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
? inst_.result_id
: inst_.type_id);
(void)insertion_result;
assert(insertion_result.second);
}
// Add result_id to MTFs.
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
switch (opcode) {
case SpvOpTypeFloat:
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypeVector:
case SpvOpTypePointer:
case SpvOpExtInstImport:
case SpvOpTypeSampledImage:
case SpvOpTypeImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
break;
default:
break;
}
if (spvOpcodeIsComposite(opcode)) {
multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
}
if (opcode == SpvOpLabel) {
multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
}
if (opcode == SpvOpTypeInt) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeFloat) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeBool) {
multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
if (opcode == SpvOpTypeVector) {
const uint32_t component_type_id = inst_.words[2];
const uint32_t size = inst_.words[3];
if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
} else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
component_type_id)) {
multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
}
multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
}
if (inst_.opcode == SpvOpTypeFunction) {
const uint32_t return_type = inst_.words[2];
multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
inst_.result_id);
}
if (inst_.type_id) {
const val::Instruction* type_inst = FindDef(inst_.type_id);
assert(type_inst);
multi_mtf_.Insert(kMtfObject, inst_.result_id);
multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
}
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
multi_mtf_.Insert(kMtfComposite, inst_.result_id);
switch (type_inst->opcode()) {
case SpvOpTypeInt:
case SpvOpTypeBool:
case SpvOpTypePointer:
case SpvOpTypeVector:
case SpvOpTypeImage:
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
multi_mtf_.Insert(
GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
inst_.result_id);
break;
default:
break;
}
if (type_inst->opcode() == SpvOpTypeVector) {
const uint32_t component_type = type_inst->word(2);
multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
inst_.result_id);
}
if (type_inst->opcode() == SpvOpTypePointer) {
assert(type_inst->operands().size() > 2);
assert(type_inst->words().size() > type_inst->operands()[2].offset);
const uint32_t data_type =
type_inst->word(type_inst->operands()[2].offset);
multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
}
}
if (spvOpcodeGeneratesType(opcode)) {
if (opcode != SpvOpTypeFunction) {
multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
}
}
}
if (model_->AnyDescriptorHasCodingScheme()) {
const uint32_t long_descriptor =
long_id_descriptors_.ProcessInstruction(inst_);
if (model_->DescriptorHasCodingScheme(long_descriptor))
multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
inst_.result_id);
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
const uint32_t short_descriptor =
short_id_descriptors_.ProcessInstruction(inst_);
multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
inst_.result_id);
}
}
uint64_t MarkvCodec::GetRuleBasedMtf() {
// This function is only called for id operands (but not result ids).
assert(spvIsIdType(operand_.type) ||
operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
// All operand slots which expect label id.
if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
(inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
(inst_.opcode == SpvOpBranchConditional &&
(operand_index_ == 1 || operand_index_ == 2)) ||
(inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
operand_index_ % 2 == 1) ||
(inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
return kMtfLabel;
}
switch (opcode) {
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpFNegate: {
if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
return GetMtfIdOfType(inst_.type_id);
}
case SpvOpISub:
case SpvOpIAdd:
case SpvOpIMul:
case SpvOpSDiv:
case SpvOpUDiv:
case SpvOpSMod:
case SpvOpUMod:
case SpvOpSRem:
case SpvOpSNegate: {
if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
return kMtfIntScalarOrVector;
}
// TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
case SpvOpFOrdEqual:
case SpvOpFUnordEqual:
case SpvOpFOrdNotEqual:
case SpvOpFUnordNotEqual:
case SpvOpFOrdLessThan:
case SpvOpFUnordLessThan:
case SpvOpFOrdGreaterThan:
case SpvOpFUnordGreaterThan:
case SpvOpFOrdLessThanEqual:
case SpvOpFUnordLessThanEqual:
case SpvOpFOrdGreaterThanEqual:
case SpvOpFUnordGreaterThanEqual: {
if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
if (operand_index_ == 2) return kMtfFloatScalarOrVector;
if (operand_index_ == 3) {
const uint32_t first_operand_id = GetInstWords()[3];
const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
return GetMtfIdOfType(first_operand_type);
}
break;
}
case SpvOpVectorShuffle: {
if (operand_index_ == 0) {
assert(inst_.num_operands > 4);
return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
}
assert(inst_.type_id);
if (operand_index_ == 2 || operand_index_ == 3)
return GetMtfVectorOfComponentType(
GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpVectorTimesScalar: {
if (operand_index_ == 0) {
// TODO(atgoo@github.com) Could be narrowed to vector of floats.
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
assert(inst_.type_id);
if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
if (operand_index_ == 3)
return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
break;
}
case SpvOpDot: {
if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
assert(inst_.type_id);
if (operand_index_ == 2)
return GetMtfVectorOfComponentType(inst_.type_id);
if (operand_index_ == 3) {
const uint32_t vector_id = GetInstWords()[3];
const uint32_t vector_type = id_to_type_id_.at(vector_id);
return GetMtfIdOfType(vector_type);
}
break;
}
case SpvOpTypeVector: {
if (operand_index_ == 1) {
return kMtfTypeScalar;
}
break;
}
case SpvOpTypeMatrix: {
if (operand_index_ == 1) {
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
}
break;
}
case SpvOpTypePointer: {
if (operand_index_ == 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeStruct: {
if (operand_index_ >= 1) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpTypeFunction: {
if (operand_index_ == 1) {
return kMtfTypeNonFunction;
}
if (operand_index_ >= 2) {
return kMtfTypeNonFunction;
}
break;
}
case SpvOpLoad: {
if (operand_index_ == 0) return kMtfTypeNonFunction;
if (operand_index_ == 2) {
assert(inst_.type_id);
return GetMtfPointerToType(inst_.type_id);
}
break;
}
case SpvOpStore: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 1) {
const uint32_t pointer_id = GetInstWords()[1];
const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
const val::Instruction* pointer_inst = FindDef(pointer_type);
assert(pointer_inst);
assert(pointer_inst->opcode() == SpvOpTypePointer);
const uint32_t data_type =
pointer_inst->word(pointer_inst->operands()[2].offset);
return GetMtfIdOfType(data_type);
}
break;
}
case SpvOpVariable: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
break;
}
case SpvOpAccessChain: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
if (operand_index_ == 2) return kMtfTypePointerToComposite;
if (operand_index_ >= 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
break;
}
case SpvOpCompositeConstruct: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const uint32_t composite_type = GetInstWords()[1];
if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
return kMtfFloatScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
return kMtfIntScalarOrVector;
if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
return kMtfBoolScalarOrVector;
}
break;
}
case SpvOpCompositeExtract: {
if (operand_index_ == 2) return kMtfComposite;
break;
}
case SpvOpConstantComposite: {
if (operand_index_ == 0) return kMtfTypeComposite;
if (operand_index_ >= 2) {
const val::Instruction* composite_type_inst = FindDef(inst_.type_id);
assert(composite_type_inst);
if (composite_type_inst->opcode() == SpvOpTypeVector) {
return GetMtfIdOfType(composite_type_inst->word(2));
}
}
break;
}
case SpvOpExtInst: {
if (operand_index_ == 2)
return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
if (operand_index_ >= 4) {
const uint32_t return_type = GetInstWords()[1];
const uint32_t ext_inst_type = inst_.ext_inst_type;
const uint32_t ext_inst_index = GetInstWords()[4];
// TODO(atgoo@github.com) The list of extended instructions is
// incomplete. Only common instructions and low-hanging fruits listed.
if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
switch (ext_inst_index) {
case GLSLstd450FAbs:
case GLSLstd450FClamp:
case GLSLstd450FMax:
case GLSLstd450FMin:
case GLSLstd450FMix:
case GLSLstd450Step:
case GLSLstd450SmoothStep:
case GLSLstd450Fma:
case GLSLstd450Pow:
case GLSLstd450Exp:
case GLSLstd450Exp2:
case GLSLstd450Log:
case GLSLstd450Log2:
case GLSLstd450Sqrt:
case GLSLstd450InverseSqrt:
case GLSLstd450Fract:
case GLSLstd450Floor:
case GLSLstd450Ceil:
case GLSLstd450Radians:
case GLSLstd450Degrees:
case GLSLstd450Sin:
case GLSLstd450Cos:
case GLSLstd450Tan:
case GLSLstd450Sinh:
case GLSLstd450Cosh:
case GLSLstd450Tanh:
case GLSLstd450Asin:
case GLSLstd450Acos:
case GLSLstd450Atan:
case GLSLstd450Atan2:
case GLSLstd450Asinh:
case GLSLstd450Acosh:
case GLSLstd450Atanh:
case GLSLstd450MatrixInverse:
case GLSLstd450Cross:
case GLSLstd450Normalize:
case GLSLstd450Reflect:
case GLSLstd450FaceForward:
return GetMtfIdOfType(return_type);
case GLSLstd450Length:
case GLSLstd450Distance:
case GLSLstd450Refract:
return kMtfFloatScalarOrVector;
default:
break;
}
} else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
switch (ext_inst_index) {
case OpenCLLIB::Fabs:
case OpenCLLIB::FClamp:
case OpenCLLIB::Fmax:
case OpenCLLIB::Fmin:
case OpenCLLIB::Step:
case OpenCLLIB::Smoothstep:
case OpenCLLIB::Fma:
case OpenCLLIB::Pow:
case OpenCLLIB::Exp:
case OpenCLLIB::Exp2:
case OpenCLLIB::Log:
case OpenCLLIB::Log2:
case OpenCLLIB::Sqrt:
case OpenCLLIB::Rsqrt:
case OpenCLLIB::Fract:
case OpenCLLIB::Floor:
case OpenCLLIB::Ceil:
case OpenCLLIB::Radians:
case OpenCLLIB::Degrees:
case OpenCLLIB::Sin:
case OpenCLLIB::Cos:
case OpenCLLIB::Tan:
case OpenCLLIB::Sinh:
case OpenCLLIB::Cosh:
case OpenCLLIB::Tanh:
case OpenCLLIB::Asin:
case OpenCLLIB::Acos:
case OpenCLLIB::Atan:
case OpenCLLIB::Atan2:
case OpenCLLIB::Asinh:
case OpenCLLIB::Acosh:
case OpenCLLIB::Atanh:
case OpenCLLIB::Cross:
case OpenCLLIB::Normalize:
return GetMtfIdOfType(return_type);
case OpenCLLIB::Length:
case OpenCLLIB::Distance:
return kMtfFloatScalarOrVector;
default:
break;
}
}
}
break;
}
case SpvOpFunction: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 3) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionTypeWithReturnType(return_type);
}
break;
}
case SpvOpFunctionCall: {
if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
if (operand_index_ == 2) {
const uint32_t return_type = GetInstWords()[1];
return GetMtfFunctionWithReturnType(return_type);
}
if (operand_index_ >= 3) {
const uint32_t function_id = GetInstWords()[3];
const val::Instruction* function_inst = FindDef(function_id);
if (!function_inst) return kMtfObject;
assert(function_inst->opcode() == SpvOpFunction);
const uint32_t function_type_id = function_inst->word(4);
const val::Instruction* function_type_inst = FindDef(function_type_id);
assert(function_type_inst);
assert(function_type_inst->opcode() == SpvOpTypeFunction);
const uint32_t argument_type = function_type_inst->word(operand_index_);
return GetMtfIdOfType(argument_type);
}
break;
}
case SpvOpReturnValue: {
if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
break;
}
case SpvOpBranchConditional: {
if (operand_index_ == 0)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
break;
}
case SpvOpSampledImage: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
break;
}
case SpvOpImageSampleImplicitLod: {
if (operand_index_ == 0)
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
if (operand_index_ == 2)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
if (operand_index_ == 3)
return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
break;
}
default:
break;
}
return kMtfNone;
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,337 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_CODEC_H_
#define SOURCE_COMP_MARKV_CODEC_H_
#include <list>
#include <map>
#include <memory>
#include <vector>
#include "source/assembly_grammar.h"
#include "source/comp/huffman_codec.h"
#include "source/comp/markv_model.h"
#include "source/comp/move_to_front.h"
#include "source/diagnostic.h"
#include "source/id_descriptor.h"
#include "source/val/instruction.h"
// Base class for MARK-V encoder and decoder. Contains common functionality
// such as:
// - Validator connection and validation state.
// - SPIR-V grammar and helper functions.
namespace spvtools {
namespace comp {
class MarkvLogger;
// Handles for move-to-front sequences. Enums which end with "Begin" define
// handle spaces which start at that value and span 16 or 32 bit wide.
enum : uint64_t {
kMtfNone = 0,
// All ids.
kMtfAll,
// All forward declared ids.
kMtfForwardDeclared,
// All type ids except for generated by OpTypeFunction.
kMtfTypeNonFunction,
// All labels.
kMtfLabel,
// All ids created by instructions which had type_id.
kMtfObject,
// All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
kMtfTypeScalar,
// All composite types.
kMtfTypeComposite,
// Boolean type or any vector type of it.
kMtfTypeBoolScalarOrVector,
// All float types or any vector floats type.
kMtfTypeFloatScalarOrVector,
// All int types or any vector int type.
kMtfTypeIntScalarOrVector,
// All types declared as return types in OpTypeFunction.
kMtfTypeReturnedByFunction,
// All composite objects.
kMtfComposite,
// All bool objects or vectors of bools.
kMtfBoolScalarOrVector,
// All float objects or vectors of float.
kMtfFloatScalarOrVector,
// All int objects or vectors of int.
kMtfIntScalarOrVector,
// All pointer types which point to composited.
kMtfTypePointerToComposite,
// Used by EncodeMtfRankHuffman.
kMtfGenericNonZeroRank,
// Handle space for ids of specific type.
kMtfIdOfTypeBegin = 0x10000,
// Handle space for ids generated by specific opcode.
kMtfIdGeneratedByOpcode = 0x20000,
// Handle space for ids of objects with type generated by specific opcode.
kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
// All vectors of specific component type.
kMtfVectorOfComponentTypeBegin = 0x40000,
// All vector types of specific size.
kMtfTypeVectorOfSizeBegin = 0x50000,
// All pointer types to specific type.
kMtfPointerToTypeBegin = 0x60000,
// All function types which return specific type.
kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
// All function objects which return specific type.
kMtfFunctionWithReturnTypeBegin = 0x80000,
// Short id descriptor space (max 16-bit).
kMtfShortIdDescriptorSpaceBegin = 0x90000,
// Long id descriptor space (32-bit).
kMtfLongIdDescriptorSpaceBegin = 0x100000000,
};
class MarkvCodec {
public:
static const uint32_t kMarkvMagicNumber;
// Mtf ranks smaller than this are encoded with Huffman coding.
static const uint32_t kMtfSmallestRankEncodedByValue;
// Signals that the mtf rank is too large to be encoded with Huffman.
static const uint32_t kMtfRankEncodedByValueSignal;
static const uint32_t kShortDescriptorNumBits;
static const size_t kByteBreakAfterInstIfLessThanUntilNextByte;
static uint32_t GetMarkvVersion();
virtual ~MarkvCodec();
protected:
struct MarkvHeader {
MarkvHeader();
uint32_t magic_number;
uint32_t markv_version;
// Magic number to identify or verify MarkvModel used for encoding.
uint32_t markv_model = 0;
uint32_t markv_length_in_bits = 0;
uint32_t spirv_version = 0;
uint32_t spirv_generator = 0;
};
// |model| is owned by the caller, must be not null and valid during the
// lifetime of the codec.
MarkvCodec(spv_const_context context, spv_validator_options validator_options,
const MarkvModel* model);
// Returns instruction which created |id| or nullptr if such instruction was
// not registered.
const val::Instruction* FindDef(uint32_t id) const {
const auto it = id_to_def_instruction_.find(id);
if (it == id_to_def_instruction_.end()) return nullptr;
return it->second;
}
size_t GetNumBitsToNextByte(size_t bit_pos) const;
bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const;
// Returns type id of vector type component.
uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
const val::Instruction* type_inst = FindDef(vector_type_id);
assert(type_inst);
assert(type_inst->opcode() == SpvOpTypeVector);
const uint32_t component_type =
type_inst->word(type_inst->operands()[1].offset);
return component_type;
}
// Returns mtf handle for ids of given type.
uint64_t GetMtfIdOfType(uint32_t type_id) const {
return kMtfIdOfTypeBegin + type_id;
}
// Returns mtf handle for ids generated by given opcode.
uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
return kMtfIdGeneratedByOpcode + opcode;
}
// Returns mtf handle for ids of type generated by given opcode.
uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
}
// Returns mtf handle for vectors of specific component type.
uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
return kMtfVectorOfComponentTypeBegin + type_id;
}
// Returns mtf handle for vector type of specific size.
uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
return kMtfTypeVectorOfSizeBegin + size;
}
// Returns mtf handle for pointers to specific size.
uint64_t GetMtfPointerToType(uint32_t type_id) const {
return kMtfPointerToTypeBegin + type_id;
}
// Returns mtf handle for function types with given return type.
uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
return kMtfFunctionTypeWithReturnTypeBegin + type_id;
}
// Returns mtf handle for functions with given return type.
uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
return kMtfFunctionWithReturnTypeBegin + type_id;
}
// Returns mtf handle for the given long id descriptor.
uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
return kMtfLongIdDescriptorSpaceBegin + descriptor;
}
// Returns mtf handle for the given short id descriptor.
uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
return kMtfShortIdDescriptorSpaceBegin + descriptor;
}
// Process data from the current instruction. This would update MTFs and
// other data containers.
void ProcessCurInstruction();
// Returns move-to-front handle to be used for the current operand slot.
// Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
uint64_t GetRuleBasedMtf();
// Returns words of the current instruction. Decoder has a different
// implementation and the array is valid only until the previously decoded
// word.
virtual const uint32_t* GetInstWords() const { return inst_.words; }
// Returns the opcode of the previous instruction.
SpvOp GetPrevOpcode() const {
if (instructions_.empty()) return SpvOpNop;
return instructions_.back()->opcode();
}
// Returns diagnostic stream, position index is set to instruction number.
DiagnosticStream Diag(spv_result_t error_code) const {
return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
"", error_code);
}
// Returns current id bound.
uint32_t GetIdBound() const { return id_bound_; }
// Sets current id bound, expected to be no lower than the previous one.
void SetIdBound(uint32_t id_bound) {
assert(id_bound >= id_bound_);
id_bound_ = id_bound;
}
// Returns Huffman codec for ranks of the mtf with given |handle|.
// Different mtfs can use different rank distributions.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
const auto it = mtf_huffman_codecs_.find(handle);
if (it == mtf_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Promotes id in all move-to-front sequences if ids can be shared by multiple
// sequences.
void PromoteIfNeeded(uint32_t id) {
if (!model_->AnyDescriptorHasCodingScheme() &&
model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
// Move-to-front sequences do not share ids. Nothing to do.
return;
}
multi_mtf_.Promote(id);
}
spv_validator_options validator_options_ = nullptr;
const AssemblyGrammar grammar_;
MarkvHeader header_;
// MARK-V model, not owned.
const MarkvModel* model_ = nullptr;
// Current instruction, current operand and current operand index.
spv_parsed_instruction_t inst_;
spv_parsed_operand_t operand_;
uint32_t operand_index_;
// Maps a result ID to its type ID. By convention:
// - a result ID that is a type definition maps to itself.
// - a result ID without a type maps to 0. (E.g. for OpLabel)
std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
// Container for all move-to-front sequences.
MultiMoveToFront multi_mtf_;
// Id of the current function or zero if outside of function.
uint32_t cur_function_id_ = 0;
// Return type of the current function.
uint32_t cur_function_return_type_ = 0;
// Remaining function parameter types. This container is filled on OpFunction,
// and drained on OpFunctionParameter.
std::list<uint32_t> remaining_function_parameter_types_;
// List of ids local to the current function.
std::vector<uint32_t> ids_local_to_cur_function_;
// List of instructions in the order they are given in the module.
std::vector<std::unique_ptr<const val::Instruction>> instructions_;
// Container/computer for long (32-bit) id descriptors.
IdDescriptorCollection long_id_descriptors_;
// Container/computer for short id descriptors.
// Short descriptors are stored in uint32_t, but their actual bit width is
// defined with kShortDescriptorNumBits.
// It doesn't seem logical to have a different computer for short id
// descriptors, since one could actually map/truncate long descriptors.
// But as short descriptors have collisions, the efficiency of
// compression depends on the collision pattern, and short descriptors
// produced by function ShortHashU32Array have been empirically proven to
// produce better results.
IdDescriptorCollection short_id_descriptors_;
// Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
// need to contain a different codec for every handle as most use one and the
// same.
std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
mtf_huffman_codecs_;
// If not nullptr, codec will log comments on the compression process.
std::unique_ptr<MarkvLogger> logger_;
spv_const_context context_ = nullptr;
private:
// Maps result id to the instruction which defined it.
std::unordered_map<uint32_t, const val::Instruction*> id_to_def_instruction_;
uint32_t id_bound_ = 1;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_CODEC_H_

View File

@@ -1,925 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv_decoder.h"
#include <cstring>
#include <iterator>
#include <numeric>
#include "source/ext_inst.h"
#include "source/opcode.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to decode non-id word with Huffman";
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
// The word decoded successfully.
*word = uint32_t(decoded_value);
assert(*word == decoded_value);
return SPV_SUCCESS;
}
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
}
const size_t chunk_length =
model_->GetOperandVariableWidthChunkLength(operand_.type);
if (chunk_length) {
if (!reader_.ReadVariableWidthU32(word, chunk_length))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to decode non-id word with varint";
} else {
if (!reader_.ReadUnencoded(word))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read unencoded non-id word";
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
uint32_t* opcode, uint32_t* num_operands) {
// First try to use the Markov chain codec.
auto* codec =
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode opcode_and_num_operands, previous opcode is "
<< spvOpcodeString(GetPrevOpcode());
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
// The word was successfully decoded.
*opcode = uint32_t(decoded_value & 0xFFFF);
*num_operands = uint32_t(decoded_value >> 16);
return SPV_SUCCESS;
}
// Received kMarkvNoneOfTheAbove signal, use fallback decoding.
}
// Fallback to base-rate codec.
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
assert(codec);
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode opcode_and_num_operands with global codec";
if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
// Received kMarkvNoneOfTheAbove signal, fallback further.
return SPV_UNSUPPORTED;
}
*opcode = uint32_t(decoded_value & 0xFFFF);
*num_operands = uint32_t(decoded_value >> 16);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
uint32_t fallback_method,
uint32_t* rank) {
const auto* codec = GetMtfHuffmanCodec(mtf);
if (!codec) {
assert(fallback_method != kMtfNone);
codec = GetMtfHuffmanCodec(fallback_method);
}
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
uint32_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
if (decoded_value == kMtfRankEncodedByValueSignal) {
// Decode by value.
if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode MTF rank with varint";
*rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
} else {
// Decode using Huffman coding.
assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
*rank = decoded_value;
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
auto* codec =
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
uint64_t mtf = kMtfNone;
if (codec) {
uint64_t decoded_value = 0;
if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode descriptor with Huffman";
if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
const uint32_t long_descriptor = uint32_t(decoded_value);
mtf = GetMtfLongIdDescriptor(long_descriptor);
}
}
if (mtf == kMtfNone) {
if (model_->id_fallback_strategy() !=
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
return SPV_UNSUPPORTED;
}
uint64_t decoded_value = 0;
if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
const uint32_t short_descriptor = uint32_t(decoded_value);
if (short_descriptor == 0) {
// Forward declared id.
return SPV_UNSUPPORTED;
}
mtf = GetMtfShortIdDescriptor(short_descriptor);
}
return DecodeExistingId(mtf, id);
}
spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
assert(multi_mtf_.GetSize(mtf) > 0);
*id = 0;
uint32_t rank = 0;
if (multi_mtf_.GetSize(mtf) == 1) {
rank = 1;
} else {
const spv_result_t result =
DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
if (result != SPV_SUCCESS) return result;
}
assert(rank);
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
{
const spv_result_t result = DecodeIdWithDescriptor(id);
if (result != SPV_UNSUPPORTED) return result;
}
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
SpvOp(inst_.opcode))(operand_index_);
uint32_t rank = 0;
*id = 0;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
uint64_t mtf = GetRuleBasedMtf();
if (mtf != kMtfNone && !can_forward_declare) {
return DecodeExistingId(mtf, id);
}
if (mtf == kMtfNone) mtf = kMtfAll;
{
const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
if (result != SPV_SUCCESS) return result;
}
if (rank == 0) {
// This is the first occurrence of a forward declared id.
*id = GetIdBound();
SetIdBound(*id + 1);
multi_mtf_.Insert(kMtfAll, *id);
multi_mtf_.Insert(kMtfForwardDeclared, *id);
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
} else {
if (!multi_mtf_.ValueFromRank(mtf, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
}
} else {
assert(can_forward_declare);
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to decode MTF rank with varint";
if (rank == 0) {
// This is the first occurrence of a forward declared id.
*id = GetIdBound();
SetIdBound(*id + 1);
multi_mtf_.Insert(kMtfForwardDeclared, *id);
} else {
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
}
}
assert(*id);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeTypeId() {
if (inst_.opcode == SpvOpFunctionParameter) {
assert(!remaining_function_parameter_types_.empty());
inst_.type_id = remaining_function_parameter_types_.front();
remaining_function_parameter_types_.pop_front();
return SPV_SUCCESS;
}
{
const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
if (result != SPV_UNSUPPORTED) return result;
}
assert(model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased);
uint64_t mtf = GetRuleBasedMtf();
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
operand_index_));
if (mtf == kMtfNone) {
mtf = kMtfTypeNonFunction;
// Function types should have been handled by GetRuleBasedMtf.
assert(inst_.opcode != SpvOpFunction);
}
return DecodeExistingId(mtf, &inst_.type_id);
}
spv_result_t MarkvDecoder::DecodeResultId() {
uint32_t rank = 0;
const uint64_t num_still_forward_declared =
multi_mtf_.GetSize(kMtfForwardDeclared);
if (num_still_forward_declared) {
// Some ids were forward declared. Check if this id is one of them.
uint64_t id_was_forward_declared;
if (!reader_.ReadBits(&id_was_forward_declared, 1))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read id_was_forward_declared flag";
if (id_was_forward_declared) {
if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read MTF rank of forward declared id";
if (rank) {
// The id was forward declared, recover it from kMtfForwardDeclared.
if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
&inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Forward declared MTF rank is out of bounds";
// We can now remove the id from kMtfForwardDeclared.
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to remove id from kMtfForwardDeclared";
}
}
}
if (inst_.result_id == 0) {
// The id was not forward declared, issue a new id.
inst_.result_id = GetIdBound();
SetIdBound(inst_.result_id + 1);
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
if (!rank) {
multi_mtf_.Insert(kMtfAll, inst_.result_id);
}
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width <= 32) {
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
} else {
assert(operand.number_bit_width <= 64);
uint64_t word = 0;
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
model_->s64_block_exponent()))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
std::memcpy(&word, &val, 8);
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
if (!reader_.ReadUnencoded(&word))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
} else {
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
}
inst_words_.push_back(static_cast<uint32_t>(word));
inst_words_.push_back(static_cast<uint32_t>(word >> 32));
}
return SPV_SUCCESS;
}
bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
const size_t num_bits_to_next_byte =
GetNumBitsToNextByte(reader_.GetNumReadBits());
if (num_bits_to_next_byte == 0 ||
num_bits_to_next_byte > byte_break_if_less_than)
return true;
uint64_t bits = 0;
if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
assert(bits == 0);
if (bits != 0) return false;
return true;
}
spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
const bool header_read_success =
reader_.ReadUnencoded(&header_.magic_number) &&
reader_.ReadUnencoded(&header_.markv_version) &&
reader_.ReadUnencoded(&header_.markv_model) &&
reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
reader_.ReadUnencoded(&header_.spirv_version) &&
reader_.ReadUnencoded(&header_.spirv_generator);
if (!header_read_success)
return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
if (header_.markv_length_in_bits == 0)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Header markv_length_in_bits field is zero";
if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has incorrect magic number";
// TODO(atgoo@github.com): Print version strings.
if (header_.markv_version != MarkvCodec::GetMarkvVersion())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec have different versions";
const uint32_t model_type = header_.markv_model >> 16;
const uint32_t model_version = header_.markv_model & 0xFFFF;
if (model_type != model_->model_type())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec use different MARK-V models";
if (model_version != model_->model_version())
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary and the codec use different versions if the same "
<< "MARK-V model";
spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
spirv_.resize(5, 0);
spirv_[0] = SpvMagicNumber;
spirv_[1] = header_.spirv_version;
spirv_[2] = header_.spirv_generator;
if (logger_) {
reader_.SetCallback(
[this](const std::string& str) { logger_->AppendBitSequence(str); });
}
while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
inst_ = {};
const spv_result_t decode_result = DecodeInstruction();
if (decode_result != SPV_SUCCESS) return decode_result;
}
if (validator_options_) {
spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
const spv_result_t result = spvValidateWithOptions(
context_, validator_options_, &validation_binary, nullptr);
if (result != SPV_SUCCESS) return result;
}
// Validate the decode binary
if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
!reader_.OnlyZeroesLeft()) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "MARK-V binary has wrong stated bit length "
<< reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
}
// Decoding of the module is finished, validation state should have correct
// id bound.
spirv_[3] = GetIdBound();
*spirv_binary = std::move(spirv_);
return SPV_SUCCESS;
}
// TODO(atgoo@github.com): The implementation borrows heavily from
// Parser::parseOperand.
// Consider coupling them together in some way once MARK-V codec is more mature.
// For now it's better to keep the code independent for experimentation
// purposes.
spv_result_t MarkvDecoder::DecodeOperand(
size_t operand_offset, const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands) {
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
memset(&operand_, 0, sizeof(operand_));
assert((operand_offset >> 16) == 0);
operand_.offset = static_cast<uint16_t>(operand_offset);
operand_.type = type;
// Set default values, may be updated later.
operand_.number_kind = SPV_NUMBER_NONE;
operand_.number_bit_width = 0;
const size_t first_word_index = inst_words_.size();
switch (type) {
case SPV_OPERAND_TYPE_RESULT_ID: {
const spv_result_t result = DecodeResultId();
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(inst_.result_id);
SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
PromoteIfNeeded(inst_.result_id);
break;
}
case SPV_OPERAND_TYPE_TYPE_ID: {
const spv_result_t result = DecodeTypeId();
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(inst_.type_id);
SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
PromoteIfNeeded(inst_.type_id);
break;
}
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
uint32_t id = 0;
const spv_result_t result = DecodeRefId(&id);
if (result != SPV_SUCCESS) return result;
if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
operand_.type = SPV_OPERAND_TYPE_ID;
if (opcode == SpvOpExtInst && operand_.offset == 3) {
// The current word is the extended instruction set id.
// Set the extended instruction set type for the current
// instruction.
auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
return Diag(SPV_ERROR_INVALID_ID)
<< "OpExtInst set id " << id
<< " does not reference an OpExtInstImport result Id";
}
inst_.ext_inst_type = ext_inst_type_iter->second;
}
}
inst_words_.push_back(id);
SetIdBound(std::max(GetIdBound(), id + 1));
PromoteIfNeeded(id);
break;
}
case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
assert(SpvOpExtInst == opcode);
assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
spv_ext_inst_desc ext_inst;
if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction number: " << word;
spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
// These are regular single-word literal integer operands.
// Post-parsing validation should check the range of the parsed value.
operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
// It turns out they are always unsigned integers!
operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
operand_.number_bit_width = 32;
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
if (opcode == SpvOpSwitch) {
// The literal operands have the same type as the value
// referenced by the selector Id.
const uint32_t selector_id = inst_words_.at(1);
const auto type_id_iter = id_to_type_id_.find(selector_id);
if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " has no type";
}
uint32_t type_id = type_id_iter->second;
if (selector_id == type_id) {
// Recall that by convention, a result ID that is a type definition
// maps to itself.
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is a type, not a value";
}
if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
return error;
if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid OpSwitch: selector id " << selector_id
<< " is not a scalar integer";
}
} else {
assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
// The literal number type is determined by the type Id for the
// constant.
assert(inst_.type_id);
if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
return error;
}
if (auto error = DecodeLiteralNumber(operand_)) return error;
break;
}
case SPV_OPERAND_TYPE_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
std::vector<char> str;
auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
if (codec) {
std::string decoded_string;
const bool huffman_result =
codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
assert(huffman_result);
if (!huffman_result)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal string";
if (decoded_string != "kMarkvNoneOfTheAbove") {
std::copy(decoded_string.begin(), decoded_string.end(),
std::back_inserter(str));
str.push_back('\0');
}
}
// The loop is expected to terminate once we encounter '\0' or exhaust
// the bit stream.
if (str.empty()) {
while (true) {
char ch = 0;
if (!reader_.ReadUnencoded(&ch))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read literal string";
str.push_back(ch);
if (ch == '\0') break;
}
}
while (str.size() % 4 != 0) str.push_back('\0');
inst_words_.resize(inst_words_.size() + str.size() / 4);
std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
if (SpvOpExtInstImport == opcode) {
// Record the extended instruction type for the ID for this import.
// There is only one string literal argument to OpExtInstImport,
// so it's sufficient to guard this just on the opcode.
const spv_ext_inst_type_t ext_inst_type =
spvExtInstImportTypeGet(str.data());
if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid extended instruction import '" << str.data()
<< "'";
}
// We must have parsed a valid result ID. It's a condition
// of the grammar, and we only accept non-zero result Ids.
assert(inst_.result_id);
const bool inserted =
import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
.second;
(void)inserted;
assert(inserted);
}
break;
}
case SPV_OPERAND_TYPE_CAPABILITY:
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
case SPV_OPERAND_TYPE_MEMORY_MODEL:
case SPV_OPERAND_TYPE_EXECUTION_MODE:
case SPV_OPERAND_TYPE_STORAGE_CLASS:
case SPV_OPERAND_TYPE_DIMENSIONALITY:
case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
case SPV_OPERAND_TYPE_LINKAGE_TYPE:
case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
case SPV_OPERAND_TYPE_DECORATION:
case SPV_OPERAND_TYPE_BUILT_IN:
case SPV_OPERAND_TYPE_GROUP_OPERATION:
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
// A single word that is a plain enum value.
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
spv_operand_desc entry;
if (grammar_.lookupOperand(type, word, &entry)) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid " << spvOperandTypeStr(operand_.type)
<< " operand: " << word;
}
// Prepare to accept operands to this operand, if needed.
spvPushOperandTypes(entry->operandTypes, expected_operands);
break;
}
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
case SPV_OPERAND_TYPE_LOOP_CONTROL:
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
// This operand is a mask.
uint32_t word = 0;
const spv_result_t result = DecodeNonIdWord(&word);
if (result != SPV_SUCCESS) return result;
inst_words_.push_back(word);
// Map an optional operand type to its corresponding concrete type.
if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
operand_.type = SPV_OPERAND_TYPE_IMAGE;
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
// MSB to LSB since we can only prepend operands to a pattern.
// The only case in the grammar where you have more than one mask bit
// having an operand is for image operands. See SPIR-V 3.14 Image
// Operands.
uint32_t remaining_word = word;
for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
if (remaining_word & mask) {
spv_operand_desc entry;
if (grammar_.lookupOperand(type, mask, &entry)) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Invalid " << spvOperandTypeStr(operand_.type)
<< " operand: " << word << " has invalid mask component "
<< mask;
}
remaining_word ^= mask;
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
if (word == 0) {
// An all-zeroes mask *might* also be valid.
spv_operand_desc entry;
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
// Prepare for its operands, if any.
spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
break;
}
default:
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Internal error: Unhandled operand type: " << type;
}
operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
assert(spvOperandIsConcrete(operand_.type));
parsed_operands_.push_back(operand_);
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::DecodeInstruction() {
parsed_operands_.clear();
inst_words_.clear();
// Opcode/num_words placeholder, the word will be filled in later.
inst_words_.push_back(0);
bool num_operands_still_unknown = true;
{
uint32_t opcode = 0;
uint32_t num_operands = 0;
const spv_result_t opcode_decoding_result =
DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
if (opcode_decoding_result < 0) return opcode_decoding_result;
if (opcode_decoding_result == SPV_SUCCESS) {
inst_.num_operands = static_cast<uint16_t>(num_operands);
num_operands_still_unknown = false;
} else {
if (!reader_.ReadVariableWidthU32(&opcode,
model_->opcode_chunk_length())) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read opcode of instruction";
}
}
inst_.opcode = static_cast<uint16_t>(opcode);
}
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
spv_opcode_desc opcode_desc;
if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
}
spv_operand_pattern_t expected_operands;
expected_operands.reserve(opcode_desc->numTypes);
for (auto i = 0; i < opcode_desc->numTypes; i++) {
expected_operands.push_back(
opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
}
if (num_operands_still_unknown) {
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
model_->num_operands_chunk_length()))
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to read num_operands of instruction";
} else {
inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
}
}
for (operand_index_ = 0;
operand_index_ < static_cast<size_t>(inst_.num_operands);
++operand_index_) {
assert(!expected_operands.empty());
const spv_operand_type_t type =
spvTakeFirstMatchableOperand(&expected_operands);
const size_t operand_offset = inst_words_.size();
const spv_result_t decode_result =
DecodeOperand(operand_offset, type, &expected_operands);
if (decode_result != SPV_SUCCESS) return decode_result;
}
assert(inst_.num_operands == parsed_operands_.size());
// Only valid while inst_words_ and parsed_operands_ remain unchanged (until
// next DecodeInstruction call).
inst_.words = inst_words_.data();
inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
inst_.num_words = static_cast<uint16_t>(inst_words_.size());
inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
assert(inst_.num_words ==
std::accumulate(
parsed_operands_.begin(), parsed_operands_.end(), 1,
[](int num_words, const spv_parsed_operand_t& operand) {
return num_words += operand.num_words;
}) &&
"num_words in instruction doesn't correspond to the sum of num_words"
"in the operands");
RecordNumberType();
ProcessCurInstruction();
if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
if (logger_) {
logger_->NewLine();
std::stringstream ss;
ss << spvOpcodeString(opcode) << " ";
for (size_t index = 1; index < inst_words_.size(); ++index)
ss << inst_words_[index] << " ";
logger_->AppendText(ss.str());
logger_->NewLine();
logger_->NewLine();
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
}
return SPV_SUCCESS;
}
spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
assert(type_id != 0);
auto type_info_iter = type_id_to_number_type_info_.find(type_id);
if (type_info_iter == type_id_to_number_type_info_.end()) {
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a type";
}
const NumberType& info = type_info_iter->second;
if (info.type == SPV_NUMBER_NONE) {
// This is a valid type, but for something other than a scalar number.
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Type Id " << type_id << " is not a scalar numeric type";
}
parsed_operand->number_kind = info.type;
parsed_operand->number_bit_width = info.bit_width;
// Round up the word count.
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
return SPV_SUCCESS;
}
void MarkvDecoder::RecordNumberType() {
const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
if (spvOpcodeGeneratesType(opcode)) {
NumberType info = {SPV_NUMBER_NONE, 0};
if (SpvOpTypeInt == opcode) {
info.bit_width = inst_.words[inst_.operands[1].offset];
info.type = inst_.words[inst_.operands[2].offset]
? SPV_NUMBER_SIGNED_INT
: SPV_NUMBER_UNSIGNED_INT;
} else if (SpvOpTypeFloat == opcode) {
info.bit_width = inst_.words[inst_.operands[1].offset];
info.type = SPV_NUMBER_FLOATING;
}
// The *result* Id of a type generating instruction is the type Id.
type_id_to_number_type_info_[inst_.result_id] = info;
}
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,175 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/bit_stream.h"
#include "source/comp/markv.h"
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/util/make_unique.h"
#ifndef SOURCE_COMP_MARKV_DECODER_H_
#define SOURCE_COMP_MARKV_DECODER_H_
namespace spvtools {
namespace comp {
class MarkvLogger;
// Decodes MARK-V buffers written by MarkvEncoder.
class MarkvDecoder : public MarkvCodec {
public:
// |model| is owned by the caller, must be not null and valid during the
// lifetime of MarkvEncoder.
MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel* model)
: MarkvCodec(context, GetValidatorOptions(options), model),
options_(options),
reader_(markv) {
SetIdBound(1);
parsed_operands_.reserve(25);
inst_words_.reserve(25);
}
~MarkvDecoder() = default;
// Creates an internal logger which writes comments on the decoding process.
void CreateLogger(MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer) {
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
}
// Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
// Can be called only once. Fails if data of wrong format or ends prematurely,
// of if validation fails.
spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
// Creates and returns validator options. Returned value owned by the caller.
static spv_validator_options GetValidatorOptions(
const MarkvCodecOptions& options) {
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
: nullptr;
}
private:
// Describes the format of a typed literal number.
struct NumberType {
spv_number_kind_t type;
uint32_t bit_width;
};
// Reads a single bit from reader_. The read bit is stored in |bit|.
// Returns false iff reader_ fails.
bool ReadBit(bool* bit) {
uint64_t bits = 0;
const bool result = reader_.ReadBits(&bits, 1);
if (result) *bit = bits ? true : false;
return result;
};
// Returns ReadBit bound to the class object.
std::function<bool(bool*)> GetReadBitCallback() {
return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
}
// Reads a single non-id word from bit stream. operand_.type determines if
// the word needs to be decoded and how.
spv_result_t DecodeNonIdWord(uint32_t* word);
// Reads and decodes both opcode and num_operands as a single code.
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
uint32_t* num_operands);
// Reads mtf rank from bit stream. |mtf| is used to determine the codec
// scheme. |fallback_method| is used if no codec defined for |mtf|.
spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
uint32_t* rank);
// Reads id using coding based on mtf associated with the id descriptor.
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
spv_result_t DecodeIdWithDescriptor(uint32_t* id);
// Reads id using coding based on the given |mtf|, which is expected to
// contain the needed |id|.
spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
// Reads type id of the current instruction if can't be inferred.
spv_result_t DecodeTypeId();
// Reads result id of the current instruction if can't be inferred.
spv_result_t DecodeResultId();
// Reads id which is neither type nor result id.
spv_result_t DecodeRefId(uint32_t* id);
// Reads and discards bits until the beginning of the next byte if the
// number of bits until the next byte is less than |byte_break_if_less_than|.
bool ReadToByteBreak(size_t byte_break_if_less_than);
// Returns instruction words decoded up to this point.
const uint32_t* GetInstWords() const override { return inst_words_.data(); }
// Reads a literal number as it is described in |operand| from the bit stream,
// decodes and writes it to spirv_.
spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
// Reads instruction from bit stream, decodes and validates it.
// Decoded instruction is valid until the next call of DecodeInstruction().
spv_result_t DecodeInstruction();
// Read operand from the stream decodes and validates it.
spv_result_t DecodeOperand(size_t operand_offset,
const spv_operand_type_t type,
spv_operand_pattern_t* expected_operands);
// Records the numeric type for an operand according to the type information
// associated with the given non-zero type Id. This can fail if the type Id
// is not a type Id, or if the type Id does not reference a scalar numeric
// type. On success, return SPV_SUCCESS and populates the num_words,
// number_kind, and number_bit_width fields of parsed_operand.
spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
uint32_t type_id);
// Records the number type for the current instruction, if it generates a
// type. For types that aren't scalar numbers, record something with number
// kind SPV_NUMBER_NONE.
void RecordNumberType();
MarkvCodecOptions options_;
// Temporary sink where decoded SPIR-V words are written. Once it contains the
// entire module, the container is moved and returned.
std::vector<uint32_t> spirv_;
// Bit stream containing encoded data.
BitReaderWord64 reader_;
// Temporary storage for operands of the currently parsed instruction.
// Valid until next DecodeInstruction call.
std::vector<spv_parsed_operand_t> parsed_operands_;
// Temporary storage for current instruction words.
// Valid until next DecodeInstruction call.
std::vector<uint32_t> inst_words_;
// Maps a type ID to its number type description.
std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
// Maps an ExtInstImport id to the extended instruction type.
std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_DECODER_H_

View File

@@ -1,486 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv_encoder.h"
#include "source/binary.h"
#include "source/opcode.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
namespace {
const size_t kCommentNumWhitespaces = 2;
} // namespace
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
if (codec->Encode(word, &bits, &num_bits)) {
// Encoding successful.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// Encoding failed, write kMarkvNoneOfTheAbove flag.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Non-id word Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback encoding.
const size_t chunk_length =
model_->GetOperandVariableWidthChunkLength(operand_.type);
if (chunk_length) {
writer_.WriteVariableWidthU32(word, chunk_length);
} else {
writer_.WriteUnencoded(word);
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands) {
uint64_t bits = 0;
size_t num_bits = 0;
const uint32_t word = opcode | (num_operands << 16);
// First try to use the Markov chain codec.
auto* codec =
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
if (codec) {
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and use fallback encoding.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "opcode_and_num_operands Huffman table for "
<< spvOpcodeString(GetPrevOpcode())
<< "is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback to base-rate codec.
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
assert(codec);
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and return false.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Global opcode_and_num_operands Huffman table is missing "
<< "kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
return SPV_UNSUPPORTED;
}
}
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
uint64_t fallback_method) {
const auto* codec = GetMtfHuffmanCodec(mtf);
if (!codec) {
assert(fallback_method != kMtfNone);
codec = GetMtfHuffmanCodec(fallback_method);
}
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
uint64_t bits = 0;
size_t num_bits = 0;
if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) {
// Encode using Huffman coding.
if (!codec->Encode(rank, &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode MTF rank with Huffman";
writer_.WriteBits(bits, num_bits);
} else {
// Encode by value.
if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode kMtfRankEncodedByValueSignal";
writer_.WriteBits(bits, num_bits);
writer_.WriteVariableWidthU32(
rank - MarkvCodec::kMtfSmallestRankEncodedByValue,
model_->mtf_rank_chunk_length());
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
// Get the descriptor for id.
const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
auto* codec =
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
uint64_t bits = 0;
size_t num_bits = 0;
uint64_t mtf = kMtfNone;
if (long_descriptor && codec &&
codec->Encode(long_descriptor, &bits, &num_bits)) {
// If the descriptor exists and is in the table, write the descriptor and
// proceed to encoding the rank.
writer_.WriteBits(bits, num_bits);
mtf = GetMtfLongIdDescriptor(long_descriptor);
} else {
if (codec) {
// The descriptor doesn't exist or we have no coding for it. Write
// kMarkvNoneOfTheAbove and go to fallback method.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Descriptor Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
if (model_->id_fallback_strategy() !=
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
return SPV_UNSUPPORTED;
}
const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits);
if (short_descriptor == 0) {
// Forward declared id.
return SPV_UNSUPPORTED;
}
mtf = GetMtfShortIdDescriptor(short_descriptor);
}
// Descriptor has been encoded. Now encode the rank of the id in the
// associated mtf sequence.
return EncodeExistingId(mtf, id);
}
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
assert(multi_mtf_.GetSize(mtf) > 0);
if (multi_mtf_.GetSize(mtf) == 1) {
// If the sequence has only one element no need to write rank, the decoder
// would make the same decision.
return SPV_SUCCESS;
}
uint32_t rank = 0;
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
}
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
SpvOp(inst_.opcode))(operand_index_);
uint32_t rank = 0;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
// Encode using rule-based mtf.
uint64_t mtf = GetRuleBasedMtf();
if (mtf != kMtfNone && !can_forward_declare) {
assert(multi_mtf_.HasValue(kMtfAll, id));
return EncodeExistingId(mtf, id);
}
if (mtf == kMtfNone) mtf = kMtfAll;
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfAll, id);
multi_mtf_.Insert(kMtfForwardDeclared, id);
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
rank = 0;
}
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
} else {
assert(can_forward_declare);
if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfForwardDeclared, id);
rank = 0;
}
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
return SPV_SUCCESS;
}
}
spv_result_t MarkvEncoder::EncodeTypeId() {
if (inst_.opcode == SpvOpFunctionParameter) {
assert(!remaining_function_parameter_types_.empty());
assert(inst_.type_id == remaining_function_parameter_types_.front());
remaining_function_parameter_types_.pop_front();
return SPV_SUCCESS;
}
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
assert(model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased);
uint64_t mtf = GetRuleBasedMtf();
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
operand_index_));
if (mtf == kMtfNone) {
mtf = kMtfTypeNonFunction;
// Function types should have been handled by GetRuleBasedMtf.
assert(inst_.opcode != SpvOpFunction);
}
return EncodeExistingId(mtf, inst_.type_id);
}
spv_result_t MarkvEncoder::EncodeResultId() {
uint32_t rank = 0;
const uint64_t num_still_forward_declared =
multi_mtf_.GetSize(kMtfForwardDeclared);
if (num_still_forward_declared) {
// We write the rank only if kMtfForwardDeclared is not empty. If it is
// empty the decoder knows that there are no forward declared ids to expect.
if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
// This is a definition of a forward declared id. We can remove the id
// from kMtfForwardDeclared.
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to remove id from kMtfForwardDeclared";
writer_.WriteBits(1, 1);
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
} else {
rank = 0;
writer_.WriteBits(0, 1);
}
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
if (!rank) {
multi_mtf_.Insert(kMtfAll, inst_.result_id);
}
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width <= 32) {
const uint32_t word = inst_.words[operand.offset];
return EncodeNonIdWord(word);
} else {
assert(operand.number_bit_width <= 64);
const uint64_t word = uint64_t(inst_.words[operand.offset]) |
(uint64_t(inst_.words[operand.offset + 1]) << 32);
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
std::memcpy(&val, &word, 8);
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
model_->s64_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
writer_.WriteUnencoded(word);
} else {
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
}
}
return SPV_SUCCESS;
}
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
const size_t num_bits_to_next_byte =
GetNumBitsToNextByte(writer_.GetNumBits());
if (num_bits_to_next_byte == 0 ||
num_bits_to_next_byte > byte_break_if_less_than)
return;
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<byte break>");
}
writer_.WriteBits(0, num_bits_to_next_byte);
}
spv_result_t MarkvEncoder::EncodeInstruction(
const spv_parsed_instruction_t& inst) {
SpvOp opcode = SpvOp(inst.opcode);
inst_ = inst;
LogDisassemblyInstruction();
const spv_result_t opcode_encodig_result =
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
if (opcode_encodig_result < 0) return opcode_encodig_result;
if (opcode_encodig_result != SPV_SUCCESS) {
// Fallback encoding for opcode and num_operands.
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
// If the opcode has a variable number of operands, encode the number of
// operands with the instruction.
if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
writer_.WriteVariableWidthU16(inst.num_operands,
model_->num_operands_chunk_length());
}
}
// Write operands.
const uint32_t num_operands = inst_.num_operands;
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
operand_ = inst_.operands[operand_index_];
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<");
logger_->AppendText(spvOperandTypeStr(operand_.type));
logger_->AppendText(">");
}
switch (operand_.type) {
case SPV_OPERAND_TYPE_RESULT_ID:
case SPV_OPERAND_TYPE_TYPE_ID:
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
const uint32_t id = inst_.words[operand_.offset];
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
const spv_result_t result = EncodeTypeId();
if (result != SPV_SUCCESS) return result;
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
const spv_result_t result = EncodeResultId();
if (result != SPV_SUCCESS) return result;
} else {
const spv_result_t result = EncodeRefId(id);
if (result != SPV_SUCCESS) return result;
}
PromoteIfNeeded(id);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
const spv_result_t result =
EncodeNonIdWord(inst_.words[operand_.offset]);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
const spv_result_t result = EncodeLiteralNumber(operand_);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_LITERAL_STRING: {
const char* src =
reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
const std::string str = src;
if (codec->Encode(str, &bits, &num_bits)) {
writer_.WriteBits(bits, num_bits);
break;
} else {
bool result =
codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
(void)result;
assert(result);
writer_.WriteBits(bits, num_bits);
}
}
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
if (length == operand_.num_words * 4)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to find terminal character of literal string";
for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
break;
}
default: {
for (int i = 0; i < operand_.num_words; ++i) {
const uint32_t word = inst_.words[operand_.offset + i];
const spv_result_t result = EncodeNonIdWord(word);
if (result != SPV_SUCCESS) return result;
}
break;
}
}
}
AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte);
if (logger_) {
logger_->NewLine();
logger_->NewLine();
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
}
ProcessCurInstruction();
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,167 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/bit_stream.h"
#include "source/comp/markv.h"
#include "source/comp/markv_codec.h"
#include "source/comp/markv_logger.h"
#include "source/util/make_unique.h"
#ifndef SOURCE_COMP_MARKV_ENCODER_H_
#define SOURCE_COMP_MARKV_ENCODER_H_
#include <cstring>
namespace spvtools {
namespace comp {
// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
// EncodeInstruction which can be used as callback by spvBinaryParse.
// Encoded binary is written to an internally maintained bitstream.
// After the last instruction is encoded, the resulting MARK-V binary can be
// acquired by calling GetMarkvBinary().
//
// The encoder uses SPIR-V validator to keep internal state, therefore
// SPIR-V binary needs to be able to pass validator checks.
// CreateCommentsLogger() can be used to enable the encoder to write comments
// on how encoding was done, which can later be accessed with GetComments().
class MarkvEncoder : public MarkvCodec {
public:
// |model| is owned by the caller, must be not null and valid during the
// lifetime of MarkvEncoder.
MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options,
const MarkvModel* model)
: MarkvCodec(context, GetValidatorOptions(options), model),
options_(options) {}
~MarkvEncoder() override = default;
// Writes data from SPIR-V header to MARK-V header.
spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */,
uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t /* schema */) {
SetIdBound(id_bound);
header_.spirv_version = version;
header_.spirv_generator = generator;
return SPV_SUCCESS;
}
// Creates an internal logger which writes comments on the encoding process.
void CreateLogger(MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer) {
logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
writer_.SetCallback(
[this](const std::string& str) { logger_->AppendBitSequence(str); });
}
// Encodes SPIR-V instruction to MARK-V and writes to bit stream.
// Operation can fail if the instruction fails to pass the validator or if
// the encoder stubmles on something unexpected.
spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
// Concatenates MARK-V header and the bit stream with encoded instructions
// into a single buffer and returns it as spv_markv_binary. The returned
// value is owned by the caller and needs to be destroyed with
// spvMarkvBinaryDestroy().
std::vector<uint8_t> GetMarkvBinary() {
header_.markv_length_in_bits =
static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
header_.markv_model =
(model_->model_type() << 16) | model_->model_version();
const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
std::vector<uint8_t> markv(num_bytes);
assert(writer_.GetData());
std::memcpy(markv.data(), &header_, sizeof(header_));
std::memcpy(markv.data() + sizeof(header_), writer_.GetData(),
writer_.GetDataSizeBytes());
return markv;
}
// Optionally adds disassembly to the comments.
// Disassembly should contain all instructions in the module separated by
// \n, and no header.
void SetDisassembly(std::string&& disassembly) {
disassembly_ = MakeUnique<std::stringstream>(std::move(disassembly));
}
// Extracts the next instruction line from the disassembly and logs it.
void LogDisassemblyInstruction() {
if (logger_ && disassembly_) {
std::string line;
std::getline(*disassembly_, line, '\n');
logger_->AppendTextNewLine(line);
}
}
private:
// Creates and returns validator options. Returned value owned by the caller.
static spv_validator_options GetValidatorOptions(
const MarkvCodecOptions& options) {
return options.validate_spirv_binary ? spvValidatorOptionsCreate()
: nullptr;
}
// Writes a single word to bit stream. operand_.type determines if the word is
// encoded and how.
spv_result_t EncodeNonIdWord(uint32_t word);
// Writes both opcode and num_operands as a single code.
// Returns SPV_UNSUPPORTED iff no suitable codec was found.
spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands);
// Writes mtf rank to bit stream. |mtf| is used to determine the codec
// scheme. |fallback_method| is used if no codec defined for |mtf|.
spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
uint64_t fallback_method);
// Writes id using coding based on mtf associated with the id descriptor.
// Returns SPV_UNSUPPORTED iff fallback method needs to be used.
spv_result_t EncodeIdWithDescriptor(uint32_t id);
// Writes id using coding based on the given |mtf|, which is expected to
// contain the given |id|.
spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
// Writes type id of the current instruction if can't be inferred.
spv_result_t EncodeTypeId();
// Writes result id of the current instruction if can't be inferred.
spv_result_t EncodeResultId();
// Writes ids which are neither type nor result ids.
spv_result_t EncodeRefId(uint32_t id);
// Writes bits to the stream until the beginning of the next byte if the
// number of bits until the next byte is less than |byte_break_if_less_than|.
void AddByteBreak(size_t byte_break_if_less_than);
// Encodes a literal number operand and writes it to the bit stream.
spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
MarkvCodecOptions options_;
// Bit stream where encoded instructions are written.
BitWriterWord64 writer_;
// If not nullptr, disassembled instruction lines will be written to comments.
// Format: \n separated instruction lines, no header.
std::unique_ptr<std::stringstream> disassembly_;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_ENCODER_H_

View File

@@ -1,93 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_LOGGER_H_
#define SOURCE_COMP_MARKV_LOGGER_H_
#include "source/comp/markv.h"
namespace spvtools {
namespace comp {
class MarkvLogger {
public:
MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer)
: log_consumer_(log_consumer), debug_consumer_(debug_consumer) {}
void AppendText(const std::string& str) {
Append(str);
use_delimiter_ = false;
}
void AppendTextNewLine(const std::string& str) {
Append(str);
Append("\n");
use_delimiter_ = false;
}
void AppendBitSequence(const std::string& str) {
if (debug_consumer_) instruction_bits_ << str;
if (use_delimiter_) Append("-");
Append(str);
use_delimiter_ = true;
}
void AppendWhitespaces(size_t num) {
Append(std::string(num, ' '));
use_delimiter_ = false;
}
void NewLine() {
Append("\n");
use_delimiter_ = false;
}
bool DebugInstruction(const spv_parsed_instruction_t& inst) {
bool result = true;
if (debug_consumer_) {
result = debug_consumer_(
std::vector<uint32_t>(inst.words, inst.words + inst.num_words),
instruction_bits_.str(), instruction_comment_.str());
instruction_bits_.str(std::string());
instruction_comment_.str(std::string());
}
return result;
}
private:
MarkvLogger(const MarkvLogger&) = delete;
MarkvLogger(MarkvLogger&&) = delete;
MarkvLogger& operator=(const MarkvLogger&) = delete;
MarkvLogger& operator=(MarkvLogger&&) = delete;
void Append(const std::string& str) {
if (log_consumer_) log_consumer_(str);
if (debug_consumer_) instruction_comment_ << str;
}
MarkvLogConsumer log_consumer_;
MarkvDebugConsumer debug_consumer_;
std::stringstream instruction_bits_;
std::stringstream instruction_comment_;
// If true a delimiter will be appended before the next bit sequence.
// Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
bool use_delimiter_ = false;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_LOGGER_H_

View File

@@ -1,232 +0,0 @@
// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MARKV_MODEL_H_
#define SOURCE_COMP_MARKV_MODEL_H_
#include <unordered_set>
#include "source/comp/huffman_codec.h"
#include "source/latest_version_spirv_header.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
// Base class for MARK-V models.
// The class contains encoding/decoding model with various constants and
// codecs used by the compression algorithm.
class MarkvModel {
public:
MarkvModel()
: operand_chunk_lengths_(
static_cast<size_t>(SPV_OPERAND_TYPE_NUM_OPERAND_TYPES), 0) {
// Set default values.
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPE_ID] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_RESULT_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SCOPE_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID] = 8;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LITERAL_INTEGER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_CAPABILITY] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SOURCE_LANGUAGE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODEL] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ADDRESSING_MODEL] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_MEMORY_MODEL] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXECUTION_MODE] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_STORAGE_CLASS] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_DIMENSIONALITY] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_ROUNDING_MODE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LINKAGE_TYPE] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_ACCESS_QUALIFIER] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE] = 3;
operand_chunk_lengths_[SPV_OPERAND_TYPE_DECORATION] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_BUILT_IN] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_GROUP_OPERATION] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO] = 2;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FP_FAST_MATH_MODE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_FUNCTION_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_LOOP_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_IMAGE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_IMAGE] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_SELECTION_CONTROL] = 4;
operand_chunk_lengths_[SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER] = 6;
operand_chunk_lengths_[SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER] = 6;
}
uint32_t model_type() const { return model_type_; }
uint32_t model_version() const { return model_version_; }
uint32_t opcode_chunk_length() const { return opcode_chunk_length_; }
uint32_t num_operands_chunk_length() const {
return num_operands_chunk_length_;
}
uint32_t mtf_rank_chunk_length() const { return mtf_rank_chunk_length_; }
uint32_t u64_chunk_length() const { return u64_chunk_length_; }
uint32_t s64_chunk_length() const { return s64_chunk_length_; }
uint32_t s64_block_exponent() const { return s64_block_exponent_; }
enum class IdFallbackStrategy {
kRuleBased = 0,
kShortDescriptor,
};
IdFallbackStrategy id_fallback_strategy() const {
return id_fallback_strategy_;
}
// Returns a codec for common opcode_and_num_operands words for the given
// previous opcode. May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetOpcodeAndNumOperandsMarkovHuffmanCodec(
uint32_t prev_opcode) const {
if (prev_opcode == SpvOpNop)
return opcode_and_num_operands_huffman_codec_.get();
const auto it =
opcode_and_num_operands_markov_huffman_codecs_.find(prev_opcode);
if (it == opcode_and_num_operands_markov_huffman_codecs_.end())
return nullptr;
return it->second.get();
}
// Returns a codec for common non-id words used for given operand slot.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetNonIdWordHuffmanCodec(
uint32_t opcode, uint32_t operand_index) const {
const auto it = non_id_word_huffman_codecs_.find(
std::pair<uint32_t, uint32_t>(opcode, operand_index));
if (it == non_id_word_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Returns a codec for common id descriptos used for given operand slot.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<uint64_t>* GetIdDescriptorHuffmanCodec(
uint32_t opcode, uint32_t operand_index) const {
const auto it = id_descriptor_huffman_codecs_.find(
std::pair<uint32_t, uint32_t>(opcode, operand_index));
if (it == id_descriptor_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Returns a codec for common strings used by the given opcode.
// Operand slot is defined by the opcode and the operand index.
// May return nullptr if the codec doesn't exist.
const HuffmanCodec<std::string>* GetLiteralStringHuffmanCodec(
uint32_t opcode) const {
const auto it = literal_string_huffman_codecs_.find(opcode);
if (it == literal_string_huffman_codecs_.end()) return nullptr;
return it->second.get();
}
// Checks if |descriptor| has a coding scheme in any of
// id_descriptor_huffman_codecs_.
bool DescriptorHasCodingScheme(uint32_t descriptor) const {
return descriptors_with_coding_scheme_.count(descriptor);
}
// Checks if any descriptor has a coding scheme.
bool AnyDescriptorHasCodingScheme() const {
return !descriptors_with_coding_scheme_.empty();
}
// Returns chunk length used for variable length encoding of spirv operand
// words.
uint32_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) const {
return operand_chunk_lengths_.at(static_cast<size_t>(type));
}
// Sets model type.
void SetModelType(uint32_t in_model_type) { model_type_ = in_model_type; }
// Sets model version.
void SetModelVersion(uint32_t in_model_version) {
model_version_ = in_model_version;
}
// Returns value used by Huffman codecs as a signal that a value is not in the
// coding table.
static uint64_t GetMarkvNoneOfTheAbove() {
// Magic number.
return 1111111111111111111;
}
MarkvModel(const MarkvModel&) = delete;
const MarkvModel& operator=(const MarkvModel&) = delete;
protected:
// Huffman codec for base-rate of opcode_and_num_operands.
std::unique_ptr<HuffmanCodec<uint64_t>>
opcode_and_num_operands_huffman_codec_;
// Huffman codecs for opcode_and_num_operands. The map key is previous opcode.
std::map<uint32_t, std::unique_ptr<HuffmanCodec<uint64_t>>>
opcode_and_num_operands_markov_huffman_codecs_;
// Huffman codecs for non-id single-word operand values.
// The map key is pair <opcode, operand_index>.
std::map<std::pair<uint32_t, uint32_t>,
std::unique_ptr<HuffmanCodec<uint64_t>>>
non_id_word_huffman_codecs_;
// Huffman codecs for id descriptors. The map key is pair
// <opcode, operand_index>.
std::map<std::pair<uint32_t, uint32_t>,
std::unique_ptr<HuffmanCodec<uint64_t>>>
id_descriptor_huffman_codecs_;
// Set of all descriptors which have a coding scheme in any of
// id_descriptor_huffman_codecs_.
std::unordered_set<uint32_t> descriptors_with_coding_scheme_;
// Huffman codecs for literal strings. The map key is the opcode of the
// current instruction. This assumes, that there is no more than one literal
// string operand per instruction, but would still work even if this is not
// the case. Names and debug information strings are not collected.
std::map<uint32_t, std::unique_ptr<HuffmanCodec<std::string>>>
literal_string_huffman_codecs_;
// Chunk lengths used for variable width encoding of operands (index is
// spv_operand_type of the operand).
std::vector<uint32_t> operand_chunk_lengths_;
uint32_t opcode_chunk_length_ = 7;
uint32_t num_operands_chunk_length_ = 3;
uint32_t mtf_rank_chunk_length_ = 5;
uint32_t u64_chunk_length_ = 8;
uint32_t s64_chunk_length_ = 8;
uint32_t s64_block_exponent_ = 10;
IdFallbackStrategy id_fallback_strategy_ =
IdFallbackStrategy::kShortDescriptor;
uint32_t model_type_ = 0;
uint32_t model_version_ = 0;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MARKV_MODEL_H_

View File

@@ -1,456 +0,0 @@
// Copyright (c) 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/move_to_front.h"
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sstream>
#include <unordered_set>
#include <utility>
namespace spvtools {
namespace comp {
bool MoveToFront::Insert(uint32_t value) {
auto it = value_to_node_.find(value);
if (it != value_to_node_.end() && IsInTree(it->second)) return false;
const uint32_t old_size = GetSize();
(void)old_size;
InsertNode(CreateNode(next_timestamp_++, value));
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
assert(value_to_node_.count(value));
assert(old_size + 1 == GetSize());
return true;
}
bool MoveToFront::Remove(uint32_t value) {
auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) return false;
if (!IsInTree(it->second)) return false;
if (last_accessed_value_ == value) last_accessed_value_valid_ = false;
const uint32_t orphan = RemoveNode(it->second);
(void)orphan;
// The node of |value| is still alive but it's orphaned now. Can still be
// reused later.
assert(!IsInTree(orphan));
assert(ValueOf(orphan) == value);
return true;
}
bool MoveToFront::RankFromValue(uint32_t value, uint32_t* rank) {
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
*rank = 1;
return true;
}
const uint32_t old_size = GetSize();
if (old_size == 1) {
if (ValueOf(root_) == value) {
*rank = 1;
return true;
} else {
return false;
}
}
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
uint32_t target = it->second;
if (!IsInTree(target)) {
return false;
}
uint32_t node = target;
*rank = 1 + SizeOf(LeftOf(node));
while (node) {
if (IsRightChild(node)) *rank += 1 + SizeOf(LeftOf(ParentOf(node)));
node = ParentOf(node);
}
// Don't update timestamp if the node has rank 1.
if (*rank != 1) {
// Update timestamp and reposition the node.
target = RemoveNode(target);
assert(ValueOf(target) == value);
assert(old_size == GetSize() + 1);
MutableTimestampOf(target) = next_timestamp_++;
InsertNode(target);
assert(old_size == GetSize());
}
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
return true;
}
bool MoveToFront::HasValue(uint32_t value) const {
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
return IsInTree(it->second);
}
bool MoveToFront::Promote(uint32_t value) {
if (last_accessed_value_valid_ && last_accessed_value_ == value) {
return true;
}
const uint32_t old_size = GetSize();
if (old_size == 1) return ValueOf(root_) == value;
const auto it = value_to_node_.find(value);
if (it == value_to_node_.end()) {
return false;
}
uint32_t target = it->second;
if (!IsInTree(target)) {
return false;
}
// Update timestamp and reposition the node.
target = RemoveNode(target);
assert(ValueOf(target) == value);
assert(old_size == GetSize() + 1);
MutableTimestampOf(target) = next_timestamp_++;
InsertNode(target);
assert(old_size == GetSize());
last_accessed_value_ = value;
last_accessed_value_valid_ = true;
return true;
}
bool MoveToFront::ValueFromRank(uint32_t rank, uint32_t* value) {
if (last_accessed_value_valid_ && rank == 1) {
*value = last_accessed_value_;
return true;
}
const uint32_t old_size = GetSize();
if (rank <= 0 || rank > old_size) {
return false;
}
if (old_size == 1) {
*value = ValueOf(root_);
return true;
}
const bool update_timestamp = (rank != 1);
uint32_t node = root_;
while (node) {
const uint32_t left_subtree_num_nodes = SizeOf(LeftOf(node));
if (rank == left_subtree_num_nodes + 1) {
// This is the node we are looking for.
// Don't update timestamp if the node has rank 1.
if (update_timestamp) {
node = RemoveNode(node);
assert(old_size == GetSize() + 1);
MutableTimestampOf(node) = next_timestamp_++;
InsertNode(node);
assert(old_size == GetSize());
}
*value = ValueOf(node);
last_accessed_value_ = *value;
last_accessed_value_valid_ = true;
return true;
}
if (rank < left_subtree_num_nodes + 1) {
// Descend into the left subtree. The rank is still valid.
node = LeftOf(node);
} else {
// Descend into the right subtree. We leave behind the left subtree and
// the current node, adjust the |rank| accordingly.
rank -= left_subtree_num_nodes + 1;
node = RightOf(node);
}
}
assert(0);
return false;
}
uint32_t MoveToFront::CreateNode(uint32_t timestamp, uint32_t value) {
uint32_t handle = static_cast<uint32_t>(nodes_.size());
const auto result = value_to_node_.emplace(value, handle);
if (result.second) {
// Create new node.
nodes_.emplace_back(Node());
Node& node = nodes_.back();
node.timestamp = timestamp;
node.value = value;
node.size = 1;
// Non-NIL nodes start with height 1 because their NIL children are
// leaves.
node.height = 1;
} else {
// Reuse old node.
handle = result.first->second;
assert(!IsInTree(handle));
assert(ValueOf(handle) == value);
assert(SizeOf(handle) == 1);
assert(HeightOf(handle) == 1);
MutableTimestampOf(handle) = timestamp;
}
return handle;
}
void MoveToFront::InsertNode(uint32_t node) {
assert(!IsInTree(node));
assert(SizeOf(node) == 1);
assert(HeightOf(node) == 1);
assert(TimestampOf(node));
if (!root_) {
root_ = node;
return;
}
uint32_t iter = root_;
uint32_t parent = 0;
// Will determine if |node| will become the right or left child after
// insertion (but before balancing).
bool right_child = true;
// Find the node which will become |node|'s parent after insertion
// (but before balancing).
while (iter) {
parent = iter;
assert(TimestampOf(iter) != TimestampOf(node));
right_child = TimestampOf(iter) > TimestampOf(node);
iter = right_child ? RightOf(iter) : LeftOf(iter);
}
assert(parent);
// Connect node and parent.
MutableParentOf(node) = parent;
if (right_child)
MutableRightOf(parent) = node;
else
MutableLeftOf(parent) = node;
// Insertion is finished. Start the balancing process.
bool needs_rebalancing = true;
parent = ParentOf(node);
while (parent) {
UpdateNode(parent);
if (needs_rebalancing) {
const int parent_balance = BalanceOf(parent);
if (RightOf(parent) == node) {
// Added node to the right subtree.
if (parent_balance > 1) {
// Parent is right heavy, rotate left.
if (BalanceOf(node) < 0) RotateRight(node);
parent = RotateLeft(parent);
} else if (parent_balance == 0 || parent_balance == -1) {
// Parent is balanced or left heavy, no need to balance further.
needs_rebalancing = false;
}
} else {
// Added node to the left subtree.
if (parent_balance < -1) {
// Parent is left heavy, rotate right.
if (BalanceOf(node) > 0) RotateLeft(node);
parent = RotateRight(parent);
} else if (parent_balance == 0 || parent_balance == 1) {
// Parent is balanced or right heavy, no need to balance further.
needs_rebalancing = false;
}
}
}
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
node = parent;
parent = ParentOf(parent);
}
}
uint32_t MoveToFront::RemoveNode(uint32_t node) {
if (LeftOf(node) && RightOf(node)) {
// If |node| has two children, then use another node as scapegoat and swap
// their contents. We pick the scapegoat on the side of the tree which has
// more nodes.
const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node))
? RightestDescendantOf(LeftOf(node))
: LeftestDescendantOf(RightOf(node));
assert(scapegoat);
std::swap(MutableValueOf(node), MutableValueOf(scapegoat));
std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat));
value_to_node_[ValueOf(node)] = node;
value_to_node_[ValueOf(scapegoat)] = scapegoat;
node = scapegoat;
}
// |node| may have only one child at this point.
assert(!RightOf(node) || !LeftOf(node));
uint32_t parent = ParentOf(node);
uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node);
// Orphan |node| and reconnect parent and child.
if (child) MutableParentOf(child) = parent;
if (parent) {
if (LeftOf(parent) == node)
MutableLeftOf(parent) = child;
else
MutableRightOf(parent) = child;
}
MutableParentOf(node) = 0;
MutableLeftOf(node) = 0;
MutableRightOf(node) = 0;
UpdateNode(node);
const uint32_t orphan = node;
if (root_ == node) root_ = child;
// Removal is finished. Start the balancing process.
bool needs_rebalancing = true;
node = child;
while (parent) {
UpdateNode(parent);
if (needs_rebalancing) {
const int parent_balance = BalanceOf(parent);
if (parent_balance == 1 || parent_balance == -1) {
// The height of the subtree was not changed.
needs_rebalancing = false;
} else {
if (RightOf(parent) == node) {
// Removed node from the right subtree.
if (parent_balance < -1) {
// Parent is left heavy, rotate right.
const uint32_t sibling = LeftOf(parent);
if (BalanceOf(sibling) > 0) RotateLeft(sibling);
parent = RotateRight(parent);
}
} else {
// Removed node from the left subtree.
if (parent_balance > 1) {
// Parent is right heavy, rotate left.
const uint32_t sibling = RightOf(parent);
if (BalanceOf(sibling) < 0) RotateRight(sibling);
parent = RotateLeft(parent);
}
}
}
}
assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
node = parent;
parent = ParentOf(parent);
}
return orphan;
}
uint32_t MoveToFront::RotateLeft(const uint32_t node) {
const uint32_t pivot = RightOf(node);
assert(pivot);
// LeftOf(pivot) gets attached to node in place of pivot.
MutableRightOf(node) = LeftOf(pivot);
if (RightOf(node)) MutableParentOf(RightOf(node)) = node;
// Pivot gets attached to ParentOf(node) in place of node.
MutableParentOf(pivot) = ParentOf(node);
if (!ParentOf(node))
root_ = pivot;
else if (IsLeftChild(node))
MutableLeftOf(ParentOf(node)) = pivot;
else
MutableRightOf(ParentOf(node)) = pivot;
// Node is child of pivot.
MutableLeftOf(pivot) = node;
MutableParentOf(node) = pivot;
// Update both node and pivot. Pivot is the new parent of node, so node should
// be updated first.
UpdateNode(node);
UpdateNode(pivot);
return pivot;
}
uint32_t MoveToFront::RotateRight(const uint32_t node) {
const uint32_t pivot = LeftOf(node);
assert(pivot);
// RightOf(pivot) gets attached to node in place of pivot.
MutableLeftOf(node) = RightOf(pivot);
if (LeftOf(node)) MutableParentOf(LeftOf(node)) = node;
// Pivot gets attached to ParentOf(node) in place of node.
MutableParentOf(pivot) = ParentOf(node);
if (!ParentOf(node))
root_ = pivot;
else if (IsLeftChild(node))
MutableLeftOf(ParentOf(node)) = pivot;
else
MutableRightOf(ParentOf(node)) = pivot;
// Node is child of pivot.
MutableRightOf(pivot) = node;
MutableParentOf(node) = pivot;
// Update both node and pivot. Pivot is the new parent of node, so node should
// be updated first.
UpdateNode(node);
UpdateNode(pivot);
return pivot;
}
void MoveToFront::UpdateNode(uint32_t node) {
MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node));
MutableHeightOf(node) =
1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node)));
}
} // namespace comp
} // namespace spvtools

View File

@@ -1,384 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_COMP_MOVE_TO_FRONT_H_
#define SOURCE_COMP_MOVE_TO_FRONT_H_
#include <cassert>
#include <cstdint>
#include <map>
#include <set>
#include <unordered_map>
#include <vector>
namespace spvtools {
namespace comp {
// Log(n) move-to-front implementation. Implements the following functions:
// Insert - pushes value to the front of the mtf sequence
// (only unique values allowed).
// Remove - remove value from the sequence.
// ValueFromRank - access value by its 1-indexed rank in the sequence.
// RankFromValue - get the rank of the given value in the sequence.
// Accessing a value with ValueFromRank or RankFromValue moves the value to the
// front of the sequence (rank of 1).
//
// The implementation is based on an AVL-based order statistic tree. The tree
// is ordered by timestamps issued when values are inserted or accessed (recent
// values go to the left side of the tree, old values are gradually rotated to
// the right side).
//
// Terminology
// rank: 1-indexed rank showing how recently the value was inserted or accessed.
// node: handle used internally to access node data.
// size: size of the subtree of a node (including the node).
// height: distance from a node to the farthest leaf.
class MoveToFront {
public:
explicit MoveToFront(size_t reserve_capacity = 4) {
nodes_.reserve(reserve_capacity);
// Create NIL node.
nodes_.emplace_back(Node());
}
virtual ~MoveToFront() = default;
// Inserts value in the move-to-front sequence. Does nothing if the value is
// already in the sequence. Returns true if insertion was successful.
// The inserted value is placed at the front of the sequence (rank 1).
bool Insert(uint32_t value);
// Removes value from move-to-front sequence. Returns false iff the value
// was not found.
bool Remove(uint32_t value);
// Computes 1-indexed rank of value in the move-to-front sequence and moves
// the value to the front. Example:
// Before the call: 4 8 2 1 7
// RankFromValue(8) returns 2
// After the call: 8 4 2 1 7
// Returns true iff the value was found in the sequence.
bool RankFromValue(uint32_t value, uint32_t* rank);
// Returns value corresponding to a 1-indexed rank in the move-to-front
// sequence and moves the value to the front. Example:
// Before the call: 4 8 2 1 7
// ValueFromRank(2) returns 8
// After the call: 8 4 2 1 7
// Returns true iff the rank is within bounds [1, GetSize()].
bool ValueFromRank(uint32_t rank, uint32_t* value);
// Moves the value to the front of the sequence.
// Returns false iff value is not in the sequence.
bool Promote(uint32_t value);
// Returns true iff the move-to-front sequence contains the value.
bool HasValue(uint32_t value) const;
// Returns the number of elements in the move-to-front sequence.
uint32_t GetSize() const { return SizeOf(root_); }
protected:
// Internal tree data structure uses handles instead of pointers. Leaves and
// root parent reference a singleton under handle 0. Although dereferencing
// a null pointer is not possible, inappropriate access to handle 0 would
// cause an assertion. Handles are not garbage collected if value was
// deprecated
// with DeprecateValue(). But handles are recycled when a node is
// repositioned.
// Internal tree data structure node.
struct Node {
// Timestamp from a logical clock which updates every time the element is
// accessed through ValueFromRank or RankFromValue.
uint32_t timestamp = 0;
// The size of the node's subtree, including the node.
// SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1.
uint32_t size = 0;
// Handles to connected nodes.
uint32_t left = 0;
uint32_t right = 0;
uint32_t parent = 0;
// Distance to the farthest leaf.
// Leaves have height 0, real nodes at least 1.
uint32_t height = 0;
// Stored value.
uint32_t value = 0;
};
// Creates node and sets correct values. Non-NIL nodes should be created only
// through this function. If the node with this value has been created
// previously
// and since orphaned, reuses the old node instead of creating a new one.
uint32_t CreateNode(uint32_t timestamp, uint32_t value);
// Node accessor methods. Naming is designed to be similar to natural
// language as these functions tend to be used in sequences, for example:
// ParentOf(LeftestDescendentOf(RightOf(node)))
// Returns value of the node referenced by |handle|.
uint32_t ValueOf(uint32_t node) const { return nodes_.at(node).value; }
// Returns left child of |node|.
uint32_t LeftOf(uint32_t node) const { return nodes_.at(node).left; }
// Returns right child of |node|.
uint32_t RightOf(uint32_t node) const { return nodes_.at(node).right; }
// Returns parent of |node|.
uint32_t ParentOf(uint32_t node) const { return nodes_.at(node).parent; }
// Returns timestamp of |node|.
uint32_t TimestampOf(uint32_t node) const {
assert(node);
return nodes_.at(node).timestamp;
}
// Returns size of |node|.
uint32_t SizeOf(uint32_t node) const { return nodes_.at(node).size; }
// Returns height of |node|.
uint32_t HeightOf(uint32_t node) const { return nodes_.at(node).height; }
// Returns mutable reference to value of |node|.
uint32_t& MutableValueOf(uint32_t node) {
assert(node);
return nodes_.at(node).value;
}
// Returns mutable reference to handle of left child of |node|.
uint32_t& MutableLeftOf(uint32_t node) {
assert(node);
return nodes_.at(node).left;
}
// Returns mutable reference to handle of right child of |node|.
uint32_t& MutableRightOf(uint32_t node) {
assert(node);
return nodes_.at(node).right;
}
// Returns mutable reference to handle of parent of |node|.
uint32_t& MutableParentOf(uint32_t node) {
assert(node);
return nodes_.at(node).parent;
}
// Returns mutable reference to timestamp of |node|.
uint32_t& MutableTimestampOf(uint32_t node) {
assert(node);
return nodes_.at(node).timestamp;
}
// Returns mutable reference to size of |node|.
uint32_t& MutableSizeOf(uint32_t node) {
assert(node);
return nodes_.at(node).size;
}
// Returns mutable reference to height of |node|.
uint32_t& MutableHeightOf(uint32_t node) {
assert(node);
return nodes_.at(node).height;
}
// Returns true iff |node| is left child of its parent.
bool IsLeftChild(uint32_t node) const {
assert(node);
return LeftOf(ParentOf(node)) == node;
}
// Returns true iff |node| is right child of its parent.
bool IsRightChild(uint32_t node) const {
assert(node);
return RightOf(ParentOf(node)) == node;
}
// Returns true iff |node| has no relatives.
bool IsOrphan(uint32_t node) const {
assert(node);
return !ParentOf(node) && !LeftOf(node) && !RightOf(node);
}
// Returns true iff |node| is in the tree.
bool IsInTree(uint32_t node) const {
assert(node);
return node == root_ || !IsOrphan(node);
}
// Returns the height difference between right and left subtrees.
int BalanceOf(uint32_t node) const {
return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node)));
}
// Updates size and height of the node, assuming that the children have
// correct values.
void UpdateNode(uint32_t node);
// Returns the most LeftOf(LeftOf(... descendent which is not leaf.
uint32_t LeftestDescendantOf(uint32_t node) const {
uint32_t parent = 0;
while (node) {
parent = node;
node = LeftOf(node);
}
return parent;
}
// Returns the most RightOf(RightOf(... descendent which is not leaf.
uint32_t RightestDescendantOf(uint32_t node) const {
uint32_t parent = 0;
while (node) {
parent = node;
node = RightOf(node);
}
return parent;
}
// Inserts node in the tree. The node must be an orphan.
void InsertNode(uint32_t node);
// Removes node from the tree. May change value_to_node_ if removal uses a
// scapegoat. Returns the removed (orphaned) handle for recycling. The
// returned handle may not be equal to |node| if scapegoat was used.
uint32_t RemoveNode(uint32_t node);
// Rotates |node| left, reassigns all connections and returns the node
// which takes place of the |node|.
uint32_t RotateLeft(const uint32_t node);
// Rotates |node| right, reassigns all connections and returns the node
// which takes place of the |node|.
uint32_t RotateRight(const uint32_t node);
// Root node handle. The tree is empty if root_ is 0.
uint32_t root_ = 0;
// Incremented counters for next timestamp and value.
uint32_t next_timestamp_ = 1;
// Holds all tree nodes. Indices of this vector are node handles.
std::vector<Node> nodes_;
// Maps ids to node handles.
std::unordered_map<uint32_t, uint32_t> value_to_node_;
// Cache for the last accessed value in the sequence.
uint32_t last_accessed_value_ = 0;
bool last_accessed_value_valid_ = false;
};
class MultiMoveToFront {
public:
// Inserts |value| to sequence with handle |mtf|.
// Returns false if |mtf| already has |value|.
bool Insert(uint64_t mtf, uint32_t value) {
if (GetMtf(mtf).Insert(value)) {
val_to_mtfs_[value].insert(mtf);
return true;
}
return false;
}
// Removes |value| from sequence with handle |mtf|.
// Returns false if |mtf| doesn't have |value|.
bool Remove(uint64_t mtf, uint32_t value) {
if (GetMtf(mtf).Remove(value)) {
val_to_mtfs_[value].erase(mtf);
return true;
}
assert(val_to_mtfs_[value].count(mtf) == 0);
return false;
}
// Removes |value| from all sequences which have it.
void RemoveFromAll(uint32_t value) {
auto it = val_to_mtfs_.find(value);
if (it == val_to_mtfs_.end()) return;
auto& mtfs_containing_value = it->second;
for (uint64_t mtf : mtfs_containing_value) {
GetMtf(mtf).Remove(value);
}
val_to_mtfs_.erase(value);
}
// Computes rank of |value| in sequence |mtf|.
// Returns false if |mtf| doesn't have |value|.
bool RankFromValue(uint64_t mtf, uint32_t value, uint32_t* rank) {
return GetMtf(mtf).RankFromValue(value, rank);
}
// Finds |value| with |rank| in sequence |mtf|.
// Returns false if |rank| is out of bounds.
bool ValueFromRank(uint64_t mtf, uint32_t rank, uint32_t* value) {
return GetMtf(mtf).ValueFromRank(rank, value);
}
// Returns size of |mtf| sequence.
uint32_t GetSize(uint64_t mtf) { return GetMtf(mtf).GetSize(); }
// Promotes |value| in all sequences which have it.
void Promote(uint32_t value) {
const auto it = val_to_mtfs_.find(value);
if (it == val_to_mtfs_.end()) return;
const auto& mtfs_containing_value = it->second;
for (uint64_t mtf : mtfs_containing_value) {
GetMtf(mtf).Promote(value);
}
}
// Inserts |value| in sequence |mtf| or promotes if it's already there.
void InsertOrPromote(uint64_t mtf, uint32_t value) {
if (!Insert(mtf, value)) {
GetMtf(mtf).Promote(value);
}
}
// Returns if |mtf| sequence has |value|.
bool HasValue(uint64_t mtf, uint32_t value) {
return GetMtf(mtf).HasValue(value);
}
private:
// Returns actual MoveToFront object corresponding to |handle|.
// As multiple operations are often performed consecutively for the same
// sequence, the last returned value is cached.
MoveToFront& GetMtf(uint64_t handle) {
if (!cached_mtf_ || cached_handle_ != handle) {
cached_handle_ = handle;
cached_mtf_ = &mtfs_[handle];
}
return *cached_mtf_;
}
// Container holding MoveToFront objects. Map key is sequence handle.
std::map<uint64_t, MoveToFront> mtfs_;
// Container mapping value to sequences which contain that value.
std::unordered_map<uint32_t, std::set<uint64_t>> val_to_mtfs_;
// Cache for the last accessed sequence.
uint64_t cached_handle_ = 0;
MoveToFront* cached_mtf_ = nullptr;
};
} // namespace comp
} // namespace spvtools
#endif // SOURCE_COMP_MOVE_TO_FRONT_H_

View File

@@ -0,0 +1,120 @@
# Copyright (c) 2019 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if(SPIRV_BUILD_FUZZER)
set(PROTOBUF_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/protobufs/spvtoolsfuzz.proto)
add_custom_command(
OUTPUT protobufs/spvtoolsfuzz.pb.cc protobufs/spvtoolsfuzz.pb.h
COMMAND protobuf::protoc
-I=${CMAKE_CURRENT_SOURCE_DIR}/protobufs
--cpp_out=protobufs
${PROTOBUF_SOURCE}
DEPENDS ${PROTOBUF_SOURCE}
COMMENT "Generate protobuf sources from proto definition file."
)
set(SPIRV_TOOLS_FUZZ_SOURCES
fact_manager.h
fuzzer.h
fuzzer_context.h
fuzzer_pass.h
fuzzer_pass_add_useful_constructs.h
fuzzer_pass_add_dead_breaks.h
fuzzer_pass_permute_blocks.h
fuzzer_pass_split_blocks.h
fuzzer_util.h
id_use_descriptor.h
protobufs/spirvfuzz_protobufs.h
pseudo_random_generator.h
random_generator.h
transformation_add_constant_boolean.h
transformation_add_constant_scalar.h
transformation_add_dead_break.h
transformation_add_type_boolean.h
transformation_add_type_float.h
transformation_add_type_int.h
transformation_move_block_down.h
transformation_replace_boolean_constant_with_constant_binary.h
transformation_split_block.h
${CMAKE_CURRENT_BINARY_DIR}/protobufs/spvtoolsfuzz.pb.h
fact_manager.cpp
fuzzer.cpp
fuzzer_context.cpp
fuzzer_pass.cpp
fuzzer_pass_add_dead_breaks.cpp
fuzzer_pass_add_useful_constructs.cpp
fuzzer_pass_permute_blocks.cpp
fuzzer_pass_split_blocks.cpp
fuzzer_util.cpp
id_use_descriptor.cpp
pseudo_random_generator.cpp
random_generator.cpp
transformation_add_constant_boolean.cpp
transformation_add_constant_scalar.cpp
transformation_add_dead_break.cpp
transformation_add_type_boolean.cpp
transformation_add_type_float.cpp
transformation_add_type_int.cpp
transformation_move_block_down.cpp
transformation_replace_boolean_constant_with_constant_binary.cpp
transformation_split_block.cpp
${CMAKE_CURRENT_BINARY_DIR}/protobufs/spvtoolsfuzz.pb.cc
)
if(MSVC)
# Enable parallel builds across four cores for this lib
add_definitions(/MP4)
endif()
spvtools_pch(SPIRV_TOOLS_FUZZ_SOURCES pch_source_fuzz)
add_library(SPIRV-Tools-fuzz ${SPIRV_TOOLS_FUZZ_SOURCES})
spvtools_default_compile_options(SPIRV-Tools-fuzz)
target_compile_definitions(SPIRV-Tools-fuzz PUBLIC -DGOOGLE_PROTOBUF_NO_RTTI -DGOOGLE_PROTOBUF_USE_UNALIGNED=0)
# Compilation of the auto-generated protobuf source file will yield warnings,
# which we have no control over and thus wish to ignore.
if(${COMPILER_IS_LIKE_GNU})
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/protobufs/spvtoolsfuzz.pb.cc PROPERTIES COMPILE_FLAGS -w)
endif()
if(MSVC)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/protobufs/spvtoolsfuzz.pb.cc PROPERTIES COMPILE_FLAGS /w)
endif()
target_include_directories(SPIRV-Tools-fuzz
PUBLIC ${spirv-tools_SOURCE_DIR}/include
PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
PRIVATE ${spirv-tools_BINARY_DIR}
PRIVATE ${CMAKE_BINARY_DIR})
# The fuzzer reuses a lot of functionality from the SPIRV-Tools library.
target_link_libraries(SPIRV-Tools-fuzz
PUBLIC ${SPIRV_TOOLS}
PUBLIC SPIRV-Tools-opt
PUBLIC protobuf::libprotobuf)
set_property(TARGET SPIRV-Tools-fuzz PROPERTY FOLDER "SPIRV-Tools libraries")
spvtools_check_symbol_exports(SPIRV-Tools-fuzz)
if(ENABLE_SPIRV_TOOLS_INSTALL)
install(TARGETS SPIRV-Tools-fuzz
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif(ENABLE_SPIRV_TOOLS_INSTALL)
endif(SPIRV_BUILD_FUZZER)

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "source/fuzz/fact_manager.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
FactManager::FactManager() = default;
FactManager::~FactManager() = default;
bool FactManager::AddFacts(const protobufs::FactSequence& initial_facts,
opt::IRContext* context) {
for (auto& fact : initial_facts.fact()) {
if (!AddFact(fact, context)) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2621) Provide
// information about the fact that could not be added.
return false;
}
}
return true;
}
bool FactManager::AddFact(const spvtools::fuzz::protobufs::Fact&,
spvtools::opt::IRContext*) {
assert(0 && "No facts are yet supported.");
return true;
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FACT_MANAGER_H_
#define SOURCE_FUZZ_FACT_MANAGER_H_
#include <utility>
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/constants.h"
namespace spvtools {
namespace fuzz {
// Keeps track of facts about the module being transformed on which the fuzzing
// process can depend. Some initial facts can be provided, for example about
// guarantees on the values of inputs to SPIR-V entry points. Transformations
// may then rely on these facts, can add further facts that they establish.
// Facts are intended to be simple properties that either cannot be deduced from
// the module (such as properties that are guaranteed to hold for entry point
// inputs), or that are established by transformations, likely to be useful for
// future transformations, and not completely trivial to deduce straight from
// the module.
class FactManager {
public:
FactManager();
~FactManager();
// Adds all the facts from |facts|, checking them for validity with respect to
// |context|. Returns true if and only if all facts are valid.
bool AddFacts(const protobufs::FactSequence& facts, opt::IRContext* context);
// Adds |fact| to the fact manager, checking it for validity with respect to
// |context|. Returns true if and only if the fact is valid.
bool AddFact(const protobufs::Fact& fact, opt::IRContext* context);
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FACT_MANAGER_H_

View File

@@ -0,0 +1,131 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer.h"
#include <cassert>
#include <sstream>
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/fuzzer_context.h"
#include "source/fuzz/fuzzer_pass_add_dead_breaks.h"
#include "source/fuzz/fuzzer_pass_add_useful_constructs.h"
#include "source/fuzz/fuzzer_pass_permute_blocks.h"
#include "source/fuzz/fuzzer_pass_split_blocks.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/fuzz/pseudo_random_generator.h"
#include "source/opt/build_module.h"
#include "source/spirv_fuzzer_options.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace fuzz {
namespace {
const uint32_t kIdBoundGap = 100;
}
struct Fuzzer::Impl {
explicit Impl(spv_target_env env) : target_env(env) {}
const spv_target_env target_env; // Target environment.
MessageConsumer consumer; // Message consumer.
};
Fuzzer::Fuzzer(spv_target_env env) : impl_(MakeUnique<Impl>(env)) {}
Fuzzer::~Fuzzer() = default;
void Fuzzer::SetMessageConsumer(MessageConsumer c) {
impl_->consumer = std::move(c);
}
Fuzzer::FuzzerResultStatus Fuzzer::Run(
const std::vector<uint32_t>& binary_in,
const protobufs::FactSequence& initial_facts,
std::vector<uint32_t>* binary_out,
protobufs::TransformationSequence* transformation_sequence_out,
spv_const_fuzzer_options options) const {
// Check compatibility between the library version being linked with and the
// header files being used.
GOOGLE_PROTOBUF_VERIFY_VERSION;
spvtools::SpirvTools tools(impl_->target_env);
if (!tools.IsValid()) {
impl_->consumer(SPV_MSG_ERROR, nullptr, {},
"Failed to create SPIRV-Tools interface; stopping.");
return Fuzzer::FuzzerResultStatus::kFailedToCreateSpirvToolsInterface;
}
// Initial binary should be valid.
if (!tools.Validate(&binary_in[0], binary_in.size())) {
impl_->consumer(SPV_MSG_ERROR, nullptr, {},
"Initial binary is invalid; stopping.");
return Fuzzer::FuzzerResultStatus::kInitialBinaryInvalid;
}
// Build the module from the input binary.
std::unique_ptr<opt::IRContext> ir_context = BuildModule(
impl_->target_env, impl_->consumer, binary_in.data(), binary_in.size());
assert(ir_context);
// Make a PRNG, either from a given seed or from a random device.
PseudoRandomGenerator random_generator(
options->has_random_seed ? options->random_seed
: static_cast<uint32_t>(std::random_device()()));
// The fuzzer will introduce new ids into the module. The module's id bound
// gives the smallest id that can be used for this purpose. We add an offset
// to this so that there is a sizeable gap between the ids used in the
// original module and the ids used for fuzzing, as a readability aid.
//
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
// case where the maximum id bound is reached.
auto minimum_fresh_id = ir_context->module()->id_bound() + kIdBoundGap;
FuzzerContext fuzzer_context(&random_generator, minimum_fresh_id);
FactManager fact_manager;
if (!fact_manager.AddFacts(initial_facts, ir_context.get())) {
return Fuzzer::FuzzerResultStatus::kInitialFactsInvalid;
}
// Add some essential ingredients to the module if they are not already
// present, such as boolean constants.
FuzzerPassAddUsefulConstructs(ir_context.get(), &fact_manager,
&fuzzer_context, transformation_sequence_out)
.Apply();
// Apply some semantics-preserving passes.
FuzzerPassSplitBlocks(ir_context.get(), &fact_manager, &fuzzer_context,
transformation_sequence_out)
.Apply();
FuzzerPassAddDeadBreaks(ir_context.get(), &fact_manager, &fuzzer_context,
transformation_sequence_out)
.Apply();
// TODO(afd) Various other passes will be added.
// Finally, give the blocks in the module a good shake-up.
FuzzerPassPermuteBlocks(ir_context.get(), &fact_manager, &fuzzer_context,
transformation_sequence_out)
.Apply();
// Encode the module as a binary.
ir_context->module()->ToBinary(binary_out, false);
return Fuzzer::FuzzerResultStatus::kComplete;
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,74 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_H_
#define SOURCE_FUZZ_FUZZER_H_
#include <memory>
#include <vector>
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace fuzz {
// Transforms a SPIR-V module into a semantically equivalent SPIR-V module by
// running a number of randomized fuzzer passes.
class Fuzzer {
public:
// Possible statuses that can result from running the fuzzer.
enum class FuzzerResultStatus {
kComplete,
kFailedToCreateSpirvToolsInterface,
kInitialBinaryInvalid,
kInitialFactsInvalid,
};
// Constructs a fuzzer from the given target environment.
explicit Fuzzer(spv_target_env env);
// Disables copy/move constructor/assignment operations.
Fuzzer(const Fuzzer&) = delete;
Fuzzer(Fuzzer&&) = delete;
Fuzzer& operator=(const Fuzzer&) = delete;
Fuzzer& operator=(Fuzzer&&) = delete;
~Fuzzer();
// Sets the message consumer to the given |consumer|. The |consumer| will be
// invoked once for each message communicated from the library.
void SetMessageConsumer(MessageConsumer consumer);
// Transforms |binary_in| to |binary_out| by running a number of randomized
// fuzzer passes, controlled via |options|. Initial facts about the input
// binary and the context in which it will execute are provided via
// |initial_facts|. The transformation sequence that was applied is returned
// via |transformation_sequence_out|.
FuzzerResultStatus Run(
const std::vector<uint32_t>& binary_in,
const protobufs::FactSequence& initial_facts,
std::vector<uint32_t>* binary_out,
protobufs::TransformationSequence* transformation_sequence_out,
spv_const_fuzzer_options options) const;
private:
struct Impl; // Opaque struct for holding internal data.
std::unique_ptr<Impl> impl_; // Unique pointer to internal data.
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_FUZZER_H_

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_context.h"
namespace spvtools {
namespace fuzz {
namespace {
// Default probabilities for applying various transformations.
// All values are percentages.
// Keep them in alphabetical order.
const uint32_t kDefaultChanceOfAddingDeadBreak = 20;
const uint32_t kDefaultChanceOfMovingBlockDown = 25;
const uint32_t kDefaultChanceOfSplittingBlock = 20;
} // namespace
FuzzerContext::FuzzerContext(RandomGenerator* random_generator,
uint32_t min_fresh_id)
: random_generator_(random_generator),
next_fresh_id_(min_fresh_id),
chance_of_adding_dead_break_(kDefaultChanceOfAddingDeadBreak),
chance_of_moving_block_down_(kDefaultChanceOfMovingBlockDown),
chance_of_splitting_block_(kDefaultChanceOfSplittingBlock) {}
FuzzerContext::~FuzzerContext() = default;
uint32_t FuzzerContext::GetFreshId() { return next_fresh_id_++; }
RandomGenerator* FuzzerContext::GetRandomGenerator() {
return random_generator_;
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_CONTEXT_H_
#define SOURCE_FUZZ_FUZZER_CONTEXT_H_
#include "source/fuzz/random_generator.h"
#include "source/opt/function.h"
namespace spvtools {
namespace fuzz {
// Encapsulates all parameters that control the fuzzing process, such as the
// source of randomness and the probabilities with which transformations are
// applied.
class FuzzerContext {
public:
// Constructs a fuzzer context with a given random generator and the minimum
// value that can be used for fresh ids.
FuzzerContext(RandomGenerator* random_generator, uint32_t min_fresh_id);
~FuzzerContext();
// Provides the random generator used to control fuzzing.
RandomGenerator* GetRandomGenerator();
// Yields an id that is guaranteed not to be used in the module being fuzzed,
// or to have been issued before.
uint32_t GetFreshId();
// Probabilities associated with applying various transformations.
// Keep them in alphabetical order.
uint32_t GetChanceOfAddingDeadBreak() { return chance_of_adding_dead_break_; }
uint32_t GetChanceOfMovingBlockDown() { return chance_of_moving_block_down_; }
uint32_t GetChanceOfSplittingBlock() { return chance_of_splitting_block_; }
private:
// The source of randomness.
RandomGenerator* random_generator_;
// The next fresh id to be issued.
uint32_t next_fresh_id_;
// Probabilities associated with applying various transformations.
// Keep them in alphabetical order.
uint32_t chance_of_adding_dead_break_;
uint32_t chance_of_moving_block_down_;
uint32_t chance_of_splitting_block_;
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_FUZZER_CONTEXT_H_

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
FuzzerPass::FuzzerPass(opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: ir_context_(ir_context),
fact_manager_(fact_manager),
fuzzer_context_(fuzzer_context),
transformations_(transformations) {}
FuzzerPass::~FuzzerPass() = default;
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_H_
#define SOURCE_FUZZ_FUZZER_PASS_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/fuzzer_context.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
namespace spvtools {
namespace fuzz {
// Interface for applying a pass of transformations to a module.
class FuzzerPass {
public:
FuzzerPass(opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
virtual ~FuzzerPass();
// Applies the pass to the module |ir_context_|, assuming and updating
// facts from |fact_manager_|, and using |fuzzer_context_| to guide the
// process. Appends to |transformations_| all transformations that were
// applied during the pass.
virtual void Apply() = 0;
protected:
opt::IRContext* GetIRContext() const { return ir_context_; }
FactManager* GetFactManager() const { return fact_manager_; }
FuzzerContext* GetFuzzerContext() const { return fuzzer_context_; }
protobufs::TransformationSequence* GetTransformations() const {
return transformations_;
}
private:
opt::IRContext* ir_context_;
FactManager* fact_manager_;
FuzzerContext* fuzzer_context_;
protobufs::TransformationSequence* transformations_;
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FUZZER_PASS_H_

View File

@@ -0,0 +1,105 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass_add_dead_breaks.h"
#include "source/fuzz/transformation_add_dead_break.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
FuzzerPassAddDeadBreaks::FuzzerPassAddDeadBreaks(
opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {}
FuzzerPassAddDeadBreaks::~FuzzerPassAddDeadBreaks() = default;
void FuzzerPassAddDeadBreaks::Apply() {
// We first collect up lots of possibly-applicable transformations.
std::vector<protobufs::TransformationAddDeadBreak> candidate_transformations;
// We consider each function separately.
for (auto& function : *GetIRContext()->module()) {
// For a given function, we find all the merge blocks in that function.
std::vector<uint32_t> merge_block_ids;
for (auto& block : function) {
auto maybe_merge_id = block.MergeBlockIdIfAny();
if (maybe_merge_id) {
merge_block_ids.push_back(maybe_merge_id);
}
}
// We rather aggressively consider the possibility of adding a break from
// every block in the function to every merge block. Many of these will be
// inapplicable as they would be illegal. That's OK - we later discard the
// ones that turn out to be no good.
for (auto& block : function) {
for (auto merge_block_id : merge_block_ids) {
// TODO(afd): right now we completely ignore OpPhi instructions at
// merge blocks. This will lead to interesting opportunities being
// missed.
std::vector<uint32_t> phi_ids;
auto candidate_transformation =
transformation::MakeTransformationAddDeadBreak(
block.id(), merge_block_id,
GetFuzzerContext()->GetRandomGenerator()->RandomBool(),
std::move(phi_ids));
if (transformation::IsApplicable(candidate_transformation,
GetIRContext(), *GetFactManager())) {
// Only consider a transformation as a candidate if it is applicable.
candidate_transformations.push_back(
std::move(candidate_transformation));
}
}
}
}
// Go through the candidate transformations that were accumulated,
// probabilistically deciding whether to consider each one further and
// applying the still-applicable ones that are considered further.
//
// We iterate through the candidate transformations in a random order by
// repeatedly removing a random candidate transformation from the sequence
// until no candidate transformations remain. This is done because
// transformations can potentially disable one another, so that iterating
// through them in order would lead to a higher probability of
// transformations appearing early in the sequence being applied compared
// with later transformations.
while (!candidate_transformations.empty()) {
// Choose a random index into the sequence of remaining candidate
// transformations.
auto index = GetFuzzerContext()->GetRandomGenerator()->RandomUint32(
static_cast<uint32_t>(candidate_transformations.size()));
// Remove the transformation at the chosen index from the sequence.
auto transformation = std::move(candidate_transformations[index]);
candidate_transformations.erase(candidate_transformations.begin() + index);
// Probabilistically decide whether to try to apply it vs. ignore it.
if (GetFuzzerContext()->GetRandomGenerator()->RandomPercentage() >
GetFuzzerContext()->GetChanceOfAddingDeadBreak()) {
continue;
}
// If the transformation can be applied, apply it and add it to the
// sequence of transformations that have been applied.
if (transformation::IsApplicable(transformation, GetIRContext(),
*GetFactManager())) {
transformation::Apply(transformation, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_dead_break() =
transformation;
}
}
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_ADD_DEAD_BREAKS_H_
#define SOURCE_FUZZ_FUZZER_PASS_ADD_DEAD_BREAKS_H_
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
// A fuzzer pass for adding dead break edges to the module.
class FuzzerPassAddDeadBreaks : public FuzzerPass {
public:
FuzzerPassAddDeadBreaks(opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
~FuzzerPassAddDeadBreaks();
void Apply() override;
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FUZZER_PASS_ADD_DEAD_BREAKS_H_

View File

@@ -0,0 +1,182 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass_add_useful_constructs.h"
#include "source/fuzz/transformation_add_constant_boolean.h"
#include "source/fuzz/transformation_add_constant_scalar.h"
#include "source/fuzz/transformation_add_type_boolean.h"
#include "source/fuzz/transformation_add_type_float.h"
#include "source/fuzz/transformation_add_type_int.h"
namespace spvtools {
namespace fuzz {
using opt::IRContext;
FuzzerPassAddUsefulConstructs::FuzzerPassAddUsefulConstructs(
opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations){};
FuzzerPassAddUsefulConstructs::~FuzzerPassAddUsefulConstructs() = default;
void FuzzerPassAddUsefulConstructs::MaybeAddIntConstant(
uint32_t width, bool is_signed, std::vector<uint32_t> data) const {
opt::analysis::Integer temp_int_type(width, is_signed);
assert(GetIRContext()->get_type_mgr()->GetId(&temp_int_type) &&
"int type should already be registered.");
auto registered_int_type = GetIRContext()
->get_type_mgr()
->GetRegisteredType(&temp_int_type)
->AsInteger();
auto int_type_id = GetIRContext()->get_type_mgr()->GetId(registered_int_type);
assert(int_type_id &&
"The relevant int type should have been added to the module already.");
opt::analysis::IntConstant int_constant(registered_int_type, data);
if (!GetIRContext()->get_constant_mgr()->FindConstant(&int_constant)) {
protobufs::TransformationAddConstantScalar add_constant_int =
transformation::MakeTransformationAddConstantScalar(
GetFuzzerContext()->GetFreshId(), int_type_id, data);
assert(transformation::IsApplicable(add_constant_int, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_constant_int, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_constant_scalar() =
add_constant_int;
}
}
void FuzzerPassAddUsefulConstructs::MaybeAddFloatConstant(
uint32_t width, std::vector<uint32_t> data) const {
opt::analysis::Float temp_float_type(width);
assert(GetIRContext()->get_type_mgr()->GetId(&temp_float_type) &&
"float type should already be registered.");
auto registered_float_type = GetIRContext()
->get_type_mgr()
->GetRegisteredType(&temp_float_type)
->AsFloat();
auto float_type_id =
GetIRContext()->get_type_mgr()->GetId(registered_float_type);
assert(
float_type_id &&
"The relevant float type should have been added to the module already.");
opt::analysis::FloatConstant float_constant(registered_float_type, data);
if (!GetIRContext()->get_constant_mgr()->FindConstant(&float_constant)) {
protobufs::TransformationAddConstantScalar add_constant_float =
transformation::MakeTransformationAddConstantScalar(
GetFuzzerContext()->GetFreshId(), float_type_id, data);
assert(transformation::IsApplicable(add_constant_float, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_constant_float, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_constant_scalar() =
add_constant_float;
}
}
void FuzzerPassAddUsefulConstructs::Apply() {
{
// Add boolean type if not present.
opt::analysis::Bool temp_bool_type;
if (!GetIRContext()->get_type_mgr()->GetId(&temp_bool_type)) {
protobufs::TransformationAddTypeBoolean add_type_boolean =
transformation::MakeTransformationAddTypeBoolean(
GetFuzzerContext()->GetFreshId());
assert(transformation::IsApplicable(add_type_boolean, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_type_boolean, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_type_boolean() =
add_type_boolean;
}
}
{
// Add signed and unsigned 32-bit integer types if not present.
for (auto is_signed : {true, false}) {
opt::analysis::Integer temp_int_type(32, is_signed);
if (!GetIRContext()->get_type_mgr()->GetId(&temp_int_type)) {
protobufs::TransformationAddTypeInt add_type_int =
transformation::MakeTransformationAddTypeInt(
GetFuzzerContext()->GetFreshId(), 32, is_signed);
assert(transformation::IsApplicable(add_type_int, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_type_int, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_type_int() =
add_type_int;
}
}
}
{
// Add 32-bit float type if not present.
opt::analysis::Float temp_float_type(32);
if (!GetIRContext()->get_type_mgr()->GetId(&temp_float_type)) {
protobufs::TransformationAddTypeFloat add_type_float =
transformation::MakeTransformationAddTypeFloat(
GetFuzzerContext()->GetFreshId(), 32);
assert(transformation::IsApplicable(add_type_float, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_type_float, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_add_type_float() =
add_type_float;
}
}
// Add boolean constants true and false if not present.
opt::analysis::Bool temp_bool_type;
auto bool_type = GetIRContext()
->get_type_mgr()
->GetRegisteredType(&temp_bool_type)
->AsBool();
for (auto boolean_value : {true, false}) {
// Add OpConstantTrue/False if not already there.
opt::analysis::BoolConstant bool_constant(bool_type, boolean_value);
if (!GetIRContext()->get_constant_mgr()->FindConstant(&bool_constant)) {
protobufs::TransformationAddConstantBoolean add_constant_boolean =
transformation::MakeTransformationAddConstantBoolean(
GetFuzzerContext()->GetFreshId(), boolean_value);
assert(transformation::IsApplicable(add_constant_boolean, GetIRContext(),
*GetFactManager()) &&
"Should be applicable by construction.");
transformation::Apply(add_constant_boolean, GetIRContext(),
GetFactManager());
*GetTransformations()
->add_transformation()
->mutable_add_constant_boolean() = add_constant_boolean;
}
}
// Add signed and unsigned 32-bit integer constants 0 and 1 if not present.
for (auto is_signed : {true, false}) {
for (auto value : {0u, 1u}) {
MaybeAddIntConstant(32, is_signed, {value});
}
}
// Add 32-bit float constants 0.0 and 1.0 if not present.
uint32_t uint_data[2];
float float_data[2] = {0.0, 1.0};
memcpy(uint_data, float_data, sizeof(float_data));
for (unsigned int& datum : uint_data) {
MaybeAddFloatConstant(32, {datum});
}
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_
#define SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
// An initial pass for adding useful ingredients to the module, such as boolean
// constants, if they are not present.
class FuzzerPassAddUsefulConstructs : public FuzzerPass {
public:
FuzzerPassAddUsefulConstructs(
opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
~FuzzerPassAddUsefulConstructs() override;
void Apply() override;
private:
void MaybeAddIntConstant(uint32_t width, bool is_signed,
std::vector<uint32_t> data) const;
void MaybeAddFloatConstant(uint32_t width, std::vector<uint32_t> data) const;
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FUZZER_PASS_ADD_USEFUL_CONSTRUCTS_

View File

@@ -0,0 +1,85 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass_permute_blocks.h"
#include "source/fuzz/transformation_move_block_down.h"
namespace spvtools {
namespace fuzz {
FuzzerPassPermuteBlocks::FuzzerPassPermuteBlocks(
opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {}
FuzzerPassPermuteBlocks::~FuzzerPassPermuteBlocks() = default;
void FuzzerPassPermuteBlocks::Apply() {
// For now we do something very simple: we randomly decide whether to move a
// block, and for each block that we do move, we push it down as far as we
// legally can.
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2635): it would be
// nice to randomly sample from the set of legal block permutations and then
// encode the chosen permutation via a series of move-block-down
// transformations. This should be possible but will require some thought.
for (auto& function : *GetIRContext()->module()) {
std::vector<uint32_t> block_ids;
// Collect all block ids for the function before messing with block
// ordering.
for (auto& block : function) {
block_ids.push_back(block.id());
}
// Now consider each block id. We consider block ids in reverse, because
// e.g. in code generated from the following:
//
// if (...) {
// A
// B
// } else {
// C
// }
//
// block A cannot be moved down, but B has freedom to move and that movement
// would provide more freedom for A to move.
for (auto id = block_ids.rbegin(); id != block_ids.rend(); ++id) {
// Randomly decide whether to ignore the block id.
if (GetFuzzerContext()->GetRandomGenerator()->RandomPercentage() >
GetFuzzerContext()->GetChanceOfMovingBlockDown()) {
continue;
}
// Keep pushing the block down, until pushing down fails.
// The loop is guaranteed to terminate because a block cannot be pushed
// down indefinitely.
while (true) {
protobufs::TransformationMoveBlockDown message;
message.set_block_id(*id);
if (transformation::IsApplicable(message, GetIRContext(),
*GetFactManager())) {
transformation::Apply(message, GetIRContext(), GetFactManager());
*GetTransformations()
->add_transformation()
->mutable_move_block_down() = message;
} else {
break;
}
}
}
}
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_PERMUTE_BLOCKS_
#define SOURCE_FUZZ_FUZZER_PASS_PERMUTE_BLOCKS_
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
// A fuzzer pass for shuffling the blocks of the module in a validity-preserving
// manner.
class FuzzerPassPermuteBlocks : public FuzzerPass {
public:
FuzzerPassPermuteBlocks(opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
~FuzzerPassPermuteBlocks() override;
void Apply() override;
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FUZZER_PASS_PERMUTE_BLOCKS_

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_pass_split_blocks.h"
#include <utility>
#include <vector>
#include "source/fuzz/transformation_split_block.h"
namespace spvtools {
namespace fuzz {
FuzzerPassSplitBlocks::FuzzerPassSplitBlocks(
opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations)
: FuzzerPass(ir_context, fact_manager, fuzzer_context, transformations) {}
FuzzerPassSplitBlocks::~FuzzerPassSplitBlocks() = default;
void FuzzerPassSplitBlocks::Apply() {
// Gather up pointers to all the blocks in the module. We are then able to
// iterate over these pointers and split the blocks to which they point;
// we cannot safely split blocks while we iterate through the module.
std::vector<opt::BasicBlock*> blocks;
for (auto& function : *GetIRContext()->module()) {
for (auto& block : function) {
blocks.push_back(&block);
}
}
// Now go through all the block pointers that were gathered.
for (auto& block : blocks) {
// Probabilistically decide whether to try to split this block.
if (GetFuzzerContext()->GetRandomGenerator()->RandomPercentage() >
GetFuzzerContext()->GetChanceOfSplittingBlock()) {
continue;
}
// We are going to try to split this block. We now need to choose where
// to split it. We do this by finding a base instruction that has a
// result id, and an offset from that base instruction. We would like
// offsets to be as small as possible and ideally 0 - we only need offsets
// because not all instructions can be identified by a result id (e.g.
// OpStore instructions cannot).
std::vector<std::pair<uint32_t, uint32_t>> base_offset_pairs;
// The initial base instruction is the block label.
uint32_t base = block->id();
uint32_t offset = 0;
// Consider every instruction in the block. The label is excluded: it is
// only necessary to consider it as a base in case the first instruction
// in the block does not have a result id.
for (auto& inst : *block) {
if (inst.HasResultId()) {
// In the case that the instruction has a result id, we use the
// instruction as its own base, with zero offset.
base = inst.result_id();
offset = 0;
} else {
// The instruction does not have a result id, so we need to identify
// it via the latest instruction that did have a result id (base), and
// an incremented offset.
offset++;
}
base_offset_pairs.emplace_back(base, offset);
}
// Having identified all the places we might be able to split the block,
// we choose one of them.
auto base_offset = base_offset_pairs
[GetFuzzerContext()->GetRandomGenerator()->RandomUint32(
static_cast<uint32_t>(base_offset_pairs.size()))];
auto message = transformation::MakeTransformationSplitBlock(
base_offset.first, base_offset.second,
GetFuzzerContext()->GetFreshId());
// If the position we have chosen turns out to be a valid place to split
// the block, we apply the split. Otherwise the block just doesn't get
// split.
if (transformation::IsApplicable(message, GetIRContext(),
*GetFactManager())) {
transformation::Apply(message, GetIRContext(), GetFactManager());
*GetTransformations()->add_transformation()->mutable_split_block() =
message;
}
}
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_PASS_SPLIT_BLOCKS_
#define SOURCE_FUZZ_FUZZER_PASS_SPLIT_BLOCKS_
#include "source/fuzz/fuzzer_pass.h"
namespace spvtools {
namespace fuzz {
// A fuzzer pass for splitting blocks in the module, to create more blocks; this
// can be very useful for giving other passes a chance to apply.
class FuzzerPassSplitBlocks : public FuzzerPass {
public:
FuzzerPassSplitBlocks(opt::IRContext* ir_context, FactManager* fact_manager,
FuzzerContext* fuzzer_context,
protobufs::TransformationSequence* transformations);
~FuzzerPassSplitBlocks() override;
void Apply() override;
};
} // namespace fuzz
} // namespace spvtools
#endif // #define SOURCE_FUZZ_FUZZER_PASS_SPLIT_BLOCKS_

View File

@@ -0,0 +1,36 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
namespace fuzz {
namespace fuzzerutil {
bool IsFreshId(opt::IRContext* context, uint32_t id) {
return !context->get_def_use_mgr()->GetDef(id);
}
void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
// case where the maximum id bound is reached.
context->module()->SetIdBound(
std::max(context->module()->id_bound(), id + 1));
}
} // namespace fuzzerutil
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_FUZZER_UTIL_H_
#define SOURCE_FUZZ_FUZZER_UTIL_H_
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
// Provides global utility methods for use by the fuzzer
namespace fuzzerutil {
// Returns true if and only if the module does not define the given id.
bool IsFreshId(opt::IRContext* context, uint32_t id);
// Updates the module's id bound if needed so that it is large enough to
// account for the given id.
void UpdateModuleIdBound(opt::IRContext* context, uint32_t id);
} // namespace fuzzerutil
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_FUZZER_UTIL_H_

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/id_use_descriptor.h"
namespace spvtools {
namespace fuzz {
opt::Instruction* transformation::FindInstruction(
const protobufs::IdUseDescriptor& descriptor,
spvtools::opt::IRContext* context) {
for (auto& function : *context->module()) {
for (auto& block : function) {
bool found_base = block.id() == descriptor.base_instruction_result_id();
uint32_t num_ignored = 0;
for (auto& instruction : block) {
if (instruction.HasResultId() &&
instruction.result_id() ==
descriptor.base_instruction_result_id()) {
assert(!found_base &&
"It should not be possible to find the base instruction "
"multiple times.");
found_base = true;
assert(num_ignored == 0 &&
"The skipped instruction count should only be incremented "
"after the instruction base has been found.");
}
if (found_base &&
instruction.opcode() == descriptor.target_instruction_opcode()) {
if (num_ignored == descriptor.num_opcodes_to_ignore()) {
if (descriptor.in_operand_index() >= instruction.NumInOperands()) {
return nullptr;
}
auto in_operand =
instruction.GetInOperand(descriptor.in_operand_index());
if (in_operand.type != SPV_OPERAND_TYPE_ID) {
return nullptr;
}
if (in_operand.words[0] != descriptor.id_of_interest()) {
return nullptr;
}
return &instruction;
}
num_ignored++;
}
}
if (found_base) {
// We found the base instruction, but did not find the target
// instruction in the same block.
return nullptr;
}
}
}
return nullptr;
}
protobufs::IdUseDescriptor transformation::MakeIdUseDescriptor(
uint32_t id_of_interest, SpvOp target_instruction_opcode,
uint32_t in_operand_index, uint32_t base_instruction_result_id,
uint32_t num_opcodes_to_ignore) {
protobufs::IdUseDescriptor result;
result.set_id_of_interest(id_of_interest);
result.set_target_instruction_opcode(target_instruction_opcode);
result.set_in_operand_index(in_operand_index);
result.set_base_instruction_result_id(base_instruction_result_id);
result.set_num_opcodes_to_ignore(num_opcodes_to_ignore);
return result;
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,42 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_ID_USE_LOCATOR_H_
#define SOURCE_FUZZ_ID_USE_LOCATOR_H_
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// Looks for an instruction in |context| such that the id use represented by
// |descriptor| is one of the operands to said instruction. Returns |nullptr|
// if no such instruction can be found.
opt::Instruction* FindInstruction(const protobufs::IdUseDescriptor& descriptor,
opt::IRContext* context);
// Creates an IdUseDescriptor protobuf message from the given components.
// See the protobuf definition for details of what these components mean.
protobufs::IdUseDescriptor MakeIdUseDescriptor(
uint32_t id_of_interest, SpvOp target_instruction_opcode,
uint32_t in_operand_index, uint32_t base_instruction_result_id,
uint32_t num_opcodes_to_ignore);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_ID_USE_LOCATOR_H_

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_SPIRVFUZZ_PROTOBUFS_H_
#define SOURCE_FUZZ_SPIRVFUZZ_PROTOBUFS_H_
// This header file serves to act as a barrier between the protobuf header
// files and files that include them. It uses compiler pragmas to disable
// diagnostics, in order to ignore warnings generated during the processing
// of these header files without having to compromise on freedom from warnings
// in the rest of the project.
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-parameter"
#elif defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#elif defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4244)
#endif
// The following should be the only place in the project where protobuf files
// are directly included. This is so that they can be compiled in a manner
// where warnings are ignored.
#include "google/protobuf/util/json_util.h"
#include "source/fuzz/protobufs/spvtoolsfuzz.pb.h"
#if defined(__clang__)
#pragma clang diagnostic pop
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#elif defined(_MSC_VER)
#pragma warning(pop)
#endif
#endif // SOURCE_FUZZ_SPIRVFUZZ_PROTOBUFS_H_

View File

@@ -0,0 +1,226 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file is specifically named spvtools_fuzz.proto so that the string
// 'spvtools_fuzz' appears in the names of global-scope symbols that protoc
// generates when targeting C++. This is to reduce the potential for name
// clashes with other globally-scoped symbols.
syntax = "proto3";
package spvtools.fuzz.protobufs;
message IdUseDescriptor {
// Describes a use of an id as an input operand to an instruction in some block
// of a function.
// Example:
// - id_of_interest = 42
// - target_instruction_opcode = OpStore
// - in_operand_index = 1
// - base_instruction_result_id = 50
// - num_opcodes_to_ignore = 7
// represents a use of id 42 as input operand 1 to an OpStore instruction,
// such that the OpStore instruction can be found in the same basic block as
// the instruction with result id 50, and in particular is the 8th OpStore
// instruction found from instruction 50 onwards (i.e. 7 OpStore
// instructions are skipped).
// An id that we would like to be able to find a use of.
uint32 id_of_interest = 1;
// The opcode for the instruction that uses the id.
uint32 target_instruction_opcode = 2;
// The input operand index at which the use is expected.
uint32 in_operand_index = 3;
// The id of an instruction after which the instruction that contains the use
// is believed to occur; it might be the using instruction itself.
uint32 base_instruction_result_id = 4;
// The number of matching opcodes to skip over when searching for the using
// instruction from the base instruction.
uint32 num_opcodes_to_ignore = 5;
}
message FactSequence {
repeated Fact fact = 1;
}
message Fact {
// Currently there are no facts.
}
message TransformationSequence {
repeated Transformation transformation = 1;
}
message Transformation {
oneof transformation {
// Order the transformation options by numeric id (rather than
// alphabetically).
TransformationMoveBlockDown move_block_down = 1;
TransformationSplitBlock split_block = 2;
TransformationAddConstantBoolean add_constant_boolean = 3;
TransformationAddConstantScalar add_constant_scalar = 4;
TransformationAddTypeBoolean add_type_boolean = 5;
TransformationAddTypeFloat add_type_float = 6;
TransformationAddTypeInt add_type_int = 7;
TransformationAddDeadBreak add_dead_break = 8;
TransformationReplaceBooleanConstantWithConstantBinary replace_boolean_constant_with_constant_binary = 9;
// Add additional option using the next available number.
}
}
// Keep transformation message types in alphabetical order:
message TransformationAddConstantBoolean {
// Supports adding the constants true and false to a module, which may be
// necessary in order to enable other transformations if they are not present.
uint32 fresh_id = 1;
bool is_true = 2;
}
message TransformationAddConstantScalar {
// Adds a constant of the given scalar type
// Id for the constant
uint32 fresh_id = 1;
// Id for the scalar type of the constant
uint32 type_id = 2;
// Value of the constant
repeated uint32 word = 3;
}
message TransformationAddDeadBreak {
// A transformation that turns a basic block that unconditionally branches to
// its successor into a block that potentially breaks out of a structured
// control flow construct, but in such a manner that the break cannot actually
// be taken.
// The block to break from
uint32 from_block = 1;
// The merge block to break to
uint32 to_block = 2;
// Determines whether the break condition is true or false
bool break_condition_value = 3;
// A sequence of ids suitable for extending OpPhi instructions as a result of
// the new break edge
repeated uint32 phi_id = 4;
}
message TransformationAddTypeBoolean {
// Adds OpTypeBool to the module
// Id to be used for the type
uint32 fresh_id = 1;
}
message TransformationAddTypeFloat {
// Adds OpTypeFloat to the module with the given width
// Id to be used for the type
uint32 fresh_id = 1;
// Floating-point width
uint32 width = 2;
}
message TransformationAddTypeInt {
// Adds OpTypeInt to the module with the given width and signedness
// Id to be used for the type
uint32 fresh_id = 1;
// Integer width
uint32 width = 2;
// True if and only if this is a signed type
bool is_signed = 3;
}
message TransformationMoveBlockDown {
// A transformation that moves a basic block to be one position lower in
// program order.
// The id of the block to move down.
uint32 block_id = 1;
}
message TransformationReplaceBooleanConstantWithConstantBinary {
// A transformation to capture replacing a use of a boolean constant with
// binary operation on two constant values
// A descriptor for the boolean constant id we would like to replace
IdUseDescriptor id_use_descriptor = 1;
// Id for the constant to be used on the LHS of the comparision
uint32 lhs_id = 2;
// Id for the constant to be used on the RHS of the comparision
uint32 rhs_id = 3;
// Opcode for binary operator
uint32 opcode = 4;
// Id that will store the result of the binary operation instruction
uint32 fresh_id_for_binary_operation = 5;
}
message TransformationSplitBlock {
// A transformation that splits a basic block into two basic blocks.
// The result id of an instruction.
uint32 result_id = 1;
// An offset, such that the block containing |result_id_| should be split
// right before the instruction |offset_| instructions after |result_id_|.
uint32 offset = 2;
// An id that must not yet be used by the module to which this transformation
// is applied. Rather than having the transformation choose a suitable id on
// application, we require the id to be given upfront in order to facilitate
// reducing fuzzed shaders by removing transformations. The reason is that
// future transformations may refer to the fresh id introduced by this
// transformation, and if we end up changing what that id is, due to removing
// earlier transformations, it may inhibit later transformations from
// applying.
uint32 fresh_id = 3;
}

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/pseudo_random_generator.h"
#include <cassert>
namespace spvtools {
namespace fuzz {
PseudoRandomGenerator::PseudoRandomGenerator(uint32_t seed) : mt_(seed) {}
PseudoRandomGenerator::~PseudoRandomGenerator() = default;
uint32_t PseudoRandomGenerator::RandomUint32(uint32_t bound) {
assert(bound > 0 && "Bound must be positive");
return static_cast<uint32_t>(
std::uniform_int_distribution<>(0, bound - 1)(mt_));
}
bool PseudoRandomGenerator::RandomBool() {
return static_cast<bool>(std::uniform_int_distribution<>(0, 1)(mt_));
}
uint32_t PseudoRandomGenerator::RandomPercentage() {
// We use 101 because we want a result in the closed interval [0, 100], and
// RandomUint32 is not inclusive of its bound.
return RandomUint32(101);
}
double PseudoRandomGenerator::RandomDouble() {
return std::uniform_real_distribution<double>(0.0, 1.0)(mt_);
}
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_PSEUDO_RANDOM_GENERATOR_H_
#define SOURCE_FUZZ_PSEUDO_RANDOM_GENERATOR_H_
#include <random>
#include "source/fuzz/random_generator.h"
namespace spvtools {
namespace fuzz {
// Generates random data from a pseudo-random number generator.
class PseudoRandomGenerator : public RandomGenerator {
public:
explicit PseudoRandomGenerator(uint32_t seed);
~PseudoRandomGenerator() override;
uint32_t RandomUint32(uint32_t bound) override;
uint32_t RandomPercentage() override;
bool RandomBool() override;
double RandomDouble() override;
private:
std::mt19937 mt_;
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_PSEUDO_RANDOM_GENERATOR_H_

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2017 Google Inc.
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,26 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TOOLS_COMP_MARKV_MODEL_FACTORY_H_
#define TOOLS_COMP_MARKV_MODEL_FACTORY_H_
#include <memory>
#include "source/comp/markv_model.h"
#include "source/fuzz/random_generator.h"
namespace spvtools {
namespace comp {
namespace fuzz {
enum MarkvModelType {
kMarkvModelUnknown = 0,
kMarkvModelShaderLite,
kMarkvModelShaderMid,
kMarkvModelShaderMax,
};
RandomGenerator::RandomGenerator() = default;
std::unique_ptr<MarkvModel> CreateMarkvModel(MarkvModelType type);
RandomGenerator::~RandomGenerator() = default;
} // namespace comp
} // namespace fuzz
} // namespace spvtools
#endif // TOOLS_COMP_MARKV_MODEL_FACTORY_H_

View File

@@ -0,0 +1,45 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_RANDOM_GENERATOR_H_
#define SOURCE_FUZZ_RANDOM_GENERATOR_H_
#include <stdint.h>
namespace spvtools {
namespace fuzz {
class RandomGenerator {
public:
RandomGenerator();
virtual ~RandomGenerator();
// Returns a value in the half-open interval [0, bound).
virtual uint32_t RandomUint32(uint32_t bound) = 0;
// Returns a value in the closed interval [0, 100].
virtual uint32_t RandomPercentage() = 0;
// Returns a boolean.
virtual bool RandomBool() = 0;
// Returns a double in the closed interval [0, 1]
virtual double RandomDouble() = 0;
};
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_RANDOM_GENERATOR_H_

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_constant_boolean.h"
#include "source/fuzz/fuzzer_util.h"
#include "source/opt/types.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationAddConstantBoolean& message,
IRContext* context, const FactManager& /*unused*/) {
opt::analysis::Bool bool_type;
if (!context->get_type_mgr()->GetId(&bool_type)) {
// No OpTypeBool is present.
return false;
}
return fuzzerutil::IsFreshId(context, message.fresh_id());
}
void Apply(const protobufs::TransformationAddConstantBoolean& message,
IRContext* context, FactManager* /*unused*/) {
opt::analysis::Bool bool_type;
// Add the boolean constant to the module, ensuring the module's id bound is
// high enough.
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
context->module()->AddGlobalValue(
message.is_true() ? SpvOpConstantTrue : SpvOpConstantFalse,
message.fresh_id(), context->get_type_mgr()->GetId(&bool_type));
// We have added an instruction to the module, so need to be careful about the
// validity of existing analyses.
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddConstantBoolean
MakeTransformationAddConstantBoolean(uint32_t fresh_id, bool is_true) {
protobufs::TransformationAddConstantBoolean result;
result.set_fresh_id(fresh_id);
result.set_is_true(is_true);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_BOOLEAN_CONSTANT_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_BOOLEAN_CONSTANT_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |fresh_id| must not be used by the module.
// - The module must already contain OpTypeBool.
bool IsApplicable(const protobufs::TransformationAddConstantBoolean& message,
opt::IRContext* context, const FactManager& fact_manager);
// - Adds OpConstantTrue (OpConstantFalse) to the module with id |fresh_id|
// if |is_true| holds (does not hold).
void Apply(const protobufs::TransformationAddConstantBoolean& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddConstantBoolean
MakeTransformationAddConstantBoolean(uint32_t fresh_id, bool is_true);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_BOOLEAN_CONSTANT_H_

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_constant_scalar.h"
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationAddConstantScalar& message,
IRContext* context,
const spvtools::fuzz::FactManager& /*unused*/) {
// The id needs to be fresh.
if (!fuzzerutil::IsFreshId(context, message.fresh_id())) {
return false;
}
// The type id for the scalar must exist and be a type.
auto type = context->get_type_mgr()->GetType(message.type_id());
if (!type) {
return false;
}
uint32_t width;
if (type->AsFloat()) {
width = type->AsFloat()->width();
} else if (type->AsInteger()) {
width = type->AsInteger()->width();
} else {
return false;
}
// The number of words is the integer floor of the width.
auto words = (width + 32 - 1) / 32;
// The number of words provided by the transformation needs to match the
// width of the type.
return static_cast<uint32_t>(message.word().size()) == words;
}
void Apply(const protobufs::TransformationAddConstantScalar& message,
IRContext* context, spvtools::fuzz::FactManager* /*unused*/) {
opt::Instruction::OperandList operand_list;
for (auto word : message.word()) {
operand_list.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {word}});
}
context->module()->AddGlobalValue(
MakeUnique<opt::Instruction>(context, SpvOpConstant, message.type_id(),
message.fresh_id(), operand_list));
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
// We have added an instruction to the module, so need to be careful about the
// validity of existing analyses.
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddConstantScalar MakeTransformationAddConstantScalar(
uint32_t fresh_id, uint32_t type_id, std::vector<uint32_t> words) {
protobufs::TransformationAddConstantScalar result;
result.set_fresh_id(fresh_id);
result.set_type_id(type_id);
for (auto word : words) {
result.add_word(word);
}
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_SCALAR_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_SCALAR_H_
#include <vector>
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.fresh_id| must not be used by the module
// - |message.type_id| must be the id of a floating-point or integer type
// - The size of |message.word| must be compatible with the width of this type
bool IsApplicable(const protobufs::TransformationAddConstantScalar& message,
opt::IRContext* context, const FactManager& fact_manager);
// Adds a new OpConstant instruction with the given type and words.
void Apply(const protobufs::TransformationAddConstantScalar& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddConstantScalar MakeTransformationAddConstantScalar(
uint32_t fresh_id, uint32_t type_id, std::vector<uint32_t> words);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_CONSTANT_SCALAR_H_

View File

@@ -0,0 +1,316 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_dead_break.h"
#include "source/fuzz/fact_manager.h"
#include "source/opt/basic_block.h"
#include "source/opt/ir_context.h"
#include "source/opt/struct_cfg_analysis.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::BasicBlock;
using opt::IRContext;
using opt::Instruction;
namespace {
BasicBlock* MaybeFindBlock(IRContext* context, uint32_t maybe_block_id) {
auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
if (inst == nullptr) {
// No instruction defining this id was found.
return nullptr;
}
if (inst->opcode() != SpvOpLabel) {
// The instruction defining the id is not a label, so it cannot be a block
// id.
return nullptr;
}
return context->cfg()->block(maybe_block_id);
}
bool PhiIdsOk(const protobufs::TransformationAddDeadBreak& message,
IRContext* context, BasicBlock* bb_from, BasicBlock* bb_to) {
if (bb_from->IsSuccessor(bb_to)) {
// There is already an edge from |from_block| to |to_block|, so there is
// no need to extend OpPhi instructions. Do not allow phi ids to be
// present. This might turn out to be too strict; perhaps it would be OK
// just to ignore the ids in this case.
return message.phi_id().empty();
}
// The break would add a previously non-existent edge from |from_block| to
// |to_block|, so we go through the given phi ids and check that they exactly
// match the OpPhi instructions in |to_block|.
uint32_t phi_index = 0;
// An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
// makes sense here because we need to increment |phi_index| for each OpPhi
// instruction.
for (auto& inst : *bb_to) {
if (inst.opcode() != SpvOpPhi) {
// The OpPhi instructions all occur at the start of the block; if we find
// a non-OpPhi then we have seen them all.
break;
}
if (phi_index == static_cast<uint32_t>(message.phi_id().size())) {
// Not enough phi ids have been provided to account for the OpPhi
// instructions.
return false;
}
// Look for an instruction defining the next phi id.
Instruction* phi_extension =
context->get_def_use_mgr()->GetDef(message.phi_id()[phi_index]);
if (!phi_extension) {
// The id given to extend this OpPhi does not exist.
return false;
}
if (phi_extension->type_id() != inst.type_id()) {
// The instruction given to extend this OpPhi either does not have a type
// or its type does not match that of the OpPhi.
return false;
}
if (context->get_instr_block(phi_extension)) {
// The instruction defining the phi id has an associated block (i.e., it
// is not a global value). Check whether its definition dominates the
// exit of |from_block|.
auto dominator_analysis =
context->GetDominatorAnalysis(bb_from->GetParent());
if (!dominator_analysis->Dominates(phi_extension,
bb_from->terminator())) {
// The given id is no good as its definition does not dominate the exit
// of |from_block|
return false;
}
}
phi_index++;
}
// Reject the transformation if not all of the ids for extending OpPhi
// instructions are needed. This might turn out to be stricter than necessary;
// perhaps it would be OK just to not use the ids in this case.
return phi_index == static_cast<uint32_t>(message.phi_id().size());
}
bool FromBlockIsInLoopContinueConstruct(
const protobufs::TransformationAddDeadBreak& message, IRContext* context,
uint32_t maybe_loop_header) {
// We deem a block to be part of a loop's continue construct if the loop's
// continue target dominates the block.
auto containing_construct_block = context->cfg()->block(maybe_loop_header);
if (containing_construct_block->IsLoopHeader()) {
auto continue_target = containing_construct_block->ContinueBlockId();
if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
->Dominates(continue_target, message.from_block())) {
return true;
}
}
return false;
}
bool AddingBreakRespectsStructuredControlFlow(
const protobufs::TransformationAddDeadBreak& message, IRContext* context,
BasicBlock* bb_from) {
// Look at the structured control flow associated with |from_block| and
// check whether it is contained in an appropriate construct with merge id
// |to_block| such that a break from |from_block| to |to_block| is legal.
// There are three legal cases to consider:
// (1) |from_block| is a loop header and |to_block| is its merge
// (2) |from_block| is a non-header node of a construct, and |to_block|
// is the merge for that construct
// (3) |from_block| is a non-header node of a selection construct, and
// |to_block| is the merge for the innermost loop containing
// |from_block|
//
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2653) It may be
// possible to be more aggressive in breaking from switch constructs.
//
// The reason we need to distinguish between cases (1) and (2) is that the
// structured CFG analysis does not deem a header to be part of the construct
// that it heads.
// Consider case (1)
if (bb_from->IsLoopHeader()) {
// Case (1) holds if |to_block| is the merge block for the loop;
// otherwise no case holds
return bb_from->MergeBlockId() == message.to_block();
}
// Both cases (2) and (3) require that |from_block| is inside some
// structured control flow construct.
auto containing_construct =
context->GetStructuredCFGAnalysis()->ContainingConstruct(
message.from_block());
if (!containing_construct) {
// |from_block| is not in a construct from which we can break.
return false;
}
// Consider case (2)
if (message.to_block() ==
context->cfg()->block(containing_construct)->MergeBlockId()) {
// This looks like an instance of case (2).
// However, the structured CFG analysis regards the continue construct of a
// loop as part of the loop, but it is not legal to jump from a loop's
// continue construct to the loop's merge (except from the back-edge block),
// so we need to check for this case.
//
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2577): We do not
// currently allow a dead break from a back edge block, but we could and
// ultimately should.
return !FromBlockIsInLoopContinueConstruct(message, context,
containing_construct);
}
// Case (3) holds if and only if |to_block| is the merge block for this
// innermost loop that contains |from_block|
auto containing_loop_header =
context->GetStructuredCFGAnalysis()->ContainingLoop(message.from_block());
if (containing_loop_header &&
message.to_block() ==
context->cfg()->block(containing_loop_header)->MergeBlockId()) {
return !FromBlockIsInLoopContinueConstruct(message, context,
containing_loop_header);
}
return false;
}
} // namespace
bool IsApplicable(const protobufs::TransformationAddDeadBreak& message,
IRContext* context, const FactManager& /*unused*/) {
// First, we check that a constant with the same value as
// |break_condition_value| is present.
opt::analysis::Bool bool_type;
auto registered_bool_type =
context->get_type_mgr()->GetRegisteredType(&bool_type);
if (!registered_bool_type) {
return false;
}
opt::analysis::BoolConstant bool_constant(registered_bool_type->AsBool(),
message.break_condition_value());
if (!context->get_constant_mgr()->FindConstant(&bool_constant)) {
// The required constant is not present, so the transformation cannot be
// applied.
return false;
}
// Check that |from_block| and |to_block| really are block ids
BasicBlock* bb_from = MaybeFindBlock(context, message.from_block());
if (bb_from == nullptr) {
return false;
}
BasicBlock* bb_to = MaybeFindBlock(context, message.to_block());
if (bb_to == nullptr) {
return false;
}
// Check that |from_block| ends with an unconditional branch.
if (bb_from->terminator()->opcode() != SpvOpBranch) {
// The block associated with the id does not end with an unconditional
// branch.
return false;
}
assert(bb_from != nullptr &&
"We should have found a block if this line of code is reached.");
assert(
bb_from->id() == message.from_block() &&
"The id of the block we found should match the source id for the break.");
assert(bb_to != nullptr &&
"We should have found a block if this line of code is reached.");
assert(
bb_to->id() == message.to_block() &&
"The id of the block we found should match the target id for the break.");
// Check whether the data passed to extend OpPhi instructions is appropriate.
if (!PhiIdsOk(message, context, bb_from, bb_to)) {
return false;
}
// Finally, check that adding the break would respect the rules of structured
// control flow.
return AddingBreakRespectsStructuredControlFlow(message, context, bb_from);
}
void Apply(const protobufs::TransformationAddDeadBreak& message,
IRContext* context, FactManager* /*unused*/) {
// Get the id of the boolean constant to be used as the break condition.
opt::analysis::Bool bool_type;
opt::analysis::BoolConstant bool_constant(
context->get_type_mgr()->GetRegisteredType(&bool_type)->AsBool(),
message.break_condition_value());
uint32_t bool_id = context->get_constant_mgr()->FindDeclaredConstant(
&bool_constant, context->get_type_mgr()->GetId(&bool_type));
auto bb_from = context->cfg()->block(message.from_block());
auto bb_to = context->cfg()->block(message.to_block());
const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
auto successor = bb_from->terminator()->GetSingleWordInOperand(0);
assert(bb_from->terminator()->opcode() == SpvOpBranch &&
"Precondition for the transformation requires that the source block "
"ends with OpBranch");
// Add the dead break, by turning OpBranch into OpBranchConditional, and
// ordering the targets depending on whether the given boolean corresponds to
// true or false.
bb_from->terminator()->SetOpcode(SpvOpBranchConditional);
bb_from->terminator()->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {bool_id}},
{SPV_OPERAND_TYPE_ID,
{message.break_condition_value() ? successor : message.to_block()}},
{SPV_OPERAND_TYPE_ID,
{message.break_condition_value() ? message.to_block() : successor}}});
// Update OpPhi instructions in the target block if this break adds a
// previously non-existent edge from source to target.
if (!from_to_edge_already_exists) {
uint32_t phi_index = 0;
for (auto& inst : *bb_to) {
if (inst.opcode() != SpvOpPhi) {
break;
}
assert(phi_index < static_cast<uint32_t>(message.phi_id().size()) &&
"There should be exactly one phi id per OpPhi instruction.");
inst.AddOperand({SPV_OPERAND_TYPE_ID, {message.phi_id()[phi_index]}});
inst.AddOperand({SPV_OPERAND_TYPE_ID, {message.from_block()}});
phi_index++;
}
assert(phi_index == static_cast<uint32_t>(message.phi_id().size()) &&
"There should be exactly one phi id per OpPhi instruction.");
}
// Invalidate all analyses
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddDeadBreak MakeTransformationAddDeadBreak(
uint32_t from_block, uint32_t to_block, bool break_condition_value,
std::vector<uint32_t> phi_id) {
protobufs::TransformationAddDeadBreak result;
result.set_from_block(from_block);
result.set_to_block(to_block);
result.set_break_condition_value(break_condition_value);
for (auto id : phi_id) {
result.add_phi_id(id);
}
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_DEAD_BREAK_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_DEAD_BREAK_H_
#include <vector>
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.from_block| must be the id of a block a in the given module.
// - |message.to_block| must be the id of a block b in the given module.
// - if |message.break_condition_value| holds (does not hold) then
// OpConstantTrue (OpConstantFalse) must be present in the module
// - |message.phi_ids| must be a list of ids that are all available at
// |message.from_block|
// - a and b must be in the same function.
// - b must be a merge block.
// - a must end with an unconditional branch to some block c.
// - replacing this branch with a conditional branch to b or c, with
// the boolean constant associated with |message.break_condition_value| as
// the condition, and the ids in |message.phi_ids| used to extend
// any OpPhi instructions at b as a result of the edge from a, must
// maintain validity of the module.
bool IsApplicable(const protobufs::TransformationAddDeadBreak& message,
opt::IRContext* context, const FactManager& fact_manager);
// Replaces the terminator of a with a conditional branch to b or c.
// The boolean constant associated with |message.break_condition_value| is used
// as the condition, and the order of b and c is arranged such that control is
// guaranteed to jump to c.
void Apply(const protobufs::TransformationAddDeadBreak& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddDeadBreak MakeTransformationAddDeadBreak(
uint32_t from_block, uint32_t to_block, bool break_condition_value,
std::vector<uint32_t> phi_ids);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_DEAD_BREAK_H_

View File

@@ -0,0 +1,58 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_type_boolean.h"
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationAddTypeBoolean& message,
IRContext* context,
const spvtools::fuzz::FactManager& /*unused*/) {
// The id must be fresh.
if (!fuzzerutil::IsFreshId(context, message.fresh_id())) {
return false;
}
// Applicable if there is no bool type already declared in the module.
opt::analysis::Bool bool_type;
return context->get_type_mgr()->GetId(&bool_type) == 0;
}
void Apply(const protobufs::TransformationAddTypeBoolean& message,
IRContext* context, spvtools::fuzz::FactManager* /*unused*/) {
opt::Instruction::OperandList empty_operands;
context->module()->AddType(MakeUnique<opt::Instruction>(
context, SpvOpTypeBool, 0, message.fresh_id(), empty_operands));
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
// We have added an instruction to the module, so need to be careful about the
// validity of existing analyses.
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddTypeBoolean MakeTransformationAddTypeBoolean(
uint32_t fresh_id) {
protobufs::TransformationAddTypeBoolean result;
result.set_fresh_id(fresh_id);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,43 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_BOOLEAN_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_BOOLEAN_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.fresh_id| must not be used by the module.
// - The module must not yet declare OpTypeBoolean
bool IsApplicable(const protobufs::TransformationAddTypeBoolean& message,
opt::IRContext* context, const FactManager& fact_manager);
// Adds OpTypeBoolean with |message.fresh_id| as result id.
void Apply(const protobufs::TransformationAddTypeBoolean& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddTypeBoolean MakeTransformationAddTypeBoolean(
uint32_t fresh_id);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_BOOLEAN_H_

View File

@@ -0,0 +1,61 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_type_float.h"
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationAddTypeFloat& message,
IRContext* context,
const spvtools::fuzz::FactManager& /*unused*/) {
// The id must be fresh.
if (!fuzzerutil::IsFreshId(context, message.fresh_id())) {
return false;
}
// Applicable if there is no float type with this width already declared in
// the module.
opt::analysis::Float float_type(message.width());
return context->get_type_mgr()->GetId(&float_type) == 0;
}
void Apply(const protobufs::TransformationAddTypeFloat& message,
IRContext* context, spvtools::fuzz::FactManager* /*unused*/) {
opt::Instruction::OperandList width = {
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {message.width()}}};
context->module()->AddType(MakeUnique<opt::Instruction>(
context, SpvOpTypeFloat, 0, message.fresh_id(), width));
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
// We have added an instruction to the module, so need to be careful about the
// validity of existing analyses.
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddTypeFloat MakeTransformationAddTypeFloat(
uint32_t fresh_id, uint32_t width) {
protobufs::TransformationAddTypeFloat result;
result.set_fresh_id(fresh_id);
result.set_width(width);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_FLOAT_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_FLOAT_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.fresh_id| must not be used by the module
// - The module must not contain an OpTypeFloat instruction with width
// |message.width|
bool IsApplicable(const protobufs::TransformationAddTypeFloat& message,
opt::IRContext* context, const FactManager& fact_manager);
// Adds an OpTypeFloat instruction to the module with the given width
void Apply(const protobufs::TransformationAddTypeFloat& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddTypeFloat MakeTransformationAddTypeFloat(
uint32_t fresh_id, uint32_t width);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_FLOAT_H_

View File

@@ -0,0 +1,63 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_add_type_int.h"
#include "source/fuzz/fuzzer_util.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationAddTypeInt& message,
IRContext* context,
const spvtools::fuzz::FactManager& /*unused*/) {
// The id must be fresh.
if (!fuzzerutil::IsFreshId(context, message.fresh_id())) {
return false;
}
// Applicable if there is no int type with this width and signedness already
// declared in the module.
opt::analysis::Integer int_type(message.width(), message.is_signed());
return context->get_type_mgr()->GetId(&int_type) == 0;
}
void Apply(const protobufs::TransformationAddTypeInt& message,
IRContext* context, spvtools::fuzz::FactManager* /*unused*/) {
opt::Instruction::OperandList in_operands = {
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {message.width()}},
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {message.is_signed() ? 1u : 0u}}};
context->module()->AddType(MakeUnique<opt::Instruction>(
context, SpvOpTypeInt, 0, message.fresh_id(), in_operands));
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
// We have added an instruction to the module, so need to be careful about the
// validity of existing analyses.
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
}
protobufs::TransformationAddTypeInt MakeTransformationAddTypeInt(
uint32_t fresh_id, uint32_t width, bool is_signed) {
protobufs::TransformationAddTypeInt result;
result.set_fresh_id(fresh_id);
result.set_width(width);
result.set_is_signed(is_signed);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,45 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_INT_H_
#define SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_INT_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.fresh_id| must not be used by the module
// - The module must not contain an OpTypeInt instruction with width
// |message.width| and signedness |message.is_signed|
bool IsApplicable(const protobufs::TransformationAddTypeInt& message,
opt::IRContext* context, const FactManager& fact_manager);
// Adds an OpTypeInt instruction to the module with the given width and
// signedness.
void Apply(const protobufs::TransformationAddTypeInt& message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationAddTypeInt MakeTransformationAddTypeInt(
uint32_t fresh_id, uint32_t width, bool is_signed);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_TYPE_INT_H_

View File

@@ -0,0 +1,102 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_move_block_down.h"
#include "source/opt/basic_block.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::BasicBlock;
using opt::IRContext;
bool IsApplicable(const protobufs::TransformationMoveBlockDown& message,
IRContext* context, const FactManager& /*unused*/) {
// Go through every block in every function, looking for a block whose id
// matches that of the block we want to consider moving down.
for (auto& function : *context->module()) {
for (auto block_it = function.begin(); block_it != function.end();
++block_it) {
if (block_it->id() == message.block_id()) {
// We have found a match.
if (block_it == function.begin()) {
// The block is the first one appearing in the function. We are not
// allowed to move this block down.
return false;
}
// Record the block we would like to consider moving down.
BasicBlock* block_matching_id = &*block_it;
// Now see whether there is some block following that block in program
// order.
++block_it;
if (block_it == function.end()) {
// There is no such block; i.e., the block we are considering moving
// is the last one in the function. The transformation thus does not
// apply.
return false;
}
BasicBlock* next_block_in_program_order = &*block_it;
// We can move the block of interest down if and only if it does not
// dominate the block that comes next.
return !context->GetDominatorAnalysis(&function)->Dominates(
block_matching_id, next_block_in_program_order);
}
}
}
// We did not find a matching block, so the transformation is not applicable:
// there is no relevant block to move.
return false;
}
void Apply(const protobufs::TransformationMoveBlockDown& message,
IRContext* context, FactManager* /*unused*/) {
// Go through every block in every function, looking for a block whose id
// matches that of the block we want to move down.
for (auto& function : *context->module()) {
for (auto block_it = function.begin(); block_it != function.end();
++block_it) {
if (block_it->id() == message.block_id()) {
++block_it;
assert(block_it != function.end() &&
"To be able to move a block down, it needs to have a "
"program-order successor.");
function.MoveBasicBlockToAfter(message.block_id(), &*block_it);
// It is prudent to invalidate analyses after changing block ordering in
// case any of them depend on it, but the ones that definitely do not
// depend on ordering can be preserved. These include the following,
// which can likely be extended.
context->InvalidateAnalysesExceptFor(
IRContext::Analysis::kAnalysisDefUse |
IRContext::Analysis::kAnalysisDominatorAnalysis);
return;
}
}
}
assert(false && "No block was found to move down.");
}
protobufs::TransformationMoveBlockDown MakeTransformationMoveBlockDown(
uint32_t id) {
protobufs::TransformationMoveBlockDown result;
result.set_block_id(id);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_MOVE_BLOCK_DOWN_H_
#define SOURCE_FUZZ_TRANSFORMATION_MOVE_BLOCK_DOWN_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |block_id| must be the id of a block b in the given module.
// - b must not be the first nor last block appearing, in program order,
// in a function.
// - b must not dominate the block that follows it in program order.
bool IsApplicable(const protobufs::TransformationMoveBlockDown& message,
opt::IRContext* context, const FactManager& fact_manager);
// The block with id |block_id| is moved down; i.e. the program order
// between it and the block that follows it is swapped.
void Apply(const protobufs::TransformationMoveBlockDown& message,
opt::IRContext* context, FactManager* fact_manager);
// Creates a protobuf message to move down the block with id |id|.
protobufs::TransformationMoveBlockDown MakeTransformationMoveBlockDown(
uint32_t id);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_MOVE_BLOCK_DOWN_H_

View File

@@ -0,0 +1,282 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_replace_boolean_constant_with_constant_binary.h"
#include <cmath>
#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/id_use_descriptor.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
namespace {
// Given floating-point values |lhs| and |rhs|, and a floating-point binary
// operator |binop|, returns true if it is certain that 'lhs binop rhs'
// evaluates to |required_value|.
template <typename T>
bool float_binop_evaluates_to(T lhs, T rhs, SpvOp binop, bool required_value) {
// Infinity and NaN values are conservatively treated as out of scope.
if (!std::isfinite(lhs) || !std::isfinite(rhs)) {
return false;
}
bool binop_result;
// The following captures the binary operators that spirv-fuzz can actually
// generate when turning a boolean constant into a binary expression.
switch (binop) {
case SpvOpFOrdGreaterThanEqual:
case SpvOpFUnordGreaterThanEqual:
binop_result = (lhs >= rhs);
break;
case SpvOpFOrdGreaterThan:
case SpvOpFUnordGreaterThan:
binop_result = (lhs > rhs);
break;
case SpvOpFOrdLessThanEqual:
case SpvOpFUnordLessThanEqual:
binop_result = (lhs <= rhs);
break;
case SpvOpFOrdLessThan:
case SpvOpFUnordLessThan:
binop_result = (lhs < rhs);
break;
default:
return false;
}
return binop_result == required_value;
}
// Analogous to 'float_binop_evaluates_to', but for signed int values.
template <typename T>
bool signed_int_binop_evaluates_to(T lhs, T rhs, SpvOp binop,
bool required_value) {
bool binop_result;
switch (binop) {
case SpvOpSGreaterThanEqual:
binop_result = (lhs >= rhs);
break;
case SpvOpSGreaterThan:
binop_result = (lhs > rhs);
break;
case SpvOpSLessThanEqual:
binop_result = (lhs <= rhs);
break;
case SpvOpSLessThan:
binop_result = (lhs < rhs);
break;
default:
return false;
}
return binop_result == required_value;
}
// Analogous to 'float_binop_evaluates_to', but for unsigned int values.
template <typename T>
bool unsigned_int_binop_evaluates_to(T lhs, T rhs, SpvOp binop,
bool required_value) {
bool binop_result;
switch (binop) {
case SpvOpUGreaterThanEqual:
binop_result = (lhs >= rhs);
break;
case SpvOpUGreaterThan:
binop_result = (lhs > rhs);
break;
case SpvOpULessThanEqual:
binop_result = (lhs <= rhs);
break;
case SpvOpULessThan:
binop_result = (lhs < rhs);
break;
default:
return false;
}
return binop_result == required_value;
}
} // namespace
bool IsApplicable(
const protobufs::TransformationReplaceBooleanConstantWithConstantBinary&
message,
opt::IRContext* context, const FactManager& /*unused*/) {
// The id for the binary result must be fresh
if (!fuzzerutil::IsFreshId(context,
message.fresh_id_for_binary_operation())) {
return false;
}
// The used id must be for a boolean constant
auto boolean_constant = context->get_def_use_mgr()->GetDef(
message.id_use_descriptor().id_of_interest());
if (!boolean_constant) {
return false;
}
if (!(boolean_constant->opcode() == SpvOpConstantFalse ||
boolean_constant->opcode() == SpvOpConstantTrue)) {
return false;
}
// The left-hand-side id must correspond to a constant instruction.
auto lhs_constant_inst = context->get_def_use_mgr()->GetDef(message.lhs_id());
if (!lhs_constant_inst) {
return false;
}
if (lhs_constant_inst->opcode() != SpvOpConstant) {
return false;
}
// The right-hand-side id must correspond to a constant instruction.
auto rhs_constant_inst = context->get_def_use_mgr()->GetDef(message.rhs_id());
if (!rhs_constant_inst) {
return false;
}
if (rhs_constant_inst->opcode() != SpvOpConstant) {
return false;
}
// The left- and right-hand side instructions must have the same type.
if (lhs_constant_inst->type_id() != rhs_constant_inst->type_id()) {
return false;
}
// The expression 'LHS opcode RHS' must evaluate to the boolean constant.
auto lhs_constant =
context->get_constant_mgr()->FindDeclaredConstant(message.lhs_id());
auto rhs_constant =
context->get_constant_mgr()->FindDeclaredConstant(message.rhs_id());
bool expected_result = (boolean_constant->opcode() == SpvOpConstantTrue);
const SpvOp binary_opcode = static_cast<SpvOp>(message.opcode());
// We consider the floating point, signed and unsigned integer cases
// separately. In each case the logic is very similar.
if (lhs_constant->AsFloatConstant()) {
assert(rhs_constant->AsFloatConstant() &&
"Both constants should be of the same type.");
if (lhs_constant->type()->AsFloat()->width() == 32) {
if (!float_binop_evaluates_to(lhs_constant->GetFloat(),
rhs_constant->GetFloat(), binary_opcode,
expected_result)) {
return false;
}
} else {
assert(lhs_constant->type()->AsFloat()->width() == 64);
if (!float_binop_evaluates_to(lhs_constant->GetDouble(),
rhs_constant->GetDouble(), binary_opcode,
expected_result)) {
return false;
}
}
} else {
assert(lhs_constant->AsIntConstant() && "Constants should be in or float.");
assert(rhs_constant->AsIntConstant() &&
"Both constants should be of the same type.");
if (lhs_constant->type()->AsInteger()->IsSigned()) {
if (lhs_constant->type()->AsInteger()->width() == 32) {
if (!signed_int_binop_evaluates_to(lhs_constant->GetS32(),
rhs_constant->GetS32(),
binary_opcode, expected_result)) {
return false;
}
} else {
assert(lhs_constant->type()->AsInteger()->width() == 64);
if (!signed_int_binop_evaluates_to(lhs_constant->GetS64(),
rhs_constant->GetS64(),
binary_opcode, expected_result)) {
return false;
}
}
} else {
if (lhs_constant->type()->AsInteger()->width() == 32) {
if (!unsigned_int_binop_evaluates_to(lhs_constant->GetU32(),
rhs_constant->GetU32(),
binary_opcode, expected_result)) {
return false;
}
} else {
assert(lhs_constant->type()->AsInteger()->width() == 64);
if (!unsigned_int_binop_evaluates_to(lhs_constant->GetU64(),
rhs_constant->GetU64(),
binary_opcode, expected_result)) {
return false;
}
}
}
}
// The id use descriptor must identify some instruction
return transformation::FindInstruction(message.id_use_descriptor(),
context) != nullptr;
}
opt::Instruction* Apply(
const protobufs::TransformationReplaceBooleanConstantWithConstantBinary&
message,
opt::IRContext* context, FactManager* /*unused*/) {
opt::analysis::Bool bool_type;
opt::Instruction::OperandList operands = {
{SPV_OPERAND_TYPE_ID, {message.lhs_id()}},
{SPV_OPERAND_TYPE_ID, {message.rhs_id()}}};
auto binary_instruction = MakeUnique<opt::Instruction>(
context, static_cast<SpvOp>(message.opcode()),
context->get_type_mgr()->GetId(&bool_type),
message.fresh_id_for_binary_operation(), operands);
opt::Instruction* result = binary_instruction.get();
auto instruction_containing_constant_use =
transformation::FindInstruction(message.id_use_descriptor(), context);
// We want to insert the new instruction before the instruction that contains
// the use of the boolean, but we need to go backwards one more instruction if
// the using instruction is preceded by a merge instruction.
auto instruction_before_which_to_insert = instruction_containing_constant_use;
{
opt::Instruction* previous_node =
instruction_before_which_to_insert->PreviousNode();
if (previous_node && (previous_node->opcode() == SpvOpLoopMerge ||
previous_node->opcode() == SpvOpSelectionMerge)) {
instruction_before_which_to_insert = previous_node;
}
}
instruction_before_which_to_insert->InsertBefore(
std::move(binary_instruction));
instruction_containing_constant_use->SetInOperand(
message.id_use_descriptor().in_operand_index(),
{message.fresh_id_for_binary_operation()});
fuzzerutil::UpdateModuleIdBound(context,
message.fresh_id_for_binary_operation());
context->InvalidateAnalysesExceptFor(opt::IRContext::Analysis::kAnalysisNone);
return result;
}
protobufs::TransformationReplaceBooleanConstantWithConstantBinary
MakeTransformationReplaceBooleanConstantWithConstantBinary(
const protobufs::IdUseDescriptor& id_use_descriptor, uint32_t lhs_id,
uint32_t rhs_id, SpvOp comparison_opcode,
uint32_t fresh_id_for_binary_operation) {
protobufs::TransformationReplaceBooleanConstantWithConstantBinary result;
*result.mutable_id_use_descriptor() = id_use_descriptor;
result.set_lhs_id(lhs_id);
result.set_rhs_id(rhs_id);
result.set_opcode(comparison_opcode);
result.set_fresh_id_for_binary_operation(fresh_id_for_binary_operation);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,59 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_BOOLEAN_CONSTANT_WITH_CONSTANT_BINARY_H_
#define SOURCE_FUZZ_TRANSFORMATION_REPLACE_BOOLEAN_CONSTANT_WITH_CONSTANT_BINARY_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |message.fresh_id_for_binary_operation| must not already be used by the
// module.
// - |message.id_use_descriptor| must identify a use of a boolean constant c.
// - |message.lhs_id| and |message.rhs_id| must be the ids of constant
// instructions with the same type
// - |message.opcode| must be suitable for applying to |message.lhs_id| and
// |message.rhs_id|, and the result must evaluate to the boolean constant c.
bool IsApplicable(
const protobufs::TransformationReplaceBooleanConstantWithConstantBinary&
message,
opt::IRContext* context, const FactManager& fact_manager);
// A new instruction is added before the boolean constant usage that computes
// the result of applying |message.opcode| to |message.lhs_id| and
// |message.rhs_id| is added, with result id
// |message.fresh_id_for_binary_operation|. The boolean constant usage is
// replaced with this result id.
opt::Instruction* Apply(
const protobufs::TransformationReplaceBooleanConstantWithConstantBinary&
message,
opt::IRContext* context, FactManager* fact_manager);
// Helper factory to create a transformation message.
protobufs::TransformationReplaceBooleanConstantWithConstantBinary
MakeTransformationReplaceBooleanConstantWithConstantBinary(
const protobufs::IdUseDescriptor& id_use_descriptor, uint32_t lhs_id,
uint32_t rhs_id, SpvOp comparison_opcode,
uint32_t fresh_id_for_binary_operation);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_REPLACE_BOOLEAN_CONSTANT_WITH_CONSTANT_BINARY_H_

View File

@@ -0,0 +1,183 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/fuzz/transformation_split_block.h"
#include <utility>
#include "source/fuzz/fuzzer_util.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
using opt::BasicBlock;
using opt::IRContext;
using opt::Instruction;
using opt::Operand;
namespace {
// Returns:
// - (true, block->end()) if the relevant instruction is in this block
// but inapplicable
// - (true, it) if 'it' is an iterator for the relevant instruction
// - (false, _) otherwise.
std::pair<bool, BasicBlock::iterator> FindInstToSplitBefore(
const protobufs::TransformationSplitBlock& message, BasicBlock* block) {
// There are three possibilities:
// (1) the transformation wants to split at some offset from the block's
// label.
// (2) the transformation wants to split at some offset from a
// non-label instruction inside the block.
// (3) the split associated with this transformation has nothing to do with
// this block
if (message.result_id() == block->id()) {
// Case (1).
if (message.offset() == 0) {
// The offset is not allowed to be 0: this would mean splitting before the
// block's label.
// By returning (true, block->end()), we indicate that we did find the
// instruction (so that it is not worth searching further for it), but
// that splitting will not be possible.
return {true, block->end()};
}
// Conceptually, the first instruction in the block is [label + 1].
// We thus start from 1 when applying the offset.
auto inst_it = block->begin();
for (uint32_t i = 1; i < message.offset() && inst_it != block->end(); i++) {
++inst_it;
}
// This is either the desired instruction, or the end of the block.
return {true, inst_it};
}
for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
if (message.result_id() == inst_it->result_id()) {
// Case (2): we have found the base instruction; we now apply the offset.
for (uint32_t i = 0; i < message.offset() && inst_it != block->end();
i++) {
++inst_it;
}
// This is either the desired instruction, or the end of the block.
return {true, inst_it};
}
}
// Case (3).
return {false, block->end()};
}
} // namespace
bool IsApplicable(const protobufs::TransformationSplitBlock& message,
IRContext* context, const FactManager& /*unused*/) {
if (!fuzzerutil::IsFreshId(context, message.fresh_id())) {
// We require the id for the new block to be unused.
return false;
}
// Consider every block in every function.
for (auto& function : *context->module()) {
for (auto& block : function) {
auto maybe_split_before = FindInstToSplitBefore(message, &block);
if (!maybe_split_before.first) {
continue;
}
if (maybe_split_before.second == block.end()) {
// The base instruction was found, but the offset was inappropriate.
return false;
}
if (block.IsLoopHeader()) {
// We cannot split a loop header block: back-edges would become invalid.
return false;
}
auto split_before = maybe_split_before.second;
if (split_before->PreviousNode() &&
split_before->PreviousNode()->opcode() == SpvOpSelectionMerge) {
// We cannot split directly after a selection merge: this would separate
// the merge from its associated branch or switch operation.
return false;
}
if (split_before->opcode() == SpvOpVariable) {
// We cannot split directly after a variable; variables in a function
// must be contiguous in the entry block.
return false;
}
if (split_before->opcode() == SpvOpPhi &&
split_before->NumInOperands() != 2) {
// We cannot split before an OpPhi unless the OpPhi has exactly one
// associated incoming edge.
return false;
}
return true;
}
}
return false;
}
void Apply(const protobufs::TransformationSplitBlock& message,
IRContext* context, FactManager* /*unused*/) {
for (auto& function : *context->module()) {
for (auto& block : function) {
auto maybe_split_before = FindInstToSplitBefore(message, &block);
if (!maybe_split_before.first) {
continue;
}
assert(maybe_split_before.second != block.end() &&
"If the transformation is applicable, we should have an "
"instruction to split on.");
// We need to make sure the module's id bound is large enough to add the
// fresh id.
fuzzerutil::UpdateModuleIdBound(context, message.fresh_id());
// Split the block.
auto new_bb = block.SplitBasicBlock(context, message.fresh_id(),
maybe_split_before.second);
// The split does not automatically add a branch between the two parts of
// the original block, so we add one.
block.AddInstruction(MakeUnique<Instruction>(
context, SpvOpBranch, 0, 0,
std::initializer_list<Operand>{Operand(
spv_operand_type_t::SPV_OPERAND_TYPE_ID, {message.fresh_id()})}));
// If we split before OpPhi instructions, we need to update their
// predecessor operand so that the block they used to be inside is now the
// predecessor.
new_bb->ForEachPhiInst([&block](Instruction* phi_inst) {
// The following assertion is a sanity check. It is guaranteed to hold
// if IsApplicable holds.
assert(phi_inst->NumInOperands() == 2 &&
"We can only split a block before an OpPhi if block has exactly "
"one predecessor.");
phi_inst->SetInOperand(1, {block.id()});
});
// Invalidate all analyses
context->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
return;
}
}
assert(0 &&
"Should be unreachable: it should have been possible to apply this "
"transformation.");
}
protobufs::TransformationSplitBlock MakeTransformationSplitBlock(
uint32_t result_id, uint32_t offset, uint32_t fresh_id) {
protobufs::TransformationSplitBlock result;
result.set_result_id(result_id);
result.set_offset(offset);
result.set_fresh_id(fresh_id);
return result;
}
} // namespace transformation
} // namespace fuzz
} // namespace spvtools

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_FUZZ_TRANSFORMATION_SPLIT_BLOCK_H_
#define SOURCE_FUZZ_TRANSFORMATION_SPLIT_BLOCK_H_
#include "source/fuzz/fact_manager.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace fuzz {
namespace transformation {
// - |result_id| must be the result id of an instruction 'base' in some
// block 'blk'.
// - 'blk' must contain an instruction 'inst' located |offset| instructions
// after 'inst' (if |offset| = 0 then 'inst' = 'base').
// - Splitting 'blk' at 'inst', so that all instructions from 'inst' onwards
// appear in a new block that 'blk' directly jumps to must be valid.
// - |fresh_id| must not be used by the module.
bool IsApplicable(const protobufs::TransformationSplitBlock& message,
opt::IRContext* context, const FactManager& fact_manager);
// - A new block with label |fresh_id| is inserted right after 'blk' in
// program order.
// - All instructions of 'blk' from 'inst' onwards are moved into the new
// block.
// - 'blk' is made to jump unconditionally to the new block.
void Apply(const protobufs::TransformationSplitBlock& message,
opt::IRContext* context, FactManager* fact_manager);
// Creates a protobuf message representing a block-splitting transformation.
protobufs::TransformationSplitBlock MakeTransformationSplitBlock(
uint32_t result_id, uint32_t offset, uint32_t fresh_id);
} // namespace transformation
} // namespace fuzz
} // namespace spvtools
#endif // SOURCE_FUZZ_TRANSFORMATION_SPLIT_BLOCK_H_

View File

@@ -1,78 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/id_descriptor.h"
#include <cassert>
#include <iostream>
#include "source/opcode.h"
#include "source/operand.h"
namespace spvtools {
namespace {
// Hashes an array of words. Order of words is important.
uint32_t HashU32Array(const std::vector<uint32_t>& words) {
// The hash function is a sum of hashes of each word seeded by word index.
// Knuth's multiplicative hash is used to hash the words.
const uint32_t kKnuthMulHash = 2654435761;
uint32_t val = 0;
for (uint32_t i = 0; i < words.size(); ++i) {
val += (words[i] + i + 123) * kKnuthMulHash;
}
return val;
}
} // namespace
uint32_t IdDescriptorCollection::ProcessInstruction(
const spv_parsed_instruction_t& inst) {
if (!inst.result_id) return 0;
assert(words_.empty());
words_.push_back(inst.words[0]);
for (size_t operand_index = 0; operand_index < inst.num_operands;
++operand_index) {
const auto& operand = inst.operands[operand_index];
if (spvIsIdType(operand.type)) {
const uint32_t id = inst.words[operand.offset];
const auto it = id_to_descriptor_.find(id);
// Forward declared ids are not hashed.
if (it != id_to_descriptor_.end()) {
words_.push_back(it->second);
}
} else {
for (size_t operand_word_index = 0;
operand_word_index < operand.num_words; ++operand_word_index) {
words_.push_back(inst.words[operand.offset + operand_word_index]);
}
}
}
uint32_t descriptor =
custom_hash_func_ ? custom_hash_func_(words_) : HashU32Array(words_);
if (descriptor == 0) descriptor = 1;
assert(descriptor);
words_.clear();
const auto result = id_to_descriptor_.emplace(inst.result_id, descriptor);
assert(result.second);
(void)result;
return descriptor;
}
} // namespace spvtools

View File

@@ -1,63 +0,0 @@
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_ID_DESCRIPTOR_H_
#define SOURCE_ID_DESCRIPTOR_H_
#include <unordered_map>
#include <vector>
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
using CustomHashFunc = std::function<uint32_t(const std::vector<uint32_t>&)>;
// Computes and stores id descriptors.
//
// Descriptors are computed as hash of all words in the instruction where ids
// were substituted with previously computed descriptors.
class IdDescriptorCollection {
public:
explicit IdDescriptorCollection(
CustomHashFunc custom_hash_func = CustomHashFunc())
: custom_hash_func_(custom_hash_func) {
words_.reserve(16);
}
// Computes descriptor for the result id of the given instruction and
// registers it in id_to_descriptor_. Returns the computed descriptor.
// This function needs to be sequentially called for every instruction in the
// module.
uint32_t ProcessInstruction(const spv_parsed_instruction_t& inst);
// Returns a previously computed descriptor id.
uint32_t GetDescriptor(uint32_t id) const {
const auto it = id_to_descriptor_.find(id);
if (it == id_to_descriptor_.end()) return 0;
return it->second;
}
private:
std::unordered_map<uint32_t, uint32_t> id_to_descriptor_;
std::function<uint32_t(const std::vector<uint32_t>&)> custom_hash_func_;
// Scratch buffer used for hashing. Class member to optimize on allocation.
std::vector<uint32_t> words_;
};
} // namespace spvtools
#endif // SOURCE_ID_DESCRIPTOR_H_

View File

@@ -33,6 +33,7 @@
#include "source/opt/ir_loader.h"
#include "source/opt/pass_manager.h"
#include "source/opt/remove_duplicates_pass.h"
#include "source/opt/type_manager.h"
#include "source/spirv_target_env.h"
#include "source/util/make_unique.h"
#include "spirv-tools/libspirv.hpp"
@@ -40,14 +41,15 @@
namespace spvtools {
namespace {
using opt::IRContext;
using opt::Instruction;
using opt::IRContext;
using opt::Module;
using opt::Operand;
using opt::PassManager;
using opt::RemoveDuplicatesPass;
using opt::analysis::DecorationManager;
using opt::analysis::DefUseManager;
using opt::analysis::Type;
using opt::analysis::TypeManager;
// Stores various information about an imported or exported symbol.
struct LinkageSymbolInfo {
@@ -472,14 +474,15 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
opt::IRContext* context) {
spv_position_t position = {};
// Ensure th import and export types are the same.
const DefUseManager& def_use_manager = *context->get_def_use_mgr();
// Ensure the import and export types are the same.
const DecorationManager& decoration_manager = *context->get_decoration_mgr();
const TypeManager& type_manager = *context->get_type_mgr();
for (const auto& linking_entry : linkings_to_do) {
if (!RemoveDuplicatesPass::AreTypesEqual(
*def_use_manager.GetDef(linking_entry.imported_symbol.type_id),
*def_use_manager.GetDef(linking_entry.exported_symbol.type_id),
context))
Type* imported_symbol_type =
type_manager.GetType(linking_entry.imported_symbol.type_id);
Type* exported_symbol_type =
type_manager.GetType(linking_entry.exported_symbol.type_id);
if (!(*imported_symbol_type == *exported_symbol_type))
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
<< "Type mismatch on symbol \""
<< linking_entry.imported_symbol.name

View File

@@ -408,6 +408,28 @@ UnaryScalarFoldingRule FoldIToFOp() {
};
}
// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
return [](const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
const analysis::Float* float_type = a->type()->AsFloat();
assert(float_type != nullptr);
if (float_type->width() != 32) {
return nullptr;
}
float fa = a->GetFloat();
utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
orignal.castTo(quantized, utils::round_direction::kToZero);
quantized.castTo(result, utils::round_direction::kToZero);
std::vector<uint32_t> words = {result.getBits()};
return const_mgr->GetConstant(result_type, words);
};
}
// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
// operator |op| must work for both float and double, and use syntax "f1 op f2".
#define FOLD_FPARITH_OP(op) \
@@ -438,6 +460,9 @@ UnaryScalarFoldingRule FoldIToFOp() {
// Define the folding rule for conversion between floating point and integer
ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
ConstantFoldingRule FoldQuantizeToF16() {
return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
}
// Define the folding rules for subtraction, addition, multiplication, and
// division for floating point values.
@@ -848,6 +873,7 @@ ConstantFoldingRules::ConstantFoldingRules() {
rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
rules_[SpvOpFNegate].push_back(FoldFNegate());
rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
}
} // namespace opt
} // namespace spvtools

View File

@@ -185,8 +185,6 @@ Instruction* ConstantManager::BuildInstructionAndAddToModule(
Instruction* ConstantManager::GetDefiningInstruction(
const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
assert(type_id == 0 ||
context()->get_type_mgr()->GetType(type_id) == c->type());
uint32_t decl_id = FindDeclaredConstant(c, type_id);
if (decl_id == 0) {
auto iter = context()->types_values_end();

View File

@@ -524,12 +524,7 @@ class ConstantManager {
// instruction at the end of the current module's types section.
//
// |type_id| is an optional argument for disambiguating equivalent types. If
// |type_id| is specified, it is used as the type of the constant when a new
// instruction is created. Otherwise the type of the constant is derived by
// getting an id from the type manager for |c|.
//
// When |type_id| is not zero, the type of |c| must be the type returned by
// type manager when given |type_id|.
// |type_id| is specified, the contant returned will have that type id.
Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0,
Module::inst_iterator* pos = nullptr);

View File

@@ -169,24 +169,43 @@ bool DeadBranchElimPass::MarkLiveBlocks(
if (simplify) {
modified = true;
// Replace with unconditional branch.
// Remove the merge instruction if it is a selection merge.
AddBranch(live_lab_id, block);
context()->KillInst(terminator);
// Replace branch with a simpler branch.
// Fix up the merge instruction if it is a selection merge.
Instruction* mergeInst = block->GetMergeInst();
if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) {
Instruction* first_break = FindFirstExitFromSelectionMerge(
live_lab_id, mergeInst->GetSingleWordInOperand(0),
cfgAnalysis->LoopMergeBlock(live_lab_id),
cfgAnalysis->LoopContinueBlock(live_lab_id));
if (first_break == nullptr) {
context()->KillInst(mergeInst);
if (mergeInst->NextNode()->opcode() == SpvOpSwitch &&
SwitchHasNestedBreak(block->id())) {
// We have to keep the switch because it has a nest break, so we
// remove all cases except for the live one.
Instruction::OperandList new_operands;
new_operands.push_back(terminator->GetInOperand(0));
new_operands.push_back({SPV_OPERAND_TYPE_ID, {live_lab_id}});
terminator->SetInOperands(std::move(new_operands));
context()->UpdateDefUse(terminator);
} else {
mergeInst->RemoveFromList();
first_break->InsertBefore(std::unique_ptr<Instruction>(mergeInst));
context()->set_instr_block(mergeInst,
context()->get_instr_block(first_break));
// Check if the merge instruction is still needed because of a
// non-nested break from the construct. Move the merge instruction if
// it is still needed.
Instruction* first_break = FindFirstExitFromSelectionMerge(
live_lab_id, mergeInst->GetSingleWordInOperand(0),
cfgAnalysis->LoopMergeBlock(live_lab_id),
cfgAnalysis->LoopContinueBlock(live_lab_id),
cfgAnalysis->SwitchMergeBlock(live_lab_id));
AddBranch(live_lab_id, block);
context()->KillInst(terminator);
if (first_break == nullptr) {
context()->KillInst(mergeInst);
} else {
mergeInst->RemoveFromList();
first_break->InsertBefore(std::unique_ptr<Instruction>(mergeInst));
context()->set_instr_block(mergeInst,
context()->get_instr_block(first_break));
}
}
} else {
AddBranch(live_lab_id, block);
context()->KillInst(terminator);
}
stack.push_back(GetParentBlock(live_lab_id));
} else {
@@ -455,11 +474,12 @@ Pass::Status DeadBranchElimPass::Process() {
Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
uint32_t start_block_id, uint32_t merge_block_id, uint32_t loop_merge_id,
uint32_t loop_continue_id) {
uint32_t loop_continue_id, uint32_t switch_merge_id) {
// To find the "first" exit, we follow branches looking for a conditional
// branch that is not in a nested construct and is not the header of a new
// construct. We follow the control flow from |start_block_id| to find the
// first one.
while (start_block_id != merge_block_id && start_block_id != loop_merge_id &&
start_block_id != loop_continue_id) {
BasicBlock* start_block = context()->get_instr_block(start_block_id);
@@ -483,6 +503,11 @@ Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
next_block_id = branch->GetSingleWordInOperand(3 - i);
break;
}
if (branch->GetSingleWordInOperand(i) == switch_merge_id &&
switch_merge_id != merge_block_id) {
next_block_id = branch->GetSingleWordInOperand(3 - i);
break;
}
}
if (next_block_id == 0) {
@@ -493,11 +518,15 @@ Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
case SpvOpSwitch:
next_block_id = start_block->MergeBlockIdIfAny();
if (next_block_id == 0) {
// A switch with no merge instructions can have at most 4 targets:
// A switch with no merge instructions can have at most 5 targets:
// a. |merge_block_id|
// b. |loop_merge_id|
// c. |loop_continue_id|
// d. 1 block inside the current region.
// d. |switch_merge_id|
// e. 1 block inside the current region.
//
// Note that because this is a switch, |merge_block_id| must equal
// |switch_merge_id|.
//
// This leads to a number of cases of what to do.
//
@@ -511,7 +540,6 @@ Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
//
// 3. Otherwise, this branch may break, but not to the current merge
// block. So we continue with the block that is inside the loop.
bool found_break = false;
for (uint32_t i = 1; i < branch->NumInOperands(); i += 2) {
uint32_t target = branch->GetSingleWordInOperand(i);
@@ -585,5 +613,26 @@ void DeadBranchElimPass::AddBlocksWithBackEdge(
}
}
bool DeadBranchElimPass::SwitchHasNestedBreak(uint32_t switch_header_id) {
std::vector<BasicBlock*> block_in_construct;
BasicBlock* start_block = context()->get_instr_block(switch_header_id);
uint32_t merge_block_id = start_block->MergeBlockIdIfAny();
StructuredCFGAnalysis* cfg_analysis = context()->GetStructuredCFGAnalysis();
return !get_def_use_mgr()->WhileEachUser(
merge_block_id,
[this, cfg_analysis, switch_header_id](Instruction* inst) {
if (!inst->IsBranch()) {
return true;
}
BasicBlock* bb = context()->get_instr_block(inst);
if (bb->id() == switch_header_id) {
return true;
}
return (cfg_analysis->ContainingConstruct(inst) == switch_header_id);
});
}
} // namespace opt
} // namespace spvtools

View File

@@ -147,7 +147,8 @@ class DeadBranchElimPass : public MemPass {
Instruction* FindFirstExitFromSelectionMerge(uint32_t start_block_id,
uint32_t merge_block_id,
uint32_t loop_merge_id,
uint32_t loop_continue_id);
uint32_t loop_continue_id,
uint32_t switch_merge_id);
// Adds to |blocks_with_back_edges| all of the blocks on the path from the
// basic block |cont_id| to |header_id| and |merge_id|. The intention is that
@@ -156,6 +157,10 @@ class DeadBranchElimPass : public MemPass {
void AddBlocksWithBackEdge(
uint32_t cont_id, uint32_t header_id, uint32_t merge_id,
std::unordered_set<BasicBlock*>* blocks_with_back_edges);
// Returns true if there is a brach to the merge node of the selection
// construct |switch_header_id| that is inside a nested selection construct.
bool SwitchHasNestedBreak(uint32_t switch_header_id);
};
} // namespace opt

View File

@@ -120,19 +120,15 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
case SpvOp::SpvOpCompositeExtract:
folded_inst = DoCompositeExtract(pos);
break;
case SpvOp::SpvOpVectorShuffle:
folded_inst = DoVectorShuffle(pos);
break;
case SpvOp::SpvOpCompositeInsert:
// Current Glslang does not generate code with OpSpecConstantOp
// CompositeInsert instruction, so this is not implmented so far.
// TODO(qining): Implement CompositeInsert case.
return false;
case SpvOp::SpvOpQuantizeToF16:
folded_inst = FoldWithInstructionFolder(pos);
break;
default:
// TODO: This should use the instruction folder as well, but some folding
// rules are missing.
// Component-wise operations.
folded_inst = DoComponentWiseOperation(pos);
break;
@@ -157,54 +153,65 @@ uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent(
return subtype;
}
Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
Module::inst_iterator* pos) {
Instruction* inst = &**pos;
assert(inst->NumInOperands() - 1 >= 2 &&
"OpSpecConstantOp CompositeExtract requires at least two non-type "
"non-opcode operands.");
assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
"The composite operand must have a SPV_OPERAND_TYPE_ID type");
assert(
inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
"The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
// Note that for OpSpecConstantOp, the second in-operand is the first id
// operand. The first in-operand is the spec opcode.
uint32_t source = inst->GetSingleWordInOperand(1);
uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id();
const analysis::Constant* first_operand_const =
context()->get_constant_mgr()->FindDeclaredConstant(source);
if (!first_operand_const) return nullptr;
const analysis::Constant* current_const = first_operand_const;
for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
uint32_t literal = inst->GetSingleWordInOperand(i);
type = GetTypeComponent(type, literal);
}
for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
uint32_t literal = inst->GetSingleWordInOperand(i);
if (const analysis::CompositeConstant* composite_const =
current_const->AsCompositeConstant()) {
// Case 1: current constant is a non-null composite type constant.
assert(literal < composite_const->GetComponents().size() &&
"Literal index out of bound of the composite constant");
current_const = composite_const->GetComponents().at(literal);
} else if (current_const->AsNullConstant()) {
// Case 2: current constant is a constant created with OpConstantNull.
// Because components of a NullConstant are always NullConstants, we can
// return early with a NullConstant in the result type.
return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
context()->get_constant_mgr()->GetConstant(
context()->get_constant_mgr()->GetType(inst), {}),
pos, type);
} else {
// Dereferencing a non-composite constant. Invalid case.
Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
Module::inst_iterator* inst_iter_ptr) {
// If one of operands to the instruction is not a
// constant, then we cannot fold this spec constant.
for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
const Operand& operand = (*inst_iter_ptr)->GetInOperand(i);
if (operand.type != SPV_OPERAND_TYPE_ID &&
operand.type != SPV_OPERAND_TYPE_OPTIONAL_ID) {
continue;
}
uint32_t id = operand.words[0];
if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) {
return nullptr;
}
}
return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
current_const, pos);
// All of the operands are constant. Construct a regular version of the
// instruction and pass it to the instruction folder.
std::unique_ptr<Instruction> inst((*inst_iter_ptr)->Clone(context()));
inst->SetOpcode(
static_cast<SpvOp>((*inst_iter_ptr)->GetSingleWordInOperand(0)));
inst->RemoveOperand(2);
// We want the current instruction to be replaced by an |OpConstant*|
// instruction in the same position. We need to keep track of which constants
// the instruction folder creates, so we can move them into the correct place.
auto last_type_value_iter = (context()->types_values_end());
--last_type_value_iter;
Instruction* last_type_value = &*last_type_value_iter;
auto identity_map = [](uint32_t id) { return id; };
Instruction* new_const_inst =
context()->get_instruction_folder().FoldInstructionToConstant(
inst.get(), identity_map);
assert(new_const_inst != nullptr &&
"Failed to fold instruction that must be folded.");
// Get the instruction before |pos| to insert after. |pos| cannot be the
// first instruction in the list because its type has to come first.
Instruction* insert_pos = (*inst_iter_ptr)->PreviousNode();
assert(insert_pos != nullptr &&
"pos is the first instruction in the types and values.");
bool need_to_clone = true;
for (Instruction* i = last_type_value->NextNode(); i != nullptr;
i = last_type_value->NextNode()) {
if (i == new_const_inst) {
need_to_clone = false;
}
i->InsertAfter(insert_pos);
insert_pos = insert_pos->NextNode();
}
if (need_to_clone) {
new_const_inst = new_const_inst->Clone(context());
new_const_inst->SetResultId(TakeNextId());
new_const_inst->InsertAfter(insert_pos);
get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
}
return new_const_inst;
}
Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(

View File

@@ -54,11 +54,9 @@ class FoldSpecConstantOpAndCompositePass : public Pass {
// it.
bool ProcessOpSpecConstantOp(Module::inst_iterator* pos);
// Try to fold the OpSpecConstantOp CompositeExtract instruction pointed by
// the given instruction iterator to a normal constant defining instruction.
// Returns the pointer to the new constant defining instruction if succeeded.
// Otherwise returns nullptr.
Instruction* DoCompositeExtract(Module::inst_iterator* inst_iter_ptr);
// Returns the result of folding the OpSpecConstantOp instruction
// |inst_iter_ptr| using the instruction folder.
Instruction* FoldWithInstructionFolder(Module::inst_iterator* inst_iter_ptr);
// Try to fold the OpSpecConstantOp VectorShuffle instruction pointed by the
// given instruction iterator to a normal constant defining instruction.

View File

@@ -30,13 +30,14 @@ class InstBindlessCheckPass : public InstrumentPass {
public:
// For test harness only
InstBindlessCheckPass()
: InstrumentPass(7, 23, kInstValidationIdBindless),
: InstrumentPass(7, 23, kInstValidationIdBindless, 1),
input_length_enabled_(true),
input_init_enabled_(true) {}
// For all other interfaces
InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id,
bool input_length_enable, bool input_init_enable)
: InstrumentPass(desc_set, shader_id, kInstValidationIdBindless),
bool input_length_enable, bool input_init_enable,
uint32_t version)
: InstrumentPass(desc_set, shader_id, kInstValidationIdBindless, version),
input_length_enabled_(input_length_enable),
input_init_enabled_(input_init_enable) {}

View File

@@ -143,23 +143,21 @@ void InstrumentPass::GenFragCoordEltDebugOutputCode(
element_val_inst->result_id(), builder);
}
uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
InstructionBuilder* builder) {
Instruction* var_inst = get_def_use_mgr()->GetDef(var_id);
uint32_t type_id = GetPointeeTypeId(var_inst);
Instruction* load_inst = builder->AddUnaryOp(type_id, SpvOpLoad, var_id);
return load_inst->result_id();
}
void InstrumentPass::GenBuiltinOutputCode(uint32_t builtin_id,
uint32_t builtin_off,
uint32_t base_offset_id,
InstructionBuilder* builder) {
// Load and store builtin
Instruction* var_inst = get_def_use_mgr()->GetDef(builtin_id);
uint32_t type_id = GetPointeeTypeId(var_inst);
Instruction* load_inst = builder->AddUnaryOp(type_id, SpvOpLoad, builtin_id);
uint32_t val_id = GenUintCastCode(load_inst->result_id(), builder);
GenDebugOutputFieldCode(base_offset_id, builtin_off, val_id, builder);
}
void InstrumentPass::GenUintNullOutputCode(uint32_t field_off,
uint32_t base_offset_id,
InstructionBuilder* builder) {
GenDebugOutputFieldCode(base_offset_id, field_off,
builder->GetNullId(GetUintId()), builder);
uint32_t load_id = GenVarLoad(builtin_id, builder);
GenDebugOutputFieldCode(base_offset_id, builtin_off, load_id, builder);
}
void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
@@ -169,37 +167,97 @@ void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
switch (stage_idx) {
case SpvExecutionModelVertex: {
// Load and store VertexId and InstanceId
GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInVertexIndex),
kInstVertOutVertexIndex, base_offset_id, builder);
GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInstanceIndex),
kInstVertOutInstanceIndex, base_offset_id, builder);
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInVertexIndex),
kInstVertOutVertexIndex, base_offset_id, builder);
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInInstanceIndex),
kInstVertOutInstanceIndex, base_offset_id, builder);
} break;
case SpvExecutionModelGLCompute: {
// Load and store GlobalInvocationId. Second word is unused; store zero.
GenBuiltinOutputCode(
context()->GetBuiltinVarId(SpvBuiltInGlobalInvocationId),
kInstCompOutGlobalInvocationId, base_offset_id, builder);
GenUintNullOutputCode(kInstCompOutUnused, base_offset_id, builder);
// Load and store GlobalInvocationId.
uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(SpvBuiltInGlobalInvocationId),
builder);
Instruction* x_inst = builder->AddIdLiteralOp(
GetUintId(), SpvOpCompositeExtract, load_id, 0);
Instruction* y_inst = builder->AddIdLiteralOp(
GetUintId(), SpvOpCompositeExtract, load_id, 1);
Instruction* z_inst = builder->AddIdLiteralOp(
GetUintId(), SpvOpCompositeExtract, load_id, 2);
if (version_ == 1) {
// For version 1 format, as a stopgap, pack uvec3 into first word:
// x << 21 | y << 10 | z. Second word is unused. (DEPRECATED)
Instruction* x_shft_inst = builder->AddBinaryOp(
GetUintId(), SpvOpShiftLeftLogical, x_inst->result_id(),
builder->GetUintConstantId(21));
Instruction* y_shft_inst = builder->AddBinaryOp(
GetUintId(), SpvOpShiftLeftLogical, y_inst->result_id(),
builder->GetUintConstantId(10));
Instruction* x_or_y_inst = builder->AddBinaryOp(
GetUintId(), SpvOpBitwiseOr, x_shft_inst->result_id(),
y_shft_inst->result_id());
Instruction* x_or_y_or_z_inst =
builder->AddBinaryOp(GetUintId(), SpvOpBitwiseOr,
x_or_y_inst->result_id(), z_inst->result_id());
GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationId,
x_or_y_or_z_inst->result_id(), builder);
} else {
// For version 2 format, write all three words
GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdX,
x_inst->result_id(), builder);
GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdY,
y_inst->result_id(), builder);
GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdZ,
z_inst->result_id(), builder);
}
} break;
case SpvExecutionModelGeometry: {
// Load and store PrimitiveId and InvocationId.
GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInPrimitiveId),
kInstGeomOutPrimitiveId, base_offset_id, builder);
GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId),
kInstGeomOutInvocationId, base_offset_id, builder);
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
kInstGeomOutPrimitiveId, base_offset_id, builder);
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
kInstGeomOutInvocationId, base_offset_id, builder);
} break;
case SpvExecutionModelTessellationControl: {
// Load and store InvocationId and PrimitiveId
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
kInstTessCtlOutInvocationId, base_offset_id, builder);
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
kInstTessCtlOutPrimitiveId, base_offset_id, builder);
} break;
case SpvExecutionModelTessellationControl:
case SpvExecutionModelTessellationEvaluation: {
// Load and store InvocationId. Second word is unused; store zero.
GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId),
kInstTessOutInvocationId, base_offset_id, builder);
GenUintNullOutputCode(kInstTessOutUnused, base_offset_id, builder);
if (version_ == 1) {
// For format version 1, load and store InvocationId.
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
kInstTessOutInvocationId, base_offset_id, builder);
} else {
// For format version 2, load and store PrimitiveId and TessCoord.uv
GenBuiltinOutputCode(
context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
kInstTessEvalOutPrimitiveId, base_offset_id, builder);
uint32_t load_id = GenVarLoad(
context()->GetBuiltinInputVarId(SpvBuiltInTessCoord), builder);
Instruction* u_inst = builder->AddIdLiteralOp(
GetUintId(), SpvOpCompositeExtract, load_id, 0);
Instruction* v_inst = builder->AddIdLiteralOp(
GetUintId(), SpvOpCompositeExtract, load_id, 1);
GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordU,
u_inst->result_id(), builder);
GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordV,
v_inst->result_id(), builder);
}
} break;
case SpvExecutionModelFragment: {
// Load FragCoord and convert to Uint
Instruction* frag_coord_inst =
builder->AddUnaryOp(GetVec4FloatId(), SpvOpLoad,
context()->GetBuiltinVarId(SpvBuiltInFragCoord));
Instruction* frag_coord_inst = builder->AddUnaryOp(
GetVec4FloatId(), SpvOpLoad,
context()->GetBuiltinInputVarId(SpvBuiltInFragCoord));
Instruction* uint_frag_coord_inst = builder->AddUnaryOp(
GetVec4UintId(), SpvOpBitcast, frag_coord_inst->result_id());
for (uint32_t u = 0; u < 2u; ++u)
@@ -547,7 +605,9 @@ uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
context(), &*new_blk_ptr,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
// Gen test if debug output buffer size will not be exceeded.
uint32_t obuf_record_sz = kInstStageOutCnt + val_spec_param_cnt;
uint32_t val_spec_offset =
(version_ == 1) ? kInstStageOutCnt : kInst2StageOutCnt;
uint32_t obuf_record_sz = val_spec_offset + val_spec_param_cnt;
uint32_t buf_id = GetOutputBufferId();
uint32_t buf_uint_ptr_id = GetBufferUintPtrId();
Instruction* obuf_curr_sz_ac_inst =
@@ -593,7 +653,7 @@ uint32_t InstrumentPass::GetStreamWriteFunctionId(uint32_t stage_idx,
GenStageStreamWriteCode(stage_idx, obuf_curr_sz_id, &builder);
// Gen writes of validation specific data
for (uint32_t i = 0; i < val_spec_param_cnt; ++i) {
GenDebugOutputFieldCode(obuf_curr_sz_id, kInstStageOutCnt + i,
GenDebugOutputFieldCode(obuf_curr_sz_id, val_spec_offset + i,
param_vec[kInstCommonParamCnt + i], &builder);
}
// Close write block and gen merge block

View File

@@ -78,16 +78,18 @@ class InstrumentPass : public Pass {
}
protected:
// Create instrumentation pass which utilizes descriptor set |desc_set|
// for debug input and output buffers and writes |shader_id| into debug
// output records.
InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id)
// Create instrumentation pass for |validation_id| which utilizes descriptor
// set |desc_set| for debug input and output buffers and writes |shader_id|
// into debug output records with format |version|.
InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id,
uint32_t version)
: Pass(),
desc_set_(desc_set),
shader_id_(shader_id),
validation_id_(validation_id) {}
validation_id_(validation_id),
version_(version) {}
// Initialize state for instrumentation of module by |validation_id|.
// Initialize state for instrumentation of module.
void InitializeInstrument();
// Call |pfn| on all instructions in all functions in the call tree of the
@@ -146,6 +148,7 @@ class InstrumentPass : public Pass {
// Stage
// Stage-specific Word 0
// Stage-specific Word 1
// ...
// Validation Error Code
// Validation-specific Word 0
// Validation-specific Word 1
@@ -170,12 +173,12 @@ class InstrumentPass : public Pass {
// following Stage-specific words.
//
// The Stage-specific Words identify which invocation of the shader generated
// the error. Every stage will write two words, although in some cases the
// second word is unused and so zero is written. Vertex shaders will write
// the Vertex and Instance ID. Fragment shaders will write FragCoord.xy.
// Compute shaders will write the Global Invocation ID and zero (unused).
// Both tesselation shaders will write the Invocation Id and zero (unused).
// The geometry shader will write the Primitive ID and Invocation ID.
// the error. Every stage will write a fixed number of words. Vertex shaders
// will write the Vertex and Instance ID. Fragment shaders will write
// FragCoord.xy. Compute shaders will write the GlobalInvocation ID.
// The tesselation eval shader will write the Primitive ID and TessCoords.uv.
// The tesselation control shader and geometry shader will write the
// Primitive ID and Invocation ID.
//
// The Validation Error Code specifies the exact error which has occurred.
// These are enumerated with the kInstError* static consts. This allows
@@ -291,16 +294,15 @@ class InstrumentPass : public Pass {
uint32_t component,
InstructionBuilder* builder);
// Generate instructions into |builder| which will load |var_id| and return
// its result id.
uint32_t GenVarLoad(uint32_t var_id, InstructionBuilder* builder);
// Generate instructions into |builder| which will load the uint |builtin_id|
// and write it into the debug output buffer at |base_off| + |builtin_off|.
void GenBuiltinOutputCode(uint32_t builtin_id, uint32_t builtin_off,
uint32_t base_off, InstructionBuilder* builder);
// Generate instructions into |builder| which will write a uint null into
// the debug output buffer at |base_off| + |builtin_off|.
void GenUintNullOutputCode(uint32_t field_off, uint32_t base_off,
InstructionBuilder* builder);
// Generate instructions into |builder| which will write the |stage_idx|-
// specific members of the debug output stream at |base_off|.
void GenStageStreamWriteCode(uint32_t stage_idx, uint32_t base_off,
@@ -376,6 +378,9 @@ class InstrumentPass : public Pass {
// id for void type
uint32_t void_id_;
// Record format version
uint32_t version_;
// boolean to remember storage buffer extension
bool storage_buffer_ext_defined_;

View File

@@ -621,7 +621,7 @@ LoopDescriptor* IRContext::GetLoopDescriptor(const Function* f) {
return &it->second;
}
uint32_t IRContext::FindBuiltinVar(uint32_t builtin) {
uint32_t IRContext::FindBuiltinInputVar(uint32_t builtin) {
for (auto& a : module_->annotations()) {
if (a.opcode() != SpvOpDecorate) continue;
if (a.GetSingleWordInOperand(kSpvDecorateDecorationInIdx) !=
@@ -631,6 +631,7 @@ uint32_t IRContext::FindBuiltinVar(uint32_t builtin) {
uint32_t target_id = a.GetSingleWordInOperand(kSpvDecorateTargetIdInIdx);
Instruction* b_var = get_def_use_mgr()->GetDef(target_id);
if (b_var->opcode() != SpvOpVariable) continue;
if (b_var->GetSingleWordInOperand(0) != SpvStorageClassInput) continue;
return target_id;
}
return 0;
@@ -653,14 +654,14 @@ void IRContext::AddVarToEntryPoints(uint32_t var_id) {
}
}
uint32_t IRContext::GetBuiltinVarId(uint32_t builtin) {
uint32_t IRContext::GetBuiltinInputVarId(uint32_t builtin) {
if (!AreAnalysesValid(kAnalysisBuiltinVarId)) ResetBuiltinAnalysis();
// If cached, return it.
std::unordered_map<uint32_t, uint32_t>::iterator it =
builtin_var_id_map_.find(builtin);
if (it != builtin_var_id_map_.end()) return it->second;
// Look for one in shader
uint32_t var_id = FindBuiltinVar(builtin);
uint32_t var_id = FindBuiltinInputVar(builtin);
if (var_id == 0) {
// If not found, create it
// TODO(greg-lunarg): Add support for all builtins

View File

@@ -491,10 +491,10 @@ class IRContext {
uint32_t max_id_bound() const { return max_id_bound_; }
void set_max_id_bound(uint32_t new_bound) { max_id_bound_ = new_bound; }
// Return id of variable only decorated with |builtin|, if in module.
// Return id of input variable only decorated with |builtin|, if in module.
// Create variable and return its id otherwise. If builtin not currently
// supported, return 0.
uint32_t GetBuiltinVarId(uint32_t builtin);
uint32_t GetBuiltinInputVarId(uint32_t builtin);
// Returns the function whose id is |id|, if one exists. Returns |nullptr|
// otherwise.
@@ -657,9 +657,9 @@ class IRContext {
// true if the cfg is invalidated.
bool CheckCFG();
// Return id of variable only decorated with |builtin|, if in module.
// Return id of input variable only decorated with |builtin|, if in module.
// Return 0 otherwise.
uint32_t FindBuiltinVar(uint32_t builtin);
uint32_t FindBuiltinInputVar(uint32_t builtin);
// Add |var_id| to all entry points in module.
void AddVarToEntryPoints(uint32_t var_id);

View File

@@ -229,11 +229,13 @@ Optimizer& Optimizer::RegisterVulkanToWebGPUPasses() {
.RegisterPass(CreateEliminateDeadConstantPass())
.RegisterPass(CreateFlattenDecorationPass())
.RegisterPass(CreateAggressiveDCEPass())
.RegisterPass(CreateDeadBranchElimPass());
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateCompactIdsPass());
}
Optimizer& Optimizer::RegisterWebGPUToVulkanPasses() {
return RegisterPass(CreateDecomposeInitializedVariablesPass());
return RegisterPass(CreateDecomposeInitializedVariablesPass())
.RegisterPass(CreateCompactIdsPass());
}
bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags) {
@@ -397,7 +399,7 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
} else if (pass_name == "replace-invalid-opcode") {
RegisterPass(CreateReplaceInvalidOpcodePass());
} else if (pass_name == "inst-bindless-check") {
RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true));
RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true, 1));
RegisterPass(CreateSimplificationPass());
RegisterPass(CreateDeadBranchElimPass());
RegisterPass(CreateBlockMergePass());
@@ -472,6 +474,10 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateGenerateWebGPUInitializersPass());
} else if (pass_name == "legalize-vector-shuffle") {
RegisterPass(CreateLegalizeVectorShufflePass());
} else if (pass_name == "split-invalid-unreachable") {
RegisterPass(CreateLegalizeVectorShufflePass());
} else if (pass_name == "decompose-initialized-variables") {
RegisterPass(CreateDecomposeInitializedVariablesPass());
} else {
Errorf(consumer(), nullptr, {},
"Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -843,10 +849,12 @@ Optimizer::PassToken CreateUpgradeMemoryModelPass() {
Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set,
uint32_t shader_id,
bool input_length_enable,
bool input_init_enable) {
bool input_init_enable,
uint32_t version) {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::InstBindlessCheckPass>(
desc_set, shader_id, input_length_enable, input_init_enable));
MakeUnique<opt::InstBindlessCheckPass>(desc_set, shader_id,
input_length_enable,
input_init_enable, version));
}
Optimizer::PassToken CreateCodeSinkingPass() {

View File

@@ -96,35 +96,67 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes() const {
return modified;
}
analysis::TypeManager type_manager(context()->consumer(), context());
std::vector<Instruction*> visited_types;
std::vector<analysis::ForwardPointer> visited_forward_pointers;
std::vector<Instruction*> to_delete;
for (auto* i = &*context()->types_values_begin(); i; i = i->NextNode()) {
const bool is_i_forward_pointer = i->opcode() == SpvOpTypeForwardPointer;
// We only care about types.
if (!spvOpcodeGeneratesType((i->opcode())) &&
i->opcode() != SpvOpTypeForwardPointer) {
if (!spvOpcodeGeneratesType(i->opcode()) && !is_i_forward_pointer) {
continue;
}
// Is the current type equal to one of the types we have aready visited?
SpvId id_to_keep = 0u;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
if (AreTypesEqual(*i, *j, context())) {
id_to_keep = j->result_id();
break;
if (!is_i_forward_pointer) {
// Is the current type equal to one of the types we have already visited?
SpvId id_to_keep = 0u;
analysis::Type* i_type = type_manager.GetType(i->result_id());
assert(i_type);
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
analysis::Type* j_type = type_manager.GetType(j->result_id());
assert(j_type);
if (*i_type == *j_type) {
id_to_keep = j->result_id();
break;
}
}
}
if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
}
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
analysis::ForwardPointer i_type(
i->GetSingleWordInOperand(0u),
(SpvStorageClass)i->GetSingleWordInOperand(1u));
i_type.SetTargetPointer(
type_manager.GetType(i_type.target_id())->AsPointer());
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
const bool found_a_match =
std::find(std::begin(visited_forward_pointers),
std::end(visited_forward_pointers),
i_type) != std::end(visited_forward_pointers);
if (!found_a_match) {
// This is a never seen before type, keep it around.
visited_forward_pointers.emplace_back(i_type);
} else {
// The same type has already been seen before, remove this one.
modified = true;
to_delete.emplace_back(i);
}
}
}
@@ -151,8 +183,8 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {
analysis::DecorationManager decoration_manager(context()->module());
for (auto* i = &*context()->annotation_begin(); i;) {
// Is the current decoration equal to one of the decorations we have aready
// visited?
// Is the current decoration equal to one of the decorations we have
// already visited?
bool already_visited = false;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
@@ -177,20 +209,5 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {
return modified;
}
bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1,
const Instruction& inst2,
IRContext* context) {
if (inst1.opcode() != inst2.opcode()) return false;
if (!IsTypeInst(inst1.opcode())) return false;
const analysis::Type* type1 =
context->get_type_mgr()->GetType(inst1.result_id());
const analysis::Type* type2 =
context->get_type_mgr()->GetType(inst2.result_id());
if (type1 && type2 && *type1 == *type2) return true;
return false;
}
} // namespace opt
} // namespace spvtools

View File

@@ -36,12 +36,6 @@ class RemoveDuplicatesPass : public Pass {
const char* name() const override { return "remove-duplicates"; }
Status Process() override;
// TODO(pierremoreau): Move this function somewhere else (e.g. pass.h or
// within the type manager)
// Returns whether two types are equal, and have the same decorations.
static bool AreTypesEqual(const Instruction& inst1, const Instruction& inst2,
IRContext* context);
private:
// Remove duplicate capabilities from the module
//

View File

@@ -52,6 +52,7 @@ void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
state.emplace_back();
state[0].cinfo.containing_construct = 0;
state[0].cinfo.containing_loop = 0;
state[0].cinfo.containing_switch = 0;
state[0].merge_node = 0;
for (BasicBlock* block : order) {
@@ -74,8 +75,15 @@ void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
if (merge_inst->opcode() == SpvOpLoopMerge) {
new_state.cinfo.containing_loop = block->id();
new_state.cinfo.containing_switch = 0;
} else {
new_state.cinfo.containing_loop = state.back().cinfo.containing_loop;
if (merge_inst->NextNode()->opcode() == SpvOpSwitch) {
new_state.cinfo.containing_switch = block->id();
} else {
new_state.cinfo.containing_switch =
state.back().cinfo.containing_switch;
}
}
state.emplace_back(new_state);
@@ -84,6 +92,11 @@ void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
}
}
uint32_t StructuredCFGAnalysis::ContainingConstruct(Instruction* inst) {
uint32_t bb = context_->get_instr_block(inst)->id();
return ContainingConstruct(bb);
}
uint32_t StructuredCFGAnalysis::MergeBlock(uint32_t bb_id) {
uint32_t header_id = ContainingConstruct(bb_id);
if (header_id == 0) {
@@ -117,6 +130,17 @@ uint32_t StructuredCFGAnalysis::LoopContinueBlock(uint32_t bb_id) {
return merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
}
uint32_t StructuredCFGAnalysis::SwitchMergeBlock(uint32_t bb_id) {
uint32_t header_id = ContainingSwitch(bb_id);
if (header_id == 0) {
return 0;
}
BasicBlock* header = context_->cfg()->block(header_id);
Instruction* merge_inst = header->GetMergeInst();
return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
}
bool StructuredCFGAnalysis::IsContinueBlock(uint32_t bb_id) {
assert(bb_id != 0);
return LoopContinueBlock(bb_id) == bb_id;

View File

@@ -42,6 +42,11 @@ class StructuredCFGAnalysis {
return it->second.containing_construct;
}
// Returns the id of the header of the innermost merge construct
// that contains |inst|. Returns |0| if |inst| is not contained in any
// merge construct.
uint32_t ContainingConstruct(Instruction* inst);
// Returns the id of the merge block of the innermost merge construct
// that contains |bb_id|. Returns |0| if |bb_id| is not contained in any
// merge construct.
@@ -68,6 +73,21 @@ class StructuredCFGAnalysis {
// construct.
uint32_t LoopContinueBlock(uint32_t bb_id);
// Returns the id of the header of the innermost switch construct
// that contains |bb_id| as long as there is no intervening loop. Returns |0|
// if no such construct exists.
uint32_t ContainingSwitch(uint32_t bb_id) {
auto it = bb_to_construct_.find(bb_id);
if (it == bb_to_construct_.end()) {
return 0;
}
return it->second.containing_switch;
}
// Returns the id of the merge block of the innermost switch construct
// that contains |bb_id| as long as there is no intervening loop. Return |0|
// if no such block exists.
uint32_t SwitchMergeBlock(uint32_t bb_id);
bool IsContinueBlock(uint32_t bb_id);
bool IsMergeBlock(uint32_t bb_id);
@@ -82,6 +102,7 @@ class StructuredCFGAnalysis {
struct ConstructInfo {
uint32_t containing_construct;
uint32_t containing_loop;
uint32_t containing_switch;
};
// Populates |bb_to_construct_| with the innermost containing merge and loop

View File

@@ -66,7 +66,13 @@ uint32_t TypeManager::GetId(const Type* type) const {
}
void TypeManager::AnalyzeTypes(const Module& module) {
// First pass through the types. Any types that reference a forward pointer
// First pass through the constants, as some will be needed when traversing
// the types in the next pass.
for (const auto* inst : module.GetConstants()) {
id_to_constant_inst_[inst->result_id()] = inst;
}
// Then pass through the types. Any types that reference a forward pointer
// (directly or indirectly) are incomplete, and are added to incomplete types.
for (const auto* inst : module.GetTypes()) {
RecordIfTypeDefinition(*inst);
@@ -154,7 +160,7 @@ void TypeManager::AnalyzeTypes(const Module& module) {
#ifndef NDEBUG
// Check if the type pool contains two types that are the same. This
// is an indication that the hashing and comparision are wrong. It
// is an indication that the hashing and comparison are wrong. It
// will cause a problem if the type pool gets resized and everything
// is rehashed.
for (auto& i : type_pool_) {
@@ -504,9 +510,8 @@ Type* TypeManager::RebuildType(const Type& type) {
}
case Type::kArray: {
const Array* array_ty = type.AsArray();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId());
MakeUnique<Array>(array_ty->element_type(), array_ty->length_info());
break;
}
case Type::kRuntimeArray: {
@@ -636,15 +641,56 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case SpvOpTypeSampledImage:
type = new SampledImage(GetType(inst.GetSingleWordInOperand(0)));
break;
case SpvOpTypeArray:
type = new Array(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
case SpvOpTypeArray: {
const uint32_t length_id = inst.GetSingleWordInOperand(1);
const Instruction* length_constant_inst = id_to_constant_inst_[length_id];
assert(length_constant_inst);
// How will we distinguish one length value from another?
// Determine extra words required to distinguish this array length
// from another.
std::vector<uint32_t> extra_words{Array::LengthInfo::kDefiningId};
// If it is a specialised constant, retrieve its SpecId.
// Only OpSpecConstant has a SpecId.
uint32_t spec_id = 0u;
bool has_spec_id = false;
if (length_constant_inst->opcode() == SpvOpSpecConstant) {
context()->get_decoration_mgr()->ForEachDecoration(
length_id, SpvDecorationSpecId,
[&spec_id, &has_spec_id](const Instruction& decoration) {
assert(decoration.opcode() == SpvOpDecorate);
spec_id = decoration.GetSingleWordOperand(2u);
has_spec_id = true;
});
}
const auto opcode = length_constant_inst->opcode();
if (has_spec_id) {
extra_words.push_back(spec_id);
}
if ((opcode == SpvOpConstant) || (opcode == SpvOpSpecConstant)) {
// Always include the literal constant words. In the spec constant
// case, the constant might not be overridden, so it's still
// significant.
extra_words.insert(extra_words.end(),
length_constant_inst->GetOperand(2).words.begin(),
length_constant_inst->GetOperand(2).words.end());
extra_words[0] = has_spec_id ? Array::LengthInfo::kConstantWithSpecId
: Array::LengthInfo::kConstant;
} else {
assert(extra_words[0] == Array::LengthInfo::kDefiningId);
extra_words.push_back(length_id);
}
assert(extra_words.size() >= 2);
Array::LengthInfo length_info{length_id, extra_words};
type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_info);
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
} break;
case SpvOpTypeRuntimeArray:
type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {

Some files were not shown because too many files have changed in this diff Show More