﻿/******************************************************************************/
/*
  Project   - MudBun
  Publisher - Long Bunny Labs
              http://LongBunnyLabs.com
  Author    - Ming-Lun "Allen" Chou
              http://AllenChou.net
*/
/******************************************************************************/

#pragma kernel generate_dual_quads
#pragma kernel update_dual_meshing_indirect_dispatch_args
#pragma kernel dual_meshing_flat_mesh_normal
#pragma kernel dual_meshing_smooth_mesh_normal
#pragma kernel dual_meshing_update_auto_smooth
#pragma kernel dual_meshing_compute_auto_smooth
#pragma kernel dual_meshing_udpate_auto_smooth_SmoothCorner_indirect_dispatch_args
#pragma kernel dual_meshing_auto_smooth_smooth_corner
#pragma kernel update_dual_meshing_splats_indirect_args
#pragma kernel convert_dual_meshing_splats

#include "../../Shader/ComputeCommon.cginc"

#include "../../Shader/AutoSmoothFuncs.cginc"
#include "../../Shader/BrushFuncs.cginc"
#include "../../Shader/DualMeshingFuncs.cginc"
#include "../../Shader/GenPointDefs.cginc"
#include "../../Shader/IndirectArgsDefs.cginc"
#include "../../Shader/Math/Codec.cginc"
#include "../../Shader/Math/MathConst.cginc"
#include "../../Shader/MeshingModeDefs.cginc"
#include "../../Shader/NormalFuncs.cginc"
#include "../../Shader/RenderModeDefs.cginc"
#include "../../Shader/VoxelFuncs.cginc"
#include "../../Shader/VoxelCacheFuncs.cginc"

[numthreads(kThreadGroupSize, 1, 1)]
void generate_dual_quads(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_ALL)
  return;
