Optimized memory modified check (#538)

* Optimized memory modified check

This was initially in some cases more expensive than plainly sending the data. Now it should have way better performance.

* Small refactoring

* renamed InvalidAccessEventArgs
* Renamed PtPageBits

* Removed ValueRange(set)

They are currently unused and won't be likely to be used in the near future
This commit is contained in:
Roderick Sieben 2018-12-12 02:48:54 +01:00 committed by gdkchan
parent 36e8e074c9
commit 2e143365eb
7 changed files with 84 additions and 409 deletions

View file

@ -2,11 +2,11 @@ using System;
namespace ChocolArm64.Events namespace ChocolArm64.Events
{ {
public class InvalidAccessEventArgs : EventArgs public class MemoryAccessEventArgs : EventArgs
{ {
public long Position { get; private set; } public long Position { get; private set; }
public InvalidAccessEventArgs(long position) public MemoryAccessEventArgs(long position)
{ {
Position = position; Position = position;
} }

View file

@ -17,18 +17,18 @@ namespace ChocolArm64.Memory
{ {
private const int PtLvl0Bits = 13; private const int PtLvl0Bits = 13;
private const int PtLvl1Bits = 14; private const int PtLvl1Bits = 14;
private const int PtPageBits = 12; public const int PageBits = 12;
private const int PtLvl0Size = 1 << PtLvl0Bits; private const int PtLvl0Size = 1 << PtLvl0Bits;
private const int PtLvl1Size = 1 << PtLvl1Bits; private const int PtLvl1Size = 1 << PtLvl1Bits;
public const int PageSize = 1 << PtPageBits; public const int PageSize = 1 << PageBits;
private const int PtLvl0Mask = PtLvl0Size - 1; private const int PtLvl0Mask = PtLvl0Size - 1;
private const int PtLvl1Mask = PtLvl1Size - 1; private const int PtLvl1Mask = PtLvl1Size - 1;
public const int PageMask = PageSize - 1; public const int PageMask = PageSize - 1;
private const int PtLvl0Bit = PtPageBits + PtLvl1Bits; private const int PtLvl0Bit = PageBits + PtLvl1Bits;
private const int PtLvl1Bit = PtPageBits; private const int PtLvl1Bit = PageBits;
private const long ErgMask = (4 << CpuThreadState.ErgSizeLog2) - 1; private const long ErgMask = (4 << CpuThreadState.ErgSizeLog2) - 1;
@ -53,7 +53,9 @@ namespace ChocolArm64.Memory
private byte*** _pageTable; private byte*** _pageTable;
public event EventHandler<InvalidAccessEventArgs> InvalidAccess; public event EventHandler<MemoryAccessEventArgs> InvalidAccess;
public event EventHandler<MemoryAccessEventArgs> ObservedAccess;
public MemoryManager(IntPtr ram) public MemoryManager(IntPtr ram)
{ {
@ -632,7 +634,7 @@ namespace ChocolArm64.Memory
return false; return false;
} }
return _pageTable[l0][l1] != null || _observedPages.ContainsKey(position >> PtPageBits); return _pageTable[l0][l1] != null || _observedPages.ContainsKey(position >> PageBits);
} }
public long GetPhysicalAddress(long virtualAddress) public long GetPhysicalAddress(long virtualAddress)
@ -678,14 +680,14 @@ Unmapped:
private byte* HandleNullPte(long position) private byte* HandleNullPte(long position)
{ {
long key = position >> PtPageBits; long key = position >> PageBits;
if (_observedPages.TryGetValue(key, out IntPtr ptr)) if (_observedPages.TryGetValue(key, out IntPtr ptr))
{ {
return (byte*)ptr + (position & PageMask); return (byte*)ptr + (position & PageMask);
} }
InvalidAccess?.Invoke(this, new InvalidAccessEventArgs(position)); InvalidAccess?.Invoke(this, new MemoryAccessEventArgs(position));
throw new VmmPageFaultException(position); throw new VmmPageFaultException(position);
} }
@ -726,16 +728,20 @@ Unmapped:
private byte* HandleNullPteWrite(long position) private byte* HandleNullPteWrite(long position)
{ {
long key = position >> PtPageBits; long key = position >> PageBits;
MemoryAccessEventArgs e = new MemoryAccessEventArgs(position);
if (_observedPages.TryGetValue(key, out IntPtr ptr)) if (_observedPages.TryGetValue(key, out IntPtr ptr))
{ {
SetPtEntry(position, (byte*)ptr); SetPtEntry(position, (byte*)ptr);
ObservedAccess?.Invoke(this, e);
return (byte*)ptr + (position & PageMask); return (byte*)ptr + (position & PageMask);
} }
InvalidAccess?.Invoke(this, new InvalidAccessEventArgs(position)); InvalidAccess?.Invoke(this, e);
throw new VmmPageFaultException(position); throw new VmmPageFaultException(position);
} }
@ -784,45 +790,15 @@ Unmapped:
_pageTable[l0][l1] = ptr; _pageTable[l0][l1] = ptr;
} }
public (bool[], int) IsRegionModified(long position, long size) public void StartObservingRegion(long position, long size)
{ {
long endPosition = (position + size + PageMask) & ~PageMask; long endPosition = (position + size + PageMask) & ~PageMask;
position &= ~PageMask; position &= ~PageMask;
size = endPosition - position; while ((ulong)position < (ulong)endPosition)
bool[] modified = new bool[size >> PtPageBits];
int count = 0;
lock (_observedPages)
{ {
for (int page = 0; page < modified.Length; page++) _observedPages[position >> PageBits] = (IntPtr)Translate(position);
{
byte* ptr = Translate(position);
if (_observedPages.TryAdd(position >> PtPageBits, (IntPtr)ptr))
{
modified[page] = true;
count++;
}
else
{
long l0 = (position >> PtLvl0Bit) & PtLvl0Mask;
long l1 = (position >> PtLvl1Bit) & PtLvl1Mask;
byte** lvl1 = _pageTable[l0];
if (lvl1 != null)
{
if (modified[page] = lvl1[l1] != null)
{
count++;
}
}
}
SetPtEntry(position, null); SetPtEntry(position, null);
@ -830,9 +806,6 @@ Unmapped:
} }
} }
return (modified, count);
}
public void StopObservingRegion(long position, long size) public void StopObservingRegion(long position, long size)
{ {
long endPosition = (position + size + PageMask) & ~PageMask; long endPosition = (position + size + PageMask) & ~PageMask;
@ -841,7 +814,7 @@ Unmapped:
{ {
lock (_observedPages) lock (_observedPages)
{ {
if (_observedPages.TryRemove(position >> PtPageBits, out IntPtr ptr)) if (_observedPages.TryRemove(position >> PageBits, out IntPtr ptr))
{ {
SetPtEntry(position, (byte*)ptr); SetPtEntry(position, (byte*)ptr);
} }
@ -891,7 +864,7 @@ Unmapped:
public bool IsValidPosition(long position) public bool IsValidPosition(long position)
{ {
return position >> (PtLvl0Bits + PtLvl1Bits + PtPageBits) == 0; return position >> (PtLvl0Bits + PtLvl1Bits + PageBits) == 0;
} }
public void Dispose() public void Dispose()

View file

@ -36,7 +36,7 @@ namespace Ryujinx.Graphics.Memory
{ {
this.Memory = Memory; this.Memory = Memory;
Cache = new NvGpuVmmCache(); Cache = new NvGpuVmmCache(Memory);
PageTable = new long[PTLvl0Size][]; PageTable = new long[PTLvl0Size][];
} }
@ -262,7 +262,7 @@ namespace Ryujinx.Graphics.Memory
public bool IsRegionModified(long PA, long Size, NvGpuBufferType BufferType) public bool IsRegionModified(long PA, long Size, NvGpuBufferType BufferType)
{ {
return Cache.IsRegionModified(Memory, BufferType, PA, Size); return Cache.IsRegionModified(PA, Size, BufferType);
} }
public bool TryGetHostAddress(long Position, long Size, out IntPtr Ptr) public bool TryGetHostAddress(long Position, long Size, out IntPtr Ptr)

View file

@ -1,130 +1,83 @@
using ChocolArm64.Events;
using ChocolArm64.Memory; using ChocolArm64.Memory;
using System; using System.Collections.Concurrent;
namespace Ryujinx.Graphics.Memory namespace Ryujinx.Graphics.Memory
{ {
class NvGpuVmmCache class NvGpuVmmCache
{ {
private struct CachedResource private const int PageBits = MemoryManager.PageBits;
private const long PageSize = MemoryManager.PageSize;
private const long PageMask = MemoryManager.PageMask;
private ConcurrentDictionary<long, int>[] CachedPages;
private MemoryManager _memory;
public NvGpuVmmCache(MemoryManager memory)
{ {
public long Key; _memory = memory;
public int Mask;
public CachedResource(long Key, int Mask) _memory.ObservedAccess += MemoryAccessHandler;
CachedPages = new ConcurrentDictionary<long, int>[1 << 20];
}
private void MemoryAccessHandler(object sender, MemoryAccessEventArgs e)
{ {
this.Key = Key; long pa = _memory.GetPhysicalAddress(e.Position);
this.Mask = Mask;
CachedPages[pa >> PageBits]?.Clear();
} }
public override int GetHashCode() public bool IsRegionModified(long position, long size, NvGpuBufferType bufferType)
{ {
return (int)(Key * 23 + Mask); long pa = _memory.GetPhysicalAddress(position);
}
public override bool Equals(object obj) long addr = pa;
long endAddr = (addr + size + PageMask) & ~PageMask;
int newBuffMask = 1 << (int)bufferType;
_memory.StartObservingRegion(position, size);
long cachedPagesCount = 0;
while (addr < endAddr)
{ {
return obj is CachedResource Cached && Equals(Cached); long page = addr >> PageBits;
}
public bool Equals(CachedResource other) ConcurrentDictionary<long, int> dictionary = CachedPages[page];
if (dictionary == null)
{ {
return Key == other.Key && Mask == other.Mask; dictionary = new ConcurrentDictionary<long, int>();
}
CachedPages[page] = dictionary;
} }
private ValueRangeSet<CachedResource> CachedRanges; if (dictionary.TryGetValue(pa, out int currBuffMask))
public NvGpuVmmCache()
{ {
CachedRanges = new ValueRangeSet<CachedResource>(); if ((currBuffMask & newBuffMask) != 0)
}
public bool IsRegionModified(MemoryManager Memory, NvGpuBufferType BufferType, long Start, long Size)
{ {
(bool[] Modified, long ModifiedCount) = Memory.IsRegionModified(Start, Size); cachedPagesCount++;
}
//Remove all modified ranges. else
int Index = 0;
long Position = Start & ~NvGpuVmm.PageMask;
while (ModifiedCount > 0)
{ {
if (Modified[Index++]) dictionary[pa] |= newBuffMask;
}
}
else
{ {
CachedRanges.Remove(new ValueRange<CachedResource>(Position, Position + NvGpuVmm.PageSize)); dictionary[pa] = newBuffMask;
ModifiedCount--;
} }
Position += NvGpuVmm.PageSize; addr += PageSize;
} }
//Mask has the bit set for the current resource type. return cachedPagesCount != (endAddr - pa + PageMask) >> PageBits;
//If the region is not yet present on the list, then a new ValueRange
//is directly added with the current resource type as the only bit set.
//Otherwise, it just sets the bit for this new resource type on the current mask.
//The physical address of the resource is used as key, those keys are used to keep
//track of resources that are already on the cache. A resource may be inside another
//resource, and in this case we should return true if the "sub-resource" was not
//yet cached.
int Mask = 1 << (int)BufferType;
CachedResource NewCachedValue = new CachedResource(Start, Mask);
ValueRange<CachedResource> NewCached = new ValueRange<CachedResource>(Start, Start + Size);
ValueRange<CachedResource>[] Ranges = CachedRanges.GetAllIntersections(NewCached);
bool IsKeyCached = Ranges.Length > 0 && Ranges[0].Value.Key == Start;
long LastEnd = NewCached.Start;
long Coverage = 0;
for (Index = 0; Index < Ranges.Length; Index++)
{
ValueRange<CachedResource> Current = Ranges[Index];
CachedResource Cached = Current.Value;
long RgStart = Math.Max(Current.Start, NewCached.Start);
long RgEnd = Math.Min(Current.End, NewCached.End);
if ((Cached.Mask & Mask) != 0)
{
Coverage += RgEnd - RgStart;
}
//Highest key value has priority, this prevents larger resources
//for completely invalidating smaller ones on the cache. For example,
//consider that a resource in the range [100, 200) was added, and then
//another one in the range [50, 200). We prevent the new resource from
//completely replacing the old one by spliting it like this:
//New resource key is added at [50, 100), old key is still present at [100, 200).
if (Cached.Key < Start)
{
Cached.Key = Start;
}
Cached.Mask |= Mask;
CachedRanges.Add(new ValueRange<CachedResource>(RgStart, RgEnd, Cached));
if (RgStart > LastEnd)
{
CachedRanges.Add(new ValueRange<CachedResource>(LastEnd, RgStart, NewCachedValue));
}
LastEnd = RgEnd;
}
if (LastEnd < NewCached.End)
{
CachedRanges.Add(new ValueRange<CachedResource>(LastEnd, NewCached.End, NewCachedValue));
}
return !IsKeyCached || Coverage != Size;
} }
} }
} }

View file

@ -1,17 +0,0 @@
namespace Ryujinx.Graphics
{
struct ValueRange<T>
{
public long Start { get; private set; }
public long End { get; private set; }
public T Value { get; set; }
public ValueRange(long Start, long End, T Value = default(T))
{
this.Start = Start;
this.End = End;
this.Value = Value;
}
}
}

View file

@ -1,234 +0,0 @@
using System.Collections.Generic;
namespace Ryujinx.Graphics
{
class ValueRangeSet<T>
{
private List<ValueRange<T>> Ranges;
public ValueRangeSet()
{
Ranges = new List<ValueRange<T>>();
}
public void Add(ValueRange<T> Range)
{
if (Range.End <= Range.Start)
{
//Empty or invalid range, do nothing.
return;
}
int First = BinarySearchFirstIntersection(Range);
if (First == -1)
{
//No intersections case.
//Find first greater than range (after the current one).
//If found, add before, otherwise add to the end of the list.
int GtIndex = BinarySearchGt(Range);
if (GtIndex != -1)
{
Ranges.Insert(GtIndex, Range);
}
else
{
Ranges.Add(Range);
}
return;
}
(int Start, int End) = GetAllIntersectionRanges(Range, First);
ValueRange<T> Prev = Ranges[Start];
ValueRange<T> Next = Ranges[End];
Ranges.RemoveRange(Start, (End - Start) + 1);
InsertNextNeighbour(Start, Range, Next);
int NewIndex = Start;
Ranges.Insert(Start, Range);
InsertPrevNeighbour(Start, Range, Prev);
//Try merging neighbours if the value is equal.
if (NewIndex > 0)
{
Prev = Ranges[NewIndex - 1];
if (Prev.End == Range.Start && CompareValues(Prev, Range))
{
Ranges.RemoveAt(--NewIndex);
Ranges[NewIndex] = new ValueRange<T>(Prev.Start, Range.End, Range.Value);
}
}
if (NewIndex < Ranges.Count - 1)
{
Next = Ranges[NewIndex + 1];
if (Next.Start == Range.End && CompareValues(Next, Range))
{
Ranges.RemoveAt(NewIndex + 1);
Ranges[NewIndex] = new ValueRange<T>(Ranges[NewIndex].Start, Next.End, Range.Value);
}
}
}
private bool CompareValues(ValueRange<T> LHS, ValueRange<T> RHS)
{
return LHS.Value?.Equals(RHS.Value) ?? RHS.Value == null;
}
public void Remove(ValueRange<T> Range)
{
int First = BinarySearchFirstIntersection(Range);
if (First == -1)
{
//Nothing to remove.
return;
}
(int Start, int End) = GetAllIntersectionRanges(Range, First);
ValueRange<T> Prev = Ranges[Start];
ValueRange<T> Next = Ranges[End];
Ranges.RemoveRange(Start, (End - Start) + 1);
InsertNextNeighbour(Start, Range, Next);
InsertPrevNeighbour(Start, Range, Prev);
}
private void InsertNextNeighbour(int Index, ValueRange<T> Range, ValueRange<T> Next)
{
//Split last intersection (ordered by Start) if necessary.
if (Range.End < Next.End)
{
InsertNewRange(Index, Range.End, Next.End, Next.Value);
}
}
private void InsertPrevNeighbour(int Index, ValueRange<T> Range, ValueRange<T> Prev)
{
//Split first intersection (ordered by Start) if necessary.
if (Range.Start > Prev.Start)
{
InsertNewRange(Index, Prev.Start, Range.Start, Prev.Value);
}
}
private void InsertNewRange(int Index, long Start, long End, T Value)
{
Ranges.Insert(Index, new ValueRange<T>(Start, End, Value));
}
public ValueRange<T>[] GetAllIntersections(ValueRange<T> Range)
{
int First = BinarySearchFirstIntersection(Range);
if (First == -1)
{
return new ValueRange<T>[0];
}
(int Start, int End) = GetAllIntersectionRanges(Range, First);
return Ranges.GetRange(Start, (End - Start) + 1).ToArray();
}
private (int Start, int End) GetAllIntersectionRanges(ValueRange<T> Range, int BaseIndex)
{
int Start = BaseIndex;
int End = BaseIndex;
while (Start > 0 && Intersects(Range, Ranges[Start - 1]))
{
Start--;
}
while (End < Ranges.Count - 1 && Intersects(Range, Ranges[End + 1]))
{
End++;
}
return (Start, End);
}
private int BinarySearchFirstIntersection(ValueRange<T> Range)
{
int Left = 0;
int Right = Ranges.Count - 1;
while (Left <= Right)
{
int Size = Right - Left;
int Middle = Left + (Size >> 1);
ValueRange<T> Current = Ranges[Middle];
if (Intersects(Range, Current))
{
return Middle;
}
if (Range.Start < Current.Start)
{
Right = Middle - 1;
}
else
{
Left = Middle + 1;
}
}
return -1;
}
private int BinarySearchGt(ValueRange<T> Range)
{
int GtIndex = -1;
int Left = 0;
int Right = Ranges.Count - 1;
while (Left <= Right)
{
int Size = Right - Left;
int Middle = Left + (Size >> 1);
ValueRange<T> Current = Ranges[Middle];
if (Range.Start < Current.Start)
{
Right = Middle - 1;
if (GtIndex == -1 || Current.Start < Ranges[GtIndex].Start)
{
GtIndex = Middle;
}
}
else
{
Left = Middle + 1;
}
}
return GtIndex;
}
private bool Intersects(ValueRange<T> LHS, ValueRange<T> RHS)
{
return LHS.Start < RHS.End && RHS.Start < LHS.End;
}
}
}

View file

@ -995,7 +995,7 @@ namespace Ryujinx.HLE.HOS.Kernel
} }
} }
private void InvalidAccessHandler(object sender, InvalidAccessEventArgs e) private void InvalidAccessHandler(object sender, MemoryAccessEventArgs e)
{ {
PrintCurrentThreadStackTrace(); PrintCurrentThreadStackTrace();
} }