if (!requireNamespace('pacman', quietly=TRUE)) install.packages('pacman')
pacman::p_load(ggplot2, patchwork, tibble, tidyr, haven)
time  <- seq(0.1, 79.9, length.out=100)
beta2 <- 0
u1    <- 0
u2    <- 0
eta   <- 0
cif <- function(time, beta1, beta2, gamma, w, u1, u2, eta)
{
    delta       <- 80
    core        <- expand.grid(time, beta1, beta2, gamma, w, u1, u2, eta)
    names(core) <-
        c('time', 'beta1', 'beta2', 'gamma', 'w', 'u1', 'u2', 'eta')
    risk1 <- exp(core$beta1 + core$u1)
    risk2 <- exp(core$beta2 + core$u2)
    level <- 1 + risk1 + risk2
    gtime <- atanh(2 * core$time/delta - 1)
    x     <- core$w * gtime - core$gamma - core$eta
    ftime <- pnorm(x)
    out   <- risk1/level * ftime
    return(cbind(out, core))
}
parplot <- function(data, par, expar)
{
    ggplot(data, aes(time, out, group=data[, par]))+
        geom_line(aes(linetype=factor(data[, par])))+
        labs(x='Time', y=NULL, linetype=expar)+
        ## theme_minimal()+
        theme(legend.position      ='bottom',
              legend.box.background=element_rect(color='black'), 
              legend.title         =element_text(size=14),
              legend.text          =element_text(size=14),
              axis.text.x          =element_text(size=13),
              axis.text.y          =element_text(size=13),
              axis.title.x         =element_text(
                  size=14, margin=unit(c(t=3, r=0, b=0, l=0), 'mm')),
              axis.title.y         =element_text(
                  size=14, margin=unit(c(t=0, r=3, b=0, l=0), 'mm')))
}
beta1 <- c(-0.5, 0, 0.5)
gamma <- 0
w     <- 1
df <- cif(time, beta1, beta2, gamma, w, u1, u2, eta)
g1 <- parplot(df, par='beta1', expression(beta[1]))
beta1 <- 0
gamma <- c(-2, 0, 2)
w     <- 1
df <- cif(time, beta1, beta2, gamma, w, u1, u2, eta)
g2 <- parplot(df, par='gamma', expression(gamma[1]))
beta1 <- 0
gamma <- 0
w     <- c(-1, 0, 1)
df <- cif(time, beta1, beta2, gamma, w, u1, u2, eta)
g3 <- parplot(df, par='w', expression(w[1]))
beta1 <- -2
beta2 <- -1
gamma <- 1
w     <- 3
u1    <- c(-1, -0.5, 0.5, 1)
eta   <- c(-2, -1, 0, 1, 2)
df <- cif(time, beta1, beta2, gamma, w, u1, u2, eta)
g4 <-
    ggplot(df, aes(time, out, group=eta))+
    geom_line(aes(linetype=factor(eta)))+
    facet_wrap(~u1, labeller=label_bquote(u[1] : .(u1)), nrow=1)+
    labs(x='Time', y=NULL, linetype=expression(eta[1]))+
    ## theme_minimal()+
    ## scale_color_brewer(palette='Spectral')+
    ## paletteer::scale_color_paletteer_d('scico::tokyo')+
    ## scico::scale_color_scico_d(palette='davos')+
    ## scale_color_grey(start=0, end=0.75)+
    theme(legend.position      ='bottom',
          legend.title         =element_text(size=14),
          legend.text          =element_text(size=14),
          legend.box.background=element_rect(color='black'), 
          strip.background     =element_rect(colour='black', fill='white'),
          strip.text.x         =element_text(size=14),
          axis.text.x          =element_text(size=13),
          axis.text.y          =element_text(size=13),
          axis.title.x         =element_text(
              size=14,
              margin=unit(c(t=3, r=0, b=0, l=0), 'mm')),
          axis.title.y         =element_text(
              size=14,
              margin=unit(c(t=0, r=3, b=0, l=0), 'mm')))
(g1+g2+g3)/g4+
    plot_annotation(
        title='Cluster-specific Cumulative Incidence Function (CIF)'
    )&theme(plot.title=element_text(size=16))