#endif

  uint iNode = uint(id.x);
  if (iNode >= uint(aNumNodesAllocated[currentNodeDepth + 1]))
    return;

  for (int i = 1; i <= currentNodeDepth; ++i)
    iNode += aNumNodesAllocated[i];
  if (iNode >= nodePoolSize)
    return;

  float3 nodeCenter = nodePool[iNode].center;
  float halfNodeSize = 0.5f * voxelSize;
  float3 minCorner = nodeCenter - halfNodeSize;
  int iBrushMask = get_brush_mask_index(iNode);

  float3 aAxis0[3] = { kUnitX, kUnitY, kUnitZ };
  float3 aAxis1[3] = { kUnitY, kUnitZ, kUnitX };
  float3 aAxis2[3] = { kUnitZ, kUnitX, kUnitY };
  float3 aCornerOffset[3] =
  {
    float3(voxelSize, 0.0f, 0.0f),
    float3(0.0f, voxelSize, 0.0f),
    float3(0.0f, 0.0f, voxelSize),
  };

  SdfBrushMaterial mat = init_brush_material();

  float3 aSamplePoint[8] =
  {
    nodeCenter, 
    minCorner, 
    minCorner + aCornerOffset[0], 
    minCorner + aCornerOffset[1], 
    minCorner + aCornerOffset[2], 
    minCorner + aCornerOffset[0] + 1e-2f * aAxis0[0], 
    minCorner + aCornerOffset[1] + 1e-2f * aAxis0[1], 
    minCorner + aCornerOffset[2] + 1e-2f * aAxis0[2], 
  };
  float aRes[8];
  [loop] for (int iSample = 0; iSample < 8; ++iSample)
  {
    if (iSample < 5
        || aRes[1] == 0.0f      // min corner res
        || aRes[clamp(iSample - 3, 0, 7)] == 0.0f) // axis corner res)
    {
      // doubles generated assemblies, but gets better performance in return
      aRes[iSample] = sdf_masked_brushes(aSamplePoint[iSample], iBrushMask, mat);
    }
  }

  bool emitSplats =
    meshingMode == kMeshingModeDualQuads
    && (renderMode == kRenderModeCircleSplats || renderMode == kRenderModeQuadSplats);

  if (emitSplats && renderMode == kRenderModeQuadSplats)
      mat.metallicSmoothnessSizeTightness.z *= 0.70711f;

  float centerRes = aRes[0];
  SdfBrushMaterialCompressed packedMat = pack_material(mat);

  float minCornerRes = aRes[1];

  int iGenPoint = 0;

  float aCornerRes[3] = { aRes[2], aRes[3], aRes[4] };
  float aMinCornerDeltaRes[3] = { aRes[5], aRes[6], aRes[7] };
 
  [loop] for (int iAxis = 0; iAxis < 3; ++iAxis)
  {
    float3 axis = aAxis0[iAxis];
    float3 corner = minCorner + aCornerOffset[iAxis];
    float cornerRes = aCornerRes[iAxis];
    float s = sign(minCornerRes * cornerRes);
    if (s <= 0.0f)
    {
      // quad center & half extent vectors
      float3 c = 0.5f * (minCorner + corner);
      float3 h1 = halfNodeSize * aAxis1[iAxis];
      float3 h2 = halfNodeSize * (minCornerRes <= 0.0f ? 1.0f : -1) * aAxis2[iAxis];
      float packedNorm = pack_normal((minCornerRes <= 0.0f ? 1 : -1) * axis);
      if (s == 0.0f)
      {
        float nRes = aMinCornerDeltaRes[iAxis];
        packedNorm = pack_normal(dot(nRes - minCornerRes, axis) > 0.0f ? axis : -axis);
      }

      if (!emitSplats)
      {
        // normal dual quads
        InterlockedAdd(indirectDrawArgs[0], 6, iGenPoint);
        float3 aQuadVertOffset[2][6] = 
        {
          { - h1 - h2, + h1 - h2, + h1 + h2, - h1 - h2, + h1 + h2, - h1 + h2, }, 
          { + h1 - h2, + h1 + h2, - h1 + h2, + h1 - h2, - h1 + h2, - h1 - h2, }, 
        };
        float3 qCenter = round((nodeCenter - 0.5f * voxelSize) / voxelSize);
        int iaQuadVertOffset = 0;//(uint(int(qCenter.x + qCenter.y + qCenter.z) + 0x80000000) % 2 == 0) ? 0 : 1;
        for (int iVert = 0; iVert < 6; ++iVert, ++iGenPoint)
        {
          float3 pos = c + aQuadVertOffset[iaQuadVertOffset][iVert];
          aGenPoint[iGenPoint].posNorm = float4(pos, packedNorm);
          aGenPoint[iGenPoint].iBrushMask = iBrushMask;
          aGenPoint[iGenPoint].material = packedMat;
          aGenPoint[iGenPoint].vertId = auto_smooth_vert_data_id(pos);
        }
      }
      else
      {
        // splats
        int iVertBase;
        switch (renderMode)
        {
        case kRenderModeCircleSplats:
          InterlockedAdd(indirectDrawArgs[0], 3, iVertBase);
          iGenPoint = uint(iVertBase) / 3;
          break;
        case kRenderModeQuadSplats:
          InterlockedAdd(indirectDrawArgs[0], 6, iVertBase);
          iGenPoint = uint(iVertBase) / 6;
          break;
        }
        aGenPoint[iGenPoint].posNorm = float4(c, packedNorm);
        aGenPoint[iGenPoint].material = packedMat;
      }
    }
  }
}

[numthreads(1, 1, 1)]
void update_dual_meshing_indirect_dispatch_args(int3 id : SV_DispatchThreadID)
{
  indirectDispatchArgs[0] = max(1, uint(indirectDrawArgs[0] + kThreadGroupSize - 1) / kThreadGroupSize);
}

[numthreads(kThreadGroupSize, 1, 1)]
void dual_meshing_flat_mesh_normal(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_FLAT_MESH)
  return;
#endif

  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  uint iTriBase = iGenPoint - (iGenPoint % 3);

  float3 pos0 = aGenPoint[iTriBase    ].posNorm.xyz;
  float3 pos1 = aGenPoint[iTriBase + 1].posNorm.xyz;
  float3 pos2 = aGenPoint[iTriBase + 2].posNorm.xyz;

  float3 v01 = pos1 - pos0;
  float3 v02 = pos2 - pos0;
  float3 n = normalize_safe(cross(v01, v02), 0.0f);
  float pn = pack_normal(n);

  aGenPoint[iGenPoint].posNorm.w = pn;
}

[numthreads(kThreadGroupSize, 1, 1)]
void dual_meshing_smooth_mesh_normal(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_SMOOTH_MESH)
  return;
