《重构》这本书想必很多人都听过,介绍了如何写出设计良好的代码。书中第一章节给出一个案例,讲述怎样一步一步重构的,但是示例代码是JavaScript写的。本文就用golang来重写本书的案例。让使用golang的朋友更加熟悉这个重构过程。
设想有一个戏剧演出团,演员们经常要去各种场合表演戏剧。通常客户(customer)会指定几出剧目,而剧团则根据观众(audience)人数及剧目类型来向客户收费。该团目前出演两种戏剧:悲剧(tragedy)和喜剧(comedy)。给客户发出账单时,剧团还会根据到场观众的数量给出“观众量积分”(volume credit)优惠,下次客户再请剧团表演时可以使用积分获得折扣。
该剧团将剧目的数据存储在一个简单的 JSON 文件中。
Plays.json
{
"hamlet": { "name": "Hamlet", "type": "tragedy" },
"as-like": { "name": "As You Like It", "type": "comedy" },
"othello": { "name": "Othello", "type": "tragedy" }
}
他们开出的账单也存储在一个 JSON 文件里。
invoices.json
[
{
"customer": "BigCo",
"performances": [
{
"playID": "hamlet",
"audience": 55
},
{
"playID": "as-like",
"audience": 35
},
{
"playID": "othello",
"audience": 40
}
]
}
]
初始版本
package main
import (
"errors"
"fmt"
"github.com/leekchan/accounting"
"math"
)
type Performance struct {
PlayID string `json:"playID"`
Audience int `json:"audience"`
}
type Invoice struct {
Customer string `json:"customer"`
Performances []Performance `json:"performances"`
}
type Play struct {
Name string `json:"name"`
Type string `json:"type"`
}
func statement(invoice Invoice, plays map[string]Play) (result string, err error) {
totalAmount := 0
volumeCredits := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
format := accounting.Accounting{Symbol: "$", Precision: 2}
for _, perf := range invoice.Performances {
play, exist := plays[perf.PlayID]
if !exist {
continue
}
thisAmount := 0
switch play.Type {
case "tragedy":
thisAmount = 40000
if perf.Audience > 30 {
thisAmount += 1000 * (perf.Audience - 30)
}
case "comedy":
thisAmount = 30000
if perf.Audience > 20 {
thisAmount += 10000 + 500*(perf.Audience-20)
}
thisAmount += 300 * perf.Audience
default:
return "", errors.New("invalid play type")
}
volumeCredits += int(math.Max(float64(perf.Audience-30), 0))
if play.Type == "comedy" {
volumeCredits += int(math.Floor(float64(perf.Audience / 5)))
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", play.Name, format.FormatMoney(thisAmount/100), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", format.FormatMoney(totalAmount/100))
result += fmt.Sprintf("You earned %d credits\n", volumeCredits)
return
}
利用前面提供的两个json文件作为测试数据,写好单测(单测写好,后面修改代码,运行单测,通过则说明改动没有问题)
运行单测:
go test -run Test_statement
PASS
问题
- 组织不清晰,功能杂糅,一个函数搞定所有
- 如果希望支持html格式,如何改写
- 如果除了悲剧和喜剧,还增加了其他剧种,比如历史剧,田园剧,分别有不同的计费规则和积分规则,要怎样改写
初始版本中最引人注目的是switch语句。通过阅读代码,我们知道是根据不同的剧种计算费用。每次阅读这段代码时,都需要先读一遍,然后通过思考知道这是在计算费用。当你在繁重的代码中挣扎时,这无疑是一根压在骆驼背上的稻草。那为什么不直接用一个函数代替它呢?这样每次读到这边时,就直接通过函数名称知道计算费用。多么清晰。
所以我们把switch提取成一个函数:
func amountFor( perf Performance, play Play) (int, error) {
result := 0
switch play.Type {
case "tragedy":
result = 40000
if perf.Audience > 30 {
result += 1000 * (perf.Audience - 30)
}
case "comedy":
result = 30000
if perf.Audience > 20 {
result += 10000 + 500*(perf.Audience-20)
}
result += 300 * perf.Audience
default:
return 0, errors.New("invalid play type")
}
return result, nil
}
这样statement函数变成:
func statement(invoice Invoice, plays map[string]Play) (result string, err error) {
totalAmount := 0
volumeCredits := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
format := accounting.Accounting{Symbol: "$", Precision: 2}
for _, perf := range invoice.Performances {
play, exist := plays[perf.PlayID]
if !exist {
continue
}
thisAmount, err2 := amountFor(perf, play)
if err2 != nil {
return "", err2
}
volumeCredits += int(math.Max(float64(perf.Audience-30), 0))
if play.Type == "comedy" {
volumeCredits += int(math.Floor(float64(perf.Audience / 5)))
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", play.Name, format.FormatMoney(thisAmount/100), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", format.FormatMoney(totalAmount/100))
result += fmt.Sprintf("You earned %d credits\n", volumeCredits)
return
}
statement函数变短了,可读性变强了。
再看看amoutFor的参数,perf是循环变量,每次都需要,play是根据perf计算来的,所以可以移除play参数,在amountFor里面根据perf来计算得到play:
func playFor(aPerformance Performance) Play {
return Plays[aPerformance.PlayID]
}
func amountFor(perf Performance) (int, error) {
result := 0
switch playFor(perf).Type {
case "tragedy":
result = 40000
if perf.Audience > 30 {
result += 1000 * (perf.Audience - 30)
}
case "comedy":
result = 30000
if perf.Audience > 20 {
result += 10000 + 500*(perf.Audience-20)
}
result += 300 * perf.Audience
default:
return 0, errors.New("invalid play type")
}
return result, nil
}
statement函数:
var Plays = map[string]Play{
"hamlet": {Name: "Hamlet", Type: "tragedy"},
"as-like": {Name: "As You Like It", Type: "comedy"},
"othello": {Name: "Othello", Type: "tragedy"},
}
func statement(invoice Invoice) (result string, err error) {
totalAmount := 0
volumeCredits := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
format := accounting.Accounting{Symbol: "$", Precision: 2}
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
volumeCredits += int(math.Max(float64(perf.Audience-30), 0))
if playFor(perf).Type == "comedy" {
volumeCredits += int(math.Floor(float64(perf.Audience / 5)))
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, format.FormatMoney(thisAmount/100), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", format.FormatMoney(totalAmount/100))
result += fmt.Sprintf("You earned %d credits\n", volumeCredits)
return
}
运行单测:
go test -run Test_statement
PASS
ok jinwenabc/refactor_case 1.177s
同样,积分计算也可同样处理:
func volumeCreditsFor(perf Performance) int {
volumeCredits := int(math.Max(float64(perf.Audience-30), 0))
if playFor(perf).Type == "comedy" {
volumeCredits += int(math.Floor(float64(perf.Audience / 5)))
}
return volumeCredits
}
statement函数变成:
func statement(invoice Invoice) (result string, err error) {
totalAmount := 0
volumeCredits := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
format := accounting.Accounting{Symbol: "$", Precision: 2}
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
volumeCredits += volumeCreditsFor(perf)
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, format.FormatMoney(thisAmount/100), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", format.FormatMoney(totalAmount/100))
result += fmt.Sprintf("You earned %d credits\n", volumeCredits)
return
}
再将费用的格式化提取成一个函数,把重复的除以100(美分转美元)放到函数里面:
func usd(amount int) string {
return format.FormatMoney(amount / 100)
}
statement函数:
func init() {
Plays = map[string]Play{
"hamlet": {Name: "Hamlet", Type: "tragedy"},
"as-like": {Name: "As You Like It", Type: "comedy"},
"othello": {Name: "Othello", Type: "tragedy"},
}
format = accounting.Accounting{Symbol: "$", Precision: 2}
}
func statement(invoice Invoice) (result string, err error) {
totalAmount := 0
volumeCredits := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
volumeCredits += volumeCreditsFor(perf)
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, usd(thisAmount), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", usd(totalAmount))
result += fmt.Sprintf("You earned %d credits\n", volumeCredits)
return
}
再把计算积分从for循环里面提出来
func statement(invoice Invoice) (result string, err error) {
totalAmount := 0
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, usd(thisAmount), perf.Audience)
totalAmount += thisAmount
}
result += fmt.Sprintf("Amount owed is %s\n", usd(totalAmount))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(invoice))
return
}
func totalVolumeCredits(invoice Invoice) int {
result := 0
for _, perf := range invoice.Performances {
result += volumeCreditsFor(perf)
}
return result
}
再把计算amount提取出来
func statement(invoice Invoice) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, usd(thisAmount), perf.Audience)
}
amount, err3 := totalAmount(invoice)
if err3 != nil {
return "", err3
}
result += fmt.Sprintf("Amount owed is %s\n", usd(amount))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(invoice))
return
}
func totalAmount(invoice Invoice) (int, error) {
result := 0
for _, perf := range invoice.Performances {
amount, err := amountFor(perf)
if err != nil {
return 0, err
}
result += amount
}
return result, nil
}
现在代码结构已经好多了。顶层的 statement 函数处理的都是与打印详单相关的逻辑。与计算相关的逻辑从主函数中被移走,改由一组函数来支持。每个单独的计算过程和详单的整体结构,都因此变得更易理解了。
拆分计算阶段与格式化阶段到现在为止,主要工作都是从原函数中提取出函数,使原函数逻辑结构更加清晰。现在需要增加输出html格式清单。最直接的途径是复制原来的函数,将其中格式化的部分修改为html格式。这样会有大量冗余的代码。所以我们需要做的是将计算阶段与格式化阶段拆分。计算阶段得到打印需要的数据,与哪种格式无关。
我们先提取一个格式化阶段的方法:redenPlainText:
func statement(invoice Invoice) (result string, err error) {
return renderPlainText(invoice)
}
func renderPlainText(invoice Invoice) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", invoice.Customer)
for _, perf := range invoice.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, usd(thisAmount), perf.Audience)
}
amount, err3 := totalAmount(invoice)
if err3 != nil {
return "", err3
}
result += fmt.Sprintf("Amount owed is %s\n", usd(amount))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(invoice))
return
}
接着,创建一个对象,将格式化阶段需要的参数放到对象里:
type statementData struct {
Customer string
Performances []Performance
}
func statement(invoice Invoice) (result string, err error) {
data := &statementData{
Customer: invoice.Customer,
Performances: invoice.Performances,
}
return renderPlainText(data)
}
func renderPlainText(data *statementData) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", data.Customer)
for _, perf := range data.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", playFor(perf).Name, usd(thisAmount), perf.Audience)
}
amount, err3 := totalAmount(data)
if err3 != nil {
return "", err3
}
result += fmt.Sprintf("Amount owed is %s\n", usd(amount))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(data))
return
}
考虑把多处playFor去掉(renderPlainText, amountFor, volumeCreditsFor 里面都有playFor)
type enrichedPerformance struct {
Performance
play Play
}
type statementData struct {
Customer string
Performances []enrichedPerformance
}
func statement(invoice Invoice) (result string, err error) {
data := &statementData{
Customer: invoice.Customer,
}
for _, perf := range invoice.Performances {
data.Performances = append(data.Performances, enrichedPerformance{
Performance: perf,
play: playFor(perf),
})
}
return renderPlainText(data)
}
func renderPlainText(data *statementData) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", data.Customer)
for _, perf := range data.Performances {
thisAmount, err2 := amountFor(perf)
if err2 != nil {
return "", err2
}
result += fmt.Sprintf(" %s: %s (%d seats)\n", perf.play.Name, usd(thisAmount), perf.Audience)
}
amount, err3 := totalAmount(data)
if err3 != nil {
return "", err3
}
result += fmt.Sprintf("Amount owed is %s\n", usd(amount))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(data))
return
}
再把amoutFor去掉:
type enrichedPerformance struct {
Performance
play Play
amount int
}
type statementData struct {
Customer string
Performances []enrichedPerformance
}
func statement(invoice Invoice) (result string, err error) {
data := &statementData{
Customer: invoice.Customer,
}
for _, perf := range invoice.Performances {
anEnrichPerf, enrichErr := enrichPerformance(perf)
if enrichErr != nil{
return "", enrichErr
}
data.Performances = append(data.Performances, anEnrichPerf)
}
return renderPlainText(data)
}
func enrichPerformance(aPerformance Performance) (enrichedPerformance, error) {
result := enrichedPerformance{
Performance: aPerformance,
play: playFor(aPerformance),
}
amount, err := amountFor(result)
if err == nil{
result.amount = amount
}
return result, err
}
func renderPlainText(data *statementData) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", data.Customer)
for _, perf := range data.Performances {
result += fmt.Sprintf(" %s: %s (%d seats)\n", perf.play.Name, usd(perf.amount), perf.Audience)
}
result += fmt.Sprintf("Amount owed is %s\n", usd(totalAmount(data)))
result += fmt.Sprintf("You earned %d credits\n", totalVolumeCredits(data))
return
}
func totalAmount(data *statementData) int {
result := 0
for _, perf := range data.Performances {
result += perf.amount
}
return result
}
将两个计算总和的函数也搬到statement函数中去
func statement(invoice Invoice) (result string, err error) {
data := &statementData{
Customer: invoice.Customer,
}
for _, perf := range invoice.Performances {
anEnrichPerf, enrichErr := enrichPerformance(perf)
if enrichErr != nil {
return "", enrichErr
}
data.Performances = append(data.Performances, anEnrichPerf)
}
data.TotalAmount = totalAmount(data)
data.TotalCredits = totalVolumeCredits(data)
return renderPlainText(data)
}
func renderPlainText(data *statementData) (result string, err error) {
result = fmt.Sprintf("Statement for %s\n", data.Customer)
for _, perf := range data.Performances {
result += fmt.Sprintf(" %s: %s (%d seats)\n", perf.play.Name, usd(perf.amount), perf.Audience)
}
result += fmt.Sprintf("Amount owed is %s\n", usd(data.TotalAmount))
result += fmt.Sprintf("You earned %d credits\n", data.TotalCredits)
return
}
现在把第一阶段的代码提取到一个单独的函数里面:
func statement(invoice Invoice) (result string, err error) {
var data *statementData
if data, err = createStatementData(invoice); err != nil {
return "", err
}
return renderPlainText(data), nil
}
func createStatementData(invoice Invoice) (*statementData, error) {
data := &statementData{
Customer: invoice.Customer,
}
for _, perf := range invoice.Performances {
anEnrichPerf, enrichErr := enrichPerformance(perf)
if enrichErr != nil {
return nil, enrichErr
}
data.Performances = append(data.Performances, anEnrichPerf)
}
data.TotalAmount = totalAmount(data)
data.TotalCredits = totalVolumeCredits(data)
return data, nil
}
func renderPlainText(data *statementData) string {
result := fmt.Sprintf("Statement for %s\n", data.Customer)
for _, perf := range data.Performances {
result += fmt.Sprintf(" %s: %s (%d seats)\n", perf.play.Name, usd(perf.amount), perf.Audience)
}
result += fmt.Sprintf("Amount owed is %s\n", usd(data.TotalAmount))
result += fmt.Sprintf("You earned %d credits\n", data.TotalCredits)
return result
}
把createStatementData提取到另一个文件中statement_data.go:
package main
type enrichedPerformance struct {
Performance
play Play
amount int
}
type statementData struct {
Customer string
Performances []enrichedPerformance
TotalAmount int
TotalCredits int
}
func createStatementData(invoice Invoice) (*statementData, error) {
data := &statementData{
Customer: invoice.Customer,
}
for _, perf := range invoice.Performances {
anEnrichPerf, enrichErr := enrichPerformance(perf)
if enrichErr != nil {
return nil, enrichErr
}
data.Performances = append(data.Performances, anEnrichPerf)
}
data.TotalAmount = totalAmount(data)
data.TotalCredits = totalVolumeCredits(data)
return data, nil
}
运行单测:
go test -run Test_statement
PASS
ok jinwenabc/refactor_case 0.569s
这时候再写一个renderHTML就很简单了
func htmlStatement(invoice Invoice) (result string, err error) {
var data *statementData
if data, err = createStatementData(invoice); err != nil {
return "", err
}
return renderHTML(data), nil
}
func renderHTML(data *statementData) string {
result := "<h1>Statement for ${data.customer}</h1>\n"
result += "<tr><th>play</th><th>seats</th><th>cost</th></tr>"
for _, perf := range data.Performances {
result += fmt.Sprintf("<tr><td>%s</td><td>%d</td>", perf.play.Name, perf.Audience)
result += fmt.Sprintf("<td>%s</td></tr>\n", usd(perf.amount))
}
result += "</table>\n"
result += "<p>Amount owed is <em>${usd(data.totalAmount)}</em></p>\n"
result += "<p>You earned <em>${data.totalVolumeCredits}</em> credits</p>\n"
return result
}
支持多类型剧目计算
当前阶段代码为https://github.com/jinwenabc/refactor_case/tree/4cb8fe0829ef65123980a19ae3cc4ce00718652b
如果需要增加新的类型,最直接的做法是在amountFor和volumeCreditsFor中增加相应的计算。这样会增加代码的复杂度,而且需要在原来的代码中修改,容易带来bug,也就是说扩展性不好。所以,我们计划用多态来代替switch-case语句。可以构建一个performace calculator的接口,不同类型的剧目实现相应的接口:
type PerfCalculator interface {
getPlay() Play
getPerf() Performance
getAmount() int
getCredits() int
}
type PerformanceBasicInfo struct {
perf Performance
play Play
}
func (p PerformanceBasicInfo) getPlay() Play {
return p.play
}
func (p PerformanceBasicInfo) getPerf() Performance {
return p.perf
}
因为perf和play是剧目的一些共有属性,golang不支持类继承,所以我们定义了一个接口PerfCalculator和一个PerformanceBasicInfo。对于实际的不同的类型的剧目,组合PerformanceBasicInfo并实现PerfCalculator。
type ComedyCalculator struct {
PerformanceBasicInfo
}
func newComedyCalculator(aPerf Performance, aPlay Play) *ComedyCalculator {
return &ComedyCalculator{
PerformanceBasicInfo{
perf: aPerf,
play: aPlay,
},
}
}
func (c *ComedyCalculator) getAmount() int {
result := 30000
if c.perf.Audience > 20 {
result += 10000 + 500*(c.perf.Audience-20)
}
result += 300 * c.perf.Audience
return result
}
func (c *ComedyCalculator) getCredits() int {
return int(math.Max(float64(c.perf.Audience-30), 0)) + int(math.Floor(float64(c.perf.Audience/5)))
}
type TragedyCalculator struct {
PerformanceBasicInfo
}
func newTragedyCalculator(aPerf Performance, aPlay Play) *TragedyCalculator {
return &TragedyCalculator{
PerformanceBasicInfo{
perf: aPerf,
play: aPlay,
},
}
}
func (t *TragedyCalculator) getAmount() int {
result := 40000
if t.perf.Audience > 30 {
result += 1000 * (t.perf.Audience - 30)
}
return result
}
func (t *TragedyCalculator) getCredits() int {
return int(math.Max(float64(t.perf.Audience-30), 0))
}
现在enrichedPerformance定义是这样的:
type enrichedPerformance struct {
PerfCalculator
amount int
}
func enrichPerformance(aPerformance Performance) (enrichedPerformance, error) {
result := enrichedPerformance{
PerfCalculator: NewPerfCalculator(aPerformance, playFor(aPerformance)),
}
if result.PerfCalculator == nil {
return result, errors.New("construct perf calculator fail")
}
result.amount = result.PerfCalculator.getAmount()
return result, nil
}
func NewPerfCalculator(aPerf Performance, aPlay Play) PerfCalculator {
switch aPlay.Type {
case "comedy":
return newComedyCalculator(aPerf, aPlay)
case "tragedy":
return newTragedyCalculator(aPerf, aPlay)
default:
return nil
}
}
现在,我们要增加一个类型的剧目,只需要定义一个新的结构体并实现PerfCalculator接口,在NewPerfCalculator方法中注册一下就可以了。
此时代码:https://github.com/jinwenabc/refactor_case/tree/f7b2569935e65da5433322560003a01e9cff52ae
以上就是用golang改写的《重构》这本书第一章的例子,例子比较简单,但是可以启发大家在写代码的时候有重构这种思维。
至于什么样的代码是好代码,并没有一个统一的答案。但是作者提到一个检验好代码的标准是人们是否能轻而易举地修改它。
初始版本:https://github.com/jinwenabc/refactor_case/commit/98a7d322279e488a4b0124bbb41a4acb994725bd
分解statement后:https://github.com/jinwenabc/refactor_case/blob/28bc03fedf28643176bf5523c6a7c99c7bc6fbfe/statement.go
拆分成两阶段:
https://github.com/jinwenabc/refactor_case/tree/4cb8fe0829ef65123980a19ae3cc4ce00718652b
支持扩展:https://github.com/jinwenabc/refactor_case/tree/f7b2569935e65da5433322560003a01e9cff52ae