﻿/******************************************************************************/
/*
  Project   - MudBun
  Publisher - Long Bunny Labs
              http://LongBunnyLabs.com
  Author    - Ming-Lun "Allen" Chou
              http://AllenChou.net
*/
/******************************************************************************/

#pragma kernel generate_flat_marching_cubes_mesh
#pragma kernel generate_smooth_marching_cubes_mesh
#pragma kernel generate_marching_splats
#pragma kernel update_marching_cubes_auto_smooth_indirect_dispatch_args
#pragma kernel marching_cubes_update_auto_smooth
#pragma kernel marching_cubes_compute_auto_smooth


#include "../../Shader/ComputeCommon.cginc"

#include "../../Shader/AutoSmoothFuncs.cginc"
#include "../../Shader/BrushFuncs.cginc"
#include "../../Shader/GenPointDefs.cginc"
#include "../../Shader/IndirectArgsDefs.cginc"
#include "../../Shader/MarchingCubesFuncs.cginc"
#include "../../Shader/Math/MathConst.cginc"
#include "../../Shader/MeshingModeDefs.cginc"
#include "../../Shader/NormalFuncs.cginc"
#include "../../Shader/RenderModeDefs.cginc"
#include "../../Shader/VoxelFuncs.cginc"
#include "../../Shader/VoxelCacheFuncs.cginc"

[numthreads(kThreadGroupSize, 1, 1)]
void generate_flat_marching_cubes_mesh(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_MARCHING_CUBES_FLAT_MESH)
  return;
#endif

  uint iNode = uint(id.x);
  if (iNode >= uint(aNumNodesAllocated[currentNodeDepth + 1]))
    return;

  for (int i = 1; i <= currentNodeDepth; ++i)
    iNode += aNumNodesAllocated[i];
  if (iNode >= nodePoolSize)
    return;

  int iBrushMask = get_brush_mask_index(iNode);
  SdfBrushMaterial cubeMat;
  MARCHING_CUBES(nodePool[iNode].center, currentNodeSize, sdf_masked_brushes, iBrushMask, false, cubeMat, 
    // tStmtPre
    int iVertBase = 0;
    InterlockedAdd(indirectDrawArgs[0], 3, iVertBase);
    , 
    // vStmt
    aGenPoint[iVertBase + iVert].posNorm = float4(aVertPos[iVert], pack_normal(aVertNorm[iVert]));
    aGenPoint[iVertBase + iVert].material = pack_material(cubeMat);
    aGenPoint[iVertBase + iVert].vertId = auto_smooth_vert_data_id(aEdgeCenter[iVert]);
    aGenPoint[iVertBase + iVert].iBrushMask = iBrushMask;
    , 
    // tStmtPost
    { }
  );
}

[numthreads(kThreadGroupSize, 1, 1)]
void generate_smooth_marching_cubes_mesh(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_MARCHING_CUBES_SMOOTH_MESH)
  return;
#endif

  uint iNode = uint(id.x);
  if (iNode >= uint(aNumNodesAllocated[currentNodeDepth + 1]))
    return;

  for (int i = 1; i <= currentNodeDepth; ++i)
    iNode += aNumNodesAllocated[i];
  if (iNode >= nodePoolSize)
    return;

  int iBrushMask = get_brush_mask_index(iNode);
  SdfBrushMaterial cubeMat;
  MARCHING_CUBES(nodePool[iNode].center, currentNodeSize, sdf_masked_brushes, iBrushMask, true, cubeMat, 
    // tStmtPre
    int iVertBase = 0;
    InterlockedAdd(indirectDrawArgs[0], 3, iVertBase);
    , 
    // vStmt
    aGenPoint[iVertBase + iVert].posNorm = float4(aVertPos[iVert], pack_normal(aVertNorm[iVert]));
    aGenPoint[iVertBase + iVert].material = pack_material(aVertMat[iVert]);
    aGenPoint[iVertBase + iVert].iBrushMask = iBrushMask;
    , 
    // tStmtPost
    { }
  );
}

