Skip to content

Commit f01ec7c

Browse files
committed
refactor: remove legacy LangChain chains and migrate to streaming agent
- Delete `commitai/chains.py` as part of the transition to the new agent-based architecture. - Update the CLI to utilize the streaming interface for commit message generation. - Migrate TODO scanning and context summarization logic to the agent module. - Add comprehensive unit tests for the migrated logic and streaming pipeline in `tests/test_agent_v3.py`. - Update existing CLI tests to reflect the switch from `invoke()` to `stream()`.
1 parent 4fa6a7b commit f01ec7c

File tree

4 files changed

+112
-152
lines changed

4 files changed

+112
-152
lines changed

commitai/chains.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

commitai/cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,7 @@ def generate_message( # noqa: C901
265265
inputs["template"] = final_template_content
266266

267267
# Invoke the Agent Pipeline (which now returns a generator)
268-
stream_gen = agent_pipeline.invoke(inputs)
269-
270-
# Use UI to handle streaming visualization
268+
stream_gen = agent_pipeline.stream(inputs)
271269
commit_message = ui.stream_response(stream_gen)
272270

273271
except Exception as e:

tests/test_agent_v3.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
from unittest.mock import MagicMock, patch
22

3-
from commitai.agent import create_commit_agent
3+
from langchain_core.messages import AIMessage
4+
5+
from commitai.agent import create_commit_agent, scan_todos, summarize_context
46

57

68
def test_agent_v3_initialization():
79
"""Verify that the V3 agent can be initialized with the new middlewares."""
810
mock_llm = MagicMock()
911

10-
# We need to mock create_agent because it will try to set up real tools/middlewares
11-
# and we don't want to run them or depend on OS/Git state during this unit test.
12-
# However, we want to verify that create_agent CALL was correct (args passed).
13-
1412
with patch("commitai.agent.create_agent") as mock_create_agent:
1513
mock_create_agent.return_value = MagicMock() # Return a mock graph
1614

17-
# Also mock the middlewares to check if they were instantiated
1815
with (
1916
patch("commitai.agent.SummarizationMiddleware") as mock_summ,
2017
patch("commitai.agent.FilesystemFileSearchMiddleware") as mock_files,
@@ -38,3 +35,103 @@ def test_agent_v3_initialization():
3835
assert len(kwargs["middleware"]) == 5
3936

4037
assert agent_runnable is not None
38+
39+
40+
def test_scan_todos():
41+
"""Test TODO scanning logic."""
42+
diff_with_todo = "+ # TODO: Fix this later\n+ function foo() {\n+ # FIXME: Old bug"
43+
result = scan_todos(diff_with_todo)
44+
45+
assert "todos" in result
46+
assert "todo_str" in result
47+
assert len(result["todos"]) == 2
48+
assert "Fix this later" in result["todos"][0]
49+
assert "Old bug" in result["todos"][1]
50+
# Check for the raw line content as captured
51+
assert "- # TODO: Fix this later" in result["todo_str"]
52+
53+
54+
def test_scan_todos_none():
55+
"""Test TODO scanning with no todos."""
56+
diff_clean = "+ function foo() {}"
57+
result = scan_todos(diff_clean)
58+
assert not result["todos"]
59+
assert result["todo_str"] == "None"
60+
61+
62+
def test_summarize_context_short():
63+
"""Test summarization (even short diffs trigger LLM now)."""
64+
mock_llm = MagicMock()
65+
mock_llm.invoke.return_value = AIMessage(content="Summary")
66+
diff = "short diff"
67+
summary = summarize_context(mock_llm, diff)
68+
69+
# It returns content directly now
70+
assert summary == "Summary"
71+
mock_llm.invoke.assert_called_once()
72+
73+
74+
def test_summarize_context_long():
75+
"""Test summarization triggered for long diffs."""
76+
mock_llm = MagicMock()
77+
mock_llm.invoke.return_value = AIMessage(content="Summary of changes")
78+
79+
# Create diff > 3000 chars
80+
diff = "a" * 3005
81+
summary = summarize_context(mock_llm, diff)
82+
83+
assert summary == "Summary of changes"
84+
mock_llm.invoke.assert_called_once()
85+
86+
87+
def test_run_pipeline_streaming():
88+
"""Test the streaming pipeline execution."""
89+
mock_llm = MagicMock()
90+
# Mock LLM response for summarize_context call inside pipeline
91+
mock_llm.invoke.return_value = AIMessage(content="Summary")
92+
93+
# Mock the agent graph created inside factory
94+
mock_graph = MagicMock()
95+
96+
# Mock stream output from graph
97+
final_message = AIMessage(content="feat: new feature")
98+
events = [
99+
{
100+
"messages": [
101+
AIMessage(
102+
content="", tool_calls=[{"name": "git_log", "args": {}, "id": "1"}]
103+
)
104+
]
105+
},
106+
{"messages": [final_message]},
107+
]
108+
mock_graph.stream.return_value = iter(events)
109+
110+
with patch("commitai.agent.create_agent", return_value=mock_graph):
111+
with (
112+
patch("commitai.agent.SummarizationMiddleware"),
113+
patch("commitai.agent.FilesystemFileSearchMiddleware"),
114+
patch("commitai.agent.ShellToolMiddleware"),
115+
patch("commitai.agent.HumanInTheLoopMiddleware"),
116+
patch("commitai.agent.LLMToolSelectorMiddleware"),
117+
):
118+
pipeline_func = create_commit_agent(mock_llm)
119+
120+
inputs = {"diff": "some diff", "explanation": "expl"}
121+
# Use stream() to ensure we get the generator/iterator properly
122+
stream_gen = pipeline_func.stream(inputs)
123+
124+
results = list(stream_gen)
125+
126+
# Verify we received dicts
127+
for r in results:
128+
assert isinstance(r, dict), f"Received non-dict result: {r}"
129+
130+
# Check for thought event (from tool call)
131+
assert any(r["type"] == "thought" for r in results)
132+
133+
# Check for token events (from final message)
134+
token_events = [r for r in results if r["type"] == "token"]
135+
assert len(token_events) > 0
136+
full_text = "".join(t["content"] for t in token_events)
137+
assert "feat: new feature" in full_text

tests/test_cli.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def mock_generate_deps(tmp_path):
6060

6161
# Agent Mock (RunnableLambda now)
6262
mock_agent_runnable = MagicMock()
63-
mock_agent_runnable.invoke.return_value = iter(
63+
mock_agent_runnable.stream.return_value = iter(
6464
[{"type": "token", "content": "Generated commit message"}]
6565
)
6666
mock_create_agent.return_value = mock_agent_runnable
@@ -130,7 +130,7 @@ def test_generate_default_gemini(mock_generate_deps):
130130
model="gemini-3-flash-preview",
131131
google_api_key="fake_google_key",
132132
)
133-
mock_generate_deps["agent_instance"].invoke.assert_called_once()
133+
mock_generate_deps["agent_instance"].stream.assert_called_once()
134134
mock_generate_deps["commit"].assert_called_once_with("Generated commit message")
135135
mock_generate_deps["ui"].render_header.assert_called_once()
136136

@@ -213,7 +213,7 @@ def test_generate_no_staged_changes(mock_generate_deps):
213213
"⚠️ Warning: No staged changes found. Exiting."
214214
)
215215

216-
mock_generate_deps["agent_instance"].invoke.assert_not_called()
216+
mock_generate_deps["agent_instance"].stream.assert_not_called()
217217
mock_generate_deps["commit"].assert_not_called()
218218

219219

@@ -270,7 +270,7 @@ def test_generate_no_explanation(mock_generate_deps):
270270
result = runner.invoke(cli, ["generate", "--no-review"])
271271

272272
assert result.exit_code == 0, result.output
273-
mock_generate_deps["agent_instance"].invoke.assert_called_once()
273+
mock_generate_deps["agent_instance"].stream.assert_called_once()
274274
mock_generate_deps["commit"].assert_called_once()
275275

276276

@@ -298,7 +298,7 @@ def getenv_side_effect_with_template(key, default=None):
298298
assert result.exit_code == 0, result.output
299299

300300
# Verify agent invocation has correct args
301-
call_args = mock_generate_deps["agent_instance"].invoke.call_args
301+
call_args = mock_generate_deps["agent_instance"].stream.call_args
302302
assert call_args is not None, "agent invoke was not called"
303303
invoked_args = call_args[0][0]
304304
assert invoked_args["explanation"] == "Test explanation"
@@ -324,7 +324,7 @@ def test_generate_with_local_template(mock_get_template, mock_generate_deps):
324324
assert result.exit_code == 0, result.output
325325
mock_get_template.assert_called_once() # Verify get_commit_template was called
326326
# Verify agent invocation has correct args
327-
call_args = mock_generate_deps["agent_instance"].invoke.call_args
327+
call_args = mock_generate_deps["agent_instance"].stream.call_args
328328
assert call_args is not None, "agent invoke was not called"
329329
invoked_args = call_args[0][0]
330330
assert invoked_args["explanation"] == "Test explanation"
@@ -353,7 +353,7 @@ def test_generate_with_deprecated_template_option(mock_generate_deps):
353353
mock_generate_deps["ui"].console.print.assert_any_call(
354354
"[warning]⚠️ --template/-t is deprecated.[/warning]"
355355
)
356-
mock_generate_deps["agent_instance"].invoke.assert_called_once()
356+
mock_generate_deps["agent_instance"].stream.assert_called_once()
357357
mock_generate_deps["commit"].assert_called_once()
358358

359359

@@ -443,7 +443,7 @@ def test_generate_google_module_not_installed(mock_generate_deps):
443443
def test_generate_llm_invoke_error(mock_generate_deps):
444444
"""Test generate command handling error during llm.invoke."""
445445
runner = CliRunner()
446-
mock_generate_deps["agent_instance"].invoke.side_effect = Exception("AI API Error")
446+
mock_generate_deps["agent_instance"].stream.side_effect = Exception("AI API Error")
447447
result = runner.invoke(cli, ["generate", "--no-review", "Test explanation"])
448448

449449
assert result.exit_code == 1

0 commit comments

Comments
 (0)