Code development platform for open source projects from the European Union institutions :large_blue_circle: EU Login authentication by SMS has been phased out. To see alternatives please check here

Skip to content
Snippets Groups Projects
Commit dd600d2c authored by Lionel Weicker's avatar Lionel Weicker
Browse files

Merge branch 'sentence_similarity_based_cpv_classifier' into 'main'

add sentence similarity to cpv classifiers

See merge request !12
parents 63815ece b3cd6fb5
No related branches found
No related tags found
1 merge request!12add sentence similarity to cpv classifiers
Pipeline #111871 passed
......@@ -22,6 +22,9 @@ config = {
"opentender-multi-label-division-classifier": {
"ssm_path": "/tedai/sagemaker/endpoint/opentender_multi_label_division_classifier/name"
},
"ss-multi-label-cpv-classifier": {
"ssm_path": "/tedai/sagemaker/endpoint/ss_multi_label_cpv_classifier/name"
},
"roberta-multi-label-division-classifier": {
"ssm_path": "/tedai/sagemaker/endpoint/roberta_multi_label_division_classifier/name"
}
......@@ -50,6 +53,28 @@ def model_multi_label_division_classifier(title: str, description: str) -> str:
return json_data
def model_sentence_similarity_multi_label_division_classifier(title: str, description: str) -> str:
endpoint_name = config["ss-multi-label-cpv-classifier"]['endpoint_name']
print(f"Sending title '{title}' and description '{description}' to endpoint '{endpoint_name}'")
payload = {
"title": title,
"description": description
}
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
Body=json.dumps(payload),
ContentType='application/json'
)
result = json.loads(response["Body"].read().decode())
print(result)
s1 = json.dumps(result)
data = json.loads(s1)
json_data = pd.json_normalize(data['predictions'])
print(json_data)
return json_data
def model_opentender_multi_label_division_classifier(title: str, description: str) -> str:
endpoint_name = config["opentender-multi-label-division-classifier"]['endpoint_name']
print(f"Sending title '{title}' and description '{description}' to endpoint '{endpoint_name}'")
......@@ -182,6 +207,12 @@ def update_markdown_state_for_cpv_opentender_endpoint(r: gr.Request):
return f"The state of the endpoint is : **{endpoint_state}**"
def update_markdown_state_for_cpv_sentence_similarity_endpoint(r: gr.Request):
endpoint_name = config["ss-multi-label-cpv-classifier"]['endpoint_name']
endpoint_state = get_endpoint_state(endpoint_name)
return f"The state of the endpoint is : **{endpoint_state}**"
def update_markdown_state_for_cpv_roberta_endpoint(r: gr.Request):
endpoint_name = config["roberta-multi-label-division-classifier"]['endpoint_name']
endpoint_state = get_endpoint_state(endpoint_name)
......@@ -204,7 +235,7 @@ if __name__ == '__main__':
with gr.Tab("MultiLabel Division Classifier"):
# First model: multi-label-division-classifier
gr.Markdown("# <u>multi-label-division-classifier</u>")
gr.Markdown("## Model based on EU notices")
gr.Markdown("## LinearSVC Model based on English notices")
gr.Markdown("This model expects a title and a description, then returns predictions for CPV.")
gr.Markdown("**If the state of the endpoint is not InService, you need to start it with the button"
"'Start endpoint' below and wait for it to be in state 'InService'. "
......@@ -241,7 +272,46 @@ if __name__ == '__main__':
predict_button.click(model_multi_label_division_classifier, inputs=[title_textbox, description_textbox],
outputs=output_textbox)
gr.Markdown("## Model based on Opentender data")
gr.Markdown("## all-MiniLM-L6-v2 Model for sentence similarities from English Notices")
gr.Markdown("This model expects a title and a description, then returns predictions for CPV.")
gr.Markdown("**If the state of the endpoint is not InService, you need to start it with the button"
"'Start endpoint' below and wait for it to be in state 'InService'. "
"Refresh the page to check on the current state.**")
cpv_sentence_similarity_classifier_endpoint_name = config["ss-multi-label-cpv-classifier"][
"endpoint_name"]
state_markdown = gr.Markdown()
demo.load(update_markdown_state_for_cpv_sentence_similarity_endpoint, None, state_markdown)
with gr.Row():
endpoint_name_textbox = gr.Textbox(value=cpv_sentence_similarity_classifier_endpoint_name, visible=False)
with gr.Column():
start_endpoint_button = gr.Button("Start endpoint")
start_textbox = gr.Textbox(value="start", visible=False)
start_endpoint_button.click(manage_endpoint, inputs=[start_textbox, endpoint_name_textbox],
outputs=None)
demo.load(update_markdown_state_for_cpv_sentence_similarity_endpoint, None, state_markdown)
with gr.Column():
stop_endpoint_button = gr.Button("Stop endpoint")
stop_textbox = gr.Textbox(value="stop", visible=False)
stop_endpoint_button.click(manage_endpoint, inputs=[stop_textbox, endpoint_name_textbox],
outputs=None)
demo.load(update_markdown_state_for_cpv_sentence_similarity_endpoint, None, state_markdown)
with gr.Row():
with gr.Column():
title_textbox = gr.Textbox(label="Title")
description_textbox = gr.Textbox(label="Description")
with gr.Row():
with gr.Column():
gr.Markdown("Inference")
output_textbox = [gr.Dataframe(row_count=(3, "dynamic"), col_count=(1, "fixed"),
headers=['CPV Number'], height=(200))]
predict_button = gr.Button("Predict CPV")
predict_button.click(model_sentence_similarity_multi_label_division_classifier,
inputs=[title_textbox, description_textbox],
outputs=output_textbox)
gr.Markdown("## LinearSVC Model based on Opentender data")
gr.Markdown("This model expects a title and a description, then returns predictions for CPV.")
gr.Markdown("**If the state of the endpoint is not InService, you need to start it with the button"
"'Start endpoint' below and wait for it to be in state 'InService'. "
......@@ -278,7 +348,7 @@ if __name__ == '__main__':
predict_button.click(model_opentender_multi_label_division_classifier, inputs=[title_textbox, description_textbox],
outputs=output_textbox)
gr.Markdown("## RoBERTa model")
gr.Markdown("## RoBERTa Model based on English notices")
gr.Markdown("This model expects a title and a description, then returns predictions for CPV.")
gr.Markdown("**If the state of the endpoint is not InService, you need to start it with the button"
"'Start endpoint' below and wait for it to be in state 'InService'. "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment