extractParams.js 4.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. var tf = require("@tensorflow/tfjs-core");
  4. var common_1 = require("../common");
  5. var utils_1 = require("../utils");
  6. function extractorsFactory(extractWeights, paramMappings) {
  7. function extractFilterValues(numFilterValues, numFilters, filterSize) {
  8. var weights = extractWeights(numFilterValues);
  9. var depth = weights.length / (numFilters * filterSize * filterSize);
  10. if (utils_1.isFloat(depth)) {
  11. throw new Error("depth has to be an integer: " + depth + ", weights.length: " + weights.length + ", numFilters: " + numFilters + ", filterSize: " + filterSize);
  12. }
  13. return tf.tidy(function () { return tf.transpose(tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]), [2, 3, 1, 0]); });
  14. }
  15. function extractConvParams(numFilterValues, numFilters, filterSize, mappedPrefix) {
  16. var filters = extractFilterValues(numFilterValues, numFilters, filterSize);
  17. var bias = tf.tensor1d(extractWeights(numFilters));
  18. paramMappings.push({ paramPath: mappedPrefix + "/filters" }, { paramPath: mappedPrefix + "/bias" });
  19. return { filters: filters, bias: bias };
  20. }
  21. function extractScaleLayerParams(numWeights, mappedPrefix) {
  22. var weights = tf.tensor1d(extractWeights(numWeights));
  23. var biases = tf.tensor1d(extractWeights(numWeights));
  24. paramMappings.push({ paramPath: mappedPrefix + "/weights" }, { paramPath: mappedPrefix + "/biases" });
  25. return {
  26. weights: weights,
  27. biases: biases
  28. };
  29. }
  30. function extractConvLayerParams(numFilterValues, numFilters, filterSize, mappedPrefix) {
  31. var conv = extractConvParams(numFilterValues, numFilters, filterSize, mappedPrefix + "/conv");
  32. var scale = extractScaleLayerParams(numFilters, mappedPrefix + "/scale");
  33. return { conv: conv, scale: scale };
  34. }
  35. function extractResidualLayerParams(numFilterValues, numFilters, filterSize, mappedPrefix, isDown) {
  36. if (isDown === void 0) { isDown = false; }
  37. var conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, mappedPrefix + "/conv1");
  38. var conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, mappedPrefix + "/conv2");
  39. return { conv1: conv1, conv2: conv2 };
  40. }
  41. return {
  42. extractConvLayerParams: extractConvLayerParams,
  43. extractResidualLayerParams: extractResidualLayerParams
  44. };
  45. }
  46. function extractParams(weights) {
  47. var _a = common_1.extractWeightsFactory(weights), extractWeights = _a.extractWeights, getRemainingWeights = _a.getRemainingWeights;
  48. var paramMappings = [];
  49. var _b = extractorsFactory(extractWeights, paramMappings), extractConvLayerParams = _b.extractConvLayerParams, extractResidualLayerParams = _b.extractResidualLayerParams;
  50. var conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down');
  51. var conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1');
  52. var conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2');
  53. var conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3');
  54. var conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true);
  55. var conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1');
  56. var conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2');
  57. var conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3');
  58. var conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true);
  59. var conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1');
  60. var conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2');
  61. var conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true);
  62. var conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1');
  63. var conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2');
  64. var conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out');
  65. var fc = tf.tidy(function () { return tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0]); });
  66. paramMappings.push({ paramPath: "fc" });
  67. if (getRemainingWeights().length !== 0) {
  68. throw new Error("weights remaing after extract: " + getRemainingWeights().length);
  69. }
  70. var params = {
  71. conv32_down: conv32_down,
  72. conv32_1: conv32_1,
  73. conv32_2: conv32_2,
  74. conv32_3: conv32_3,
  75. conv64_down: conv64_down,
  76. conv64_1: conv64_1,
  77. conv64_2: conv64_2,
  78. conv64_3: conv64_3,
  79. conv128_down: conv128_down,
  80. conv128_1: conv128_1,
  81. conv128_2: conv128_2,
  82. conv256_down: conv256_down,
  83. conv256_1: conv256_1,
  84. conv256_2: conv256_2,
  85. conv256_down_out: conv256_down_out,
  86. fc: fc
  87. };
  88. return { params: params, paramMappings: paramMappings };
  89. }
  90. exports.extractParams = extractParams;
  91. //# sourceMappingURL=extractParams.js.map