Gradient and Optimization with PyTorch

Gradient and Optimization with PyTorch#

import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector
from meent.on_torch.optimizer.optimizer import OptimizerTorch
backend = 2  # Torch

pol = 0  # 0: TE, 1: TM

n_top = 1  # n_topncidence
n_bot = 1  # n_transmission

theta = 0 * torch.pi / 180  # angle of incidence
phi = 0 * torch.pi / 180  # angle of rotation

wavelength = 900

thickness = torch.tensor([500., 1000.])  # thickness of each layer, from top to bottom.
period = torch.tensor([1000.])  # length of the unit cell. Here it's 1D.

fto = [10]

type_complex = torch.complex128
device = torch.device('cpu')
ucell_1d_m = torch.tensor([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4 + 1.  # refractive index

2.1 Gradient#

Gradient can be calculated with the help of torch.autograd function. Read this for further information: A GENTLE INTRODUCTION TO TORCH.AUTOGRAD

Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.

2.1.1 Examples#

Example 1: manually get gradient

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True

de_ri, de_ti = mee.conv_solve()
loss = de_ti[de_ti.shape[0] // 2 + 1]

loss.backward()
print('ucell gradient:')
print(mee.ucell.grad)
print('thickness gradient:')
print(mee.thickness.grad)
ucell gradient:
tensor([[[-0.0511, -0.0253, -0.0073,  0.0787, -0.0184,  0.0945,  0.0878,
          -0.0012, -0.0364, -0.0478]],

        [[-0.1796, -0.0861, -0.2223, -0.1939,  0.0898,  0.0556, -0.0458,
          -0.1360, -0.2984,  0.1287]]], dtype=torch.float64)
thickness gradient:
tensor([ 0.0022, -0.0067], dtype=torch.float64)

Example 2: using predefined ‘grad’ function in meent

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device)

pois = ['ucell', 'thickness']  # Parameter Of Interests

forward = mee.conv_solve

# can use custom loss function or predefined loss function in meent.
loss_fn = LossDeflector(x_order=1)  # predefined in meent
# loss_fn = lambda x: x[1][x[1].shape[0] // 2 + 1]  # custom

grad = mee.grad(pois, forward, loss_fn)
print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])
ucell gradient:
tensor([[[-0.0511, -0.0253, -0.0073,  0.0787, -0.0184,  0.0945,  0.0878,
          -0.0012, -0.0364, -0.0478]],

        [[-0.1796, -0.0861, -0.2223, -0.1939,  0.0898,  0.0556, -0.0458,
          -0.1360, -0.2984,  0.1287]]], dtype=torch.float64)
thickness gradient:
tensor([ 0.0022, -0.0067], dtype=torch.float64)

2.2 Optimization#

2.2.1 Examples#

Example 1

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True
opt = torch.optim.SGD([mee.ucell, mee.thickness], lr=1E-2, momentum=0.9)

for _ in range(3):

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]

    loss.backward()
    opt.step()
    opt.zero_grad()

print('ucell final:')
print(mee.ucell)
print('thickness final:')
print(mee.thickness)
ucell final:
tensor([[[1.0029, 1.0015, 1.0005, 4.9967, 5.0018, 4.9958, 4.9962, 1.0002,
          1.0021, 1.0028]],

        [[5.0054, 4.9999, 5.0082, 5.0065, 0.9933, 4.9925, 4.9984, 5.0037,
          5.0133, 4.9886]]], dtype=torch.float64, requires_grad=True)
thickness final:
tensor([ 499.9999, 1000.0004], dtype=torch.float64, requires_grad=True)

Example 2

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device)


def forward_fn():

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]
    return loss

pois = ['ucell', 'thickness']
forward = forward_fn
loss_fn = lambda x: x
opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

res = mee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=3)

print('ucell final:')
print(res[0])
print('thickness final:')
print(res[1])
100%|██████████| 3/3 [00:00<00:00, 145.39it/s]
ucell final:
tensor([[[1.0029, 1.0015, 1.0005, 4.9967, 5.0018, 4.9958, 4.9962, 1.0002,
          1.0021, 1.0028]],

        [[5.0054, 4.9999, 5.0082, 5.0065, 0.9933, 4.9925, 4.9984, 5.0037,
          5.0133, 4.9886]]], dtype=torch.float64, requires_grad=True)
thickness final:
tensor([ 499.9999, 1000.0004], dtype=torch.float64, requires_grad=True)

Example 3

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device)

pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(1, 0)

opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

res = mee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=3)

print('ucell final:')
print(res[0])
print('thickness final:')
print(res[1])
100%|██████████| 3/3 [00:00<00:00, 196.12it/s]
ucell final:
tensor([[[1.0029, 1.0015, 1.0005, 4.9967, 5.0018, 4.9958, 4.9962, 1.0002,
          1.0021, 1.0028]],

        [[5.0054, 4.9999, 5.0082, 5.0065, 0.9933, 4.9925, 4.9984, 5.0037,
          5.0133, 4.9886]]], dtype=torch.float64, requires_grad=True)
thickness final:
tensor([ 499.9999, 1000.0004], dtype=torch.float64, requires_grad=True)