﻿/******************************************************************************/
/*
  Project   - MudBun
  Publisher - Long Bunny Labs
              http://LongBunnyLabs.com
  Author    - Ming-Lun "Allen" Chou
              http://AllenChou.net
*/
/******************************************************************************/

#pragma kernel dual_contouring_move_point

#include "../../Shader/ComputeCommon.cginc"

#include "../../Shader/BrushFuncs.cginc"
#include "../../Shader/DualMeshingFuncs.cginc"
#include "../../Shader/GenPointDefs.cginc"
#include "../../Shader/IndirectArgsDefs.cginc"
#include "../../Shader/Math/MathConst.cginc"
#include "../../Shader/MeshingModeDefs.cginc"
#include "../../Shader/NormalFuncs.cginc"
#include "../../Shader/RenderModeDefs.cginc"
#include "../../Shader/VoxelFuncs.cginc"
#include "../../Shader/VoxelCacheFuncs.cginc"

// https://www.boristhebrave.com/2018/04/15/dual-contouring-tutorial/
// https://www.mattkeeter.com/projects/contours/

float dualContouringDualQuadsBlend;
float dualContouringRelaxation;
int dualContouringBinarySearchIterations;
int dualContouringGradientDescentIterations;
float dualContouringGradientDescentFactor;

[numthreads(kThreadGroupSize, 1, 1)]
void dual_contouring_move_point(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_CONTOURING)
  return;
