Fix fixed queries once and for all (?)

wip/source-generators
copygirl 2 years ago
parent a8107a6bf2
commit 255352e7c4
  1. 3
      src/Immersion/ObserverTest.cs
  2. 3
      src/gaemstone.Bloxel/Client/Systems/ChunkMeshGenerator.cs
  3. 8
      src/gaemstone/ECS/Iterator.cs
  4. 19
      src/gaemstone/ECS/Observer.cs
  5. 88
      src/gaemstone/ECS/System.cs
  6. 36
      src/gaemstone/Utility/IL/ILGeneratorWrapper.cs
  7. 75
      src/gaemstone/Utility/IL/IterActionGenerator.cs

@ -10,7 +10,8 @@ namespace Immersion;
[DependsOn<gaemstone.Client.Components.RenderingComponents>] [DependsOn<gaemstone.Client.Components.RenderingComponents>]
public class ObserverTest public class ObserverTest
{ {
[Observer<ObserverEvent.OnSet>(Expression = "[in] Chunk, [none] (Mesh, *)")] [Observer<ObserverEvent.OnSet>]
[Expression("[in] Chunk, [none] (Mesh, *)")]
public static void DoObserver(in Chunk chunk) public static void DoObserver(in Chunk chunk)
=> Console.WriteLine($"Chunk at {chunk.Position} now has a Mesh!"); => Console.WriteLine($"Chunk at {chunk.Position} now has a Mesh!");
} }

@ -32,7 +32,8 @@ public class ChunkMeshGenerator
private Vector3D<float>[] _normals = new Vector3D<float>[StartingCapacity]; private Vector3D<float>[] _normals = new Vector3D<float>[StartingCapacity];
private Vector2D<float>[] _uvs = new Vector2D<float>[StartingCapacity]; private Vector2D<float>[] _uvs = new Vector2D<float>[StartingCapacity];
[System(Expression = "[in] Chunk, ChunkStoreBlocks, HasBasicWorldGeneration, !(Mesh, *)")] [System]
[Expression("[in] Chunk, ChunkStoreBlocks, HasBasicWorldGeneration, !(Mesh, *)")]
public void GenerateChunkMeshes(Universe universe, EntityRef entity, public void GenerateChunkMeshes(Universe universe, EntityRef entity,
in Chunk chunk, ChunkStoreBlocks blocks) in Chunk chunk, ChunkStoreBlocks blocks)
{ {

@ -21,8 +21,6 @@ public unsafe partial class Iterator
public Iterator(Universe universe, IteratorType? type, ecs_iter_t value) public Iterator(Universe universe, IteratorType? type, ecs_iter_t value)
{ Universe = universe; Type = type; Value = value; } { Universe = universe; Type = type; Value = value; }
// TODO: ecs_iter_set_var(world, 0, ent) to run a query for a known entity.
public static Iterator FromTerm(Universe universe, Term term) public static Iterator FromTerm(Universe universe, Term term)
{ {
using var alloc = TempAllocator.Use(); using var alloc = TempAllocator.Use();
@ -31,6 +29,12 @@ public unsafe partial class Iterator
return new(universe, IteratorType.Term, flecsIter); return new(universe, IteratorType.Term, flecsIter);
} }
public void SetThis(Entity entity)
{
fixed (ecs_iter_t* ptr = &Value)
ecs_iter_set_var(ptr, 0, entity);
}
public bool Next() public bool Next()
{ {
fixed (ecs_iter_t* ptr = &Value) fixed (ecs_iter_t* ptr = &Value)

@ -1,6 +1,7 @@
using System; using System;
using System.Linq; using System.Linq;
using System.Reflection; using System.Reflection;
using System.Runtime.InteropServices;
using gaemstone.Utility; using gaemstone.Utility;
using gaemstone.Utility.IL; using gaemstone.Utility.IL;
using static flecs_hub.flecs; using static flecs_hub.flecs;
@ -11,7 +12,6 @@ namespace gaemstone.ECS;
public class ObserverAttribute : Attribute public class ObserverAttribute : Attribute
{ {
public Type Event { get; } public Type Event { get; }
public string? Expression { get; init; }
internal ObserverAttribute(Type @event) => Event = @event; // Use generic type instead. internal ObserverAttribute(Type @event) => Event = @event; // Use generic type instead.
} }
public class ObserverAttribute<TEvent> : ObserverAttribute public class ObserverAttribute<TEvent> : ObserverAttribute
@ -27,7 +27,7 @@ public static class ObserverExtensions
filter = filter.ToFlecs(alloc), filter = filter.ToFlecs(alloc),
entity = universe.New((filter.Name != null) ? new(filter.Name) : null).Build(), entity = universe.New((filter.Name != null) ? new(filter.Name) : null).Build(),
binding_ctx = (void*)CallbackContextHelper.Create((universe, callback)), binding_ctx = (void*)CallbackContextHelper.Create((universe, callback)),
callback = new() { Data = new() { Pointer = &SystemExtensions.Callback } }, callback = new() { Data = new() { Pointer = &Callback } },
}; };
desc.events[0] = @event; desc.events[0] = @event;
return new(universe, new(ecs_observer_init(universe, &desc))); return new(universe, new(ecs_observer_init(universe, &desc)));
@ -38,19 +38,20 @@ public static class ObserverExtensions
{ {
var attr = method.Get<ObserverAttribute>() ?? throw new ArgumentException( var attr = method.Get<ObserverAttribute>() ?? throw new ArgumentException(
"Observer must specify ObserverAttribute", nameof(method)); "Observer must specify ObserverAttribute", nameof(method));
var expr = method.Get<ExpressionAttribute>()?.Value;
FilterDesc filter; FilterDesc filter;
Action<Iterator> iterAction; Action<Iterator> iterAction;
var param = method.GetParameters(); var param = method.GetParameters();
if (param.Length == 1 && param[0].ParameterType == typeof(Iterator)) { if (param.Length == 1 && param[0].ParameterType == typeof(Iterator)) {
filter = new(attr.Expression ?? throw new Exception( filter = new(expr ?? throw new Exception(
"Observer must specify ObserverAttribute.Expression")); "Observer must specify ExpressionAttribute"));
if (method.IsStatic) instance = null; if (method.IsStatic) instance = null;
iterAction = (Action<Iterator>)Delegate.CreateDelegate( iterAction = (Action<Iterator>)Delegate.CreateDelegate(
typeof(Action<Iterator>), instance, method); typeof(Action<Iterator>), instance, method);
} else { } else {
var gen = IterActionGenerator.GetOrBuild(universe, method); var gen = IterActionGenerator.GetOrBuild(universe, method);
filter = (attr.Expression != null) ? new(attr.Expression) : new(gen.Terms.ToArray()); filter = (expr != null) ? new(expr) : new(gen.Terms.ToArray());
iterAction = iter => gen.RunWithTryCatch(instance, iter); iterAction = iter => gen.RunWithTryCatch(instance, iter);
} }
@ -58,4 +59,12 @@ public static class ObserverExtensions
var @event = universe.LookupOrThrow(attr.Event); var @event = universe.LookupOrThrow(attr.Event);
return universe.RegisterObserver(filter, @event, iterAction); return universe.RegisterObserver(filter, @event, iterAction);
} }
[UnmanagedCallersOnly]
private static unsafe void Callback(ecs_iter_t* iter)
{
var (universe, callback) = CallbackContextHelper
.Get<(Universe, Action<Iterator>)>((nint)iter->binding_ctx);
callback(new Iterator(universe, null, *iter));
}
} }

