Skip to content

Commit

Permalink
Change return text for Fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkoBrie committed Mar 4, 2024
1 parent 0e4e34c commit f58597d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
56 changes: 56 additions & 0 deletions 3_STREAMlit_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,60 @@ def request_prediction(model_uri: str, data: dict) -> dict:

return response.json()

def hight_of_selected_point(hist, data, highlighted_index):
bin_counts = [rect.get_height() for rect in hist.patches]
print(len(bin_counts))
print(len(bin_counts)/2)
print(min(data['DAYS_BIRTH']), " ", max(data['DAYS_BIRTH']))
print("selected point: ", data.loc[highlighted_index, 'DAYS_BIRTH'])
scaled_point = int(round(data.loc[highlighted_index, 'DAYS_BIRTH']-min(data['DAYS_BIRTH'])))
print("scaled point: ", scaled_point)

steps = (max(data['DAYS_BIRTH'])-min(data['DAYS_BIRTH']))/(len(bin_counts)/2)
print("steps :", steps )

if data.loc[highlighted_index, 'TARGET'] == 0:
bucket = int(round(scaled_point / steps,0))

elif data.loc[highlighted_index, 'TARGET'] == 1:
bucket = int(round(scaled_point / steps,0)+(len(bin_counts)/2))

print("bucket :", bucket)
hight = bin_counts[bucket]/2
print("hight :", hight)

return hight

def plot_histogram(data):

# Highlighted data point
highlighted_index = 0 # Index of the data point to highlight
highlighted_value = data.loc[highlighted_index, 'DAYS_BIRTH']

# Plotting
fig, ax = plt.subplots()
hist = sns.histplot(data=data, x='DAYS_BIRTH', hue='TARGET', kde=True, multiple='stack', ax=ax) #stat='density',
# Get the counts for each bin
hight_P = hight_of_selected_point(hist, data, highlighted_index)

# Highlight one specific data point
if Y_train.loc[highlighted_index, 'TARGET'] == 1:
ax.scatter(highlighted_value, hight_P, color='red', label='Highlighted Point', zorder=5)
elif Y_train.loc[highlighted_index, 'TARGET'] == 0:
ax.scatter(highlighted_value, hight_P, color='blue', label='Highlighted Point', zorder=5)

# Customize plot
ax.set_xlabel('Customer Age')
ax.set_ylabel('Number of Customers')
ax.set_title('Stacked Distribution of Customer Age with Highlighted Point')
legend = ax.get_legend()
handles = legend.legend_handles
legend.remove()
ax.legend(handles, ['0 pays', '1 will have difficulty'], title='Client group')

st.show(fig)



def main():
MLFLOW_URI = 'https://fastapi-cd-webapp.azurewebsites.net/predict'
Expand All @@ -40,6 +94,8 @@ def main():

data_slice = pd.read_csv('data/X_train_slice.csv')

plot_histogram(data_slice)

ids_test = pd.read_csv('data/test_ids.csv')
X_train = pd.read_csv('data/X_test.csv')
feature_name = pd.read_csv('data/feature_names.csv')
Expand Down
2 changes: 1 addition & 1 deletion 5_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_client_emptyData():
def test_read_main():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"data": "Application ran successfully - FastAPI release v4.2 with Github Actions no staging: cloudpickle try environment pipenv"}
assert response.json() == {"data": "Application ran successfully - FastAPI ML endpoint deployed with Github Actions on Microsoft AZURE"}

if __name__ == '__main__':
unittest.main()

0 comments on commit f58597d

Please sign in to comment.