Skip to content
Snippets Groups Projects
Commit 9f509745 authored by Tarek Shah's avatar Tarek Shah
Browse files

unit tests for summarizers

parent 1349313f
No related branches found
No related tags found
No related merge requests found
File added
File added
File added
import unittest
from unittest.mock import patch
from text_summarizer import app
class BERTSummarizerTestCase(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
self.app.testing = True
@patch('text_summarizer.Summarizer')
def test_summarize(self, mock_summarizer):
# Mock the Summarizer class
mock_instance = mock_summarizer.return_value
mock_instance.return_value = "This is a summary."
# Define input data
input_data = {'text': 'This is a test input.'}
# Send a POST request to the endpoint
response = self.app.post('/BERT_summarize', json=input_data)
# Check the response
self.assertEqual(response.status_code, 200)
data = response.get_json()
self.assertIn('summary', data)
self.assertEqual(data['summary'], "This is a summary.")
@patch('text_summarizer.Summarizer')
def test_error_handling(self, mock_summarizer):
# Mock the Summarizer class to raise an exception
mock_instance = mock_summarizer.return_value
mock_instance.side_effect = Exception('An error occurred.')
# Define input data
input_data = {'text': 'This is a test input.'}
# Send a POST request to the endpoint
response = self.app.post('/BERT_summarize', json=input_data)
# Check the response
self.assertEqual(response.status_code, 400)
data = response.get_json()
self.assertIn('error', data)
self.assertEqual(data['error'], 'An error occurred.')
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import unittest
import json
from unittest.mock import patch
from T5_summarizer import app
class T5SummarizerTestCase(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
self.app.testing = True
@patch('T5_summarizer.tokenizer')
@patch('T5_summarizer.model')
def test_summarize(self, mock_model, mock_tokenizer):
mock_tokenizer.encode.return_value = 'encoded_text'
mock_model.generate.return_value = [1]
input_data = {
'text': 'This is a test. It is only a test.'
}
response = self.app.post('/T5_summarize', json=input_data)
data = json.loads(response.data.decode('utf-8'))
self.assertEqual(response.status_code, 200)
self.assertIn('summary', data)
self.assertTrue(isinstance(data['summary'], str))
mock_tokenizer.encode.assert_called_once_with("summarize: This is a test. It is only a test.", return_tensors='pt', max_length=512, truncation=True)
mock_model.generate.assert_called_once_with('encoded_text', max_length=150, min_length=80, length_penalty=5., num_beams=2)
def test_invalid_input(self):
input_data = {
'incorrect_key': 'This is a test. It is only a test.'
}
response = self.app.post('/T5_summarize', json=input_data)
data = json.loads(response.data.decode('utf-8'))
self.assertEqual(response.status_code, 400)
self.assertIn('error', data)
if __name__ == '__main__':
unittest.main()
import unittest
import json
from nltk_summarizer import app
class NLTKSummarizerTestCase(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
self.app.testing = True
def test_summarize(self):
input_data = {
'text': 'This is a test. It is only a test.'
}
response = self.app.post('/nltk_summarize', json=input_data)
data = json.loads(response.data.decode('utf-8'))
self.assertEqual(response.status_code, 200)
self.assertIn('summary', data)
self.assertTrue(isinstance(data['summary'], str))
def test_invalid_input(self):
input_data = {
'incorrect_key': 'This is a test. It is only a test.'
}
response = self.app.post('/nltk_summarize', json=input_data)
data = json.loads(response.data.decode('utf-8'))
self.assertEqual(response.status_code, 400)
self.assertIn('error', data)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment