In [None]:
import time
import pickle
import gzip

import numpy as np
import torch
from tqdm.auto import tqdm

import punica

In [None]:
torch.set_grad_enabled(False)

In [None]:
def bench(f, *, device="cuda:0", min_repeat: int, min_secs: float) -> np.ndarray:
    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device)
    latency = []
    
    # First run
    torch.cuda.synchronize()
    st = time.perf_counter_ns()
    f()
    torch.cuda.synchronize()
    ed = time.perf_counter_ns()
    latency.append((ed-st)/1e9)
    
    # Subsequent runs, until reaching both min_repeat and min_secs
    min_nanos = int(min_secs * 1e9)
    start_nanos = time.perf_counter_ns()
    while True:
        now_nanos = time.perf_counter_ns()
        if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:
            break
        cache.zero_()
        torch.cuda.synchronize()
        st = time.perf_counter_ns()
        f()
        torch.cuda.synchronize()
        ed = time.perf_counter_ns()
        latency.append((ed-st)/1e9)
    return np.array(latency)

def tail_mean_std(xs, skip=0.2):
    a = xs[int(len(xs) * skip):]
    return a.mean(), a.std()

def fmt_avg_std(xs, skip=0.2):
  a = xs[int(len(xs) * skip):] * 1e6
  return f"{a.mean():.3f} us +/- {a.std():.3f} us"

In [None]:
def lora_bmm(
    y: torch.Tensor,  # (batch_size, 1, out_features)
    x: torch.Tensor,  # (batch_size, 1, in_features)
    A: torch.Tensor,  # (batch_size, in_features, lora_rank)
    B: torch.Tensor,  # (batch_size, lora_rank, out_features)
):
  y[:, :] += x @ A @ B

In [None]:
def lora_loop(
    y: torch.Tensor,  # (batch_size, 1, out_features)
    x: torch.Tensor,  # (batch_size, 1, in_features)
    A: torch.Tensor,  # (num_loras, in_features, lora_rank)
    B: torch.Tensor,  # (num_loras, lora_rank, out_features)
    I: torch.LongTensor,  # (batch_size,)
):
  for i, idx in enumerate(I.cpu().numpy()):
    y[i] += x[i] @ A[idx] @ B[idx]

In [None]:
def lora_gbmm(
    y: torch.Tensor,  # (batch_size, 1, out_features)
    x: torch.Tensor,  # (batch_size, 1, in_features)
    A: torch.Tensor,  # (num_loras, in_features, lora_rank)
    B: torch.Tensor,  # (num_loras, lora_rank, out_features)
    I: torch.LongTensor,  # (batch_size,)
):
  a = torch.index_select(A, 0, I) # (batch_size, in_features, lora_rank)
  b = torch.index_select(B, 0, I) # (batch_size, lora_rank, out_features)
  y[:, :] += x @ a @ b

In [None]:
def lora_punica(
    y: torch.Tensor,  # (batch_size, out_features)
    x: torch.Tensor,  # (batch_size, in_features)
    wa_T_all: torch.Tensor,  # (num_loras, num_layers, lora_rank, in_features)
    wb_T_all: torch.Tensor,  # (num_loras, num_layers, out_features, lora_rank)
    lora_indices: torch.LongTensor,  # (batch_size,)
):
  punica.ops.add_lora(y, x, wa_T_all, wb_T_all, lora_indices, layer_idx=0, scale=1.0)

In [None]:
def bench_lora_bs():
  torch.manual_seed(0xabcdabcd987)
  dtype = torch.float16
  device = torch.device("cuda:0")
  
  num_loras = 50
  h1 = 4096
  h2 = 11008
  r = 16
  
  wa_all = torch.randn(num_loras, h1, r, dtype=dtype, device=device)
  wb_all = torch.randn(num_loras, r, h2, dtype=dtype, device=device)
  wa_T_all = wa_all.unsqueeze(1).transpose(-1, -2).contiguous()
  wb_T_all = wb_all.unsqueeze(1).transpose(-1, -2).contiguous()

  bs_list = np.arange(1, 33)
  res = dict(bmm=[], loop=[], gbmm=[], punica=[])
  for bs in tqdm(bs_list):
    x = torch.randn(bs, 1, h1, dtype=dtype, device=device)
    y = torch.randn(bs, 1, h2, dtype=dtype, device=device)
    indices = torch.randint(num_loras, (bs,), dtype=torch.long, device=device)
    a = torch.index_select(wa_all, 0, indices)
    b = torch.index_select(wb_all, 0, indices)

    y_bmm = y.clone()
    lora_bmm(y_bmm, x, a, b)
    
    y_loop = y.clone()
    lora_loop(y_loop, x, wa_all, wb_all, indices)
    
    y_gbmm = y.clone()
    lora_gbmm(y_gbmm, x, wa_all, wb_all, indices)
    # torch.testing.assert_close(y_loop, y_gbmm, rtol=1e-2, atol=1e-2)

    x_punica = x.squeeze(1).clone()
    y_punica = y.squeeze(1).clone()
    lora_punica(y_punica, x_punica, wa_T_all, wb_T_all, indices)
    # torch.testing.assert_close(y_loop, y_punica, rtol=1e-2, atol=1e-2)

    res["bmm"].append(tail_mean_std(bench(lambda: lora_bmm(y, x, a, b), min_repeat=20, min_secs=2)))
    res["loop"].append(tail_mean_std(bench(lambda: lora_loop(y, x, wa_all, wb_all, indices), min_repeat=20, min_secs=2)))
    res["gbmm"].append(tail_mean_std(bench(lambda: lora_gbmm(y, x, wa_all, wb_all, indices), min_repeat=20, min_secs=2)))
    res["punica"].append(tail_mean_std(bench(lambda: lora_punica(y_punica, x_punica, wa_T_all, wb_T_all, indices), min_repeat=20, min_secs=2)))
  ret = {
    k: dict(avg=np.array([avg for avg, std in v]),
            std=np.array([std for avg, std in v]))
    for k, v in res.items()
  }
  return bs_list, ret

In [None]:
bs, res = bench_lora_bs()
with gzip.open("data/20230911-lora-ops.pkl.gz", "wb") as f:
  pickle.dump((bs, res), f)

In [None]:
def bench_backbone_vs_lora():
  torch.manual_seed(0xabcdabcd987)
  dtype = torch.float16
  device = torch.device("cuda:0")
  
  h1 = 4096
  h2 = 11008
  r = 16
  
  bs_list = np.arange(1, 33)
  res = dict(backbone=[], lora=[])
  for bs in tqdm(bs_list):
    w = torch.randn(h1, h2, dtype=dtype, device=device)
    wa = torch.randn(bs, h1, r, dtype=dtype, device=device)
    wb = torch.randn(bs, r, h2, dtype=dtype, device=device)
    x = torch.randn(bs, 1, h1, dtype=dtype, device=device)
    
    res["backbone"].append(tail_mean_std(bench(lambda: x @ w, min_repeat=20, min_secs=2)))
    res["lora"].append(tail_mean_std(bench(lambda: x @ wa @ wb, min_repeat=20, min_secs=2)))
  ret = {
    k: dict(avg=np.array([avg for avg, std in v]),
            std=np.array([std for avg, std in v]))
    for k, v in res.items()
  }
  return bs_list, ret

In [None]:
bs, res = bench_backbone_vs_lora()
with gzip.open("data/20230911-backbone-vs-lora.pkl.gz", "wb") as f:
  pickle.dump((bs, res), f)