#endif

  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  float3 pos = aGenPoint[iGenPoint].posNorm.xyz;
  int iBrushMask = aGenPoint[iGenPoint].iBrushMask;

  SdfBrushMaterial mat;
  sdf_masked_brushes(pos, iBrushMask, mat);

  float3 n;
  SDF_NORMAL(n, pos, sdf_masked_brushes, iBrushMask, normalDifferentiationStep);

  aGenPoint[iGenPoint].posNorm.w = pack_normal(n);
  aGenPoint[iGenPoint].material = pack_material(mat);
}

[numthreads(kThreadGroupSize, 1, 1)]
void dual_meshing_update_auto_smooth(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_ALL)
  return;
#endif

  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  uint iTriBase = iGenPoint - (iGenPoint % 3);

  float3 pos0 = aGenPoint[iTriBase].posNorm.xyz;
  float3 pos1 = aGenPoint[iTriBase + 1].posNorm.xyz;
  float3 pos2 = aGenPoint[iTriBase + 2].posNorm.xyz;

  float3 v01 = pos1 - pos0;
  float3 v02 = pos2 - pos0;
  float3 c = cross(v01, v02);
  float3 n = normalize_safe(c, 0.0f);
  float pn = pack_normal(n);
  float area = length(c);

  aGenPoint[iGenPoint].posNorm.w = pn;
  update_auto_smooth_vert_data(aGenPoint[iGenPoint].vertId, pn, area);
}

[numthreads(kThreadGroupSize, 1, 1)]
void dual_meshing_compute_auto_smooth(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_ALL)
  return;
#endif

  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  float3 pos = aGenPoint[iGenPoint].posNorm.xyz;
  int iBrushMask = aGenPoint[iGenPoint].iBrushMask;

  SdfBrushMaterial mat;
  sdf_masked_brushes(pos, iBrushMask, mat);

  float3 autoSmoothNormal = compute_auto_smooth_normal(aGenPoint[iGenPoint].vertId, unpack_normal(aGenPoint[iGenPoint].posNorm.w));
  bool atSmoothEdge = false;
  if (enableSmoothCorner)
  {
    float3 blurredNormal;
    SDF_NORMAL(blurredNormal, aGenPoint[iGenPoint].posNorm.xyz, sdf_masked_brushes, iBrushMask, smoothCornerNormalBlur);
    atSmoothEdge = abs(angle_between(blurredNormal, autoSmoothNormal)) < 0.25f * autoSmoothMaxAngle;
  }

  aGenPoint[iGenPoint].posNorm.w = pack_normal(autoSmoothNormal);
  aGenPoint[iGenPoint].atSmoothEdge = int(atSmoothEdge);
  aGenPoint[iGenPoint].material = pack_material(mat);
  //aGenPoint[iGenPoint].material.color = pack_rgba(float4(n, 1.0f));
}

[numthreads(1, 1, 1)]
void dual_meshing_udpate_auto_smooth_SmoothCorner_indirect_dispatch_args(int3 id : SV_DispatchThreadID)
{
  indirectDispatchArgs[0] = max(1, uint(uint(indirectDrawArgs[0]) / 3 + kThreadGroupSize - 1) / kThreadGroupSize);
}

[numthreads(kThreadGroupSize, 1, 1)]
void dual_meshing_auto_smooth_smooth_corner(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_ALL)
  return;
#endif

  if (uint(id.x) >= uint(indirectDrawArgs[0]) / 3)
    return;

  uint iTriBase = uint(id.x) * 3;
  bool smooth0 = (aGenPoint[iTriBase + 0].atSmoothEdge != 0);
  bool smooth1 = (aGenPoint[iTriBase + 1].atSmoothEdge != 0);
  bool smooth2 = (aGenPoint[iTriBase + 2].atSmoothEdge != 0);
  if (smooth0 && smooth1 && smooth2)
    return;

  int iBrushMask = aGenPoint[iTriBase].iBrushMask;

  uint iTri0 = iTriBase;
  uint iTri1 = iTriBase + 1;
  uint iTri2 = iTriBase + 2;

  GenPoint gpCopy = aGenPoint[iTri0];
  float3 p0 = aGenPoint[iTri0].posNorm.xyz;
  float3 p1 = aGenPoint[iTri1].posNorm.xyz;
  float3 p2 = aGenPoint[iTri2].posNorm.xyz;
  float3 n0 = unpack_normal(aGenPoint[iTri0].posNorm.w);
  float3 n1 = unpack_normal(aGenPoint[iTri1].posNorm.w);
  float3 n2 = unpack_normal(aGenPoint[iTri2].posNorm.w);
  float4 c0 = unpack_rgba(aGenPoint[iTri0].material.color);
  float4 c1 = unpack_rgba(aGenPoint[iTri1].material.color);
  float4 c2 = unpack_rgba(aGenPoint[iTri2].material.color);
  float4 e0 = unpack_rgba(aGenPoint[iTri0].material.emissionTightness);
  float4 e1 = unpack_rgba(aGenPoint[iTri1].material.emissionTightness);
  float4 e2 = unpack_rgba(aGenPoint[iTri2].material.emissionTightness);
  float2 m0 = unpack_saturated(aGenPoint[iTri0].material.metallicSmoothness);
  float2 m1 = unpack_saturated(aGenPoint[iTri1].material.metallicSmoothness);
  float2 m2 = unpack_saturated(aGenPoint[iTri2].material.metallicSmoothness);
  float3 p01 = p1 - p0;
  float3 p02 = p2 - p0;
  float3 p12 = p2 - p1;
  float3 n01 = n1 - n0;
  float3 n02 = n2 - n0;
  float3 n12 = n2 - n1;
  float4 c01 = c1 - c0;
  float4 c02 = c2 - c0;
  float4 c12 = c2 - c1;
  float4 e01 = e1 - e0;
  float4 e02 = e2 - e0;
  float4 e12 = e2 - e1;
  float2 m01 = m1 - m0;
  float2 m02 = m2 - m0;
  float2 m12 = m2 - m1;

  int n = smoothCornerSubdivision;
  float dt = 1.0f / n;
  float3 dtp02 = dt * p02;
  float3 dtp12 = dt * p12;
  float3 dtn02 = dt * n02;
  float3 dtn12 = dt * n12;
  float4 dtc02 = dt * c02;
  float4 dtc12 = dt * c12;
  float4 dte02 = dt * e02;
  float4 dte12 = dt * e12;
  float2 dtm02 = dt * m02;
  float2 dtm12 = dt * m12;
  [loop] for (int i = 0; i < n; ++i)
  {
    int iNewTriBase = iTriBase;
    float3 aVert[3];
    float3 aNorm[3];
    float4 aC[3];
    float4 aE[3];
    float2 aM[3];
    float idt = float(i) / n;
    aVert[0] = p0 + idt * p01;
    aVert[1] = aVert[0] + dt * p01;
    aVert[2] = aVert[0] + dt * p02;
    aNorm[0] = n0 + idt * n01;
    aNorm[1] = aNorm[0] + dt * n01;
    aNorm[2] = aNorm[0] + dt * n02;
    aC[0] = c0 + idt * c01;
    aC[1] = aC[0] + dt * c01;
    aC[2] = aC[0] + dt * c02;
    aE[0] = e0 + idt * e01;
    aE[1] = aE[0] + dt * e01;
    aE[2] = aE[0] + dt * e02;
    aM[0] = m0 + idt * m01;
    aM[1] = aM[0] + dt * m01;
    aM[2] = aM[0] + dt * m02;
    int jn = 2 * i + 1;
    [loop] for (int j = 0; j < jn; ++j)
    {
      bool odd = (uint(j) % 2 > 0);
      if (j > 0)
      {
        uint iVertChange = uint(2 + i * 3 - j) % 3;
        float s12 = (odd ? 2.0f : 1.0f);
        float s02 = (odd ? -1.0f : 1.0f);
        aVert[iVertChange] += s12 * dtp12 + s02 * dtp02;
        aNorm[iVertChange] += s12 * dtn12 + s02 * dtn02;
        aC[iVertChange] += s12 * dtc12 + s02 * dtc02;
        aE[iVertChange] += s12 * dte12 + s02 * dte02;
        aM[iVertChange] += s12 * dtm12 + s02 * dtm02;
      }
      if (i > 0)
      {
        InterlockedAdd(indirectDrawArgs[0], 3, iNewTriBase);
      }
      int aiProp[2][3] = { { 0, 1, 2}, {0, 2, 1 } };
      [loop] for (int m = 0; m < 3; ++m)
      {
        int iProp = aiProp[odd ? 1 : 0][m];
        int iNewGenPoint = iNewTriBase + m;
        aGenPoint[iNewGenPoint] = gpCopy;
        aGenPoint[iNewGenPoint].posNorm.xyz = aVert[iProp];
        aGenPoint[iNewGenPoint].posNorm.w = pack_normal(normalize(aNorm[iProp]));
        aGenPoint[iNewGenPoint].material.color = pack_rgba(aC[iProp]);
        aGenPoint[iNewGenPoint].material.emissionTightness = pack_rgba(aE[iProp]);
        aGenPoint[iNewGenPoint].material.metallicSmoothness = pack_saturated(aM[iProp]);
      }
      [loop] for (int k = 0; k < 3; ++k)
      {
        float3 blurredNormal;
        SDF_NORMAL(blurredNormal, aGenPoint[iNewTriBase + k].posNorm.xyz, sdf_masked_brushes, iBrushMask, smoothCornerNormalBlur);
        float3 vertNormal = unpack_normal(aGenPoint[iNewTriBase + k].posNorm.w);
        if (autoSmoothMaxAngle > kEpsilon)
        {
          float t = saturate((abs(angle_between(vertNormal, blurredNormal))) / max(kEpsilon, kPi * smoothCornerFade));
          float3 smoothCornerNormal = normalize(lerp(vertNormal, blurredNormal, t));
          aGenPoint[iNewTriBase + k].posNorm.w = pack_normal(smoothCornerNormal);
        }
      }
    }
  }
}

