extractParamsFromWeigthMap.js 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import { __assign } from "tslib";
  2. import { disposeUnusedWeightTensors, extractWeightEntryFactory } from '../common';
  3. function extractorsFactory(weightMap, paramMappings) {
  4. var extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
  5. function extractConvParams(prefix) {
  6. var filters = extractWeightEntry(prefix + "/weights", 4, prefix + "/filters");
  7. var bias = extractWeightEntry(prefix + "/bias", 1);
  8. return { filters: filters, bias: bias };
  9. }
  10. function extractFCParams(prefix) {
  11. var weights = extractWeightEntry(prefix + "/weights", 2);
  12. var bias = extractWeightEntry(prefix + "/bias", 1);
  13. return { weights: weights, bias: bias };
  14. }
  15. function extractPReluParams(paramPath) {
  16. return extractWeightEntry(paramPath, 1);
  17. }
  18. function extractSharedParams(prefix) {
  19. var conv1 = extractConvParams(prefix + "/conv1");
  20. var prelu1_alpha = extractPReluParams(prefix + "/prelu1_alpha");
  21. var conv2 = extractConvParams(prefix + "/conv2");
  22. var prelu2_alpha = extractPReluParams(prefix + "/prelu2_alpha");
  23. var conv3 = extractConvParams(prefix + "/conv3");
  24. var prelu3_alpha = extractPReluParams(prefix + "/prelu3_alpha");
  25. return { conv1: conv1, prelu1_alpha: prelu1_alpha, conv2: conv2, prelu2_alpha: prelu2_alpha, conv3: conv3, prelu3_alpha: prelu3_alpha };
  26. }
  27. function extractPNetParams() {
  28. var sharedParams = extractSharedParams('pnet');
  29. var conv4_1 = extractConvParams('pnet/conv4_1');
  30. var conv4_2 = extractConvParams('pnet/conv4_2');
  31. return __assign(__assign({}, sharedParams), { conv4_1: conv4_1, conv4_2: conv4_2 });
  32. }
  33. function extractRNetParams() {
  34. var sharedParams = extractSharedParams('rnet');
  35. var fc1 = extractFCParams('rnet/fc1');
  36. var prelu4_alpha = extractPReluParams('rnet/prelu4_alpha');
  37. var fc2_1 = extractFCParams('rnet/fc2_1');
  38. var fc2_2 = extractFCParams('rnet/fc2_2');
  39. return __assign(__assign({}, sharedParams), { fc1: fc1, prelu4_alpha: prelu4_alpha, fc2_1: fc2_1, fc2_2: fc2_2 });
  40. }
  41. function extractONetParams() {
  42. var sharedParams = extractSharedParams('onet');
  43. var conv4 = extractConvParams('onet/conv4');
  44. var prelu4_alpha = extractPReluParams('onet/prelu4_alpha');
  45. var fc1 = extractFCParams('onet/fc1');
  46. var prelu5_alpha = extractPReluParams('onet/prelu5_alpha');
  47. var fc2_1 = extractFCParams('onet/fc2_1');
  48. var fc2_2 = extractFCParams('onet/fc2_2');
  49. var fc2_3 = extractFCParams('onet/fc2_3');
  50. return __assign(__assign({}, sharedParams), { conv4: conv4, prelu4_alpha: prelu4_alpha, fc1: fc1, prelu5_alpha: prelu5_alpha, fc2_1: fc2_1, fc2_2: fc2_2, fc2_3: fc2_3 });
  51. }
  52. return {
  53. extractPNetParams: extractPNetParams,
  54. extractRNetParams: extractRNetParams,
  55. extractONetParams: extractONetParams
  56. };
  57. }
  58. export function extractParamsFromWeigthMap(weightMap) {
  59. var paramMappings = [];
  60. var _a = extractorsFactory(weightMap, paramMappings), extractPNetParams = _a.extractPNetParams, extractRNetParams = _a.extractRNetParams, extractONetParams = _a.extractONetParams;
  61. var pnet = extractPNetParams();
  62. var rnet = extractRNetParams();
  63. var onet = extractONetParams();
  64. disposeUnusedWeightTensors(weightMap, paramMappings);
  65. return { params: { pnet: pnet, rnet: rnet, onet: onet }, paramMappings: paramMappings };
  66. }
  67. //# sourceMappingURL=extractParamsFromWeigthMap.js.map