@ -13,27 +13,32 @@ namespace gaemstone.ECS;
[AttributeUsage(AttributeTargets.Method)] [AttributeUsage(AttributeTargets.Method)]
public class SystemAttribute : Attribute public class SystemAttribute : Attribute
{ {
public Type Phase { get; set; } public Type Phase { get; }
public string? Expression { get; set; }
public SystemAttribute() : this(typeof(SystemPhase.OnUpdate)) { } public SystemAttribute() : this(typeof(SystemPhase.OnUpdate)) { }
internal SystemAttribute(Type phase) => Phase = phase; // Use generic type instead. internal SystemAttribute(Type phase) => Phase = phase; // Use generic type instead.
} }
public class SystemAttribute<TPhase> : SystemAttribute public class SystemAttribute<TPhase> : SystemAttribute
{ public SystemAttribute() : base(typeof(TPhase)) { } } { public SystemAttribute() : base(typeof(TPhase)) { } }
[AttributeUsage(AttributeTargets.Method)]
public class ExpressionAttribute : Attribute
{
public string Value { get; }
public ExpressionAttribute(string value) => Value = value;
}
public static class SystemExtensions public static class SystemExtensions
{ {
public static unsafe EntityRef RegisterSystem(this Universe universe, private static unsafe EntityRef RegisterSystem(this Universe universe,
QueryDesc query, Entity phase, Action<Iterator> callback) QueryDesc query, Entity phase, CallbackContext callback)
{ {
using var alloc = TempAllocator.Use(); using var alloc = TempAllocator.Use();
var desc = new ecs_system_desc_t { var desc = new ecs_system_desc_t {
query = query.ToFlecs(alloc), query = query.ToFlecs(alloc),
entity = universe.New((query.Name != null) ? new(query.Name) : null) entity = universe.New((query.Name != null) ? new(query.Name) : null)
.Add<DependsOn>(phase).Add(phase).Build(), .Add<DependsOn>(phase).Add(phase).Build(),
binding_ctx = (void*)CallbackContextHelper.Create((universe, callback)), binding_ctx = (void*)CallbackContextHelper.Create(callback),
callback = new() { Data = new() { Pointer = &Callback } }, run = new() { Data = new() { Pointer = &Run } },
}; };
return new(universe, new(ecs_system_init(universe, &desc))); return new(universe, new(ecs_system_init(universe, &desc)));
} }
@ -41,51 +46,80 @@ public static class SystemExtensions
public static EntityRef RegisterSystem(this Universe universe, Delegate action) public static EntityRef RegisterSystem(this Universe universe, Delegate action)
{ {
var attr = action.Method.Get<SystemAttribute>(); var attr = action.Method.Get<SystemAttribute>();
QueryDesc query; var expr = action.Method.Get<ExpressionAttribute>()?.Value;
if (action is Action<Iterator> iterAction) { QueryDesc query;
query = new(attr?.Expression ?? throw new ArgumentException( if (action is Action<Iterator> callback) {
"System must specify SystemAttribute.Expression", nameof(action))); query = new(expr ?? throw new ArgumentException(
"System must specify ExpressionAttribute", nameof(action)));
} else { } else {
var method = action.GetType().GetMethod("Invoke")!; var gen = IterActionGenerator.GetOrBuild(universe, action.Method);
var gen = IterActionGenerator.GetOrBuild(universe, method); query = (expr != null) ? new(expr) : new(gen.Terms.ToArray());
query = (attr?.Expression != null) ? new(attr.Expression) : new(gen.Terms.ToArray()); callback = iter => gen.RunWithTryCatch(action.Target, iter);
iterAction = iter => gen.RunWithTryCatch(action.Target, iter);
} }
query.Name = action.Method.Name; query.Name = action.Method.Name;
var phase = universe.LookupOrThrow(attr?.Phase ?? typeof(SystemPhase.OnUpdate)); var phase = universe.LookupOrThrow(attr?.Phase ?? typeof(SystemPhase.OnUpdate));
return universe.RegisterSystem(query, phase, iterAction); return universe.RegisterSystem(query, phase, new(universe, action.Method, callback));
} }
public static EntityRef RegisterSystem(this Universe universe, public static EntityRef RegisterSystem(this Universe universe,
object? instance, MethodInfo method) object? instance, MethodInfo method)
{ {
var attr = method.Get<SystemAttribute>(); var attr = method.Get<SystemAttribute>();
QueryDesc query; var expr = method.Get<ExpressionAttribute>()?.Value;
Action<Iterator> iterAction;
QueryDesc query;
Action<Iterator> callback;
var param = method.GetParameters(); var param = method.GetParameters();
if (param.Length == 1 && param[0].ParameterType == typeof(Iterator)) { if (param.Length == 1 && param[0].ParameterType == typeof(Iterator)) {
query = new(attr?.Expression ?? throw new ArgumentException( query = new(expr ?? throw new ArgumentException(
"System must specify SystemAttribute.Expression", nameof(method))); "System must specify SystemAttribute.Expression", nameof(method)));
iterAction = (Action<Iterator>)Delegate.CreateDelegate(typeof(Action<Iterator>), instance, method); callback = (Action<Iterator>)Delegate.CreateDelegate(typeof(Action<Iterator>), instance, method);
} else { } else {
var gen = IterActionGenerator.GetOrBuild(universe, method); var gen = IterActionGenerator.GetOrBuild(universe, method);
query = (attr?.Expression != null) ? new(attr.Expression) : new(gen.Terms.ToArray()); query = (expr != null) ? new(expr) : new(gen.Terms.ToArray());
iterAction = iter => gen.RunWithTryCatch(instance, iter); callback = iter => gen.RunWithTryCatch(instance, iter);
} }
query.Name = method.Name; query.Name = method.Name;
var phase = universe.LookupOrThrow(attr?.Phase ?? typeof(SystemPhase.OnUpdate)); var phase = universe.LookupOrThrow(attr?.Phase ?? typeof(SystemPhase.OnUpdate));
return universe.RegisterSystem(query, phase, iterAction); return universe.RegisterSystem(query, phase, new(universe, method, callback));
}
private class CallbackContext
{
public Universe Universe { get; }
public MethodInfo Method { get; }
public Action<Iterator> Callback { get; }
public CallbackContext(Universe universe, MethodInfo method, Action<Iterator> callback)
{ Universe = universe; Method = method; Callback = callback; }
public void Prepare(Iterator iter)
{
// If the method is marked with [Source], set the $This variable.
if (Method.Get<SourceAttribute>()?.Type is Type sourceType)
iter.SetThis(Universe.LookupOrThrow(sourceType));
}
} }
[UnmanagedCallersOnly] [UnmanagedCallersOnly]
internal static unsafe void Callback(ecs_iter_t* iter) private static unsafe void Run(ecs_iter_t* flecsIter)
{ {
var (universe, callback) = CallbackContextHelper // This is what flecs does, so I guess we'll do it too!
.Get<(Universe, Action<Iterator>)>((nint)iter->binding_ctx); var callback = CallbackContextHelper.Get<CallbackContext>((nint)flecsIter->binding_ctx);
callback(new Iterator(universe, null, *iter));
var type = (&flecsIter->next == (delegate*<ecs_iter_t*, Runtime.CBool>)&ecs_query_next)
? IteratorType.Query : (IteratorType?)null;
var iter = new Iterator(callback.Universe, type, *flecsIter);
callback.Prepare(iter);
var query = flecsIter->priv.iter.query.query;
if (query != null && ecs_query_get_filter(query)->term_count == 0)
callback.Callback(iter);
else while (iter.Next())
callback.Callback(iter);
} }
} }