cif <- function() {
    t <- seq(30, 79.5, by = 0.5)
    delta <- 80
    u <- c(0, 0)
    eta <- c(0, 0)
    beta <- c(-2, -1.5)
    gamma <- c(1.2, 1)
    w <- c(3, 5)
    risk1 <- exp(beta[1] + u[1])
    risk2 <- exp(beta[2] + u[2])
    level <- 1 +  risk1 + risk2
    risklevel1 <- risk1/level
    risklevel2 <- risk2/level
    traj1 <- pnorm(w[1] * atanh(2 * t/delta - 1) - gamma[1] - eta[1])
    traj2 <- pnorm(w[2] * atanh(2 * t/delta - 1) - gamma[2] - eta[2])
    cif1 <- risklevel1 * traj1
    cif2 <- risklevel2 * traj2
    out <- as.data.frame(cbind(t, cif1, cif2))
    out <- out %>% pivot_longer(cols = cif1:cif2,
                                names_to = "cif", values_to = "value")
    return(out)
}
dfcif <- cif()
dcif <- function() {
    t <- seq(30, 79.5, by = 0.5)
    delta <- 80
    u <- c(0, 0)
    eta <- c(0, 0)
    beta <- c(-2, -1.5)
    gamma <- c(1.2, 1)
    w <- c(3, 5)
    risk1 <- exp(beta[1] + u[1])
    risk2 <- exp(beta[2] + u[2])
    level <- 1 +  risk1 + risk2
    risklevel1 <- risk1/level
    risklevel2 <- risk2/level
    dtraj1 <- w[1] * delta/(2 * t * (delta - t)) *
        dnorm(w[1] * atanh(2 * t/delta - 1) - gamma[1] - eta[1])
    dtraj2 <- w[2] * delta/(2 * t * (delta - t)) *
        dnorm(w[2] * atanh(2 * t/delta - 1) - gamma[2] - eta[2])
    dcif1 <- risklevel1 * dtraj1
    dcif2 <- risklevel2 * dtraj2
    out <- as.data.frame(cbind(t, dcif1, dcif2))
    out <- out %>% pivot_longer(cols = dcif1:dcif2,
                                names_to = "dcif", values_to = "value")
    return(out)
}
dfdcif <- dcif()
p1 <- ggplot(dfcif, aes(t, value, group = cif)) +
    geom_line(aes(linetype = cif)) +
    labs(x = "Time", y = NULL) +
    scale_linetype_discrete(labels = c("CIF 1", "CIF 2")) +
    theme(legend.position = "bottom",
          legend.title = element_blank(),
          legend.text = element_text(size = 14),
          axis.text.x = element_text(size = 13),
          axis.text.y = element_text(size = 13),
          axis.title.x = element_text(
              size = 14,
              margin = unit(c(t = 3, r = 0, b = -2, l = 0), "mm")),
          legend.box.background=element_rect(color='black'))
p2 <- ggplot(dfdcif, aes(t, value, group = dcif)) +
    geom_line(aes(linetype = dcif)) +
    labs(x = "Time", y = NULL) +
    scale_linetype_discrete(labels = c("dCIF 1", "dCIF 2")) +
    theme(legend.position = "bottom",
          legend.title = element_blank(),
          legend.text = element_text(size = 14),
          axis.text.x = element_text(size = 13),
          axis.text.y = element_text(size = 13),
          axis.title.x = element_text(
              size = 14,
              margin = unit(c(t = 3, r = 0, b = -2, l = 0), "mm")),
          legend.box.background=element_rect(color='black'))
p3 <- dfcif|>
    tidyr::pivot_wider(names_from=cif, values_from=value)|>
    dplyr::rename('CIF 1'='cif1', 'CIF 2'='cif2')|>
    dplyr::mutate(Censorship=(1-`CIF 1`-`CIF 2`))|>
    tidyr::pivot_longer(!t, names_to='cif', values_to='value')|>
    dplyr::mutate(cif=factor(cif,
                             levels=c('Censorship', 'CIF 2', 'CIF 1')))|>
    ggplot(aes(t, value, fill=cif))+
    geom_area(alpha=1, size=1, color='white')+
    labs(x='Time')+
    scale_fill_manual(values=c('#ff4c00', '#262626', '#0096c8'))+
    theme(legend.position = "bottom",
          legend.title = element_blank(),
          legend.text = element_text(size = 14),
          axis.text.x = element_text(size = 13),
          axis.text.y = element_text(size = 13),
          axis.title.x=element_text(
              size=14, 
              margin=unit(c(t=3, r=0, b=-2, l=0), "mm")),
          axis.title.y=element_blank(), 
          legend.box.background=element_rect(color='black'))+
    guides(fill=guide_legend(reverse=TRUE, nrow=2))

p1 + p2 + p3

lev <- c('u1', 'u2', 'e1', 'e2')
tibble::tibble(expand.grid(lev, lev), 
               'Risk model'=c('v1', 'c1', NA, NA,
                              'c1', 'v2', NA, NA,
                              NA, NA, NA, NA,
                              NA, NA, NA, NA),
               'Time model'=c(NA, NA, NA, NA,
                              NA, NA, NA, NA,
                              NA, NA, 'v3', 'c2',
                              NA, NA, 'c2', 'v4'),
               'Block-diag model'=c('v1', 'c1', NA, NA,
                                    'c1', 'v2', NA, NA,
                                    NA, NA, 'v3', 'c2',
                                    NA, NA, 'c2', 'v4'),
               'Complete model'=c('v1', 'c1', 'c3', 'c4',
                                  'c1', 'v2', 'c5', 'c6',
                                  'c3', 'c4', 'v3', 'c2',
                                  'c5', 'c6', 'c2', 'v4'))|>
    tidyr::pivot_longer(`Risk model`:`Complete model`,
                        names_to='type')|>
    dplyr::mutate(Var2=factor(Var2, levels=lev[4:1]),
                  type=haven::as_factor(type))|>
    ggplot(aes(x=Var1, y=Var2, fill=value))+
    geom_tile(show.legend=FALSE)+
    facet_wrap(~ type, nrow=1)+
    labs(x=NULL, y=NULL)+
    scale_x_discrete(labels=c('u1'=expression(u[1]),
                              'u2'=expression(u[2]),
                              'e1'=expression(eta[1]),
                              'e2'=expression(eta[2])))+
    scale_y_discrete(labels=c('u1'=expression(u[1]),
                              'u2'=expression(u[2]),
                              'e1'=expression(eta[1]),
                              'e2'=expression(eta[2])))+
    scale_fill_viridis_d(option='C')+
    theme(strip.background=element_blank(), 
          strip.text.x=element_text(
              size=15,
              margin=ggplot2::unit(c(t=0, r=0, b=3, l=0), 'mm')),
          axis.text.x=element_text(size=15),
          axis.text.y=element_text(size=15))