Run dash callbacks in parallel

Hello,

I am currently hosting my Dash web-app on a Windows 10 VM and wanted to know if I can run 3 different callbacks that rely on 3 different dataframes parallely? The callbacks are fired when the user selects an option from a dropdown menu and update 3 different charts.

My web-app looks like this …

As soon as the user has made a selection, I do the following using callbacks -

  1. Read multiple parquet files in to a dataframe
  2. perform aggregations, filtering amongst other operations

Most of this gets done within a second and the resultant dataframe is used with go.Heatmap, go.Scatter and go.Heatmap components respectively to update the 3 graphs at the bottom of the page

The process of creating the dataframe up to the creating of dictionaries for passing to the Output of callbacks takes a second max however the actually display on the web app takes anywhere between 2-3 seconds to 8-10 seconds.

Since the last step of throwing the data to the browser is IO bound with calls made to the Flask Server, I wanted to know whether I can run them in parallel?

Here is a snippet of my code from the callbacks …

@app.callback([Output('heatmap1', 'figure'), Output('intermediate_container', 'children') ], [Input('rms_dropdown', 'value'), Input('freq_dropdown', 'value'), Input('date-range', 'start_date'), Input('date-range', 'end_date'), Input('reset_button', 'n_clicks'), Input('licence_data', 'value')])
    def update_heatmap1(rms, freq, start_date, end_date, reset, license):
        # print(start_date, type(start_date))
        start = dateutil.parser.parse(start_date, ignoretz=True)
        end = dateutil.parser.parse(end_date, ignoretz=True)
        # print(start, type(start))
        if (rms is not None) and (freq is not None) and (reset != None):
            #print(rms)
            onlyRMS = rms.split('_')[0]
            a, b = freq.split(' - ')[1], freq.split(' - ')[2]
            for each_file in glob.glob(os.path.join(MY_PATH, '*_spectrum_parquet.gzip')):
                file = (each_file.split('\\')[-1])
                if file.startswith(rms):
                    dictHeatmap = {}
                    dictMaxMinHold  = {}
                    if (str(a) in file) and (str(b) in file):
                        df= pd.read_parquet(MY_PATH+"\\"+file)
                        #print("file read", MY_PATH+"\\"+file)
                        start_freq, stop_freq = float(df.columns[1])/1000000, float(df.columns[-5])/1000000
                        
                        if (str(start_freq) == a) and (str(stop_freq) == b):
                            #print("yes!")
                            df['Datetime'] = df['Datetime'].astype('datetime64[ns]')
                            df.set_index('Datetime', inplace=True)                            
                            filteredDf = df.iloc[:,0:-4]
                            
                            subsetLicenseDf = licenseDf[licenseDf['RMS']==onlyRMS]
                            
                            if len(license)==1:
                                #print("ticked ..")
                                subsetLicenseDf = subsetLicenseDf.reset_index(drop=True)
                                subsetLicenseDf.Step = subsetLicenseDf.Step.astype(int)
                                subsetLicenseDf['Start_Freq'] = subsetLicenseDf.Freq - subsetLicenseDf.Step
                                subsetLicenseDf['End_Freq'] = subsetLicenseDf.Freq + 2*subsetLicenseDf.Step
                                subsetLicenseDf["range"] = subsetLicenseDf.apply(lambda x: [i for i in range(x["Start_Freq"], x["End_Freq"], x["Step"])], axis=1)
                                licenseFreq_list = [item for sublist in subsetLicenseDf.range.tolist() for item in sublist]
                                # print(licenseFreq_list)
                                clean_list = list(set(licenseFreq_list))
                                results = [str(i) for i in clean_list]
                                # print(results)
                                for items in results:
                                    if items in filteredDf.columns:
                                        #print("found..", items)
                                        filteredDf.drop(items, axis=1, inplace=True)
                                    else: 
                                        pass

                            #     print(filteredDf.columns.tolist())
                                
                            levelDf = filteredDf[(filteredDf.index >= start) & (filteredDf.index <= end)]
                            
                            hovertext1 = hoverData(levelDf, r'spectrum')
                            
                            spectrumChartTitle = r'No level data for selected dates, please update your selection'
                            # maxMinChartTitle = r'No max / min hold data for selected dates, please update your selection'
                            
                            if len(levelDf!=0):
                                spectrumChartTitle = 'Spectrum chart - '+ rms +' - '+ freq + 'MHz from '+ start_date[:10] + ' to ' + end_date[:10]
                                # maxMinChartTitle = 'Max / Min hold chart - '+ rms +' - '+ freq + 'MHz from '+ start_date[:10] + ' to ' + end_date[:10]


                            dataset = {'df_1': levelDf.to_json(orient='split', date_format='iso')}

                            dictHeatmap = {'data': [go.Heatmap(x = filteredDf.columns,
                                            y = levelDf.index,
                                            z = levelDf.values.tolist(),
                                            colorbar={"title": "Levels in dBm"},
                                            colorscale ='Viridis',
                                            hoverinfo='text',
                                            text=hovertext1)],
                                    'layout': { 'height': 400,
                                                'title' : spectrumChartTitle,
                                                #'title' : 'Spectrum chart - ' + selected_band.upper() + ' band - '+ start_date[:10] + ' to ' + end_date[:10],
                                                'xaxis' : {'side':'top'},
                                                # 'rangeslider':{'visible':True, 'autorange':True}}, 
                                                #'width' : 2000,
                                                #'legend':{'orientation':"h"},  
                                                #'legend':{'xanchor':"center",'yanchor':"top", 'y':-0.3, 'x':0.5},
                                                # 'yaxis' : { 'title':'Time'},
                                                'margin': {'l':100, 't':75,'r':15}
                                                }}
                            return dictHeatmap, json.dumps(dataset)
                        
                            
        else: return [], []

    @app.callback(Output('maxMinHold', 'figure'), [Input('intermediate_container','children'), Input('rms_dropdown', 'value'), Input('freq_dropdown', 'value'), Input('heatmap1', 'relayoutData') ])
    def updateMaxMinHold(jsonified_cleaned_data, rms, freq, relayoutData):
        # print(len(jsonified_cleaned_data))
        if (jsonified_cleaned_data)!=None:
            try:
                a = json.loads(jsonified_cleaned_data)
                b = json.loads(a['df_1'])
                df = pd.DataFrame(data=b['data'])
                if len(df)!=0:
                    df.columns = b['columns']
                    datetime_series = pd.to_datetime(b['index'])
                    datetime_index = pd.DatetimeIndex(datetime_series.values)
                    df = df.set_index(datetime_index)
                    # print(df.head())
                    
                    getX1 = relayoutData.get('xaxis.range[0]', None)
                    if (getX1 == None):
                        dictMaxMinHold = getDict(df, rms, freq)
                        return dictMaxMinHold
                    else:
                        x1, x2 = relayoutData['xaxis.range[0]'], relayoutData['xaxis.range[1]']
                        array = np.asarray(df.columns.astype(int))
                        X1 = (np.abs(array - x1)).argmin()
                        X2 = (np.abs(array - x2)).argmin()
                        df = df.iloc[:,X1:X2]
                        dictMaxMinHold = getDict(df, rms, freq)
                        
                        return dictMaxMinHold
                else: 
                    dictMaxMinHold = {'data': [],
                          'layout': go.Layout(title = "No data for selected dates, please update your selection",
                                #width=1000,
                                yaxis = dict(title='level (dBm)'),xaxis=dict(tickmode='auto', side='top'),
                                legend=dict(orientation="v"), plot_bgcolor='rgb(245,245,240)',
                                margin=dict(l=100))}

                return dictMaxMinHold
            except:
                TypeError
        else:
            dictMaxMinHold = {'data': [],
                      'layout': go.Layout(title = "No data for selected dates, please update your selection",
                            #width=1000,
                            yaxis = dict(title='level (dBm)'),xaxis=dict(tickmode='auto', side='top'),
                            legend=dict(orientation="v"), plot_bgcolor='rgb(245,245,240)',
                            margin=dict(l=100))}

            return dictMaxMinHold



    @app.callback(Output('heatmap2', 'figure') , [Input('rms_dropdown', 'value'), Input('freq_dropdown', 'value'), Input('date-range', 'start_date'), Input('date-range', 'end_date'), Input('reset_button', 'n_clicks'), Input('licence_data', 'value'), Input('heatmap1', 'relayoutData')])
    def test(rms, freq, start_date, end_date, reset, license, relayoutData):
        start = dateutil.parser.parse(start_date, ignoretz=True)
        end = dateutil.parser.parse(end_date, ignoretz=True)
        # print(rms, freq, reset, license)
        if (rms is not None) and (freq is not None) and (reset != None):
            #print(freq)
            onlyRMS = rms.split('_')[0]
            a, b = freq.split(' - ')[1], freq.split(' - ')[2]
            for each_file in glob.glob(os.path.join(MY_PATH, '*_occupancy_parquet.gzip')):
                file = (each_file.split('\\')[-1])
                # print(file)
                if file.startswith(rms):
                    dictHeatmap2 = {}
                    if (str(a) in file) and (str(b) in file):
                        # print(MY_PATH+"\\"+file)
                        df= pd.read_parquet(MY_PATH+"\\"+file)
                        start_freq, stop_freq = float(df.columns[1])/1000000, float(df.columns[-5])/1000000
                        if (str(start_freq) == a) and (str(stop_freq) == b):
                            df['Datetime'] = df['Datetime'].astype('datetime64[ns]')
                            df.set_index('Datetime', inplace=True)                            
                            
                            df = df[(df.index >= start) & (df.index <= end)]
                            occDf = df.iloc[:,0:-4]
                            #print(occDf.columns.tolist())

                            subsetLicenseDf = licenseDf[licenseDf['RMS']==onlyRMS]
                            if len(license)==1:
                                subsetLicenseDf = subsetLicenseDf.reset_index(drop=True)
                                subsetLicenseDf.Step = subsetLicenseDf.Step.astype(int)
                                subsetLicenseDf['Start_Freq'] = subsetLicenseDf.Freq - subsetLicenseDf.Step
                                subsetLicenseDf['End_Freq'] = subsetLicenseDf.Freq + subsetLicenseDf.Step
                                subsetLicenseDf["range"] = subsetLicenseDf.apply(lambda x: [i for i in range(x["Start_Freq"], x["End_Freq"]+1, x["Step"])], axis=1)
                                licenseFreq_list = [item for sublist in subsetLicenseDf.range.tolist() for item in sublist]
                                clean_list = list(set(licenseFreq_list))
                                results = [str(i) for i in clean_list]
                                
                                for items in results:
                                    if items in occDf.columns:
                                        #print("deleting..", items)
                                        occDf.drop(items, axis=1, inplace=True)
                                    else: pass
                            #print("Occupancy DF is of following size ", occDf.shape, " with length ", len(occDf))
                            
                            getX1 = relayoutData.get('xaxis.range[0]', None)
                            array = []
                            
                            #print(occDf.shape)
                            if getX1 != None:
                                xStart, xEnd = relayoutData['xaxis.range[0]'], relayoutData['xaxis.range[1]']
                                yStart, yEnd = relayoutData['yaxis.range[0]'], relayoutData['yaxis.range[1]']
                                # print(yStart, yEnd)

                                array = np.asarray(occDf.columns.astype(int))
                                idx1 = (np.abs(array - xStart)).argmin()
                                idx2 = (np.abs(array - xEnd)).argmin()
                                
                                # print(type(array[idx1]), array[idx1], type(array[idx2]), array[idx2])
                                occDf = occDf.iloc[:,idx1:idx2]
                                occDf = occDf.loc[yEnd[:10]:yStart[:10]]


                            #print(occDf.shape)
                            spectrumOccTitle = r'No occupancy data for selected dates, please update your selection'
                            
                            if len(occDf!=0):
                                spectrumOccTitle = 'Occupancy chart - '+ rms +' - '+ freq + 'MHz from '+ start_date[:10] + ' to ' + end_date[:10]
                                
                            hovertext2 = hoverData(occDf, r'Occupancy')
                            dictHeatmap2 = {'data': [go.Heatmap(x = occDf.columns,
                                                                y = occDf.index,
                                                                z = occDf.values.tolist(),
                                                    colorbar={"title": "Occupancy (%)"},
                                                    colorscale ='Viridis',
                                                    hoverinfo='text',
                                                    text=hovertext2)],
                                    'layout': { 'height': 400,
                                                'title' : spectrumOccTitle,
                                                'xaxis' : {'side':'top'},
                                                #'width' : 2000,
                                                'margin': {'l':100, 't':75,'r':15}
                                                }
                                            }

                            return dictHeatmap2
        else: return []

