首页 > 解决方案 > 如何从 R 中 Shiny 中的反应式表达式访问变量

问题描述

我对使用 Shiny 的反应式表达式比较陌生。我的问题是:我发现自己必须使用几乎完全相同的代码创建许多反应式表达式,但使用不同的附加行只是为了输出不同的东西(即,当我调用反应式表达式时)。

例如,在下面的代码中,我创建了 3 个反应式表达式:p(),它输出一个绘图,best()它从给定的选定模型列表中指示测试集中误差最小的一个,最后results(),它输出 RMSE 和每个选定模型的测试集中的 MAPE 错误指标。

正如您在下面看到的,除了每个反应式表达式中的最后几行之外,代码几乎相同。所以,我的问题是,如何访问我在反应式表达式中创建的变量?例如,在创建反应式表达式之后,如何访问holt_forecast_11ets_forecast_11arima_forecast_11和?如何在另一个反应式表达式中调用这些变量?tbats_forecast_11p()

如果您需要更多详细信息,我很乐意提供。

这是我server.R文件中的代码。在此代码下方,我提供ui.R以防万一,尽管我的问题仅与server.R

library(shiny)
library(rsconnect)
library(tidyverse)
library(tidymodels)
library(lubridate)
library(forecast)
library(fpp3)
library(pwt10)
options(scipen=999)
Sys.setenv(LANG = "en")


shinyServer(function(input, output) {
    
    countries_isocode <- c("ARG", "BRA", "CHL", "COL", "MEX", "VEN")
    countries <- c("Argentina", "Brazil", "Chile", "Colombia", "Mexico", "Venezuela")
    
    p <- reactive({
        
        df <- pwt10.0 %>% 
            as_tibble() %>%
            filter(isocode %in% countries_isocode) %>% 
            dplyr::select(year, country, rgdpo) %>% 
            spread(key =
                       country, value = rgdpo) %>% 
            rename(Venezuela = `Venezuela (Bolivarian Republic of)`) %>% 
            dplyr::select(year, input$country)
        
        ts <- ts(df[,2], freq = 1, start = c(1950), end = c(2019))
        
        if(input$country == "Chile") {
            ts <- na.interp(ts)
        }
        
        train <- window(ts, end = c(2014))
        h <- length(ts) - length(train)
        
        if("holt" %in% input$model) {
            holt_model <- holt(train, h = 11)
            holt_forecast <- forecast(holt_model, h = h)
            holt_forecast_11 <- forecast(holt_model, h = 11) 
        }
        
        if("ets" %in% input$model) {
            ets_model <- ets(train)
            ets_forecast <- forecast(ets_model, h = h)
            ets_forecast_11 <- forecast(ets_model, h = 11)
        }

        if("arima" %in% input$model) {
            arima_model <- auto.arima(train)
            arima_forecast <- forecast(arima_model, h = h)
            arima_forecast_11 <- forecast(arima_model, h = 11)
        }
        
        if("tbats" %in% input$model) {
            tbats_model <- tbats(train)
            tbats_forecast <- forecast(tbats_model, h = h)
            tbats_forecast_11 <- forecast(tbats_model, h = 11)
        }
        
        p <- autoplot(ts)
        
            if("holt" %in% input$model) {
                p <- p + autolayer(holt_forecast_11, series = "HOLT", PI = FALSE)
            }
        
            if("ets" %in% input$model) {
                p <- p + autolayer(ets_forecast_11, series = "ETS", PI = FALSE)
            }
        
            if("arima" %in% input$model) {
                p <- p + autolayer(arima_forecast_11, series = "ARIMA", PI = FALSE)
            }
        
            if("tbats" %in% input$model) {
                p <- p + autolayer(tbats_forecast_11, series = "TBATS", PI = FALSE) 
            }
        
        p
    })
    
    
    best <- reactive({
        
        df <- pwt10.0 %>% 
            as_tibble() %>%
            filter(isocode %in% countries_isocode) %>% 
            dplyr::select(year, country, rgdpo) %>% 
            spread(key =
                       country, value = rgdpo) %>% 
            rename(Venezuela = `Venezuela (Bolivarian Republic of)`) %>% 
            dplyr::select(year, input$country)
        
        ts <- ts(df[,2], freq = 1, start = c(1950), end = c(2019))
        
        if(input$country == "Chile") {
            ts <- na.interp(ts)
        }
        
        train <- window(ts, end = c(2014))
        h <- length(ts) - length(train)
        
        if("holt" %in% input$model) {
            holt_model <- holt(train, h = 11)
            holt_forecast <- forecast(holt_model, h = h)
            holt_forecast_11 <- forecast(holt_model, h = 11) 
        }
        
        if("ets" %in% input$model) {
            ets_model <- ets(train)
            ets_forecast <- forecast(ets_model, h = h)
            ets_forecast_11 <- forecast(ets_model, h = 11)
        }
        
        if("arima" %in% input$model) {
            arima_model <- auto.arima(train)
            arima_forecast <- forecast(arima_model, h = h)
            arima_forecast_11 <- forecast(arima_model, h = 11)
        }
        
        if("tbats" %in% input$model) {
            tbats_model <- tbats(train)
            tbats_forecast <- forecast(tbats_model, h = h)
            tbats_forecast_11 <- forecast(tbats_model, h = 11)
        }
        
        ### RMSE
        
        RMSE <- vector("numeric")
        
        if("holt" %in% input$model) {
            RMSE <- append(RMSE, c(HOLT = accuracy(holt_forecast, ts)["Test set","RMSE"]))
        }
        if("ets" %in% input$model) {
            RMSE <- append(RMSE, c(ETS = accuracy(ets_forecast, ts)["Test set","RMSE"]))
        }
        if("arima" %in% input$model) {
            RMSE <- append(RMSE, c(ARIMA = accuracy(arima_forecast, ts)["Test set","RMSE"]))
        }
        if("tbats" %in% input$model) {
            RMSE <- append(RMSE, c(TBATS = accuracy(tbats_forecast, ts)["Test set","RMSE"]))
        }
        
        ### MAPE
        
        MAPE <- vector("numeric")
        
        if("holt" %in% input$model) {
            MAPE <- append(MAPE, c(HOLT = accuracy(holt_forecast, ts)["Test set","MAPE"]))
        }
        if("ets" %in% input$model) {
            MAPE <- append(MAPE, c(ETS = accuracy(ets_forecast, ts)["Test set","MAPE"]))
        }
        if("arima" %in% input$model) {
            MAPE <- append(MAPE, c(ARIMA = accuracy(arima_forecast, ts)["Test set","MAPE"]))
        }
        if("tbats" %in% input$model) {
            MAPE <- append(MAPE, c(TBATS = accuracy(tbats_forecast, ts)["Test set","MAPE"]))
        }
        
        df <- as.data.frame(rbind(RMSE, MAPE))
        
        names(df)[order(df[2,])[1]]
    })

    
    results <- reactive({
        
        df <- pwt10.0 %>% 
            as_tibble() %>%
            filter(isocode %in% countries_isocode) %>% 
            dplyr::select(year, country, rgdpo) %>% 
            spread(key =
                       country, value = rgdpo) %>% 
            rename(Venezuela = `Venezuela (Bolivarian Republic of)`) %>% 
            dplyr::select(year, input$country)
        
        ts <- ts(df[,2], freq = 1, start = c(1950), end = c(2019))
        
        if(input$country == "Chile") {
            ts <- na.interp(ts)
        }
        
        train <- window(ts, end = c(2014))
        h <- length(ts) - length(train)
        
        if("holt" %in% input$model) {
            holt_model <- holt(train, h = 11)
            holt_forecast <- forecast(holt_model, h = h)
            holt_forecast_11 <- forecast(holt_model, h = 11) 
        }
        
        if("ets" %in% input$model) {
            ets_model <- ets(train)
            ets_forecast <- forecast(ets_model, h = h)
            ets_forecast_11 <- forecast(ets_model, h = 11)
        }
        
        if("arima" %in% input$model) {
            arima_model <- auto.arima(train)
            arima_forecast <- forecast(arima_model, h = h)
            arima_forecast_11 <- forecast(arima_model, h = 11)
        }
        
        if("tbats" %in% input$model) {
            tbats_model <- tbats(train)
            tbats_forecast <- forecast(tbats_model, h = h)
            tbats_forecast_11 <- forecast(tbats_model, h = 11)
        }
        
        ### RMSE
        
        RMSE <- vector("numeric")
        
        if("holt" %in% input$model) {
            RMSE <- append(RMSE, c(HOLT = accuracy(holt_forecast, ts)["Test set","RMSE"]))
        }
        if("ets" %in% input$model) {
            RMSE <- append(RMSE, c(ETS = accuracy(ets_forecast, ts)["Test set","RMSE"]))
        }
        if("arima" %in% input$model) {
            RMSE <- append(RMSE, c(ARIMA = accuracy(arima_forecast, ts)["Test set","RMSE"]))
        }
        if("tbats" %in% input$model) {
            RMSE <- append(RMSE, c(TBATS = accuracy(tbats_forecast, ts)["Test set","RMSE"]))
        }
        
        ### MAPE
        
        MAPE <- vector("numeric")
        
        if("holt" %in% input$model) {
            MAPE <- append(MAPE, c(HOLT = accuracy(holt_forecast, ts)["Test set","MAPE"]))
        }
        if("ets" %in% input$model) {
            MAPE <- append(MAPE, c(ETS = accuracy(ets_forecast, ts)["Test set","MAPE"]))
        }
        if("arima" %in% input$model) {
            MAPE <- append(MAPE, c(ARIMA = accuracy(arima_forecast, ts)["Test set","MAPE"]))
        }
        if("tbats" %in% input$model) {
            MAPE <- append(MAPE, c(TBATS = accuracy(tbats_forecast, ts)["Test set","MAPE"]))
        }
        
        df <- as.data.frame(rbind(RMSE, MAPE))
        
        df
    })
    
    
    output$plot <- renderPlot({
        p()
    })

    output$results <- renderPrint({
        print(paste("According to the MAPE, the best model is:", best()))
        print("The final results are:")
        results()
    })
    
})

现在,这里是ui.R

library(shiny)
library(rsconnect)
library(tidyverse)
library(tidymodels)
library(lubridate)
library(forecast)
library(fpp3)
library(pwt10)
options(scipen=999)
Sys.setenv(LANG = "en")


countries_isocode <- c("ARG", "BRA", "CHL", "COL", "MEX", "VEN")
countries <- c("Argentina", "Brazil", "Chile", "Colombia", "Mexico", "Venezuela")

# pwt10.0 %>% 
#     as_tibble() %>% 
#     filter(isocode %in% countries) %>%
#     ggplot(aes(year, rgdpo, color = isocode)) + 
#     geom_line() +
#     labs(x = "Year", y = "Output-side real GDP at chained PPPs (in million 2017 USD)", color = "Country")


shinyUI(fluidPage(
    titlePanel("Time Series Prediction Application"),
    sidebarLayout(
        sidebarPanel(
            selectInput("country", "Select a country:", countries, "Brazil"),
            checkboxGroupInput("model", "Select time series models to evaluate:",
                               choiceNames = list("Holt's Trend Method", "ETS", "ARIMA", "TBATS"),
                               choiceValues = list("holt", "ets", "arima", "tbats"),
                               selected = c("ets", "arima"))
        ),
        mainPanel(
            plotOutput("plot"),
            verbatimTextOutput("results")
        )
    )
))

谢谢!

标签: rshinytime-seriesforecastingreactive

解决方案


重构此代码的一种方法是编写函数来检索输入国家/地区的数据(get_data在下面的代码中)并对所选模型进行预测(get_forecasts它使用输出列表get_data作为其输入之一)。

使用这些函数,shinyServer 函数的逻辑很简单:获取所选国家的数据,使用所选模型进行预测,并显示绘图和结果。


library(shiny)
library(rsconnect)
library(tidyverse)
library(tidymodels)
library(lubridate)
library(forecast)
library(fpp3)
library(pwt10)
options(scipen=999)
Sys.setenv(LANG = "en")

countries_isocode <- c("ARG", "BRA", "CHL", "COL", "MEX", "VEN")
countries <- c("Argentina", "Brazil", "Chile", "Colombia", "Mexico", "Venezuela")

# function to retrieve data for a country
# input: country name
# output: list with components df, ts, train, h
get_data <- function(country) {
    df <- pwt10.0 %>% 
        as_tibble() %>%
        filter(isocode %in% countries_isocode) %>% 
        dplyr::select(year, country, rgdpo) %>% 
        spread(key =
                   country, value = rgdpo) %>% 
        rename(Venezuela = `Venezuela (Bolivarian Republic of)`) %>% 
        dplyr::select(year, country)
    ts <- ts(df[,2], freq = 1, start = c(1950), end = c(2019))
    if(country == "Chile") {
        ts <- na.interp(ts)
    }
    train <- window(ts, end = c(2014))
    h <- length(ts) - length(train)
    return(list(df = df,
                ts = ts,
                train = train,
                h = h))
}

# function to turn a model name into a forecasting function and a series name
# input: model name ('holt', 'ets', 'arima', tbats')
# output: list with components fn (a forecasting function) and seriesname (a series name)
get_forecast_seriesname <- function(model) {
    L <- list()
    if (model == "holt") L <- list(fn = holt, seriesname = "HOLT")
    if (model == "ets") L <- list(fn = ets, seriesname = "ETS")
    if (model == "arima") L <- list(fn = auto.arima, seriesname = "ARIMA")
    if (model == "tbats") L <- list(fn = tbats, seriesname = "TBATS")
    return(L)
}

# function to get forecasts
# inputs: g (output from get_data above), models = vector of model names
# output: list containing p (plot), results (results), best (name of best model)
get_forecasts <- function(g, models) {
    p <- autoplot(g$ts)
    RMSE <- MAPE <- vector("numeric")
    for (model in models) {
        tmplist <- get_forecast_seriesname(model)
        func <- tmplist$fn
        if (model == "holt") {
            this_model <- func(g$train, h = 11)
        } else {
            this_model <- func(g$train)
        }
        this_forecast <- forecast(this_model, h = g$h)
        RMSE <- append(RMSE, c(tmp = accuracy(this_forecast, g$ts)["Test set","RMSE"]))
        names(RMSE)[which(names(RMSE) == "tmp")] <- tmplist$seriesname
        MAPE <- append(MAPE, c(tmp = accuracy(this_forecast, g$ts)["Test set","MAPE"]))
        names(MAPE)[which(names(MAPE) == "tmp")] <- tmplist$seriesname
        this_forecast_11 <- forecast(this_model, h = 11)
        p <- p + autolayer(this_forecast_11,
                            series = tmplist$seriesname,
                            PI = FALSE)
    }
    results <- as.data.frame(rbind(RMSE, MAPE))
    best <- names(result_df)[order(result_df[2,])[1]]
    return(list(p = p, results = results, best = best))
}


shinyServer(function(input, output) {

    data_selected_country <- reactive({
        get_data(input$country)
    })

    forecasts <- reactive({
        g <- data_selected_country()
        get_forecasts(g, input$model)
    })

    output$plot <- renderPlot({
        forecasts()$p
    })
    
    output$results <- renderPrint({
        print(paste("According to the MAPE, the best model is:", forecasts()$best))
        print("The final results are:")
        forecasts()$results
    })
    
})

推荐阅读