Device Setting - CPU or GPU in JAX

Device Setting - CPU or GPU in JAX#

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import time
import numpy as np

import meent
# experiment options
grating_type = 2
pol = 0  # 0: TE, 1: TM

n_top = 1  # n_topncidence
n_bot = 1  # n_transmission

theta = 20 * np.pi / 180
phi = 50 * np.pi / 180

wavelength = 900

thickness = [500]
period = [1000, 1000]

fourier_order = [15, 15]
# fourier_order = [3, 3]
res_x, res_y, res_z = 20, 20, 20

ucell = np.array([
            [
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
                [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
            ],
        ]) * 4 + 1

set device#

  1. at initialization

backend = 1  # JaxMeent
device = 0 # CPU;
dtype = 0
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, device=device, type_complex=dtype)

    
  1. after initialization

backend = 1  # JaxMeent

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, type_complex=dtype)
mee.device = 0

Test#

CPU, 64 bit#

backend = 1  # JaxMeent
device = 0 # CPU;
dtype = 0
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, device=device, type_complex=dtype)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 2nd: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 1st: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 2nd: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 2nd: ', time.time() - t0)
/home/yongha/anaconda3/envs/meent/lib/python3.10/site-packages/numpy/core/getlimits.py:500: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/yongha/anaconda3/envs/meent/lib/python3.10/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
time for efficiency, 1st:  15.45072317123413
time for efficiency, 2nd:  11.552058219909668
time for field, 1st:  9.985004425048828
time for field, 2nd:  8.36226487159729
time for efficiency and field in one step, 1st:  22.077845573425293
time for efficiency and field in one step, 2nd:  9.886351823806763

GPU, 64 bit#

backend = 1  # JaxMeent
device = 1 # GPU;
dtype = 0
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, device=device, type_complex=dtype)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 2nd: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 1st: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 2nd: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 2nd: ', time.time() - t0)
time for efficiency, 1st:  16.228806495666504
time for efficiency, 2nd:  9.38192367553711
time for field, 1st:  8.504980564117432
time for field, 2nd:  0.5013375282287598
time for efficiency and field in one step, 1st:  13.888967990875244
time for efficiency and field in one step, 2nd:  9.839607238769531

CPU, 32 bit#

backend = 1  # JaxMeent
device = 0  # CPU;
dtype = 1  # 32bit
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, device=device, type_complex=dtype)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 2nd: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 1st: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 2nd: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 2nd: ', time.time() - t0)
/home/yongha/anaconda3/envs/meent/lib/python3.10/site-packages/numpy/core/getlimits.py:500: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/yongha/anaconda3/envs/meent/lib/python3.10/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
time for efficiency, 1st:  10.070057392120361
time for efficiency, 2nd:  5.8916335105896
time for field, 1st:  5.814962148666382
time for field, 2nd:  4.369853973388672
time for efficiency and field in one step, 1st:  7.839098691940308
time for efficiency and field in one step, 2nd:  3.287320137023926

GPU, 32 bit#

backend = 1  # JaxMeent
device = 1  # CPU;
dtype = 1  # 32bit
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
                     fto=fto, wavelength=wavelength, period=period, ucell=ucell,
                     thickness=thickness, device=device, type_complex=dtype)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti = mee.conv_solve()
print(f'time for efficiency, 2nd: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 1st: ', time.time() - t0)

t0 = time.time()
field_cell = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for field, 2nd: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 1st: ', time.time() - t0)

t0 = time.time()
de_ri, de_ti, field_cell = mee.conv_solve_field(res_x=res_x, res_y=res_y, res_z=res_z)
print(f'time for efficiency and field in one step, 2nd: ', time.time() - t0)
time for efficiency, 1st:  7.133803606033325
time for efficiency, 2nd:  3.365471601486206
time for field, 1st:  5.952059507369995
time for field, 2nd:  0.2004835605621338
time for efficiency and field in one step, 1st:  7.279724836349487
time for efficiency and field in one step, 2nd:  3.8123528957366943