diff --git a/tests/test_dns.py b/tests/test_dns.py index 48700b6..138ed99 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -115,10 +115,7 @@ class TestResolveHostname: @pytest.mark.asyncio async def test_timeout_resolution(self): """Test hostname resolution timeout.""" - async def mock_wait_for(*args, **kwargs): - raise asyncio.TimeoutError() - - with patch("asyncio.wait_for", side_effect=mock_wait_for) as mock_wait_for: + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()): resolution = await resolve_hostname("slow.example", timeout=1.0) assert resolution.hostname == "slow.example" @@ -142,14 +139,11 @@ class TestResolveHostname: @pytest.mark.asyncio async def test_empty_result_resolution(self): """Test hostname resolution with empty result.""" - async def mock_wait_for(*args, **kwargs): - return [] - with patch("asyncio.get_event_loop") as mock_loop: mock_event_loop = AsyncMock() mock_loop.return_value = mock_event_loop - with patch("asyncio.wait_for", side_effect=mock_wait_for): + with patch("asyncio.wait_for", return_value=[]): resolution = await resolve_hostname("empty.example") assert resolution.hostname == "empty.example" @@ -167,7 +161,7 @@ class TestResolveHostnamesBatch: """Test successful batch hostname resolution.""" hostnames = ["example.com", "test.example"] - with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: + with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve: # Mock successful resolutions mock_resolve.side_effect = [ DNSResolution( @@ -197,25 +191,23 @@ class TestResolveHostnamesBatch: """Test batch resolution with mixed success/failure.""" hostnames = ["example.com", "nonexistent.example"] - # Create a direct async function replacement instead of using AsyncMock - async def mock_resolve_hostname(hostname, timeout=5.0): - if hostname == "example.com": - return DNSResolution( + with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve: + # Mock mixed results + mock_resolve.side_effect = [ + DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), - ) - else: - return DNSResolution( + ), + DNSResolution( hostname="nonexistent.example", resolved_ip=None, status=DNSResolutionStatus.RESOLUTION_FAILED, resolved_at=datetime.now(), error_message="Name not found", - ) - - with patch("src.hosts.core.dns.resolve_hostname", mock_resolve_hostname): + ), + ] resolutions = await resolve_hostnames_batch(hostnames) @@ -278,17 +270,14 @@ class TestDNSService: """Test async resolution when service is enabled.""" service = DNSService(enabled=True) - with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: + with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve: mock_resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) - # Use proper async setup - async def mock_side_effect(hostname, timeout=5.0): - return mock_resolution - mock_resolve.side_effect = mock_side_effect + mock_resolve.return_value = mock_resolution resolution = await service.resolve_entry_async("example.com") @@ -312,17 +301,14 @@ class TestDNSService: """Test manual entry refresh.""" service = DNSService(enabled=True) - with patch("src.hosts.core.dns.resolve_hostname", new_callable=AsyncMock) as mock_resolve: + with patch("src.hosts.core.dns.resolve_hostname") as mock_resolve: mock_resolution = DNSResolution( hostname="example.com", resolved_ip="192.0.2.1", status=DNSResolutionStatus.RESOLVED, resolved_at=datetime.now(), ) - # Use proper async setup - async def mock_side_effect(hostname, timeout=5.0): - return mock_resolution - mock_resolve.side_effect = mock_side_effect + mock_resolve.return_value = mock_resolution result = await service.refresh_entry("example.com") diff --git a/tests/test_filters.py b/tests/test_filters.py index d556653..1348102 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -421,7 +421,7 @@ class TestEntryFilter: # Apply filters and check preset name is preserved sample_entry = HostEntry("192.168.1.1", ["test.com"], "Test", True) - entry_filter.apply_filters([sample_entry], preset_options) + result = entry_filter.apply_filters([sample_entry], preset_options) # The original preset name should be accessible assert preset_options.preset_name == "Active Only" diff --git a/tests/test_main.py b/tests/test_main.py index 14de1fd..0fd94a8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -833,20 +833,11 @@ class TestHostsManagerApp: app.query_one = mock_query_one app.edit_handler.handle_entry_type_change = Mock() - - # Mock the set_timer method to avoid event loop issues in tests - with patch.object(app, 'set_timer') as mock_set_timer: - app.edit_handler.populate_edit_form_with_type_detection() - - # Verify timer was set with the correct callback - mock_set_timer.assert_called_once_with(0.1, app.edit_handler._delayed_radio_setup) - - # Manually call the delayed setup to test the actual logic - app.edit_handler._delayed_radio_setup() - # Verify that the DNS radio was set to True (which should be the pressed button) - assert mock_dns_radio.value - assert not mock_ip_radio.value + app.edit_handler.populate_edit_form_with_type_detection() + + # Should set DNS radio button as pressed and populate DNS field + assert mock_radio_set.pressed_button == mock_dns_radio assert mock_dns_input.value == "example.com" app.edit_handler.handle_entry_type_change.assert_called_with("dns") @@ -946,15 +937,7 @@ class TestHostsManagerApp: app.manager.save_hosts_file = Mock(return_value=(True, "Success")) app.table_handler.populate_entries_table = Mock() app.details_handler.update_entry_details = Mock() - - # Create a mock that properly handles and closes coroutines - def consume_coro(coro, **kwargs): - # If it's a coroutine, close it to prevent warnings - if hasattr(coro, 'close'): - coro.close() - return None - - app.run_worker = Mock(side_effect=consume_coro) + app.run_worker = Mock() # Test action_refresh_dns in edit mode - should proceed app.action_refresh_dns()