[numthreads(1, 1, 1)]
void update_dual_meshing_splats_indirect_args(int3 id : SV_DispatchThreadID)
{
  int numSplats = uint(indirectDrawArgs[0]) / 6;
  switch (renderMode)
  {
    case kRenderModeCircleSplats:
      indirectDrawArgs[0] = numSplats * 3;
      break;
  }

  indirectDispatchArgs[0] = max(1, uint(numSplats + kThreadGroupSize - 1) / kThreadGroupSize);
}

[numthreads(kThreadGroupSize, 1, 1)]
void convert_dual_meshing_splats(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_DUAL_MESHING_SPLATS)
  return;
#endif

  int maxSplats = 0;
  switch (renderMode)
  {
    case kRenderModeCircleSplats:
      maxSplats = uint(indirectDrawArgs[0]) / 3;
      break;
    case kRenderModeQuadSplats:
      maxSplats = uint(indirectDrawArgs[0]) / 6;
      break;
  }

  int iSplat = id.x;
  if (id.x >= maxSplats)
    return;

  int iDualQuadBase = iSplat * 6;

  float3 v0 = aGenPoint[iDualQuadBase    ].posNorm.xyz;
  float3 v1 = aGenPoint[iDualQuadBase + 1].posNorm.xyz;
  float3 v2 = aGenPoint[iDualQuadBase + 2].posNorm.xyz;
  float3 v3 = aGenPoint[iDualQuadBase + 3].posNorm.xyz;
  float3 v4 = aGenPoint[iDualQuadBase + 4].posNorm.xyz;
  float3 v5 = aGenPoint[iDualQuadBase + 5].posNorm.xyz;

  float3 v01 = v1 - v0;
  float3 v02 = v2 - v0;
  float3 v34 = v4 - v3;
  float3 v35 = v5 - v3;
  float3 c012 = cross(v01, v02);
  float3 c345 = cross(v34, v35);
  float3 n012 = normalize_safe(c012);
  float3 n345 = normalize_safe(c345);
  float a012 = max(kEpsilon, abs(length(c012)));
  float a345 = max(kEpsilon, abs(length(c345)));
  float aTotal = a012 + a345;

  float3 pos = (a012 * (v0 + v1 + v2) + a345 * (v3 + v4 + v5)) / (3.0f * aTotal);
  float3 norm = normalize_safe(a012 * n012 + a345 * n345) / aTotal;

  aGenPoint[iDualQuadBase].posNorm = float4(pos, pack_normal(norm));

  float scaleMult = pow(saturate(aTotal / (0.2f * voxelSize * voxelSize)), 0.1f);
  aGenPoint[iDualQuadBase].material.size *= scaleMult;
}