@ -112,11 +112,15 @@ public class ILGeneratorWrapper
internal void Emit(OpCode code, ConstructorInfo constr) { AddInstr(code, constr); _il.Emit(code, constr); } internal void Emit(OpCode code, ConstructorInfo constr) { AddInstr(code, constr); _il.Emit(code, constr); }
public void Dup() => Emit(OpCodes.Dup); public void Dup() => Emit(OpCodes.Dup);
public void Pop() => Emit(OpCodes.Pop);
public void LoadNull() => Emit(OpCodes.Ldnull); public void LoadNull() => Emit(OpCodes.Ldnull);
public void LoadConst(int value) => Emit(OpCodes.Ldc_I4, value); public void LoadConst(int value) => Emit(OpCodes.Ldc_I4, value);
public void Load(string value) => Emit(OpCodes.Ldstr, value); public void Load(string value) => Emit(OpCodes.Ldstr, value);
private static readonly MethodInfo _typeFromHandleMethod = typeof(Type).GetMethod(nameof(Type.GetTypeFromHandle))!;
public void Typeof(Type value) { Emit(OpCodes.Ldtoken, value); Call(_typeFromHandleMethod); }
public void Load(IArgument arg) => Emit(OpCodes.Ldarg, arg); public void Load(IArgument arg) => Emit(OpCodes.Ldarg, arg);
public void LoadAddr(IArgument arg) => Emit(OpCodes.Ldarga, arg); public void LoadAddr(IArgument arg) => Emit(OpCodes.Ldarga, arg);
@ -199,35 +203,29 @@ public class ILGeneratorWrapper
public void Call(MethodInfo method) => Emit(OpCodes.Call, method); public void Call(MethodInfo method) => Emit(OpCodes.Call, method);
public void CallVirt(MethodInfo method) => Emit(OpCodes.Callvirt, method); public void CallVirt(MethodInfo method) => Emit(OpCodes.Callvirt, method);
public void Return() => Emit(OpCodes.Ret); private static readonly MethodInfo _consoleWriteLineMethod = typeof(Console).GetMethod(nameof(Console.WriteLine), new[] { typeof(string) })!;
public void Print(string str) { Load(str); Call(_consoleWriteLineMethod); }
public void Return() => Emit(OpCodes.Ret);
public IDisposable For(Action loadMax, out ILocal<int> current)
{
var r = Random.Shared.Next(10000, 100000);
Comment($"INIT for loop {r}");
var curLocal = current = Local<int>($"index_{r}"); public delegate void WhileBodyAction(Label continueLabel, Label breakLabel);
var maxLocal = Local<int>($"length_{r}"); public delegate void WhileTestAction(Label continueLabel);
public void While(string name, WhileTestAction testAction, WhileBodyAction bodyAction) {
var bodyLabel = DefineLabel(); var bodyLabel = DefineLabel();
var testLabel = DefineLabel(); var testLabel = DefineLabel();
var breakLabel = DefineLabel();
Set(curLocal, 0);
loadMax(); Store(maxLocal);
Comment($"BEGIN for loop {r}");
Goto(testLabel); Goto(testLabel);
Comment("BEGIN LOOP " + name);
MarkLabel(bodyLabel); MarkLabel(bodyLabel);
var indent = Indent(); using (Indent()) bodyAction(testLabel, breakLabel);
Comment("TEST LOOP " + name);
return Block(() => {
Increment(curLocal);
MarkLabel(testLabel); MarkLabel(testLabel);
GotoIf(bodyLabel, curLocal, Comparison.LessThan, maxLocal); using (Indent()) testAction(bodyLabel);
indent.Dispose(); Comment("END LOOP " + name);
Comment($"END for loop {r}"); MarkLabel(breakLabel);
});
} }

