123456789101112131415161718192021222324252627282930313233 |
- import { __spreadArrays } from "tslib";
- import * as tf from '@tensorflow/tfjs-core';
- import { conv, convDown, convNoRelu } from './convLayer';
- export function residual(x, params) {
- var out = conv(x, params.conv1);
- out = convNoRelu(out, params.conv2);
- out = tf.add(out, x);
- out = tf.relu(out);
- return out;
- }
- export function residualDown(x, params) {
- var out = convDown(x, params.conv1);
- out = convNoRelu(out, params.conv2);
- var pooled = tf.avgPool(x, 2, 2, 'valid');
- var zeros = tf.zeros(pooled.shape);
- var isPad = pooled.shape[3] !== out.shape[3];
- var isAdjustShape = pooled.shape[1] !== out.shape[1] || pooled.shape[2] !== out.shape[2];
- if (isAdjustShape) {
- var padShapeX = __spreadArrays(out.shape);
- padShapeX[1] = 1;
- var zerosW = tf.zeros(padShapeX);
- out = tf.concat([out, zerosW], 1);
- var padShapeY = __spreadArrays(out.shape);
- padShapeY[2] = 1;
- var zerosH = tf.zeros(padShapeY);
- out = tf.concat([out, zerosH], 2);
- }
- pooled = isPad ? tf.concat([pooled, zeros], 3) : pooled;
- out = tf.add(pooled, out);
- out = tf.relu(out);
- return out;
- }
- //# sourceMappingURL=residualLayer.js.map
|