import numpy as np
from scipy import misc
from scipy.fftpack import dct, idct
import matplotlib.pyplot as plt
def dct_2d(x):
return dct(dct(x.T, norm="ortho").T, norm="ortho")
def idct_2d(x):
return idct(idct(x.T, norm="ortho").T, norm="ortho")
cols = 10
fig, axes = plt.subplots(cols, cols)
for j in range(0, cols):
for i in range(0, cols):
e0 = np.zeros(shape=(cols, cols))
e0[i, j] = 1
axes[i, j].imshow(idct_2d(e0))
fig.show()
img = np.float32(misc.face(gray=True))
img = img / img.max()
img = img - (img.max() - img.min()) / 2
imgT = dct_2d(img)
fig2, axes2 = plt.subplots(4, 2)
axes2[0, 0].imshow(img)
axes2[0, 1].imshow(np.uint8(imgT * 1024))
imgD = imgT.copy()
imgD[150 : len(imgD)] = np.repeat(0.0, len(imgD[1]))
axes2[1, 1].imshow(np.uint8(imgD * 1024))
imgS = idct_2d(imgD)
axes2[1, 0].imshow(imgS)
imgD2 = imgT.copy()
imgD2[40 : len(imgD)] = np.repeat(0.0, len(imgD[1]))
axes2[2, 1].imshow(np.uint8(imgD2 * 1024))
imgS2 = idct_2d(imgD2)
axes2[2, 0].imshow(imgS2)
imgD3 = imgT.copy()
imgD3[1:40] = np.repeat(0.0, len(imgD3[1]))
axes2[3, 1].imshow(np.uint8(imgD3 * 1024))
imgS3 = idct_2d(imgD3)
axes2[3, 0].imshow(imgS3)
fig2.show()
fig3, axes3 = plt.subplots(3, 3)
axes3[0, 0].imshow(img)
axes3[1, 0].imshow(imgS)
axes3[2, 0].imshow(imgS2)
axes3[0, 1].imshow(img - img)
axes3[1, 1].imshow(img - imgS)
axes3[2, 1].imshow(img - imgS2)
axes3[0, 2].imshow(imgT - imgT)
axes3[1, 2].imshow(imgT - imgD)
axes3[2, 2].imshow(imgT - imgD2)
fig3.show()
imgTS = (imgT * 1024).astype(int)
imgTSZ = np.array([[x - (x % 2) for x in y] for y in imgTS])
fig4, axes4 = plt.subplots(2, 3)
axes4[0, 0].imshow(np.uint8(imgTS))
axes4[0, 1].imshow(np.uint8(imgTSZ))
axes4[0, 2].imshow((imgTS - imgTSZ) * 1024)
R0 = idct_2d(np.float32(imgTS) / 1024) * 1024
R1 = idct_2d(np.float32(imgTSZ) / 1024) * 1024
axes4[1, 0].imshow(R0)
axes4[1, 1].imshow(R1)
axes4[1, 2].imshow((R0 - R1) * 1024)
fig4.show()