Gradient and Optimization with JAX and Optax#
import jax
import optax
import jax.numpy as jnp
import meent
from meent.on_jax.optimizer.loss import LossDeflector
backend = 1 # JAX
# common
pol = 0 # 0: TE, 1: TM
n_top = 1 # n_topncidence
n_bot = 1 # n_transmission
theta = 0 * jnp.pi / 180 # angle of incidence
phi = 0 * jnp.pi / 180 # angle of rotation
wavelength = 900
thickness = [500., 1000.] # thickness of each layer, from top to bottom.
period = [1000.] # length of the unit cell. Here it's 1D.
fto = [10]
type_complex = jnp.complex128
ucell_1d_m = jnp.array([
[[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
[[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
]) * 4. + 1. # refractive index
2.1 Gradient#
Gradient can be calculated with the help of jax.value_and_grad function.
Read this for further information: Automatic Differentiation
Optax is used for optimization. Like PyTorch, Optax also provides various loss functions and optimizers so users easily can utilize well-established implementations. Refer this tutorial: Learn Optax
Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.
2.1.1 Examples#
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)
pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)
# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])
ucell gradient:
[[[-0.05114874 -0.02533636 -0.00729883 0.07873582 -0.01841166
0.09447967 0.08779338 -0.0012304 -0.03640632 -0.04779842]]
[[-0.17959986 -0.08614187 -0.22233491 -0.19389416 0.08978906
0.05564021 -0.04575985 -0.13595162 -0.29835993 0.12867445]]]
thickness gradient:
[ 0.00222043 -0.00671415]
2.2 Optimization#
2.2.1 Examples#
Example 1
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)
pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)
# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2, momentum=0.9)
res = mee.fit(pois, forward, loss_fn, optimizer, iteration=3)
print('ucell final:')
print(res['ucell'])
print('thickness final:')
print(res['thickness'])
100%|██████████| 3/3 [00:05<00:00, 1.88s/it]
ucell final:
[[[1.00286423 1.00145549 1.00050169 4.99666797 5.00175318 4.99580863
4.99617526 1.00015109 1.00214635 1.00275083]]
[[5.0054235 4.99990456 5.00824621 5.0065062 0.99325253 4.99254125
4.99835018 5.00367578 5.01333396 4.9885967 ]]]
thickness final:
[ 499.99989253 1000.00039487]