common_pruning.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # mypy: ignore-errors
  2. # Owner(s): ["module: unknown"]
  3. from torch.ao.pruning import BaseSparsifier
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn
  7. class ImplementedSparsifier(BaseSparsifier):
  8. def __init__(self, **kwargs):
  9. super().__init__(defaults=kwargs)
  10. def update_mask(self, module, **kwargs):
  11. module.parametrizations.weight[0].mask[0] = 0
  12. linear_state = self.state['linear1.weight']
  13. linear_state['step_count'] = linear_state.get('step_count', 0) + 1
  14. class MockSparseLinear(nn.Linear):
  15. """
  16. This class is a MockSparseLinear class to check convert functionality.
  17. It is the same as a normal Linear layer, except with a different type, as
  18. well as an additional from_dense method.
  19. """
  20. @classmethod
  21. def from_dense(cls, mod):
  22. """
  23. """
  24. linear = cls(mod.in_features,
  25. mod.out_features)
  26. return linear
  27. def rows_are_subset(subset_tensor, superset_tensor) -> bool:
  28. """
  29. Checks to see if all rows in subset tensor are present in the superset tensor
  30. """
  31. i = 0
  32. for row in subset_tensor:
  33. while i < len(superset_tensor):
  34. if not torch.equal(row, superset_tensor[i]):
  35. i += 1
  36. else:
  37. break
  38. else:
  39. return False
  40. return True
  41. class SimpleLinear(nn.Module):
  42. r"""Model with only Linear layers without biases, some wrapped in a Sequential,
  43. some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
  44. def __init__(self):
  45. super().__init__()
  46. self.seq = nn.Sequential(
  47. nn.Linear(7, 5, bias=False),
  48. nn.Linear(5, 6, bias=False),
  49. nn.Linear(6, 4, bias=False),
  50. )
  51. self.linear1 = nn.Linear(4, 4, bias=False)
  52. self.linear2 = nn.Linear(4, 10, bias=False)
  53. def forward(self, x):
  54. x = self.seq(x)
  55. x = self.linear1(x)
  56. x = self.linear2(x)
  57. return x
  58. class LinearBias(nn.Module):
  59. r"""Model with only Linear layers, alternating layers with biases,
  60. wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
  61. def __init__(self):
  62. super().__init__()
  63. self.seq = nn.Sequential(
  64. nn.Linear(7, 5, bias=True),
  65. nn.Linear(5, 6, bias=False),
  66. nn.Linear(6, 3, bias=True),
  67. nn.Linear(3, 3, bias=True),
  68. nn.Linear(3, 10, bias=False),
  69. )
  70. def forward(self, x):
  71. x = self.seq(x)
  72. return x
  73. class LinearActivation(nn.Module):
  74. r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
  75. Activation functions modules in between each Linear in the Sequential, and each outside layer.
  76. Used to test pruned Linear(Bias)-Activation-Linear fusion."""
  77. def __init__(self):
  78. super().__init__()
  79. self.seq = nn.Sequential(
  80. nn.Linear(7, 5, bias=True),
  81. nn.ReLU(),
  82. nn.Linear(5, 6, bias=False),
  83. nn.Tanh(),
  84. nn.Linear(6, 4, bias=True),
  85. )
  86. self.linear1 = nn.Linear(4, 3, bias=True)
  87. self.act1 = nn.ReLU()
  88. self.linear2 = nn.Linear(3, 10, bias=False)
  89. self.act2 = nn.Tanh()
  90. def forward(self, x):
  91. x = self.seq(x)
  92. x = self.linear1(x)
  93. x = self.act1(x)
  94. x = self.linear2(x)
  95. x = self.act2(x)
  96. return x
  97. class LinearActivationFunctional(nn.Module):
  98. r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
  99. Activation functions modules in between each Linear in the Sequential, and functional
  100. activationals are called in between each outside layer.
  101. Used to test pruned Linear(Bias)-Activation-Linear fusion."""
  102. def __init__(self):
  103. super().__init__()
  104. self.seq = nn.Sequential(
  105. nn.Linear(7, 5, bias=True),
  106. nn.ReLU(),
  107. nn.Linear(5, 6, bias=False),
  108. nn.ReLU(),
  109. nn.Linear(6, 4, bias=True),
  110. )
  111. self.linear1 = nn.Linear(4, 3, bias=True)
  112. self.linear2 = nn.Linear(3, 8, bias=False)
  113. self.linear3 = nn.Linear(8, 10, bias=False)
  114. self.act1 = nn.ReLU()
  115. def forward(self, x):
  116. x = self.seq(x)
  117. x = self.linear1(x)
  118. x = F.relu(x)
  119. x = self.linear2(x)
  120. x = F.relu(x)
  121. x = self.linear3(x)
  122. x = F.relu(x)
  123. return x
  124. class SimpleConv2d(nn.Module):
  125. r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
  126. Used to test pruned Conv2d-Conv2d fusion."""
  127. def __init__(self):
  128. super().__init__()
  129. self.seq = nn.Sequential(
  130. nn.Conv2d(1, 32, 3, 1, bias=False),
  131. nn.Conv2d(32, 64, 3, 1, bias=False),
  132. )
  133. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
  134. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
  135. def forward(self, x):
  136. x = self.seq(x)
  137. x = self.conv2d1(x)
  138. x = self.conv2d2(x)
  139. return x
  140. class Conv2dBias(nn.Module):
  141. r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
  142. Used to test pruned Conv2d-Bias-Conv2d fusion."""
  143. def __init__(self):
  144. super().__init__()
  145. self.seq = nn.Sequential(
  146. nn.Conv2d(1, 32, 3, 1, bias=True),
  147. nn.Conv2d(32, 32, 3, 1, bias=True),
  148. nn.Conv2d(32, 64, 3, 1, bias=False),
  149. )
  150. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
  151. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
  152. def forward(self, x):
  153. x = self.seq(x)
  154. x = self.conv2d1(x)
  155. x = self.conv2d2(x)
  156. return x
  157. class Conv2dActivation(nn.Module):
  158. r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
  159. Activation function modules in between each Sequential layer, functional activations called
  160. in-between each outside layer.
  161. Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
  162. def __init__(self):
  163. super().__init__()
  164. self.seq = nn.Sequential(
  165. nn.Conv2d(1, 32, 3, 1, bias=True),
  166. nn.ReLU(),
  167. nn.Conv2d(32, 64, 3, 1, bias=True),
  168. nn.Tanh(),
  169. nn.Conv2d(64, 64, 3, 1, bias=False),
  170. nn.ReLU(),
  171. )
  172. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
  173. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
  174. def forward(self, x):
  175. x = self.seq(x)
  176. x = self.conv2d1(x)
  177. x = F.relu(x)
  178. x = self.conv2d2(x)
  179. x = F.hardtanh(x)
  180. return x
  181. class Conv2dPadBias(nn.Module):
  182. r"""Model with only Conv2d layers, all with bias and some with padding > 0,
  183. some in a Sequential and some following. Activation function modules in between each layer.
  184. Used to test that bias is propagated correctly in the special case of
  185. pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
  186. def __init__(self):
  187. super().__init__()
  188. self.seq = nn.Sequential(
  189. nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
  190. nn.ReLU(),
  191. nn.Conv2d(32, 32, 3, 1, bias=False),
  192. nn.ReLU(),
  193. nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
  194. nn.ReLU(),
  195. nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
  196. nn.ReLU(),
  197. nn.Conv2d(32, 64, 3, 1, bias=True),
  198. nn.Tanh(),
  199. )
  200. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
  201. self.act1 = nn.ReLU()
  202. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
  203. self.act2 = nn.Tanh()
  204. def forward(self, x):
  205. x = self.seq(x)
  206. x = self.conv2d1(x)
  207. x = self.act1(x)
  208. x = self.conv2d2(x)
  209. x = self.act2(x)
  210. return x
  211. class Conv2dPool(nn.Module):
  212. r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
  213. Activation function modules in between each layer, Pool2d modules in between each layer.
  214. Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
  215. def __init__(self):
  216. super().__init__()
  217. self.seq = nn.Sequential(
  218. nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
  219. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  220. nn.ReLU(),
  221. nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
  222. nn.Tanh(),
  223. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  224. )
  225. self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
  226. self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
  227. self.af1 = nn.ReLU()
  228. self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
  229. self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
  230. def forward(self, x):
  231. x = self.seq(x)
  232. x = self.conv2d1(x)
  233. x = self.maxpool(x)
  234. x = self.af1(x)
  235. x = self.conv2d2(x)
  236. x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
  237. x = F.relu(x)
  238. x = self.conv2d3(x)
  239. return x
  240. class Conv2dPoolFlattenFunctional(nn.Module):
  241. r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
  242. and a functional Flatten followed by a Linear layer.
  243. Activation functions and Pool2ds in between each layer also.
  244. Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
  245. def __init__(self):
  246. super().__init__()
  247. self.seq = nn.Sequential(
  248. nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
  249. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  250. nn.ReLU(),
  251. nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
  252. nn.Tanh(),
  253. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  254. )
  255. self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
  256. self.af1 = nn.ReLU()
  257. self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
  258. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  259. self.fc = nn.Linear(11, 13, bias=True)
  260. def forward(self, x):
  261. x = self.seq(x)
  262. x = self.conv2d1(x)
  263. x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
  264. x = self.af1(x)
  265. x = self.conv2d2(x)
  266. x = self.avg_pool(x)
  267. x = torch.flatten(x, 1) # test functional flatten
  268. x = self.fc(x)
  269. return x
  270. class Conv2dPoolFlatten(nn.Module):
  271. r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
  272. and a Flatten module followed by a Linear layer.
  273. Activation functions and Pool2ds in between each layer also.
  274. Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
  275. def __init__(self):
  276. super().__init__()
  277. self.seq = nn.Sequential(
  278. nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
  279. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  280. nn.ReLU(),
  281. nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
  282. nn.Tanh(),
  283. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  284. )
  285. self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
  286. self.af1 = nn.ReLU()
  287. self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
  288. self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
  289. self.flatten = nn.Flatten()
  290. self.fc = nn.Linear(44, 13, bias=True)
  291. def forward(self, x):
  292. x = self.seq(x)
  293. x = self.conv2d1(x)
  294. x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
  295. x = self.af1(x)
  296. x = self.conv2d2(x)
  297. x = self.avg_pool(x)
  298. x = self.flatten(x)
  299. x = self.fc(x)
  300. return x
  301. class LSTMLinearModel(nn.Module):
  302. """Container module with an encoder, a recurrent module, and a linear."""
  303. def __init__(
  304. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
  305. ):
  306. super().__init__()
  307. self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
  308. self.linear = nn.Linear(hidden_dim, output_dim)
  309. def forward(self, input):
  310. output, hidden = self.lstm(input)
  311. decoded = self.linear(output)
  312. return decoded, output
  313. class LSTMLayerNormLinearModel(nn.Module):
  314. """Container module with an LSTM, a LayerNorm, and a linear."""
  315. def __init__(
  316. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
  317. ):
  318. super().__init__()
  319. self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
  320. self.norm = nn.LayerNorm(hidden_dim)
  321. self.linear = nn.Linear(hidden_dim, output_dim)
  322. def forward(self, x):
  323. x, state = self.lstm(x)
  324. x = self.norm(x)
  325. x = self.linear(x)
  326. return x, state