From a1b314acd26475486cc22418cabdbf22bcef4261 Mon Sep 17 00:00:00 2001 From: Isaac Marovitz Date: Tue, 15 Aug 2023 14:17:00 +0100 Subject: [PATCH] Back to where we were First special instruction Start Load/Store implementation Start TextureSample Sample progress I/O Load/Store Progress Rest of load/store TODO: Currently, the generator still assumes the GLSL style of I/O attributres. On MSL, the vertex function should output a struct which contains a float4 with the required position attribute. TextureSize and VectorExtract Fix UserDefined IO Vars Fix stage input struct names --- .../CodeGen/Msl/Declarations.cs | 66 ++-- .../CodeGen/Msl/Instructions/InstGen.cs | 18 +- .../CodeGen/Msl/Instructions/InstGenCall.cs | 25 ++ .../CodeGen/Msl/Instructions/InstGenHelper.cs | 6 +- .../CodeGen/Msl/Instructions/InstGenMemory.cs | 318 ++++++++++++++++++ .../CodeGen/Msl/Instructions/InstGenVector.cs | 32 ++ .../CodeGen/Msl/Instructions/IoMap.cs | 49 ++- .../CodeGen/Msl/MslGenerator.cs | 17 +- .../CodeGen/Msl/OperandManager.cs | 17 +- .../Translation/TranslatorContext.cs | 2 + 10 files changed, 507 insertions(+), 43 deletions(-) create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenVector.cs diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 523aa56f3..f06af235a 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; +using System.Data.Common; using System.Linq; using System.Numerics; @@ -23,7 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } - // DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input))); + DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input))); } static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind) @@ -66,28 +67,45 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl }; } - // TODO: Redo for new Shader IR rep - // private static void DeclareInputAttributes(CodeGenContext context, IEnumerable inputs) - // { - // if (context.AttributeUsage.UsedInputAttributes != 0) - // { - // context.AppendLine("struct VertexIn"); - // context.EnterScope(); - // - // int usedAttributes = context.AttributeUsage.UsedInputAttributes | context.AttributeUsage.PassthroughAttributes; - // while (usedAttributes != 0) - // { - // int index = BitOperations.TrailingZeroCount(usedAttributes); - // - // string name = $"{DefaultNames.IAttributePrefix}{index}"; - // var type = context.AttributeUsage.get .QueryAttributeType(index).ToVec4Type(TargetLanguage.Msl); - // context.AppendLine($"{type} {name} [[attribute({index})]];"); - // - // usedAttributes &= ~(1 << index); - // } - // - // context.LeaveScope(";"); - // } - // } + private static void DeclareInputAttributes(CodeGenContext context, IEnumerable inputs) + { + if (context.Definitions.IaIndexing) + { + // Not handled + } + else + { + if (inputs.Any()) + { + string prefix = ""; + + switch (context.Definitions.Stage) + { + case ShaderStage.Vertex: + prefix = "Vertex"; + break; + case ShaderStage.Fragment: + prefix = "Fragment"; + break; + case ShaderStage.Compute: + prefix = "Compute"; + break; + } + + context.AppendLine($"struct {prefix}In"); + context.EnterScope(); + + foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) + { + string type = GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)); + string name = $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"; + + context.AppendLine($"{type} {name} [[attribute({ioDefinition.Location})]];"); + } + + context.LeaveScope(";"); + } + } + } } } \ No newline at end of file diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 3f5c5ebda..ffb2105ea 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -3,7 +3,10 @@ using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenVector; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions @@ -105,7 +108,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions case Instruction.Barrier: return "|| BARRIER ||"; case Instruction.Call: - return "|| CALL ||"; + return Call(context, operation); case Instruction.FSIBegin: return "|| FSI BEGIN ||"; case Instruction.FSIEnd: @@ -125,25 +128,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions case Instruction.ImageAtomic: return "|| IMAGE ATOMIC ||"; case Instruction.Load: - return "|| LOAD ||"; + return Load(context, operation); case Instruction.Lod: return "|| LOD ||"; case Instruction.MemoryBarrier: return "|| MEMORY BARRIER ||"; case Instruction.Store: - return "|| STORE ||"; + return Store(context, operation); case Instruction.TextureSample: - return "|| TEXTURE SAMPLE ||"; + return TextureSample(context, operation); case Instruction.TextureSize: - return "|| TEXTURE SIZE ||"; + return TextureSize(context, operation); case Instruction.VectorExtract: - return "|| VECTOR EXTRACT ||"; + return VectorExtract(context, operation); case Instruction.VoteAllEqual: return "|| VOTE ALL EQUAL ||"; } } - throw new InvalidOperationException($"Unexpected instruction type \"{info.Type}\"."); + // TODO: Return this to being an error + return $"Unexpected instruction type \"{info.Type}\"."; } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs new file mode 100644 index 000000000..df9d10301 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs @@ -0,0 +1,25 @@ +using Ryujinx.Graphics.Shader.StructuredIr; + +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; + +namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions +{ + static class InstGenCall + { + public static string Call(CodeGenContext context, AstOperation operation) + { + AstOperand funcId = (AstOperand)operation.GetSource(0); + + var functon = context.GetFunction(funcId.Value); + + string[] args = new string[operation.SourcesCount - 1]; + + for (int i = 0; i < args.Length; i++) + { + args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); + } + + return $"{functon.Name}({string.Join(", ", args)})"; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index a74766000..489787747 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -2,6 +2,8 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.TypeConversion; + namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { static class InstGenHelper @@ -140,9 +142,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions public static string GetSourceExpr(CodeGenContext context, IAstNode node, AggregateType dstType) { - // TODO: Implement this - // return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType); - return ""; + return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType); } public static string Enclose(string expr, IAstNode node, Instruction pInst, bool isLhs) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs new file mode 100644 index 000000000..2f0434c96 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs @@ -0,0 +1,318 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; +using System; +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; + +namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions +{ + static class InstGenMemory + { + public static string GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore) + { + StorageKind storageKind = operation.StorageKind; + + string varName; + AggregateType varType; + int srcIndex = 0; + bool isStoreOrAtomic = operation.Inst == Instruction.Store || operation.Inst.IsAtomic(); + int inputsCount = isStoreOrAtomic ? operation.SourcesCount - 1 : operation.SourcesCount; + + if (operation.Inst == Instruction.AtomicCompareAndSwap) + { + inputsCount--; + } + + switch (storageKind) + { + case StorageKind.ConstantBuffer: + case StorageKind.StorageBuffer: + if (operation.GetSource(srcIndex++) is not AstOperand bindingIndex || bindingIndex.Type != OperandType.Constant) + { + throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + int binding = bindingIndex.Value; + BufferDefinition buffer = storageKind == StorageKind.ConstantBuffer + ? context.Properties.ConstantBuffers[binding] + : context.Properties.StorageBuffers[binding]; + + if (operation.GetSource(srcIndex++) is not AstOperand fieldIndex || fieldIndex.Type != OperandType.Constant) + { + throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + StructureField field = buffer.Type.Fields[fieldIndex.Value]; + varName = $"{buffer.Name}.{field.Name}"; + varType = field.Type; + break; + + case StorageKind.LocalMemory: + case StorageKind.SharedMemory: + if (operation.GetSource(srcIndex++) is not AstOperand { Type: OperandType.Constant } bindingId) + { + throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + MemoryDefinition memory = storageKind == StorageKind.LocalMemory + ? context.Properties.LocalMemories[bindingId.Value] + : context.Properties.SharedMemories[bindingId.Value]; + + varName = memory.Name; + varType = memory.Type; + break; + + case StorageKind.Input: + case StorageKind.InputPerPatch: + case StorageKind.Output: + case StorageKind.OutputPerPatch: + if (operation.GetSource(srcIndex++) is not AstOperand varId || varId.Type != OperandType.Constant) + { + throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + IoVariable ioVariable = (IoVariable)varId.Value; + bool isOutput = storageKind.IsOutput(); + bool isPerPatch = storageKind.IsPerPatch(); + int location = -1; + int component = 0; + + if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput)) + { + if (operation.GetSource(srcIndex++) is not AstOperand vecIndex || vecIndex.Type != OperandType.Constant) + { + throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand."); + } + + location = vecIndex.Value; + + if (operation.SourcesCount > srcIndex && + operation.GetSource(srcIndex) is AstOperand elemIndex && + elemIndex.Type == OperandType.Constant && + context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, vecIndex.Value, elemIndex.Value, isOutput)) + { + component = elemIndex.Value; + srcIndex++; + } + } + + (varName, varType) = IoMap.GetMslBuiltIn( + context.Definitions, + ioVariable, + location, + component, + isOutput, + isPerPatch); + break; + + default: + throw new InvalidOperationException($"Invalid storage kind {storageKind}."); + } + + for (; srcIndex < inputsCount; srcIndex++) + { + IAstNode src = operation.GetSource(srcIndex); + + if ((varType & AggregateType.ElementCountMask) != 0 && + srcIndex == inputsCount - 1 && + src is AstOperand elementIndex && + elementIndex.Type == OperandType.Constant) + { + varName += "." + "xyzw"[elementIndex.Value & 3]; + } + else + { + varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]"; + } + } + + if (isStore) + { + varType &= AggregateType.ElementTypeMask; + varName = $"{varName} = {GetSourceExpr(context, operation.GetSource(srcIndex), varType)}"; + } + + return varName; + } + + public static string Load(CodeGenContext context, AstOperation operation) + { + return GenerateLoadOrStore(context, operation, isStore: false); + } + + public static string Store(CodeGenContext context, AstOperation operation) + { + return GenerateLoadOrStore(context, operation, isStore: true); + } + + public static string TextureSample(CodeGenContext context, AstOperation operation) + { + AstTextureOperation texOp = (AstTextureOperation)operation; + + bool isGather = (texOp.Flags & TextureFlags.Gather) != 0; + bool isShadow = (texOp.Type & SamplerType.Shadow) != 0; + bool intCoords = (texOp.Flags & TextureFlags.IntCoords) != 0; + + bool isArray = (texOp.Type & SamplerType.Array) != 0; + + bool colorIsVector = isGather || !isShadow; + + string texCall = "texture."; + + int srcIndex = 0; + + string Src(AggregateType type) + { + return GetSourceExpr(context, texOp.GetSource(srcIndex++), type); + } + + if (intCoords) + { + texCall += "read("; + } + else + { + texCall += "sample("; + + string samplerName = GetSamplerName(context.Properties, texOp); + + texCall += samplerName; + } + + int coordsCount = texOp.Type.GetDimensions(); + + int pCount = coordsCount; + + int arrayIndexElem = -1; + + if (isArray) + { + arrayIndexElem = pCount++; + } + + if (isShadow && !isGather) + { + pCount++; + } + + void Append(string str) + { + texCall += ", " + str; + } + + AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32; + + string AssemblePVector(int count) + { + if (count > 1) + { + string[] elems = new string[count]; + + for (int index = 0; index < count; index++) + { + if (arrayIndexElem == index) + { + elems[index] = Src(AggregateType.S32); + + if (!intCoords) + { + elems[index] = "float(" + elems[index] + ")"; + } + } + else + { + elems[index] = Src(coordType); + } + } + + string prefix = intCoords ? "int" : "float"; + + return prefix + count + "(" + string.Join(", ", elems) + ")"; + } + else + { + return Src(coordType); + } + } + + Append(AssemblePVector(pCount)); + + texCall += ")" + (colorIsVector ? GetMaskMultiDest(texOp.Index) : ""); + + return texCall; + } + + private static string GetSamplerName(ShaderProperties resourceDefinitions, AstTextureOperation textOp) + { + return resourceDefinitions.Textures[textOp.Binding].Name; + } + + // TODO: Verify that this is valid in MSL + private static string GetMask(int index) + { + return $".{"rgba".AsSpan(index, 1)}"; + } + + private static string GetMaskMultiDest(int mask) + { + string swizzle = "."; + + for (int i = 0; i < 4; i++) + { + if ((mask & (1 << i)) != 0) + { + swizzle += "xyzw"[i]; + } + } + + return swizzle; + } + + public static string TextureSize(CodeGenContext context, AstOperation operation) + { + AstTextureOperation texOp = (AstTextureOperation)operation; + + string textureName = "texture"; + string texCall = textureName + "."; + + if (texOp.Index == 3) + { + texCall += $"get_num_mip_levels()"; + } + else + { + context.Properties.Textures.TryGetValue(texOp.Binding, out TextureDefinition definition); + bool hasLod = !definition.Type.HasFlag(SamplerType.Multisample) && (definition.Type & SamplerType.Mask) != SamplerType.TextureBuffer; + texCall += "get_"; + + if (texOp.Index == 0) + { + texCall += "width"; + } + else if (texOp.Index == 1) + { + texCall += "height"; + } + else + { + texCall += "depth"; + } + + texCall += "("; + + if (hasLod) + { + IAstNode lod = operation.GetSource(0); + string lodExpr = GetSourceExpr(context, lod, GetSrcVarType(operation.Inst, 0)); + + texCall += $"{lodExpr}"; + } + + texCall += $"){GetMask(texOp.Index)}"; + } + + return texCall; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenVector.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenVector.cs new file mode 100644 index 000000000..9d8dae543 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenVector.cs @@ -0,0 +1,32 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.StructuredIr; + +using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; +using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; + +namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions +{ + static class InstGenVector + { + public static string VectorExtract(CodeGenContext context, AstOperation operation) + { + IAstNode vector = operation.GetSource(0); + IAstNode index = operation.GetSource(1); + + string vectorExpr = GetSourceExpr(context, vector, OperandManager.GetNodeDestType(context, vector)); + + if (index is AstOperand indexOperand && indexOperand.Type == OperandType.Constant) + { + char elem = "xyzw"[indexOperand.Value]; + + return $"{vectorExpr}.{elem}"; + } + else + { + string indexExpr = GetSourceExpr(context, index, GetSrcVarType(operation.Inst, 1)); + + return $"{vectorExpr}[{indexExpr}]"; + } + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index 261a0e572..67e78c9f6 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -1,11 +1,18 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; +using System.Globalization; namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { static class IoMap { - public static (string, AggregateType) GetMSLBuiltIn(IoVariable ioVariable) + public static (string, AggregateType) GetMslBuiltIn( + ShaderDefinitions definitions, + IoVariable ioVariable, + int location, + int component, + bool isOutput, + bool isPerPatch) { return ioVariable switch { @@ -18,12 +25,50 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), IoVariable.PointSize => ("point_size", AggregateType.FP32), - IoVariable.Position => ("position", AggregateType.Vector4), + IoVariable.Position => ("position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), + IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.VertexId => ("vertex_id", AggregateType.S32), IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32), _ => (null, AggregateType.Invalid), }; } + + private static (string, AggregateType) GetUserDefinedVariableName(ShaderDefinitions definitions, int location, int component, bool isOutput, bool isPerPatch) + { + string name = isPerPatch + ? DefaultNames.PerPatchAttributePrefix + : (isOutput ? DefaultNames.OAttributePrefix : DefaultNames.IAttributePrefix); + + if (location < 0) + { + return (name, definitions.GetUserDefinedType(0, isOutput)); + } + + name += location.ToString(CultureInfo.InvariantCulture); + + if (definitions.HasPerLocationInputOrOutputComponent(IoVariable.UserDefined, location, component, isOutput)) + { + name += "_" + "xyzw"[component & 3]; + } + + string prefix = ""; + switch (definitions.Stage) + { + case ShaderStage.Vertex: + prefix = "Vertex"; + break; + case ShaderStage.Fragment: + prefix = "Fragment"; + break; + case ShaderStage.Compute: + prefix = "Compute"; + break; + } + + prefix += isOutput ? "Out" : "In"; + + return (prefix + "." + name, definitions.GetUserDefinedType(location, isOutput)); + } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index e0ce97abe..b4d2ecad2 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -90,10 +90,25 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl funcKeyword = "fragment"; funcName = "fragmentMain"; } + else if (stage == ShaderStage.Compute) + { + // TODO: Compute main + } if (context.AttributeUsage.UsedInputAttributes != 0) { - args = args.Prepend("VertexIn in [[stage_in]]").ToArray(); + if (stage == ShaderStage.Vertex) + { + args = args.Prepend("VertexIn in [[stage_in]]").ToArray(); + } + else if (stage == ShaderStage.Fragment) + { + args = args.Prepend("FragmentIn in [[stage_in]]").ToArray(); + } + else if (stage == ShaderStage.Compute) + { + // TODO: Compute input + } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs index beaf25f68..6d211b7e8 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/OperandManager.cs @@ -46,9 +46,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static AggregateType GetNodeDestType(CodeGenContext context, IAstNode node) { - // TODO: Get rid of that function entirely and return the type from the operation generation - // functions directly, like SPIR-V does. - if (node is AstOperation operation) { if (operation.Inst == Instruction.Load || operation.Inst.IsAtomic()) @@ -99,6 +96,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable ioVariable = (IoVariable)varId.Value; bool isOutput = operation.StorageKind == StorageKind.Output || operation.StorageKind == StorageKind.OutputPerPatch; bool isPerPatch = operation.StorageKind == StorageKind.InputPerPatch || operation.StorageKind == StorageKind.OutputPerPatch; + int location = 0; + int component = 0; if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput)) { @@ -107,18 +106,24 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand."); } - int location = vecIndex.Value; + location = vecIndex.Value; if (operation.SourcesCount > 2 && operation.GetSource(2) is AstOperand elemIndex && elemIndex.Type == OperandType.Constant && context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput)) { - int component = elemIndex.Value; + component = elemIndex.Value; } } - (_, AggregateType varType) = IoMap.GetMSLBuiltIn(ioVariable); + (_, AggregateType varType) = IoMap.GetMslBuiltIn( + context.Definitions, + ioVariable, + location, + component, + isOutput, + isPerPatch); return varType & AggregateType.ElementTypeMask; } diff --git a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs index a579433f9..246867ba5 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs @@ -1,5 +1,6 @@ using Ryujinx.Graphics.Shader.CodeGen; using Ryujinx.Graphics.Shader.CodeGen.Glsl; +using Ryujinx.Graphics.Shader.CodeGen.Msl; using Ryujinx.Graphics.Shader.CodeGen.Spirv; using Ryujinx.Graphics.Shader.Decoders; using Ryujinx.Graphics.Shader.IntermediateRepresentation; @@ -373,6 +374,7 @@ namespace Ryujinx.Graphics.Shader.Translation { TargetLanguage.Glsl => new ShaderProgram(info, TargetLanguage.Glsl, GlslGenerator.Generate(sInfo, parameters)), TargetLanguage.Spirv => new ShaderProgram(info, TargetLanguage.Spirv, SpirvGenerator.Generate(sInfo, parameters)), + TargetLanguage.Msl => new ShaderProgram(info, TargetLanguage.Msl, MslGenerator.Generate(sInfo, parameters)), _ => throw new NotImplementedException(Options.TargetLanguage.ToString()), }; }