#endif

  if (id.x >= indirectDrawArgs[0])
    return;

  int iGenPoint = id.x;

  if (dualContouringDualQuadsBlend >= 1.0f)
  {
    if (renderMode == kRenderModeQuadSplats)
      aGenPoint[iGenPoint].material.size *= 0.70711f;
    return;
  }

  float h = 0.5f * voxelSize;
  float3 center = aGenPoint[iGenPoint].posNorm.xyz;
  float3 minCornerOffset = -h;
  int iBrushMask = aGenPoint[iGenPoint].iBrushMask;
  
  float3 aCornerOffset[8] = 
  {
    float3(-h, -h, -h), 
    float3( h, -h, -h), 
    float3(-h,  h, -h), 
    float3( h,  h, -h), 
    float3(-h, -h,  h), 
    float3( h, -h,  h), 
    float3(-h,  h,  h), 
    float3( h,  h,  h), 
  };

  int aEdgeVertIndex[12][2] = 
  {
    { 0, 1 }, 
    { 1, 5 }, 
    { 5, 4 }, 
    { 4, 0 }, 
    { 2, 3 }, 
    { 3, 7 }, 
    { 7, 6 }, 
    { 6, 2 }, 
    { 0, 2 }, 
    { 1, 3 }, 
    { 5, 7 }, 
    { 4, 6 }, 
  };

  float aCornerRes[8];
  SdfBrushMaterial mat;
  [loop] for (int iCorner = 0; iCorner < 8; ++iCorner)
  {
    aCornerRes[iCorner] = sdf_masked_brushes(center + aCornerOffset[iCorner], iBrushMask, mat);
  }

  float3 avgEdgeOffset = 0.0f;
  int numEdges = 0;
  float3 aEdgeOffset[12];
  for (int iEdge = 0; iEdge < 12; ++iEdge)
  {
    int iVert0 = aEdgeVertIndex[iEdge][0];
    int iVert1 = aEdgeVertIndex[iEdge][1];
    float res0 = aCornerRes[iVert0];
    float res1 = aCornerRes[iVert1];
    if (res0 * res1 > 0)
      continue;

    float3 offset0 = aCornerOffset[iVert0];
    float3 offset1 = aCornerOffset[iVert1];

    int iEdgeOutput = numEdges++;

    float3 edgeOffset;
    if (res0 == 0.0f && res1 == 0.0f)
    {
      edgeOffset = 0.5f * (offset0 + offset1);
    }
    else if (res0 == 0.0f)
    {
      edgeOffset = offset0;
    }
    else if (res1 == 0.0f)
    {
      edgeOffset = offset1;
    }
    else if (dualContouringBinarySearchIterations <= 0)
    {
      // lerp approximation
      float t = -res0 / (res1 - res0);
      edgeOffset = lerp(offset0, offset1, t);
    }
    else
    {
      // binary search
      edgeOffset = 0.5f * (offset0 + offset1);
      [loop] for (int iSearch = 0; iSearch < dualContouringBinarySearchIterations; ++iSearch)
      {
        float resT = sdf_masked_brushes(center + edgeOffset, iBrushMask, mat);
        if (res0 * resT < 0.0f)
        {
          res1 = resT;
          offset1 = edgeOffset;
        }
        else if (resT * res1 < 0.0f)
        {
          res0 = resT;
          offset0 = edgeOffset;
        }
        edgeOffset = 0.5f * (offset0 + offset1);
      }
    }

    avgEdgeOffset += edgeOffset;
    aEdgeOffset[iEdgeOutput] = edgeOffset;
  }

  if (numEdges <= 0)
    return;

  avgEdgeOffset /= numEdges;

  float3 aEdgeNorm[12];
  [loop] for (int iEdgeOffset = 0; iEdgeOffset < numEdges; ++iEdgeOffset)
  {
    SDF_NORMAL_FULL(aEdgeNorm[iEdgeOffset], center + aEdgeOffset[iEdgeOffset], sdf_masked_brushes, iBrushMask, 1e-2f * voxelSize);
  }

  // minimize ||A * x - b||^2
  float3x3 A[5];
  float3 b[5];
  for (int i = 0; i < 5; ++i)
  {
    A[i] = 0.0f;
    b[i] = 0.0f;
  }
  {
    int r = 0;
    float dualContouringRelaxationComp = 1.0f - dualContouringRelaxation;
    while (r < numEdges)
    {
      uint i = uint(r) / 3;
      switch (uint(r) % 3)
      {
        case 0:
          A[i]._m00_m01_m02 = dualContouringRelaxationComp * aEdgeNorm[r];
          b[i].x = dualContouringRelaxationComp * dot(aEdgeNorm[r], aEdgeOffset[r]);
          break;
        case 1:
          A[i]._m10_m11_m12 = dualContouringRelaxationComp * aEdgeNorm[r];
          b[i].y = dualContouringRelaxationComp * dot(aEdgeNorm[r], aEdgeOffset[r]);
          break;
        case 2:
          A[i]._m20_m21_m22 = dualContouringRelaxationComp * aEdgeNorm[r];
          b[i].z = dualContouringRelaxationComp * dot(aEdgeNorm[r], aEdgeOffset[r]);
          break;
      }
      ++r;
    }
    A[4]._m00_m01_m02 = float3(dualContouringRelaxation, 0.0f, 0.0f);
    A[4]._m10_m11_m12 = float3(0.0f, dualContouringRelaxation, 0.0f);
    A[4]._m20_m21_m22 = float3(0.0f, 0.0f, dualContouringRelaxation);
    b[4] = dualContouringRelaxation * avgEdgeOffset;
  }

  // pseudoinverse
  float3x3 pInvA[5];
  {
    // tempA = (A' * A)^-1
    float3x3 tempA = 0.0f;
    for (int iA = 0; iA < 5; ++iA)
    {
      tempA += mul(transpose(A[iA]), A[iA]);
    }
    float3x3 tempACopy = tempA;
    float det = 
      tempACopy._m00 * (tempACopy._m11 * tempACopy._m22 - tempACopy._m21 * tempACopy._m12) - 
      tempACopy._m01 * (tempACopy._m10 * tempACopy._m22 - tempACopy._m12 * tempACopy._m20) + 
      tempACopy._m02 * (tempACopy._m10 * tempACopy._m21 - tempACopy._m11 * tempACopy._m20);
    float detInv = 1.0f / det;
    tempA._m00 = (tempACopy._m11 * tempACopy._m22 - tempACopy._m21 * tempACopy._m12) * detInv;
    tempA._m01 = (tempACopy._m02 * tempACopy._m21 - tempACopy._m01 * tempACopy._m22) * detInv;
    tempA._m02 = (tempACopy._m01 * tempACopy._m12 - tempACopy._m02 * tempACopy._m11) * detInv;
    tempA._m10 = (tempACopy._m12 * tempACopy._m20 - tempACopy._m10 * tempACopy._m22) * detInv;
    tempA._m11 = (tempACopy._m00 * tempACopy._m22 - tempACopy._m02 * tempACopy._m20) * detInv;
    tempA._m12 = (tempACopy._m10 * tempACopy._m02 - tempACopy._m00 * tempACopy._m12) * detInv;
    tempA._m20 = (tempACopy._m10 * tempACopy._m21 - tempACopy._m20 * tempACopy._m11) * detInv;
    tempA._m21 = (tempACopy._m20 * tempACopy._m01 - tempACopy._m00 * tempACopy._m21) * detInv;
    tempA._m22 = (tempACopy._m00 * tempACopy._m11 - tempACopy._m10 * tempACopy._m01) * detInv;

    // pInvA = (A' * A)^-1 * A' = tempA * A'
    for (i = 0; i < 5; ++i)
    {
      pInvA[i] = mul(tempA, transpose(A[i]));
    }
  }

  // bestOffset = pInvA * b
  float3 bestOffset = 0.0f;
  for (int iA = 0; iA < 5; ++iA)
  {
    bestOffset += mul(pInvA[iA], b[iA]);
  }

  float3 bestPos = center + bestOffset;

  // doesn't buy us much and bloats compile time
  // gradient descent
  {
    float3 n;
    SDF_NORMAL_FULL(n, bestPos, sdf_masked_brushes, iBrushMask, 1e-2f * voxelSize);
    SdfBrushMaterial mat;
    [loop] for (int iDescent = 0; iDescent < dualContouringGradientDescentIterations; ++iDescent)
    {
      float d = sdf_masked_brushes(bestPos, iBrushMask, mat);
      bestOffset -= dualContouringGradientDescentFactor * n * d;
      bestPos = center + bestOffset;
    }
  }

  aGenPoint[iGenPoint].posNorm.xyz = lerp(bestPos, aGenPoint[iGenPoint].posNorm.xyz, dualContouringDualQuadsBlend);

  if (renderMode == kRenderModeQuadSplats)
    aGenPoint[iGenPoint].material.size *= 1.0f - (1.f - 0.70711f) * dualContouringDualQuadsBlend;
}