@ -19,9 +19,9 @@ public unsafe class IterActionGenerator
private static readonly PropertyInfo _iteratorUniverseProp = typeof(Iterator).GetProperty(nameof(Iterator.Universe))!; private static readonly PropertyInfo _iteratorUniverseProp = typeof(Iterator).GetProperty(nameof(Iterator.Universe))!;
private static readonly PropertyInfo _iteratorDeltaTimeProp = typeof(Iterator).GetProperty(nameof(Iterator.DeltaTime))!; private static readonly PropertyInfo _iteratorDeltaTimeProp = typeof(Iterator).GetProperty(nameof(Iterator.DeltaTime))!;
private static readonly PropertyInfo _iteratorCountProp = typeof(Iterator).GetProperty(nameof(Iterator.Count))!; private static readonly PropertyInfo _iteratorCountProp = typeof(Iterator).GetProperty(nameof(Iterator.Count))!;
private static readonly MethodInfo _iteratorEntityMethod = typeof(Iterator).GetMethod(nameof(Iterator.Entity))!;
private static readonly MethodInfo _iteratorFieldMethod = typeof(Iterator).GetMethod(nameof(Iterator.Field))!; private static readonly MethodInfo _iteratorFieldMethod = typeof(Iterator).GetMethod(nameof(Iterator.Field))!;
private static readonly MethodInfo _iteratorMaybeFieldMethod = typeof(Iterator).GetMethod(nameof(Iterator.MaybeField))!; private static readonly MethodInfo _iteratorMaybeFieldMethod = typeof(Iterator).GetMethod(nameof(Iterator.MaybeField))!;
private static readonly MethodInfo _iteratorEntityMethod = typeof(Iterator).GetMethod(nameof(Iterator.Entity))!;
private static readonly MethodInfo _handleFromIntPtrMethod = typeof(GCHandle).GetMethod(nameof(GCHandle.FromIntPtr))!; private static readonly MethodInfo _handleFromIntPtrMethod = typeof(GCHandle).GetMethod(nameof(GCHandle.FromIntPtr))!;
private static readonly PropertyInfo _handleTargetProp = typeof(GCHandle).GetProperty(nameof(GCHandle.Target))!; private static readonly PropertyInfo _handleTargetProp = typeof(GCHandle).GetProperty(nameof(GCHandle.Target))!;
@ -46,21 +46,15 @@ public unsafe class IterActionGenerator
public void RunWithTryCatch(object? instance, Iterator iter) public void RunWithTryCatch(object? instance, Iterator iter)
{ {
try { GeneratedAction(instance, iter); } catch { try { GeneratedAction(instance, iter); }
Console.Error.WriteLine("Exception occured while running:"); catch { Console.Error.WriteLine(ReadableString); throw; }
Console.Error.WriteLine(" " + Method);
Console.Error.WriteLine();
Console.Error.WriteLine("Method's IL code:");
Console.Error.WriteLine(ReadableString);
Console.Error.WriteLine();
throw;
}
} }
public IterActionGenerator(Universe universe, MethodInfo method) public IterActionGenerator(Universe universe, MethodInfo method)
{ {
Universe = universe; Universe = universe;
Method = method; Method = method;
Parameters = method.GetParameters().Select(ParamInfo.Build).ToImmutableArray();
var name = "<>Query_" + string.Join("_", method.Name); var name = "<>Query_" + string.Join("_", method.Name);
var genMethod = new DynamicMethod(name, null, new[] { typeof(object), typeof(Iterator) }); var genMethod = new DynamicMethod(name, null, new[] { typeof(object), typeof(Iterator) });
@ -70,13 +64,11 @@ public unsafe class IterActionGenerator
var iteratorArg = IL.Argument<Iterator>(1); var iteratorArg = IL.Argument<Iterator>(1);
var fieldIndex = 1; var fieldIndex = 1;
var parameters = new List<(ParamInfo Info, Term? Term, ILocal? FieldLocal, ILocal? TempLocal)>(); var paramData = new List<(ParamInfo Info, Term? Term, ILocal? FieldLocal, ILocal? TempLocal)>();
foreach (var info in method.GetParameters()) { foreach (var p in Parameters) {
var p = ParamInfo.Build(info);
// If the parameter is unique, we don't create a term for it. // If the parameter is unique, we don't create a term for it.
if (p.Kind <= ParamKind.Unique) if (p.Kind <= ParamKind.Unique)
{ parameters.Add((p, null, null, null)); continue; } { paramData.Add((p, null, null, null)); continue; }
// Create a term to add to the query. // Create a term to add to the query.
var term = new Term(universe.LookupOrThrow(p.UnderlyingType)) { var term = new Term(universe.LookupOrThrow(p.UnderlyingType)) {
@ -96,7 +88,7 @@ public unsafe class IterActionGenerator
}; };
var spanType = typeof(Span<>).MakeGenericType(p.FieldType); var spanType = typeof(Span<>).MakeGenericType(p.FieldType);
var fieldLocal = IL.Local(spanType, $"{info.Name}Field"); var fieldLocal = IL.Local(spanType, $"{p.Info.Name}Field");
var tempLocal = (ILocal?)null; var tempLocal = (ILocal?)null;
switch (p.Kind) { switch (p.Kind) {
@ -104,27 +96,27 @@ public unsafe class IterActionGenerator
if (!p.ParameterType.IsValueType) break; if (!p.ParameterType.IsValueType) break;
// If a "has" or "not" parameter is a struct, we require a temporary local that // If a "has" or "not" parameter is a struct, we require a temporary local that
// we can later load onto the stack when loading the arguments for the action. // we can later load onto the stack when loading the arguments for the action.
IL.Comment($"{info.Name}Temp = default({p.ParameterType});"); IL.Comment($"{p.Info.Name}Temp = default({p.ParameterType});");
tempLocal = IL.Local(p.ParameterType); tempLocal = IL.Local(p.ParameterType);
IL.LoadAddr(tempLocal); IL.LoadAddr(tempLocal);
IL.Init(tempLocal.LocalType); IL.Init(tempLocal.LocalType);
break; break;
case ParamKind.Nullable: case ParamKind.Nullable:
IL.Comment($"{info.Name}Field = iterator.MaybeField<{p.FieldType.Name}>({fieldIndex})"); IL.Comment($"{p.Info.Name}Field = iterator.MaybeField<{p.FieldType.Name}>({fieldIndex})");
IL.Load(iteratorArg); IL.Load(iteratorArg);
IL.LoadConst(fieldIndex); IL.LoadConst(fieldIndex);
IL.Call(_iteratorMaybeFieldMethod.MakeGenericMethod(p.FieldType)); IL.Call(_iteratorMaybeFieldMethod.MakeGenericMethod(p.FieldType));
IL.Store(fieldLocal); IL.Store(fieldLocal);
IL.Comment($"{info.Name}Temp = default({p.ParameterType});"); IL.Comment($"{p.Info.Name}Temp = default({p.ParameterType});");
tempLocal = IL.Local(p.ParameterType); tempLocal = IL.Local(p.ParameterType);
IL.LoadAddr(tempLocal); IL.LoadAddr(tempLocal);
IL.Init(tempLocal.LocalType); IL.Init(tempLocal.LocalType);
break; break;
default: default:
IL.Comment($"{info.Name}Field = iterator.Field<{p.FieldType.Name}>({fieldIndex})"); IL.Comment($"{p.Info.Name}Field = iterator.Field<{p.FieldType.Name}>({fieldIndex})");
IL.Load(iteratorArg); IL.Load(iteratorArg);
IL.LoadConst(fieldIndex); IL.LoadConst(fieldIndex);
IL.Call(_iteratorFieldMethod.MakeGenericMethod(p.FieldType)); IL.Call(_iteratorFieldMethod.MakeGenericMethod(p.FieldType));
@ -132,27 +124,40 @@ public unsafe class IterActionGenerator
break; break;
} }
parameters.Add((p, term, fieldLocal, tempLocal)); paramData.Add((p, term, fieldLocal, tempLocal));
fieldIndex++; fieldIndex++;
} }
// If there's any reference type parameters, we need to define a GCHandle local. // If there's any reference type parameters, we need to define a GCHandle local.
var hasReferenceType = parameters var hasReferenceType = paramData
.Where(p => p.Info.Kind > ParamKind.Unique) .Where(p => p.Info.Kind > ParamKind.Unique)
.Any(p => !p.Info.UnderlyingType.IsValueType); .Any(p => !p.Info.UnderlyingType.IsValueType);
var handleLocal = hasReferenceType ? IL.Local<GCHandle>() : null; var handleLocal = hasReferenceType ? IL.Local<GCHandle>() : null;
IDisposable? forLoopBlock = null; var countLocal = IL.Local<int>("iter_count");
ILocal<int>? forCurrentLocal = null; var indexLocal = IL.Local<int>("iter_index");
IL.Load(iteratorArg, _iteratorCountProp);
IL.Store(countLocal);
IL.Set(indexLocal, 0);
// If all parameters are fixed, iterator count will be 0, but since // If all parameters are fixed, iterator count will be 0, but since
// the query matched, we want to run the callback at least once. // the query matched, we want to run the callback at least once.
if (parameters.Any(p => !p.Info.IsFixed)) IL.Comment("if (iter_count == 0) iter_count = 1;");
forLoopBlock = IL.For(() => IL.Load(iteratorArg, _iteratorCountProp), out forCurrentLocal); var dontIncrementLabel = IL.DefineLabel();
IL.Load(countLocal);
IL.GotoIfTrue(dontIncrementLabel);
IL.LoadConst(1);
IL.Store(countLocal);
IL.MarkLabel(dontIncrementLabel);
IL.While("IteratorLoop", (@continue) => {
IL.GotoIf(@continue, indexLocal, Comparison.LessThan, countLocal);
}, (_, _) => {
if (!Method.IsStatic) if (!Method.IsStatic)
IL.Load(instanceArg); IL.Load(instanceArg);
foreach (var (info, term, fieldLocal, tempLocal) in parameters) { foreach (var (info, term, fieldLocal, tempLocal) in paramData) {
switch (info.Kind) { switch (info.Kind) {
case ParamKind.GlobalUnique: case ParamKind.GlobalUnique:
@ -162,7 +167,7 @@ public unsafe class IterActionGenerator
case ParamKind.Unique: case ParamKind.Unique:
IL.Comment($"Unique parameter {info.ParameterType.GetFriendlyName()}"); IL.Comment($"Unique parameter {info.ParameterType.GetFriendlyName()}");
_uniqueParameters[info.ParameterType](IL, iteratorArg, forCurrentLocal!); _uniqueParameters[info.ParameterType](IL, iteratorArg, indexLocal!);
break; break;
case ParamKind.Has or ParamKind.Not: case ParamKind.Has or ParamKind.Not:
@ -180,12 +185,12 @@ public unsafe class IterActionGenerator
if (info.IsByRef) { if (info.IsByRef) {
IL.LoadAddr(fieldLocal!); IL.LoadAddr(fieldLocal!);
if (info.IsFixed) IL.LoadConst(0); if (info.IsFixed) IL.LoadConst(0);
else IL.Load(forCurrentLocal!); else IL.Load(indexLocal!);
IL.Call(spanItemMethod); IL.Call(spanItemMethod);
} else if (info.IsRequired) { } else if (info.IsRequired) {
IL.LoadAddr(fieldLocal!); IL.LoadAddr(fieldLocal!);
if (info.IsFixed) IL.LoadConst(0); if (info.IsFixed) IL.LoadConst(0);
else IL.Load(forCurrentLocal!); else IL.Load(indexLocal!);
IL.Call(spanItemMethod); IL.Call(spanItemMethod);
IL.LoadObj(info.FieldType); IL.LoadObj(info.FieldType);
} else { } else {
@ -196,7 +201,7 @@ public unsafe class IterActionGenerator
IL.GotoIfFalse(elseLabel); IL.GotoIfFalse(elseLabel);
IL.LoadAddr(fieldLocal!); IL.LoadAddr(fieldLocal!);
if (info.IsFixed) IL.LoadConst(0); if (info.IsFixed) IL.LoadConst(0);
else IL.Load(forCurrentLocal!); else IL.Load(indexLocal!);
IL.Call(spanItemMethod); IL.Call(spanItemMethod);
IL.LoadObj(info.FieldType); IL.LoadObj(info.FieldType);
if (info.Kind == ParamKind.Nullable) if (info.Kind == ParamKind.Nullable)
@ -222,12 +227,12 @@ public unsafe class IterActionGenerator
} }
IL.Call(Method); IL.Call(Method);
forLoopBlock?.Dispose(); IL.Increment(indexLocal);
});
IL.Return(); IL.Return();
Parameters = parameters.Select(p => p.Info).ToImmutableList(); Terms = paramData.Where(p => p.Term != null).Select(p => p.Term!).ToImmutableList();
Terms = parameters.Where(p => p.Term != null).Select(p => p.Term!).ToImmutableList();
GeneratedAction = genMethod.CreateDelegate<Action<object?, Iterator>>(); GeneratedAction = genMethod.CreateDelegate<Action<object?, Iterator>>();
ReadableString = IL.ToReadableString(); ReadableString = IL.ToReadableString();
} }

Loading…
Cancel
Save