11from unittest .mock import MagicMock , patch
22
3- from commitai .agent import create_commit_agent
3+ from langchain_core .messages import AIMessage
4+
5+ from commitai .agent import create_commit_agent , scan_todos , summarize_context
46
57
68def test_agent_v3_initialization ():
79 """Verify that the V3 agent can be initialized with the new middlewares."""
810 mock_llm = MagicMock ()
911
10- # We need to mock create_agent because it will try to set up real tools/middlewares
11- # and we don't want to run them or depend on OS/Git state during this unit test.
12- # However, we want to verify that create_agent CALL was correct (args passed).
13-
1412 with patch ("commitai.agent.create_agent" ) as mock_create_agent :
1513 mock_create_agent .return_value = MagicMock () # Return a mock graph
1614
17- # Also mock the middlewares to check if they were instantiated
1815 with (
1916 patch ("commitai.agent.SummarizationMiddleware" ) as mock_summ ,
2017 patch ("commitai.agent.FilesystemFileSearchMiddleware" ) as mock_files ,
@@ -38,3 +35,103 @@ def test_agent_v3_initialization():
3835 assert len (kwargs ["middleware" ]) == 5
3936
4037 assert agent_runnable is not None
38+
39+
40+ def test_scan_todos ():
41+ """Test TODO scanning logic."""
42+ diff_with_todo = "+ # TODO: Fix this later\n + function foo() {\n + # FIXME: Old bug"
43+ result = scan_todos (diff_with_todo )
44+
45+ assert "todos" in result
46+ assert "todo_str" in result
47+ assert len (result ["todos" ]) == 2
48+ assert "Fix this later" in result ["todos" ][0 ]
49+ assert "Old bug" in result ["todos" ][1 ]
50+ # Check for the raw line content as captured
51+ assert "- # TODO: Fix this later" in result ["todo_str" ]
52+
53+
54+ def test_scan_todos_none ():
55+ """Test TODO scanning with no todos."""
56+ diff_clean = "+ function foo() {}"
57+ result = scan_todos (diff_clean )
58+ assert not result ["todos" ]
59+ assert result ["todo_str" ] == "None"
60+
61+
62+ def test_summarize_context_short ():
63+ """Test summarization (even short diffs trigger LLM now)."""
64+ mock_llm = MagicMock ()
65+ mock_llm .invoke .return_value = AIMessage (content = "Summary" )
66+ diff = "short diff"
67+ summary = summarize_context (mock_llm , diff )
68+
69+ # It returns content directly now
70+ assert summary == "Summary"
71+ mock_llm .invoke .assert_called_once ()
72+
73+
74+ def test_summarize_context_long ():
75+ """Test summarization triggered for long diffs."""
76+ mock_llm = MagicMock ()
77+ mock_llm .invoke .return_value = AIMessage (content = "Summary of changes" )
78+
79+ # Create diff > 3000 chars
80+ diff = "a" * 3005
81+ summary = summarize_context (mock_llm , diff )
82+
83+ assert summary == "Summary of changes"
84+ mock_llm .invoke .assert_called_once ()
85+
86+
87+ def test_run_pipeline_streaming ():
88+ """Test the streaming pipeline execution."""
89+ mock_llm = MagicMock ()
90+ # Mock LLM response for summarize_context call inside pipeline
91+ mock_llm .invoke .return_value = AIMessage (content = "Summary" )
92+
93+ # Mock the agent graph created inside factory
94+ mock_graph = MagicMock ()
95+
96+ # Mock stream output from graph
97+ final_message = AIMessage (content = "feat: new feature" )
98+ events = [
99+ {
100+ "messages" : [
101+ AIMessage (
102+ content = "" , tool_calls = [{"name" : "git_log" , "args" : {}, "id" : "1" }]
103+ )
104+ ]
105+ },
106+ {"messages" : [final_message ]},
107+ ]
108+ mock_graph .stream .return_value = iter (events )
109+
110+ with patch ("commitai.agent.create_agent" , return_value = mock_graph ):
111+ with (
112+ patch ("commitai.agent.SummarizationMiddleware" ),
113+ patch ("commitai.agent.FilesystemFileSearchMiddleware" ),
114+ patch ("commitai.agent.ShellToolMiddleware" ),
115+ patch ("commitai.agent.HumanInTheLoopMiddleware" ),
116+ patch ("commitai.agent.LLMToolSelectorMiddleware" ),
117+ ):
118+ pipeline_func = create_commit_agent (mock_llm )
119+
120+ inputs = {"diff" : "some diff" , "explanation" : "expl" }
121+ # Use stream() to ensure we get the generator/iterator properly
122+ stream_gen = pipeline_func .stream (inputs )
123+
124+ results = list (stream_gen )
125+
126+ # Verify we received dicts
127+ for r in results :
128+ assert isinstance (r , dict ), f"Received non-dict result: { r } "
129+
130+ # Check for thought event (from tool call)
131+ assert any (r ["type" ] == "thought" for r in results )
132+
133+ # Check for token events (from final message)
134+ token_events = [r for r in results if r ["type" ] == "token" ]
135+ assert len (token_events ) > 0
136+ full_text = "" .join (t ["content" ] for t in token_events )
137+ assert "feat: new feature" in full_text
0 commit comments