[numthreads(kThreadGroupSize, 1, 1)]
void generate_marching_splats(int3 id : SV_DispatchThreadID)
{
#if defined(MUDBUN_DISABLE_MARCHING_CUBES_SPLATS)
  return;
#endif

  uint iNode = uint(id.x);
  if (iNode >= uint(aNumNodesAllocated[currentNodeDepth + 1]))
    return;

  for (int i = 1; i <= currentNodeDepth; ++i)
    iNode += aNumNodesAllocated[i];
  if (iNode >= nodePoolSize)
    return;

  int numTris = 0;
  float3 avgPos = 0.0f;
  float3 avgNorm = 0.0f;
  float avgWeight = 0.0f;
  int iBrushMask = get_brush_mask_index(iNode);
  SdfBrushMaterial cubeMat;
  MARCHING_CUBES(nodePool[iNode].center, currentNodeSize, sdf_masked_brushes, iBrushMask, false, cubeMat,
    // tStmtPre
    { }
    ,
    // vStmt
    { }
    ,
    // tStmtPost
    float3 c = cross(aVertPos[1] - aVertPos[0], aVertPos[2] - aVertPos[0]);
    float w = length(c);
    avgPos += w * (aVertPos[0] + aVertPos[1] + aVertPos[2]);
    if (w > kEpsilon)
      avgNorm += w * normalize(c);
    avgWeight += w;
    ++numTris;
  );

  if (numTris > 0)
  {
    avgPos /= avgWeight * 3.0f;
    avgNorm = normalize(avgNorm);

    int iVertBase;
    int iGenPoint = 0;
    switch (renderMode)
    {
      case kRenderModeCircleSplats:
        InterlockedAdd(indirectDrawArgs[0], 3, iVertBase);
        iGenPoint = uint(iVertBase) / 3;
        break;
      case kRenderModeQuadSplats:
        InterlockedAdd(indirectDrawArgs[0], 6, iVertBase);
        iGenPoint = uint(iVertBase) / 6;
        break;
    }

    float scaleMult = pow(saturate(avgWeight / (0.2f * voxelSize * voxelSize)), 0.1f);
    cubeMat.metallicSmoothnessSizeTightness.z *= scaleMult;

    aGenPoint[iGenPoint].posNorm = float4(avgPos, pack_normal(avgNorm));
    aGenPoint[iGenPoint].material = pack_material(cubeMat);
    aGenPoint[iGenPoint].iBrushMask = iBrushMask;
  }
}

[numthreads(1, 1, 1)]
void update_marching_cubes_auto_smooth_indirect_dispatch_args(int3 id : SV_DispatchThreadID)
{
  indirectDispatchArgs[0] = max(1, uint(indirectDrawArgs[0] + kThreadGroupSize - 1) / kThreadGroupSize);
}

[numthreads(kThreadGroupSize, 1, 1)]
void marching_cubes_update_auto_smooth(int3 id : SV_DispatchThreadID)
{
  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  uint iTriBase = iGenPoint - (iGenPoint % 3);

  float3 pos0 = aGenPoint[iTriBase    ].posNorm.xyz;
  float3 pos1 = aGenPoint[iTriBase + 1].posNorm.xyz;
  float3 pos2 = aGenPoint[iTriBase + 2].posNorm.xyz;

  float3 v01 = pos1 - pos0;
  float3 v12 = pos2 - pos1;
  float3 c = cross(v01, v12);
  float area = length(c);

  update_auto_smooth_vert_data(aGenPoint[iGenPoint].vertId, aGenPoint[iGenPoint].posNorm.w, area);
}

[numthreads(kThreadGroupSize, 1, 1)]
void marching_cubes_compute_auto_smooth(int3 id : SV_DispatchThreadID)
{
  if (id.x >= indirectDrawArgs[0])
    return;

  uint iGenPoint = uint(id.x);
  float3 pos = aGenPoint[iGenPoint].posNorm.xyz;
  int iBrushMask = aGenPoint[iGenPoint].iBrushMask;

  SdfBrushMaterial mat;
  sdf_masked_brushes(pos, iBrushMask, mat);

  float3 autoSmoothNormal = compute_auto_smooth_normal(aGenPoint[iGenPoint].vertId, unpack_normal(aGenPoint[iGenPoint].posNorm.w));

  aGenPoint[iGenPoint].posNorm.w = pack_normal(autoSmoothNormal);
  aGenPoint[iGenPoint].material = pack_material(mat);
  //aGenPoint[iGenPoint].material.color = pack_rgba(float4(n, 1.0f));
}

