gaussianSplattingMesh.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import { SubMesh } from "../subMesh.js";
  2. import { Mesh } from "../mesh.js";
  3. import { VertexData } from "../mesh.vertexData.js";
  4. import { Tools } from "../../Misc/tools.js";
  5. import { Matrix, TmpVectors, Vector2, Vector3, Quaternion } from "../../Maths/math.vector.js";
  6. import { Logger } from "../../Misc/logger.js";
  7. import { GaussianSplattingMaterial } from "../../Materials/GaussianSplatting/gaussianSplattingMaterial.js";
  8. import { RawTexture } from "../../Materials/Textures/rawTexture.js";
  9. /**
  10. * Class used to render a gaussian splatting mesh
  11. */
  12. export class GaussianSplattingMesh extends Mesh {
  13. /**
  14. * Gets the covariancesA texture
  15. */
  16. get covariancesATexture() {
  17. return this._covariancesATexture;
  18. }
  19. /**
  20. * Gets the covariancesB texture
  21. */
  22. get covariancesBTexture() {
  23. return this._covariancesBTexture;
  24. }
  25. /**
  26. * Gets the centers texture
  27. */
  28. get centersTexture() {
  29. return this._centersTexture;
  30. }
  31. /**
  32. * Gets the colors texture
  33. */
  34. get colorsTexture() {
  35. return this._colorsTexture;
  36. }
  37. /**
  38. * Creates a new gaussian splatting mesh
  39. * @param name defines the name of the mesh
  40. * @param url defines the url to load from (optional)
  41. * @param scene defines the hosting scene (optional)
  42. */
  43. constructor(name, url = null, scene = null) {
  44. super(name, scene);
  45. this._vertexCount = 0;
  46. this._worker = null;
  47. this._frameIdLastUpdate = -1;
  48. this._modelViewMatrix = Matrix.Identity();
  49. this._material = null;
  50. this._canPostToWorker = true;
  51. this._covariancesATexture = null;
  52. this._covariancesBTexture = null;
  53. this._centersTexture = null;
  54. this._colorsTexture = null;
  55. const vertexData = new VertexData();
  56. vertexData.positions = [-2, -2, 0, 2, -2, 0, 2, 2, 0, -2, 2, 0];
  57. vertexData.indices = [0, 1, 2, 0, 2, 3];
  58. vertexData.applyToMesh(this);
  59. this.subMeshes = [];
  60. new SubMesh(0, 0, 4, 0, 6, this);
  61. this.doNotSyncBoundingInfo = true;
  62. this.setEnabled(false);
  63. this._lastProj = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
  64. if (url) {
  65. this.loadFileAsync(url);
  66. }
  67. }
  68. /**
  69. * Returns the class name
  70. * @returns "GaussianSplattingMesh"
  71. */
  72. getClassName() {
  73. return "GaussianSplattingMesh";
  74. }
  75. /**
  76. * Returns the total number of vertices (splats) within the mesh
  77. * @returns the total number of vertices
  78. */
  79. getTotalVertices() {
  80. return this._vertexCount;
  81. }
  82. /**
  83. * Triggers the draw call for the mesh. Usually, you don't need to call this method by your own because the mesh rendering is handled by the scene rendering manager
  84. * @param subMesh defines the subMesh to render
  85. * @param enableAlphaMode defines if alpha mode can be changed
  86. * @param effectiveMeshReplacement defines an optional mesh used to provide info for the rendering
  87. * @returns the current mesh
  88. */
  89. render(subMesh, enableAlphaMode, effectiveMeshReplacement) {
  90. if (!this.material) {
  91. this._material = new GaussianSplattingMaterial(this.name + "_material", this._scene);
  92. this.material = this._material;
  93. }
  94. const frameId = this.getScene().getFrameId();
  95. if (frameId !== this._frameIdLastUpdate && this._worker && this._scene.activeCamera && this._canPostToWorker) {
  96. this.getWorldMatrix().multiplyToRef(this._scene.activeCamera.getViewMatrix(), this._modelViewMatrix);
  97. const dot = this._lastProj[2] * this._modelViewMatrix.m[2] + this._lastProj[6] * this._modelViewMatrix.m[6] + this._lastProj[10] * this._modelViewMatrix.m[10];
  98. if (Math.abs(dot - 1) >= 0.01) {
  99. this._frameIdLastUpdate = frameId;
  100. this._canPostToWorker = false;
  101. this._lastProj = this._modelViewMatrix.m.slice(0);
  102. this._worker.postMessage({ view: this._modelViewMatrix.m, depthMix: this._depthMix }, [this._depthMix.buffer]);
  103. }
  104. }
  105. return super.render(subMesh, enableAlphaMode, effectiveMeshReplacement);
  106. }
  107. /**
  108. * Code from https://github.com/dylanebert/gsplat.js/blob/main/src/loaders/PLYLoader.ts Under MIT license
  109. * Converts a .ply data array buffer to splat
  110. * if data array buffer is not ply, returns the original buffer
  111. * @param data the .ply data to load
  112. * @returns the loaded splat buffer
  113. */
  114. static ConvertPLYToSplat(data) {
  115. const ubuf = new Uint8Array(data);
  116. const header = new TextDecoder().decode(ubuf.slice(0, 1024 * 10));
  117. const headerEnd = "end_header\n";
  118. const headerEndIndex = header.indexOf(headerEnd);
  119. if (headerEndIndex < 0 || !header) {
  120. return data;
  121. }
  122. const vertexCount = parseInt(/element vertex (\d+)\n/.exec(header)[1]);
  123. let rowOffset = 0;
  124. const offsets = {
  125. double: 8,
  126. int: 4,
  127. uint: 4,
  128. float: 4,
  129. short: 2,
  130. ushort: 2,
  131. uchar: 1,
  132. };
  133. const properties = [];
  134. const filtered = header
  135. .slice(0, headerEndIndex)
  136. .split("\n")
  137. .filter((k) => k.startsWith("property "));
  138. for (const prop of filtered) {
  139. const [, type, name] = prop.split(" ");
  140. properties.push({ name, type, offset: rowOffset });
  141. if (offsets[type]) {
  142. rowOffset += offsets[type];
  143. }
  144. else {
  145. Logger.Error(`Unsupported property type: ${type}. Are you sure it's a valid Gaussian Splatting file?`);
  146. return new ArrayBuffer(0);
  147. }
  148. }
  149. const rowLength = 3 * 4 + 3 * 4 + 4 + 4;
  150. const SH_C0 = 0.28209479177387814;
  151. const dataView = new DataView(data, headerEndIndex + headerEnd.length);
  152. const buffer = new ArrayBuffer(rowLength * vertexCount);
  153. const q = new Quaternion();
  154. for (let i = 0; i < vertexCount; i++) {
  155. const position = new Float32Array(buffer, i * rowLength, 3);
  156. const scale = new Float32Array(buffer, i * rowLength + 12, 3);
  157. const rgba = new Uint8ClampedArray(buffer, i * rowLength + 24, 4);
  158. const rot = new Uint8ClampedArray(buffer, i * rowLength + 28, 4);
  159. let r0 = 255;
  160. let r1 = 0;
  161. let r2 = 0;
  162. let r3 = 0;
  163. for (let propertyIndex = 0; propertyIndex < properties.length; propertyIndex++) {
  164. const property = properties[propertyIndex];
  165. let value;
  166. switch (property.type) {
  167. case "float":
  168. value = dataView.getFloat32(property.offset + i * rowOffset, true);
  169. break;
  170. case "int":
  171. value = dataView.getInt32(property.offset + i * rowOffset, true);
  172. break;
  173. default:
  174. throw new Error(`Unsupported property type: ${property.type}`);
  175. }
  176. switch (property.name) {
  177. case "x":
  178. position[0] = value;
  179. break;
  180. case "y":
  181. position[1] = value;
  182. break;
  183. case "z":
  184. position[2] = value;
  185. break;
  186. case "scale_0":
  187. scale[0] = Math.exp(value);
  188. break;
  189. case "scale_1":
  190. scale[1] = Math.exp(value);
  191. break;
  192. case "scale_2":
  193. scale[2] = Math.exp(value);
  194. break;
  195. case "red":
  196. rgba[0] = value;
  197. break;
  198. case "green":
  199. rgba[1] = value;
  200. break;
  201. case "blue":
  202. rgba[2] = value;
  203. break;
  204. case "f_dc_0":
  205. rgba[0] = (0.5 + SH_C0 * value) * 255;
  206. break;
  207. case "f_dc_1":
  208. rgba[1] = (0.5 + SH_C0 * value) * 255;
  209. break;
  210. case "f_dc_2":
  211. rgba[2] = (0.5 + SH_C0 * value) * 255;
  212. break;
  213. case "f_dc_3":
  214. rgba[3] = (0.5 + SH_C0 * value) * 255;
  215. break;
  216. case "opacity":
  217. rgba[3] = (1 / (1 + Math.exp(-value))) * 255;
  218. break;
  219. case "rot_0":
  220. r0 = value;
  221. break;
  222. case "rot_1":
  223. r1 = value;
  224. break;
  225. case "rot_2":
  226. r2 = value;
  227. break;
  228. case "rot_3":
  229. r3 = value;
  230. break;
  231. }
  232. }
  233. q.set(r1, r2, r3, r0);
  234. q.normalize();
  235. rot[0] = q.w * 128 + 128;
  236. rot[1] = q.x * 128 + 128;
  237. rot[2] = q.y * 128 + 128;
  238. rot[3] = q.z * 128 + 128;
  239. }
  240. return buffer;
  241. }
  242. /**
  243. * Loads a .splat Gaussian Splatting array buffer asynchronously
  244. * @param data arraybuffer containing splat file
  245. * @returns a promise that resolves when the operation is complete
  246. */
  247. loadDataAsync(data) {
  248. return Promise.resolve(this._loadData(data));
  249. }
  250. /**
  251. * Loads a .splat Gaussian or .ply Splatting file asynchronously
  252. * @param url path to the splat file to load
  253. * @returns a promise that resolves when the operation is complete
  254. */
  255. loadFileAsync(url) {
  256. return Tools.LoadFileAsync(url, true).then((data) => {
  257. this._loadData(GaussianSplattingMesh.ConvertPLYToSplat(data));
  258. });
  259. }
  260. /**
  261. * Releases resources associated with this mesh.
  262. * @param doNotRecurse Set to true to not recurse into each children (recurse into each children by default)
  263. */
  264. dispose(doNotRecurse) {
  265. this._covariancesATexture?.dispose();
  266. this._covariancesBTexture?.dispose();
  267. this._centersTexture?.dispose();
  268. this._colorsTexture?.dispose();
  269. this._covariancesATexture = null;
  270. this._covariancesBTexture = null;
  271. this._centersTexture = null;
  272. this._colorsTexture = null;
  273. this._material?.dispose(false, true);
  274. this._material = null;
  275. this._worker?.terminate();
  276. this._worker = null;
  277. super.dispose(doNotRecurse);
  278. }
  279. _loadData(data) {
  280. if (!data.byteLength) {
  281. return;
  282. }
  283. // Parse the data
  284. const uBuffer = new Uint8Array(data);
  285. const fBuffer = new Float32Array(uBuffer.buffer);
  286. const rowLength = 3 * 4 + 3 * 4 + 4 + 4;
  287. const vertexCount = uBuffer.length / rowLength;
  288. this._vertexCount = vertexCount;
  289. const textureSize = this._getTextureSize(vertexCount);
  290. const textureLength = textureSize.x * textureSize.y;
  291. const positions = new Float32Array(3 * textureLength);
  292. const covA = new Float32Array(3 * textureLength);
  293. const covB = new Float32Array(3 * textureLength);
  294. const matrixRotation = TmpVectors.Matrix[0];
  295. const matrixScale = TmpVectors.Matrix[1];
  296. const quaternion = TmpVectors.Quaternion[0];
  297. const minimum = new Vector3(Number.MAX_VALUE, Number.MAX_VALUE, Number.MAX_VALUE);
  298. const maximum = new Vector3(-Number.MAX_VALUE, -Number.MAX_VALUE, -Number.MAX_VALUE);
  299. for (let i = 0; i < vertexCount; i++) {
  300. const x = fBuffer[8 * i + 0];
  301. const y = -fBuffer[8 * i + 1];
  302. const z = fBuffer[8 * i + 2];
  303. positions[3 * i + 0] = x;
  304. positions[3 * i + 1] = y;
  305. positions[3 * i + 2] = z;
  306. minimum.minimizeInPlaceFromFloats(x, y, z);
  307. maximum.maximizeInPlaceFromFloats(x, y, z);
  308. quaternion.set((uBuffer[32 * i + 28 + 1] - 128) / 128, (uBuffer[32 * i + 28 + 2] - 128) / 128, (uBuffer[32 * i + 28 + 3] - 128) / 128, -(uBuffer[32 * i + 28 + 0] - 128) / 128);
  309. quaternion.toRotationMatrix(matrixRotation);
  310. Matrix.ScalingToRef(fBuffer[8 * i + 3 + 0] * 2, fBuffer[8 * i + 3 + 1] * 2, fBuffer[8 * i + 3 + 2] * 2, matrixScale);
  311. const M = matrixRotation.multiplyToRef(matrixScale, TmpVectors.Matrix[0]).m;
  312. covA[i * 3 + 0] = M[0] * M[0] + M[1] * M[1] + M[2] * M[2];
  313. covA[i * 3 + 1] = M[0] * M[4] + M[1] * M[5] + M[2] * M[6];
  314. covA[i * 3 + 2] = M[0] * M[8] + M[1] * M[9] + M[2] * M[10];
  315. covB[i * 3 + 0] = M[4] * M[4] + M[5] * M[5] + M[6] * M[6];
  316. covB[i * 3 + 1] = M[4] * M[8] + M[5] * M[9] + M[6] * M[10];
  317. covB[i * 3 + 2] = M[8] * M[8] + M[9] * M[9] + M[10] * M[10];
  318. }
  319. // Update the mesh
  320. const binfo = this.getBoundingInfo();
  321. binfo.reConstruct(minimum, maximum, this.getWorldMatrix());
  322. binfo.isLocked = true;
  323. this.forcedInstanceCount = this._vertexCount;
  324. this.setEnabled(true);
  325. const splatIndex = new Float32Array(this._vertexCount * 1);
  326. this.thinInstanceSetBuffer("splatIndex", splatIndex, 1, false);
  327. // Update the material
  328. const createTextureFromData = (data, width, height, format) => {
  329. return new RawTexture(data, width, height, format, this._scene, false, false, 2, 1);
  330. };
  331. const convertRgbToRgba = (rgb) => {
  332. const count = rgb.length / 3;
  333. const rgba = new Float32Array(count * 4);
  334. for (let i = 0; i < count; ++i) {
  335. rgba[i * 4 + 0] = rgb[i * 3 + 0];
  336. rgba[i * 4 + 1] = rgb[i * 3 + 1];
  337. rgba[i * 4 + 2] = rgb[i * 3 + 2];
  338. rgba[i * 4 + 3] = 1.0;
  339. }
  340. return rgba;
  341. };
  342. const colorArray = new Float32Array(textureSize.x * textureSize.y * 4);
  343. for (let i = 0; i < this._vertexCount; ++i) {
  344. colorArray[i * 4 + 0] = uBuffer[32 * i + 24 + 0] / 255;
  345. colorArray[i * 4 + 1] = uBuffer[32 * i + 24 + 1] / 255;
  346. colorArray[i * 4 + 2] = uBuffer[32 * i + 24 + 2] / 255;
  347. colorArray[i * 4 + 3] = uBuffer[32 * i + 24 + 3] / 255;
  348. }
  349. this._covariancesATexture = createTextureFromData(convertRgbToRgba(covA), textureSize.x, textureSize.y, 5);
  350. this._covariancesBTexture = createTextureFromData(convertRgbToRgba(covB), textureSize.x, textureSize.y, 5);
  351. this._centersTexture = createTextureFromData(convertRgbToRgba(positions), textureSize.x, textureSize.y, 5);
  352. this._colorsTexture = createTextureFromData(colorArray, textureSize.x, textureSize.y, 5);
  353. // Start the worker thread
  354. this._worker?.terminate();
  355. this._worker = new Worker(URL.createObjectURL(new Blob(["(", GaussianSplattingMesh._CreateWorker.toString(), ")(self)"], {
  356. type: "application/javascript",
  357. })));
  358. this._depthMix = new BigInt64Array(vertexCount);
  359. this._worker.postMessage({ positions, vertexCount }, [positions.buffer]);
  360. this._worker.onmessage = (e) => {
  361. this._depthMix = e.data.depthMix;
  362. const indexMix = new Uint32Array(e.data.depthMix.buffer);
  363. for (let j = 0; j < this._vertexCount; j++) {
  364. splatIndex[j] = indexMix[2 * j];
  365. }
  366. this.thinInstanceBufferUpdated("splatIndex");
  367. this._canPostToWorker = true;
  368. };
  369. }
  370. _getTextureSize(length) {
  371. const engine = this._scene.getEngine();
  372. const width = engine.getCaps().maxTextureSize;
  373. let height = 1;
  374. if (engine.version === 1 && !engine.isWebGPU) {
  375. while (width * height < length) {
  376. height *= 2;
  377. }
  378. }
  379. else {
  380. height = Math.ceil(length / width);
  381. }
  382. if (height > width) {
  383. Logger.Error("GaussianSplatting texture size: (" + width + ", " + height + "), maxTextureSize: " + width);
  384. height = width;
  385. }
  386. return new Vector2(width, height);
  387. }
  388. }
  389. GaussianSplattingMesh._CreateWorker = function (self) {
  390. let vertexCount = 0;
  391. let positions;
  392. let depthMix;
  393. let indices;
  394. let floatMix;
  395. self.onmessage = (e) => {
  396. // updated on init
  397. if (e.data.positions) {
  398. positions = e.data.positions;
  399. vertexCount = e.data.vertexCount;
  400. }
  401. // udpate on view changed
  402. else {
  403. const viewProj = e.data.view;
  404. if (!positions || !viewProj) {
  405. // Sanity check, it shouldn't happen!
  406. throw new Error("positions or view is not defined!");
  407. }
  408. depthMix = e.data.depthMix;
  409. indices = new Uint32Array(depthMix.buffer);
  410. floatMix = new Float32Array(depthMix.buffer);
  411. // Sort
  412. for (let j = 0; j < vertexCount; j++) {
  413. indices[2 * j] = j;
  414. }
  415. for (let j = 0; j < vertexCount; j++) {
  416. floatMix[2 * j + 1] = 10000 - (viewProj[2] * positions[3 * j + 0] + viewProj[6] * positions[3 * j + 1] + viewProj[10] * positions[3 * j + 2]);
  417. }
  418. depthMix.sort();
  419. self.postMessage({ depthMix }, [depthMix.buffer]);
  420. }
  421. };
  422. };
  423. //# sourceMappingURL=gaussianSplattingMesh.js.map