Hope this helps! Please advice.

Thanks
Ananth

Hey, Ananth. Even though I am not a super expert on the innards of Dash/plotly, I am near 100% sure that you cannot parallelize separate callbacks in the way that you want. The reason for this is intuitive: when a Callback fires off, and how it triggers other Callbacks, is a complicated chain of inter-dependencies, and setting aside 3 Callbacks on the side in their own group would not be a trivial matter.

However, there is something you can do. Combine those 3 Callbacks into 1 - since Callbacks can have multiple outputs - and then you can work on parallelizing somehow what’s inside that single context, which is doable in Python. For example, check out joblib here, and tell me if your code can be parallelized in a way that’s described.

1 Like

If you run your app with multiple workers with gunicorn, then that pool of workers is available to process any requests that come from the client. If the callbacks are structured in a way that multiple could be processed in parallel (ie not chained) then they will by by gunicorn by default.
So, if computations for a set of outputs are independent then split them up into separate workers to take advantage of the pool. If computations are shared then update multiple outputs at once within a single callback so keep the pool of workers available for other requests and users.

1 Like

I am surprised to hear that parallelizing the Callbacks is possible, and looking around more, there is indeed some info regarding this in this thread.

However, what would be the advantages to parallelizing the Callbacks, instead of parallelizing the code in a single, large Callback itself?

There are already methods in Python for parallelizing normal functions, so I am extremely interested to hear what you think about this, chriddyp.

1 Like

so am I correct in understanding that in order to capitalize on web: gunicorn index:server --workers 4 as I have setup in my dash app, that I’d want multiple, independent callbacks, instead of one, multi-ouput callback?

My basic structure is that a user selects from a dropdown and then this fires 8-10 charts on the page. Right now I have it setup as a single callback with the dropdown as the input and each chart’s build function as one of the outputs.

Yes, that is correct. However, the split will introduce a little more overhead in terms of web traffic due to the increased number of requests.