extractParams.js 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tslib_1 = require("tslib");
  4. var tf = require("@tensorflow/tfjs-core");
  5. var common_1 = require("../common");
  6. function extractorsFactory(extractWeights, paramMappings) {
  7. var extractConvParams = common_1.extractConvParamsFactory(extractWeights, paramMappings);
  8. var extractFCParams = common_1.extractFCParamsFactory(extractWeights, paramMappings);
  9. function extractPReluParams(size, paramPath) {
  10. var alpha = tf.tensor1d(extractWeights(size));
  11. paramMappings.push({ paramPath: paramPath });
  12. return alpha;
  13. }
  14. function extractSharedParams(numFilters, mappedPrefix, isRnet) {
  15. if (isRnet === void 0) { isRnet = false; }
  16. var conv1 = extractConvParams(numFilters[0], numFilters[1], 3, mappedPrefix + "/conv1");
  17. var prelu1_alpha = extractPReluParams(numFilters[1], mappedPrefix + "/prelu1_alpha");
  18. var conv2 = extractConvParams(numFilters[1], numFilters[2], 3, mappedPrefix + "/conv2");
  19. var prelu2_alpha = extractPReluParams(numFilters[2], mappedPrefix + "/prelu2_alpha");
  20. var conv3 = extractConvParams(numFilters[2], numFilters[3], isRnet ? 2 : 3, mappedPrefix + "/conv3");
  21. var prelu3_alpha = extractPReluParams(numFilters[3], mappedPrefix + "/prelu3_alpha");
  22. return { conv1: conv1, prelu1_alpha: prelu1_alpha, conv2: conv2, prelu2_alpha: prelu2_alpha, conv3: conv3, prelu3_alpha: prelu3_alpha };
  23. }
  24. function extractPNetParams() {
  25. var sharedParams = extractSharedParams([3, 10, 16, 32], 'pnet');
  26. var conv4_1 = extractConvParams(32, 2, 1, 'pnet/conv4_1');
  27. var conv4_2 = extractConvParams(32, 4, 1, 'pnet/conv4_2');
  28. return tslib_1.__assign(tslib_1.__assign({}, sharedParams), { conv4_1: conv4_1, conv4_2: conv4_2 });
  29. }
  30. function extractRNetParams() {
  31. var sharedParams = extractSharedParams([3, 28, 48, 64], 'rnet', true);
  32. var fc1 = extractFCParams(576, 128, 'rnet/fc1');
  33. var prelu4_alpha = extractPReluParams(128, 'rnet/prelu4_alpha');
  34. var fc2_1 = extractFCParams(128, 2, 'rnet/fc2_1');
  35. var fc2_2 = extractFCParams(128, 4, 'rnet/fc2_2');
  36. return tslib_1.__assign(tslib_1.__assign({}, sharedParams), { fc1: fc1, prelu4_alpha: prelu4_alpha, fc2_1: fc2_1, fc2_2: fc2_2 });
  37. }
  38. function extractONetParams() {
  39. var sharedParams = extractSharedParams([3, 32, 64, 64], 'onet');
  40. var conv4 = extractConvParams(64, 128, 2, 'onet/conv4');
  41. var prelu4_alpha = extractPReluParams(128, 'onet/prelu4_alpha');
  42. var fc1 = extractFCParams(1152, 256, 'onet/fc1');
  43. var prelu5_alpha = extractPReluParams(256, 'onet/prelu5_alpha');
  44. var fc2_1 = extractFCParams(256, 2, 'onet/fc2_1');
  45. var fc2_2 = extractFCParams(256, 4, 'onet/fc2_2');
  46. var fc2_3 = extractFCParams(256, 10, 'onet/fc2_3');
  47. return tslib_1.__assign(tslib_1.__assign({}, sharedParams), { conv4: conv4, prelu4_alpha: prelu4_alpha, fc1: fc1, prelu5_alpha: prelu5_alpha, fc2_1: fc2_1, fc2_2: fc2_2, fc2_3: fc2_3 });
  48. }
  49. return {
  50. extractPNetParams: extractPNetParams,
  51. extractRNetParams: extractRNetParams,
  52. extractONetParams: extractONetParams
  53. };
  54. }
  55. function extractParams(weights) {
  56. var _a = common_1.extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  57. var paramMappings = [];
  58. var _b = extractorsFactory(extractWeights, paramMappings), extractPNetParams = _b.extractPNetParams, extractRNetParams = _b.extractRNetParams, extractONetParams = _b.extractONetParams;
  59. var pnet = extractPNetParams();
  60. var rnet = extractRNetParams();
  61. var onet = extractONetParams();
  62. if (getRemainingWeights().length !== 0) {
  63. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  64. }
  65. return { params: { pnet: pnet, rnet: rnet, onet: onet }, paramMappings: paramMappings };
  66. }
  67. exports.extractParams = extractParams;
  68. //# sourceMappingURL=extractParams.js.map